refactor: Implement unified serialization function (#6044)
* feat: Implement serialization functions for various data types and add a unified serialize method * feat: Enhance serialization by adding support for primitive types, enums, and generic types * fix: Update Pinecone integration to use VectorStore and handle import errors gracefully * test: Add hypothesis-based tests for serialization functions across various data types * refactor: Replace custom serialization logic with unified serialize function for consistency and maintainability * refactor: Replace recursive serialization function with unified serialize method for improved clarity and maintainability * refactor: Replace custom serialization logic with unified serialize function for improved consistency and clarity * refactor: Enhance serialization logic by adding instance handling and streamlining type checks * refactor: Remove custom dictionary serialization from ResultDataResponse for streamlined handling * refactor: Enhance serialization in ResultDataResponse by adding max_items_length for improved handling of outputs, logs, messages, and artifacts * refactor: Move MAX_ITEMS_LENGTH and MAX_TEXT_LENGTH constants to serialization module for better organization * refactor: Simplify message serialization in Log model by utilizing unified serialize function * refactor: Remove unnecessary pytest marker from TestSerializationHypothesis class * optimize _serialize_bytes Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> * feat: Add support for numpy integer type serialization * feat: Enhance serialization with support for pandas and numpy types * test: Add comprehensive serialization tests for numpy and pandas types * fix: Update _serialize_dispatcher to return string representation for unsupported types * fix: Update _serialize_dispatcher to return the object directly instead of its string representation * optmize conditional Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> * optimize length check Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> * fix: Update string and list truncation to include ellipsis for clarity * fix: Update _serialize_primitive to exclude string type from primitive handling * feat: Enhance serialization to handle numpy types and introduce unserializable sentinel * fix: Update test cases for serialization of numpy boolean values for consistency --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
This commit is contained in:
parent
5bcf4d001f
commit
c73070cd52
20 changed files with 696 additions and 186 deletions
|
|
@ -1,5 +1,4 @@
|
|||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -11,13 +10,14 @@ from langflow.graph.schema import RunOutputs
|
|||
from langflow.schema import dotdict
|
||||
from langflow.schema.graph import Tweaks
|
||||
from langflow.schema.schema import InputType, OutputType, OutputValue
|
||||
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
|
||||
from langflow.serialization.serialization import serialize
|
||||
from langflow.services.database.models.api_key.model import ApiKeyRead
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import FlowCreate, FlowRead
|
||||
from langflow.services.database.models.user import UserRead
|
||||
from langflow.services.settings.feature_flags import FeatureFlags
|
||||
from langflow.services.tracing.schema import Log
|
||||
from langflow.utils.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
|
||||
from langflow.utils.util_strings import truncate_long_strings
|
||||
|
||||
|
||||
|
|
@ -270,65 +270,17 @@ class ResultDataResponse(BaseModel):
|
|||
@classmethod
|
||||
def serialize_results(cls, v):
|
||||
"""Serialize results with custom handling for special types and truncation."""
|
||||
if isinstance(v, dict):
|
||||
return {key: cls._serialize_and_truncate(val, max_length=MAX_TEXT_LENGTH) for key, val in v.items()}
|
||||
return cls._serialize_and_truncate(v, max_length=MAX_TEXT_LENGTH)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_and_truncate(obj: Any, max_length: int = MAX_TEXT_LENGTH) -> Any:
|
||||
"""Helper method to serialize and truncate values."""
|
||||
if isinstance(obj, bytes):
|
||||
obj = obj.decode("utf-8", errors="ignore")
|
||||
if len(obj) > max_length:
|
||||
return f"{obj[:max_length]}... [truncated]"
|
||||
return obj
|
||||
if isinstance(obj, str):
|
||||
if len(obj) > max_length:
|
||||
return f"{obj[:max_length]}... [truncated]"
|
||||
return obj
|
||||
if isinstance(obj, datetime):
|
||||
return obj.replace(tzinfo=timezone.utc).isoformat()
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, OutputValue | Log):
|
||||
# First serialize the model
|
||||
serialized = obj.model_dump()
|
||||
# Then recursively truncate all values in the serialized dict
|
||||
for key, value in serialized.items():
|
||||
# Handle string values directly to ensure proper truncation
|
||||
if isinstance(value, str) and len(value) > max_length:
|
||||
serialized[key] = f"{value[:max_length]}... [truncated]"
|
||||
else:
|
||||
serialized[key] = ResultDataResponse._serialize_and_truncate(value, max_length=max_length)
|
||||
return serialized
|
||||
if isinstance(obj, BaseModel):
|
||||
# For other BaseModel instances, serialize all fields
|
||||
serialized = obj.model_dump()
|
||||
return {
|
||||
k: ResultDataResponse._serialize_and_truncate(v, max_length=max_length) for k, v in serialized.items()
|
||||
}
|
||||
if isinstance(obj, dict):
|
||||
return {k: ResultDataResponse._serialize_and_truncate(v, max_length=max_length) for k, v in obj.items()}
|
||||
if isinstance(obj, list | tuple):
|
||||
# If list is too long, truncate it
|
||||
if len(obj) > MAX_ITEMS_LENGTH:
|
||||
truncated_list = list(obj)[:MAX_ITEMS_LENGTH]
|
||||
truncated_list.append(f"... [truncated {len(obj) - MAX_ITEMS_LENGTH} items]")
|
||||
obj = truncated_list
|
||||
return [ResultDataResponse._serialize_and_truncate(item, max_length=max_length) for item in obj]
|
||||
return obj
|
||||
return serialize(v, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH)
|
||||
|
||||
@model_serializer(mode="plain")
|
||||
def serialize_model(self) -> dict:
|
||||
"""Custom serializer for the entire model."""
|
||||
return {
|
||||
"results": self.serialize_results(self.results),
|
||||
"outputs": self._serialize_and_truncate(self.outputs, max_length=MAX_TEXT_LENGTH),
|
||||
"logs": self._serialize_and_truncate(self.logs, max_length=MAX_TEXT_LENGTH),
|
||||
"message": self._serialize_and_truncate(self.message, max_length=MAX_TEXT_LENGTH),
|
||||
"artifacts": self._serialize_and_truncate(self.artifacts, max_length=MAX_TEXT_LENGTH),
|
||||
"outputs": serialize(self.outputs, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
|
||||
"logs": serialize(self.logs, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
|
||||
"message": serialize(self.message, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
|
||||
"artifacts": serialize(self.artifacts, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
|
||||
"timedelta": self.timedelta,
|
||||
"duration": self.duration,
|
||||
"used_frozen_result": self.used_frozen_result,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import numpy as np
|
||||
from langchain_pinecone import Pinecone
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
|
||||
from langflow.helpers.data import docs_to_data
|
||||
|
|
@ -42,8 +42,14 @@ class PineconeVectorStoreComponent(LCVectorStoreComponent):
|
|||
]
|
||||
|
||||
@check_cached_vector_store
|
||||
def build_vector_store(self) -> Pinecone:
|
||||
def build_vector_store(self) -> VectorStore:
|
||||
"""Build and return a Pinecone vector store instance."""
|
||||
try:
|
||||
from langchain_pinecone import PineconeVectorStore
|
||||
except ImportError as e:
|
||||
msg = "langchain-pinecone is not installed. Please install it with `pip install langchain-pinecone`."
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
from langchain_pinecone._utilities import DistanceStrategy
|
||||
|
||||
|
|
@ -55,7 +61,7 @@ class PineconeVectorStoreComponent(LCVectorStoreComponent):
|
|||
distance_strategy = DistanceStrategy[distance_strategy]
|
||||
|
||||
# Initialize Pinecone instance with wrapped embeddings
|
||||
pinecone = Pinecone(
|
||||
pinecone = PineconeVectorStore(
|
||||
index_name=self.index_name,
|
||||
embedding=wrapped_embeddings, # Use wrapped embeddings
|
||||
text_key=self.text_key,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, field_serializer, model_validator
|
||||
|
||||
from langflow.graph.utils import serialize_field
|
||||
from langflow.schema.schema import OutputValue, StreamURL
|
||||
from langflow.serialization import serialize
|
||||
from langflow.utils.schemas import ChatOutputResponse, ContainsEnumMeta
|
||||
|
||||
|
||||
|
|
@ -23,8 +23,8 @@ class ResultData(BaseModel):
|
|||
@field_serializer("results")
|
||||
def serialize_results(self, value):
|
||||
if isinstance(value, dict):
|
||||
return {key: serialize_field(val) for key, val in value.items()}
|
||||
return serialize_field(value)
|
||||
return {key: serialize(val) for key, val in value.items()}
|
||||
return serialize(value)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -6,14 +6,12 @@ from enum import Enum
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.serialization import serialize
|
||||
from langflow.services.database.models.transactions.crud import log_transaction as crud_log_transaction
|
||||
from langflow.services.database.models.transactions.model import TransactionBase
|
||||
from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build
|
||||
|
|
@ -68,30 +66,6 @@ def flatten_list(list_of_lists: list[list | Any]) -> list:
|
|||
return new_list
|
||||
|
||||
|
||||
def serialize_field(value):
|
||||
"""Serialize field.
|
||||
|
||||
Unified serialization function for handling both BaseModel and Document types,
|
||||
including handling lists of these types.
|
||||
"""
|
||||
if isinstance(value, list | tuple):
|
||||
return [serialize_field(v) for v in value]
|
||||
if isinstance(value, Document):
|
||||
return value.to_json()
|
||||
if isinstance(value, BaseModel):
|
||||
return serialize_field(value.model_dump())
|
||||
if isinstance(value, dict):
|
||||
return {k: serialize_field(v) for k, v in value.items()}
|
||||
if isinstance(value, V1BaseModel):
|
||||
if hasattr(value, "to_json"):
|
||||
return value.to_json()
|
||||
return value.dict()
|
||||
# Handle datetime objects
|
||||
if hasattr(value, "isoformat"):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
|
||||
def get_artifact_type(value, build_result) -> str:
|
||||
result = ArtifactType.UNKNOWN
|
||||
match value:
|
||||
|
|
@ -186,9 +160,9 @@ async def log_vertex_build(
|
|||
valid=valid,
|
||||
params=str(params) if params else None,
|
||||
# Serialize data using our custom serializer
|
||||
data=serialize_field(data),
|
||||
data=serialize(data),
|
||||
# Serialize artifacts using our custom serializer
|
||||
artifacts=serialize_field(artifacts) if artifacts else None,
|
||||
artifacts=serialize(artifacts) if artifacts else None,
|
||||
)
|
||||
async with session_getter(get_db_service()) as session:
|
||||
inserted = await crud_log_vertex_build(session, vertex_build)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,14 @@ from langchain_core.messages import AIMessage, AIMessageChunk
|
|||
from loguru import logger
|
||||
|
||||
from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes, ResultData
|
||||
from langflow.graph.utils import UnbuiltObject, log_vertex_build, rewrite_file_path, serialize_field
|
||||
from langflow.graph.utils import UnbuiltObject, log_vertex_build, rewrite_file_path
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.exceptions import NoComponentInstanceError
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.artifact import ArtifactType
|
||||
from langflow.schema.message import Message
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME
|
||||
from langflow.serialization import serialize
|
||||
from langflow.template.field.base import UNDEFINED, Output
|
||||
from langflow.utils.schemas import ChatOutputResponse, DataOutputResponse
|
||||
from langflow.utils.util import unescape_string
|
||||
|
|
@ -478,6 +479,6 @@ class StateVertex(ComponentVertex):
|
|||
|
||||
|
||||
def dict_to_codeblock(d: dict) -> str:
|
||||
serialized = {key: serialize_field(val) for key, val in d.items()}
|
||||
serialized = {key: serialize(val) for key, val in d.items()}
|
||||
json_str = json.dumps(serialized, indent=4)
|
||||
return f"```json\n{json_str}\n```"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from langflow.schema.data import Data
|
|||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.schema.encoders import CUSTOM_ENCODERS
|
||||
from langflow.schema.message import Message
|
||||
from langflow.schema.serialize import recursive_serialize_or_str
|
||||
from langflow.serialization.serialization import serialize
|
||||
|
||||
|
||||
class ArtifactType(str, Enum):
|
||||
|
|
@ -56,7 +56,7 @@ def _to_list_of_dicts(raw):
|
|||
raw_ = []
|
||||
for item in raw:
|
||||
if hasattr(item, "dict") or hasattr(item, "model_dump"):
|
||||
raw_.append(recursive_serialize_or_str(item))
|
||||
raw_.append(serialize(item))
|
||||
else:
|
||||
raw_.append(str(item))
|
||||
return raw_
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing_extensions import TypedDict
|
|||
from langflow.schema.data import Data
|
||||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.schema.message import Message
|
||||
from langflow.schema.serialize import recursive_serialize_or_str
|
||||
from langflow.serialization.serialization import serialize
|
||||
|
||||
INPUT_FIELD_NAME = "input_value"
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ def build_output_logs(vertex, result) -> dict:
|
|||
case LogType.ARRAY:
|
||||
if isinstance(message, DataFrame):
|
||||
message = message.to_dict(orient="records")
|
||||
message = [recursive_serialize_or_str(item) for item in message]
|
||||
message = [serialize(item) for item in message]
|
||||
name = output.get("name", f"output_{index}")
|
||||
outputs |= {name: OutputValue(message=message, type=type_).model_dump()}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
from collections.abc import AsyncIterator, Generator, Iterator
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, BeforeValidator
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic import BeforeValidator
|
||||
|
||||
|
||||
def str_to_uuid(v: str | UUID) -> UUID:
|
||||
|
|
@ -15,40 +11,3 @@ def str_to_uuid(v: str | UUID) -> UUID:
|
|||
|
||||
|
||||
UUIDstr = Annotated[UUID, BeforeValidator(str_to_uuid)]
|
||||
|
||||
|
||||
def recursive_serialize_or_str(obj):
|
||||
try:
|
||||
if isinstance(obj, type) and issubclass(obj, BaseModel | BaseModelV1):
|
||||
# This a type BaseModel and not an instance of it
|
||||
return repr(obj)
|
||||
if isinstance(obj, str):
|
||||
return obj
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, dict):
|
||||
return {k: recursive_serialize_or_str(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [recursive_serialize_or_str(v) for v in obj]
|
||||
if isinstance(obj, BaseModel | BaseModelV1):
|
||||
if hasattr(obj, "model_dump"):
|
||||
obj_dict = obj.model_dump()
|
||||
elif hasattr(obj, "dict"):
|
||||
obj_dict = obj.dict()
|
||||
return {k: recursive_serialize_or_str(v) for k, v in obj_dict.items()}
|
||||
|
||||
if isinstance(obj, AsyncIterator | Generator | Iterator):
|
||||
# contain memory addresses
|
||||
# without consuming the iterator
|
||||
# return list(obj) consumes the iterator
|
||||
# return f"{obj}" this generates '<generator object BaseChatModel.stream at 0x33e9ec770>'
|
||||
# it is not useful
|
||||
return "Unconsumed Stream"
|
||||
if hasattr(obj, "dict") and not isinstance(obj, type):
|
||||
return {k: recursive_serialize_or_str(v) for k, v in obj.dict().items()}
|
||||
if hasattr(obj, "model_dump") and not isinstance(obj, type):
|
||||
return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()}
|
||||
return str(obj)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug(f"Cannot serialize object {obj}")
|
||||
return str(obj)
|
||||
|
|
|
|||
3
src/backend/base/langflow/serialization/__init__.py
Normal file
3
src/backend/base/langflow/serialization/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .serialization import serialize
|
||||
|
||||
__all__ = ["serialize"]
|
||||
2
src/backend/base/langflow/serialization/constants.py
Normal file
2
src/backend/base/langflow/serialization/constants.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
MAX_TEXT_LENGTH = 20000
|
||||
MAX_ITEMS_LENGTH = 1000
|
||||
286
src/backend/base/langflow/serialization/serialization.py
Normal file
286
src/backend/base/langflow/serialization/serialization.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
from collections.abc import AsyncIterator, Generator, Iterator
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from langchain_core.documents import Document
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
|
||||
|
||||
|
||||
# Sentinel variable to signal a failed serialization.
|
||||
# Using a helper class ensures that the sentinel is a unique object,
|
||||
# while its __repr__ displays the desired message.
|
||||
class _UnserializableSentinel:
|
||||
def __repr__(self):
|
||||
return "[Unserializable Object]"
|
||||
|
||||
|
||||
UNSERIALIZABLE_SENTINEL = _UnserializableSentinel()
|
||||
|
||||
|
||||
def _serialize_str(obj: str, max_length: int | None, _) -> str:
|
||||
"""Truncate long strings with ellipsis if max_length provided."""
|
||||
if max_length is None or len(obj) <= max_length:
|
||||
return obj
|
||||
return obj[:max_length] + "..."
|
||||
|
||||
|
||||
def _serialize_bytes(obj: bytes, max_length: int | None, _) -> str:
|
||||
"""Decode bytes to string and truncate if max_length provided."""
|
||||
if max_length is not None:
|
||||
return (
|
||||
obj[:max_length].decode("utf-8", errors="ignore") + "..."
|
||||
if len(obj) > max_length
|
||||
else obj.decode("utf-8", errors="ignore")
|
||||
)
|
||||
return obj.decode("utf-8", errors="ignore")
|
||||
|
||||
|
||||
def _serialize_datetime(obj: datetime, *_) -> str:
|
||||
"""Convert datetime to UTC ISO format."""
|
||||
return obj.replace(tzinfo=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _serialize_decimal(obj: Decimal, *_) -> float:
|
||||
"""Convert Decimal to float."""
|
||||
return float(obj)
|
||||
|
||||
|
||||
def _serialize_uuid(obj: UUID, *_) -> str:
|
||||
"""Convert UUID to string."""
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _serialize_document(obj: Document, max_length: int | None, max_items: int | None) -> Any:
|
||||
"""Serialize Langchain Document recursively."""
|
||||
return serialize(obj.to_json(), max_length, max_items)
|
||||
|
||||
|
||||
def _serialize_iterator(_: AsyncIterator | Generator | Iterator, *__) -> str:
|
||||
"""Handle unconsumed iterators uniformly."""
|
||||
return "Unconsumed Stream"
|
||||
|
||||
|
||||
def _serialize_pydantic(obj: BaseModel, max_length: int | None, max_items: int | None) -> Any:
|
||||
"""Handle modern Pydantic models."""
|
||||
serialized = obj.model_dump()
|
||||
return {k: serialize(v, max_length, max_items) for k, v in serialized.items()}
|
||||
|
||||
|
||||
def _serialize_pydantic_v1(obj: BaseModelV1, max_length: int | None, max_items: int | None) -> Any:
|
||||
"""Backwards-compatible handling for Pydantic v1 models."""
|
||||
if hasattr(obj, "to_json"):
|
||||
return serialize(obj.to_json(), max_length, max_items)
|
||||
return serialize(obj.dict(), max_length, max_items)
|
||||
|
||||
|
||||
def _serialize_dict(obj: dict, max_length: int | None, max_items: int | None) -> dict:
|
||||
"""Recursively process dictionary values."""
|
||||
return {k: serialize(v, max_length, max_items) for k, v in obj.items()}
|
||||
|
||||
|
||||
def _serialize_list_tuple(obj: list | tuple, max_length: int | None, max_items: int | None) -> list:
|
||||
"""Truncate long lists and process items recursively."""
|
||||
if max_items is not None and len(obj) > max_items:
|
||||
truncated = list(obj)[:max_items]
|
||||
truncated.append(f"... [truncated {len(obj) - max_items} items]")
|
||||
obj = truncated
|
||||
return [serialize(item, max_length, max_items) for item in obj]
|
||||
|
||||
|
||||
def _serialize_primitive(obj: Any, *_) -> Any:
|
||||
"""Handle primitive types without conversion."""
|
||||
if obj is None or isinstance(obj, int | float | bool | complex):
|
||||
return obj
|
||||
return UNSERIALIZABLE_SENTINEL
|
||||
|
||||
|
||||
def _serialize_instance(obj: Any, *_) -> str:
|
||||
"""Handle regular class instances by converting to string."""
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _truncate_value(value: Any, max_length: int | None, max_items: int | None) -> Any:
|
||||
"""Truncate value based on its type and provided limits."""
|
||||
if isinstance(value, str) and max_length is not None and len(value) > max_length:
|
||||
return value[:max_length]
|
||||
if isinstance(value, list | tuple) and max_items is not None and len(value) > max_items:
|
||||
return value[:max_items]
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_dataframe(obj: pd.DataFrame, max_length: int | None, max_items: int | None) -> list[dict]:
|
||||
"""Serialize pandas DataFrame to a dictionary format."""
|
||||
if max_items is not None and len(obj) > max_items:
|
||||
obj = obj.head(max_items)
|
||||
obj = obj.apply(lambda x: x.apply(lambda y: _truncate_value(y, max_length, max_items)))
|
||||
return obj.to_dict(orient="records")
|
||||
|
||||
|
||||
def _serialize_series(obj: pd.Series, max_length: int | None, max_items: int | None) -> dict:
|
||||
"""Serialize pandas Series to a dictionary format."""
|
||||
if max_items is not None and len(obj) > max_items:
|
||||
obj = obj.head(max_items)
|
||||
obj = obj.apply(lambda x: _truncate_value(x, max_length, max_items))
|
||||
return obj.to_dict()
|
||||
|
||||
|
||||
def _is_numpy_type(obj: Any) -> bool:
|
||||
"""Check if an object is a numpy type by checking its type's module name."""
|
||||
return hasattr(type(obj), "__module__") and type(obj).__module__ == np.__name__
|
||||
|
||||
|
||||
def _serialize_numpy_type(obj: Any, max_length: int | None, max_items: int | None) -> Any:
|
||||
"""Serialize numpy types."""
|
||||
if np.issubdtype(obj.dtype, np.number) and hasattr(obj, "item"):
|
||||
return obj.item()
|
||||
if np.issubdtype(obj.dtype, np.bool_):
|
||||
return bool(obj)
|
||||
if np.issubdtype(obj.dtype, np.complexfloating):
|
||||
return complex(cast(complex, obj))
|
||||
if np.issubdtype(obj.dtype, np.str_):
|
||||
return _serialize_str(str(obj), max_length, max_items)
|
||||
if np.issubdtype(obj.dtype, np.bytes_) and hasattr(obj, "tobytes"):
|
||||
return _serialize_bytes(obj.tobytes(), max_length, max_items)
|
||||
if np.issubdtype(obj.dtype, np.object_) and hasattr(obj, "item"):
|
||||
return _serialize_instance(obj.item(), max_length, max_items)
|
||||
return UNSERIALIZABLE_SENTINEL
|
||||
|
||||
|
||||
def _serialize_dispatcher(obj: Any, max_length: int | None, max_items: int | None) -> Any | _UnserializableSentinel:
|
||||
"""Dispatch object to appropriate serializer."""
|
||||
# Handle primitive types first
|
||||
if obj is None:
|
||||
return obj
|
||||
primitive = _serialize_primitive(obj, max_length, max_items)
|
||||
if primitive is not UNSERIALIZABLE_SENTINEL:
|
||||
return primitive
|
||||
|
||||
match obj:
|
||||
case str():
|
||||
return _serialize_str(obj, max_length, max_items)
|
||||
case bytes():
|
||||
return _serialize_bytes(obj, max_length, max_items)
|
||||
case datetime():
|
||||
return _serialize_datetime(obj, max_length, max_items)
|
||||
case Decimal():
|
||||
return _serialize_decimal(obj, max_length, max_items)
|
||||
case UUID():
|
||||
return _serialize_uuid(obj, max_length, max_items)
|
||||
case Document():
|
||||
return _serialize_document(obj, max_length, max_items)
|
||||
case AsyncIterator() | Generator() | Iterator():
|
||||
return _serialize_iterator(obj, max_length, max_items)
|
||||
case BaseModel():
|
||||
return _serialize_pydantic(obj, max_length, max_items)
|
||||
case BaseModelV1():
|
||||
return _serialize_pydantic_v1(obj, max_length, max_items)
|
||||
case dict():
|
||||
return _serialize_dict(obj, max_length, max_items)
|
||||
case pd.DataFrame():
|
||||
return _serialize_dataframe(obj, max_length, max_items)
|
||||
case pd.Series():
|
||||
return _serialize_series(obj, max_length, max_items)
|
||||
case list() | tuple():
|
||||
return _serialize_list_tuple(obj, max_length, max_items)
|
||||
case object() if _is_numpy_type(obj):
|
||||
return _serialize_numpy_type(obj, max_length, max_items)
|
||||
case object() if not isinstance(obj, type): # Match any instance that's not a class
|
||||
return _serialize_instance(obj, max_length, max_items)
|
||||
case object() if hasattr(obj, "_name_"): # Enum case
|
||||
return f"{obj.__class__.__name__}.{obj._name_}"
|
||||
case object() if hasattr(obj, "__name__") and hasattr(obj, "__bound__"): # TypeVar case
|
||||
return repr(obj)
|
||||
case object() if hasattr(obj, "__origin__") or hasattr(obj, "__parameters__"): # Type alias/generic case
|
||||
return repr(obj)
|
||||
case _:
|
||||
# Handle numpy numeric types (int, float, bool, complex)
|
||||
if hasattr(obj, "dtype"):
|
||||
if np.issubdtype(obj.dtype, np.number) and hasattr(obj, "item"):
|
||||
return obj.item()
|
||||
if np.issubdtype(obj.dtype, np.bool_):
|
||||
return bool(obj)
|
||||
if np.issubdtype(obj.dtype, np.complexfloating):
|
||||
return complex(cast(complex, obj))
|
||||
if np.issubdtype(obj.dtype, np.str_):
|
||||
return str(obj)
|
||||
if np.issubdtype(obj.dtype, np.bytes_) and hasattr(obj, "tobytes"):
|
||||
return obj.tobytes().decode("utf-8", errors="ignore")
|
||||
if np.issubdtype(obj.dtype, np.object_) and hasattr(obj, "item"):
|
||||
return serialize(obj.item())
|
||||
return UNSERIALIZABLE_SENTINEL
|
||||
|
||||
|
||||
def serialize(
|
||||
obj: Any,
|
||||
max_length: int | None = MAX_TEXT_LENGTH,
|
||||
max_items: int | None = MAX_ITEMS_LENGTH,
|
||||
*,
|
||||
to_str: bool = False,
|
||||
) -> Any:
|
||||
"""Unified serialization with optional truncation support.
|
||||
|
||||
Coordinates specialized serializers through a dispatcher pattern.
|
||||
Maintains recursive processing for nested structures.
|
||||
|
||||
Args:
|
||||
obj: Object to serialize
|
||||
max_length: Maximum length for string values, None for no truncation
|
||||
max_items: Maximum items in list-like structures, None for no truncation
|
||||
to_str: If True, return a string representation of the object if serialization fails
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
try:
|
||||
# First try type-specific serialization
|
||||
result = _serialize_dispatcher(obj, max_length, max_items)
|
||||
if result is not UNSERIALIZABLE_SENTINEL: # Special check for None since it's a valid result
|
||||
return result
|
||||
|
||||
# Handle class-based Pydantic types and other types
|
||||
if isinstance(obj, type):
|
||||
if issubclass(obj, BaseModel | BaseModelV1):
|
||||
return repr(obj)
|
||||
return str(obj) # Handle other class types
|
||||
|
||||
# Handle type aliases and generic types
|
||||
if hasattr(obj, "__origin__") or hasattr(obj, "__parameters__"): # Type alias or generic type check
|
||||
try:
|
||||
return repr(obj)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"Cannot serialize object {obj}: {e!s}")
|
||||
|
||||
# Fallback to common serialization patterns
|
||||
if hasattr(obj, "model_dump"):
|
||||
return serialize(obj.model_dump(), max_length, max_items)
|
||||
if hasattr(obj, "dict") and not isinstance(obj, type):
|
||||
return serialize(obj.dict(), max_length, max_items)
|
||||
|
||||
# Final fallback to string conversion only if explicitly requested
|
||||
if to_str:
|
||||
return str(obj)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"Cannot serialize object {obj}: {e!s}")
|
||||
return "[Unserializable Object]"
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_or_str(
|
||||
obj: Any, max_length: int | None = MAX_TEXT_LENGTH, max_items: int | None = MAX_ITEMS_LENGTH
|
||||
) -> Any:
|
||||
"""Calls serialize() and if it fails, returns a string representation of the object.
|
||||
|
||||
Args:
|
||||
obj: Object to serialize
|
||||
max_length: Maximum length for string values, None for no truncation
|
||||
max_items: Maximum items in list-like structures, None for no truncation
|
||||
"""
|
||||
return serialize(obj, max_length, max_items, to_str=True)
|
||||
|
|
@ -10,7 +10,7 @@ from loguru import logger
|
|||
from sqlmodel import text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.utils import constants
|
||||
from langflow.serialization import constants
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import logging
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, field_serializer
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
from pydantic_core import PydanticSerializationError
|
||||
|
||||
from langflow.schema.log import LoggableType
|
||||
from langflow.serialization.serialization import serialize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -18,22 +17,8 @@ class Log(BaseModel):
|
|||
@field_serializer("message")
|
||||
def serialize_message(self, value):
|
||||
try:
|
||||
# We need to make sure everything inside the message has been serialized
|
||||
if isinstance(value, dict):
|
||||
return {key: self.serialize_message(value[key]) for key in value}
|
||||
if isinstance(value, list):
|
||||
return [self.serialize_message(item) for item in value]
|
||||
# To json is for LangChain Serializable objects
|
||||
if hasattr(value, "dict") and isinstance(value, V1BaseModel):
|
||||
# This is for Pydantic V1 models
|
||||
return value.dict()
|
||||
if hasattr(value, "to_json"):
|
||||
return value.to_json()
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(exclude_none=True)
|
||||
value = jsonable_encoder(value)
|
||||
return serialize(value)
|
||||
except UnicodeDecodeError:
|
||||
return str(value) # Fallback to string representation
|
||||
except PydanticSerializationError:
|
||||
return str(value)
|
||||
return value
|
||||
return str(value) # Fallback to string for Pydantic errors
|
||||
|
|
|
|||
|
|
@ -173,6 +173,3 @@ MESSAGE_SENDER_AI = "Machine"
|
|||
MESSAGE_SENDER_USER = "User"
|
||||
MESSAGE_SENDER_NAME_AI = "AI"
|
||||
MESSAGE_SENDER_NAME_USER = "User"
|
||||
|
||||
MAX_TEXT_LENGTH = 20000
|
||||
MAX_ITEMS_LENGTH = 1000
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from sqlalchemy.engine import make_url
|
||||
|
||||
from langflow.utils import constants
|
||||
from langflow.serialization import constants
|
||||
|
||||
|
||||
def truncate_long_strings(data, max_length=None):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from hypothesis import HealthCheck, example, given, settings
|
|||
from hypothesis import strategies as st
|
||||
from langflow.api.v1.schemas import ResultDataResponse, VertexBuildResponse
|
||||
from langflow.schema.schema import OutputValue
|
||||
from langflow.serialization import serialize
|
||||
from langflow.services.tracing.schema import Log
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -26,9 +27,9 @@ def test_result_data_response_truncation(long_string):
|
|||
)
|
||||
|
||||
response.serialize_model()
|
||||
truncated = response._serialize_and_truncate(long_string, max_length=TEST_TEXT_LENGTH)
|
||||
assert len(truncated) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert "... [truncated]" in truncated
|
||||
truncated = serialize(long_string, max_length=TEST_TEXT_LENGTH)
|
||||
assert len(truncated) <= TEST_TEXT_LENGTH + len("...")
|
||||
assert "..." in truncated
|
||||
|
||||
|
||||
@given(
|
||||
|
|
@ -77,20 +78,20 @@ def test_result_data_response_nested_structures(long_list, long_dict):
|
|||
"dict": long_dict,
|
||||
}
|
||||
|
||||
response = ResultDataResponse(results=nested_data)
|
||||
serialized = response._serialize_and_truncate(nested_data, max_length=TEST_TEXT_LENGTH)
|
||||
ResultDataResponse(results=nested_data)
|
||||
serialized = serialize(nested_data, max_length=TEST_TEXT_LENGTH)
|
||||
|
||||
# Check list items
|
||||
for item in serialized["list"]:
|
||||
assert len(item) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert len(item) <= TEST_TEXT_LENGTH + len("...")
|
||||
if len(item) > TEST_TEXT_LENGTH:
|
||||
assert "... [truncated]" in item
|
||||
assert "..." in item
|
||||
|
||||
# Check dict values
|
||||
for val in serialized["dict"].values():
|
||||
assert len(val) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert len(val) <= TEST_TEXT_LENGTH + len("...")
|
||||
if len(val) > TEST_TEXT_LENGTH:
|
||||
assert "... [truncated]" in val
|
||||
assert "..." in val
|
||||
|
||||
|
||||
@given(
|
||||
|
|
@ -114,7 +115,7 @@ def test_result_data_response_outputs(outputs_dict):
|
|||
outputs = {key: OutputValue(type="text", message=value) for key, value in outputs_dict.items()}
|
||||
|
||||
response = ResultDataResponse(outputs=outputs)
|
||||
serialized = ResultDataResponse._serialize_and_truncate(response, max_length=TEST_TEXT_LENGTH)
|
||||
serialized = serialize(response, max_length=TEST_TEXT_LENGTH)
|
||||
|
||||
# Check outputs are properly serialized and truncated
|
||||
for key, value in outputs_dict.items():
|
||||
|
|
@ -124,9 +125,9 @@ def test_result_data_response_outputs(outputs_dict):
|
|||
|
||||
# Check message truncation
|
||||
message = serialized_output["message"]
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("... [truncated]"), f"Message length: {len(message)}"
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("..."), f"Message length: {len(message)}"
|
||||
if len(value) > TEST_TEXT_LENGTH:
|
||||
assert "... [truncated]" in message
|
||||
assert "..." in message
|
||||
assert message.startswith(value[:TEST_TEXT_LENGTH])
|
||||
else:
|
||||
assert message == value
|
||||
|
|
@ -158,7 +159,7 @@ def test_result_data_response_logs(log_messages):
|
|||
}
|
||||
|
||||
response = ResultDataResponse(logs=logs)
|
||||
serialized = ResultDataResponse._serialize_and_truncate(response, max_length=TEST_TEXT_LENGTH)
|
||||
serialized = serialize(response, max_length=TEST_TEXT_LENGTH)
|
||||
|
||||
# Check logs are properly serialized and truncated
|
||||
assert "test_node" in serialized["logs"]
|
||||
|
|
@ -171,9 +172,9 @@ def test_result_data_response_logs(log_messages):
|
|||
|
||||
# Check message truncation
|
||||
message = serialized_log["message"]
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("...")
|
||||
if len(log_msg) > TEST_TEXT_LENGTH:
|
||||
assert "... [truncated]" in message
|
||||
assert "..." in message
|
||||
assert message.startswith(log_msg[:TEST_TEXT_LENGTH])
|
||||
else:
|
||||
assert message == log_msg
|
||||
|
|
@ -225,7 +226,7 @@ def test_result_data_response_combined_fields(outputs_dict, log_messages):
|
|||
message={"text": "test"},
|
||||
artifacts={"file": "test.txt"},
|
||||
)
|
||||
serialized = ResultDataResponse._serialize_and_truncate(response, max_length=TEST_TEXT_LENGTH)
|
||||
serialized = serialize(response, max_length=TEST_TEXT_LENGTH)
|
||||
|
||||
# Check all fields are present
|
||||
assert "outputs" in serialized
|
||||
|
|
@ -243,8 +244,8 @@ def test_result_data_response_combined_fields(outputs_dict, log_messages):
|
|||
# Check message truncation
|
||||
message = serialized_output["message"]
|
||||
if len(value) > TEST_TEXT_LENGTH:
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert "... [truncated]" in message
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("...")
|
||||
assert "..." in message
|
||||
else:
|
||||
assert message == value
|
||||
|
||||
|
|
@ -260,8 +261,8 @@ def test_result_data_response_combined_fields(outputs_dict, log_messages):
|
|||
# Check message truncation
|
||||
message = serialized_log["message"]
|
||||
if len(log_msg) > TEST_TEXT_LENGTH:
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert "... [truncated]" in message
|
||||
assert len(message) <= TEST_TEXT_LENGTH + len("...")
|
||||
assert "..." in message
|
||||
else:
|
||||
assert message == log_msg
|
||||
|
||||
|
|
@ -311,6 +312,6 @@ def test_vertex_build_response_with_long_data(long_string):
|
|||
)
|
||||
|
||||
response.model_dump()
|
||||
truncated = result_data._serialize_and_truncate(long_string, max_length=TEST_TEXT_LENGTH)
|
||||
assert len(truncated) <= TEST_TEXT_LENGTH + len("... [truncated]")
|
||||
assert "... [truncated]" in truncated
|
||||
truncated = serialize(long_string, max_length=TEST_TEXT_LENGTH)
|
||||
assert len(truncated) <= TEST_TEXT_LENGTH + len("...")
|
||||
assert "..." in truncated
|
||||
|
|
|
|||
0
src/backend/tests/unit/serialization/__init__.py
Normal file
0
src/backend/tests/unit/serialization/__init__.py
Normal file
344
src/backend/tests/unit/serialization/test_serialization.py
Normal file
344
src/backend/tests/unit/serialization/test_serialization.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
import math
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
from langchain_core.documents import Document
|
||||
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
|
||||
from langflow.serialization.serialization import serialize, serialize_or_str
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||
|
||||
# Comprehensive hypothesis strategies
|
||||
text_strategy = st.text(min_size=0, max_size=MAX_TEXT_LENGTH * 3)
|
||||
bytes_strategy = st.binary(min_size=0, max_size=MAX_TEXT_LENGTH * 3)
|
||||
datetime_strategy = st.datetimes(
|
||||
min_value=datetime.min, max_value=datetime.max, timezones=st.sampled_from([timezone.utc, None])
|
||||
)
|
||||
decimal_strategy = st.decimals(min_value=-1e6, max_value=1e6, allow_nan=False, allow_infinity=False, places=10)
|
||||
uuid_strategy = st.uuids()
|
||||
list_strategy = st.lists(st.one_of(st.integers(), st.text(), st.floats()), min_size=0, max_size=MAX_ITEMS_LENGTH * 3)
|
||||
dict_strategy = st.dictionaries(
|
||||
keys=st.text(min_size=1),
|
||||
values=st.one_of(st.integers(), st.floats(), st.text(), st.booleans(), st.none()),
|
||||
min_size=0,
|
||||
max_size=MAX_ITEMS_LENGTH,
|
||||
)
|
||||
|
||||
# Complex nested structure strategy
|
||||
nested_strategy = st.recursive(
|
||||
st.one_of(st.integers(), st.floats(), st.text(), st.booleans()),
|
||||
lambda children: st.lists(children) | st.dictionaries(st.text(), children),
|
||||
max_leaves=10,
|
||||
)
|
||||
|
||||
|
||||
# Pydantic models for testing
|
||||
class ModernModel(PydanticBaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class LegacyModel(PydanticV1BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class TestSerializationHypothesis:
|
||||
"""Hypothesis-based property tests for serialization logic."""
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(text=text_strategy)
|
||||
def test_string_serialization(self, text: str) -> None:
|
||||
result: str = serialize(text)
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
expected: str = text[:MAX_TEXT_LENGTH] + "..."
|
||||
assert result == expected
|
||||
else:
|
||||
assert result == text
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=bytes_strategy)
|
||||
def test_bytes_serialization(self, data: bytes) -> None:
|
||||
result: str = serialize(data)
|
||||
decoded: str = data.decode("utf-8", errors="ignore")
|
||||
if len(decoded) > MAX_TEXT_LENGTH:
|
||||
expected: str = decoded[:MAX_TEXT_LENGTH] + "..."
|
||||
assert result == expected
|
||||
else:
|
||||
assert result == decoded
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(dt=datetime_strategy)
|
||||
def test_datetime_serialization(self, dt: datetime) -> None:
|
||||
result: str = serialize(dt)
|
||||
assert result == dt.replace(tzinfo=timezone.utc).isoformat()
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(dec=decimal_strategy)
|
||||
def test_decimal_serialization(self, dec) -> None:
|
||||
result: float = serialize(dec)
|
||||
assert result == float(dec)
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(uid=uuid_strategy)
|
||||
def test_uuid_serialization(self, uid) -> None:
|
||||
result: str = serialize(uid)
|
||||
assert result == str(uid)
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(lst=list_strategy)
|
||||
def test_list_truncation(self, lst: list) -> None:
|
||||
result: list = serialize(lst)
|
||||
if len(lst) > MAX_ITEMS_LENGTH:
|
||||
assert len(result) == MAX_ITEMS_LENGTH + 1
|
||||
assert f"... [truncated {len(lst) - MAX_ITEMS_LENGTH} items]" in result
|
||||
else:
|
||||
assert result == lst
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(dct=dict_strategy)
|
||||
def test_dict_serialization(self, dct: dict) -> None:
|
||||
result: dict = serialize(dct)
|
||||
assert isinstance(result, dict)
|
||||
for k, v in result.items():
|
||||
assert isinstance(k, str)
|
||||
assert isinstance(v, int | float | str | bool | type(None))
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(value=st.integers())
|
||||
def test_pydantic_modern_model(self, value: int) -> None:
|
||||
model: ModernModel = ModernModel(name="test", value=value)
|
||||
result: dict = serialize(model)
|
||||
assert result == {"name": "test", "value": value}
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(value=st.integers())
|
||||
def test_pydantic_v1_model(self, value: int) -> None:
|
||||
model: LegacyModel = LegacyModel(name="test", value=value)
|
||||
result: dict = serialize(model)
|
||||
assert result == {"name": "test", "value": value}
|
||||
|
||||
def test_async_iterator_handling(self) -> None:
|
||||
async def async_gen():
|
||||
yield 1
|
||||
yield 2
|
||||
|
||||
gen = async_gen()
|
||||
result: str = serialize(gen)
|
||||
assert result == "Unconsumed Stream"
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.one_of(st.integers(), st.floats(allow_nan=True), st.booleans(), st.none()))
|
||||
def test_primitive_types(self, data: float | bool | None) -> None:
|
||||
result: int | float | bool | None = serialize(data)
|
||||
if isinstance(data, float) and math.isnan(data) and isinstance(result, float):
|
||||
assert math.isnan(result)
|
||||
else:
|
||||
assert result == data
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(nested=nested_strategy)
|
||||
def test_nested_structures(self, nested: Any) -> None:
|
||||
result: list | dict | int | float | str | bool = serialize(nested)
|
||||
assert isinstance(result, list | dict | int | float | str | bool)
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(text=text_strategy)
|
||||
def test_max_length_none(self, text: str) -> None:
|
||||
result: str = serialize(text, max_length=None)
|
||||
assert result == text
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(lst=list_strategy)
|
||||
def test_max_items_none(self, lst: list) -> None:
|
||||
result: list = serialize(lst, max_items=None)
|
||||
assert result == lst
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(obj=st.builds(object))
|
||||
def test_fallback_serialization(self, obj: object) -> None:
|
||||
result: str = serialize_or_str(obj)
|
||||
assert isinstance(result, str)
|
||||
assert str(obj) in result
|
||||
|
||||
def test_document_serialization(self) -> None:
|
||||
doc: Document = Document(page_content="test", metadata={"source": "test"})
|
||||
result: dict = serialize(doc)
|
||||
assert isinstance(result, dict)
|
||||
assert "kwargs" in result
|
||||
assert "page_content" in result["kwargs"]
|
||||
assert result["kwargs"]["page_content"] == "test"
|
||||
assert "metadata" in result["kwargs"]
|
||||
assert result["kwargs"]["metadata"] == {"source": "test"}
|
||||
|
||||
def test_class_serialization(self) -> None:
|
||||
class TestClass:
|
||||
def __init__(self, value: Any) -> None:
|
||||
self.value = value
|
||||
|
||||
result: str = serialize(TestClass)
|
||||
assert result == str(TestClass)
|
||||
|
||||
def test_instance_serialization(self) -> None:
|
||||
class TestClass:
|
||||
def __init__(self, value: int) -> None:
|
||||
self.value = value
|
||||
|
||||
instance: TestClass = TestClass(42)
|
||||
result: str = serialize(instance)
|
||||
assert result == str(instance)
|
||||
|
||||
def test_pydantic_class_serialization(self) -> None:
|
||||
result: str = serialize(ModernModel)
|
||||
assert result == repr(ModernModel)
|
||||
|
||||
def test_builtin_type_serialization(self) -> None:
|
||||
result: str = serialize(int)
|
||||
assert result == repr(int)
|
||||
|
||||
def test_none_serialization(self) -> None:
|
||||
result: None = serialize(None)
|
||||
assert result is None
|
||||
|
||||
def test_custom_type_serialization(self) -> None:
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
result: str = serialize(T)
|
||||
assert result == repr(T)
|
||||
|
||||
def test_nested_class_serialization(self) -> None:
|
||||
class Outer:
|
||||
class Inner:
|
||||
pass
|
||||
|
||||
result: str = serialize(Outer.Inner)
|
||||
assert result == str(Outer.Inner)
|
||||
|
||||
def test_enum_serialization(self) -> None:
|
||||
from enum import Enum
|
||||
|
||||
class TestEnum(Enum):
|
||||
A = 1
|
||||
B = 2
|
||||
|
||||
result: str = serialize(TestEnum.A)
|
||||
assert result == "TestEnum.A"
|
||||
|
||||
def test_type_alias_serialization(self) -> None:
|
||||
IntList = list[int] # noqa: N806
|
||||
result: str = serialize(IntList)
|
||||
assert result == repr(IntList)
|
||||
|
||||
def test_generic_type_serialization(self) -> None:
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class Box(Generic[T]):
|
||||
pass
|
||||
|
||||
result: str = serialize(Box[int])
|
||||
assert result == repr(Box[int])
|
||||
|
||||
def test_numpy_int64_serialization(self) -> None:
|
||||
"""Test serialization of numpy.int64 values."""
|
||||
np_int = np.int64(42)
|
||||
result = serialize(np_int)
|
||||
assert result == 42
|
||||
assert isinstance(result, int)
|
||||
|
||||
def test_numpy_numeric_serialization(self) -> None:
|
||||
"""Test serialization of various numpy numeric types."""
|
||||
# Test integers
|
||||
assert serialize(np.int64(42)) == 42
|
||||
assert isinstance(serialize(np.int64(42)), int)
|
||||
|
||||
# Test unsigned integers
|
||||
assert serialize(np.uint64(42)) == 42
|
||||
assert isinstance(serialize(np.uint64(42)), int)
|
||||
|
||||
# Test floats
|
||||
assert serialize(np.float64(3.14)) == 3.14
|
||||
assert isinstance(serialize(np.float64(3.14)), float)
|
||||
|
||||
# Test float32 (need to account for precision differences)
|
||||
float32_val = serialize(np.float32(3.14))
|
||||
assert isinstance(float32_val, float)
|
||||
assert abs(float32_val - 3.14) < 1e-6 # Check if close enough
|
||||
|
||||
# Test bool
|
||||
assert serialize(np.bool_(True)) is True # noqa: FBT003
|
||||
assert isinstance(serialize(np.bool_(True)), bool) # noqa: FBT003
|
||||
|
||||
# Test complex numbers
|
||||
complex_val = serialize(np.complex64(1 + 2j))
|
||||
assert isinstance(complex_val, complex)
|
||||
assert abs(complex_val - (1 + 2j)) < 1e-6
|
||||
|
||||
# Test strings
|
||||
assert serialize(np.str_("hello")) == "hello"
|
||||
assert isinstance(serialize(np.str_("hello")), str)
|
||||
|
||||
# Test bytes
|
||||
bytes_val = np.bytes_(b"world")
|
||||
assert serialize(bytes_val) == "world"
|
||||
assert isinstance(serialize(bytes_val), str)
|
||||
|
||||
# Test unicode
|
||||
assert serialize(np.str_("unicode")) == "unicode"
|
||||
assert isinstance(serialize(np.str_("unicode")), str)
|
||||
|
||||
# Test object arrays
|
||||
obj_array = np.array([1, "two", 3.0], dtype=object)
|
||||
result = serialize(obj_array[0])
|
||||
assert result == 1
|
||||
assert isinstance(result, int)
|
||||
|
||||
result = serialize(obj_array[1])
|
||||
assert result == "two"
|
||||
assert isinstance(result, str)
|
||||
|
||||
result = serialize(obj_array[2])
|
||||
assert result == 3.0
|
||||
assert isinstance(result, float)
|
||||
|
||||
def test_pandas_serialization(self) -> None:
|
||||
"""Test serialization of pandas DataFrame."""
|
||||
# Test DataFrame
|
||||
test_df = pd.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"], "C": [1.1, 2.2, 3.3]})
|
||||
result = serialize(test_df)
|
||||
assert isinstance(result, list) # DataFrame is serialized to list of records
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(row, dict) for row in result)
|
||||
assert all("A" in row and "B" in row and "C" in row for row in result)
|
||||
assert result[0] == {"A": 1, "B": "a", "C": 1.1}
|
||||
|
||||
# Test DataFrame truncation
|
||||
df_long = pd.DataFrame({"A": range(MAX_ITEMS_LENGTH + 100)})
|
||||
result = serialize(df_long, max_items=MAX_ITEMS_LENGTH)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == MAX_ITEMS_LENGTH
|
||||
assert all("A" in row for row in result)
|
||||
|
||||
def test_series_serialization(self) -> None:
|
||||
"""Test serialization of pandas Series."""
|
||||
# Test Series
|
||||
series = pd.Series([1, 2, 3], name="test")
|
||||
result = serialize(series)
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(v, int) for v in result.values())
|
||||
|
||||
def test_series_truncation(self) -> None:
|
||||
"""Test truncation of pandas Series."""
|
||||
# Test Series
|
||||
series_long = pd.Series(range(MAX_ITEMS_LENGTH + 100), name="test_long")
|
||||
result = serialize(series_long, max_items=MAX_ITEMS_LENGTH)
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == MAX_ITEMS_LENGTH
|
||||
assert all(isinstance(v, int) for v in result.values())
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
|
||||
import pytest
|
||||
from langflow.utils.constants import MAX_TEXT_LENGTH
|
||||
from langflow.serialization.constants import MAX_TEXT_LENGTH
|
||||
from langflow.utils.util_strings import truncate_long_strings
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from langflow.utils.constants import MAX_TEXT_LENGTH
|
||||
from langflow.serialization.constants import MAX_TEXT_LENGTH
|
||||
from langflow.utils.util_strings import truncate_long_strings
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue