Merge branch 'zustand/io/migration' of github.com:logspace-ai/langflow into zustand/io/migration
This commit is contained in:
commit
1676d98d1d
30 changed files with 367 additions and 73 deletions
26
poetry.lock
generated
26
poetry.lock
generated
|
|
@ -6040,13 +6040,13 @@ xmp = ["defusedxml"]
|
|||
|
||||
[[package]]
|
||||
name = "pinecone-client"
|
||||
version = "3.2.1"
|
||||
version = "3.2.2"
|
||||
description = "Pinecone client and SDK"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "pinecone_client-3.2.1-py3-none-any.whl", hash = "sha256:e3e7983762509235250b9bcd543ec6283b7dffaed2e899f1631327f2b77859e3"},
|
||||
{file = "pinecone_client-3.2.1.tar.gz", hash = "sha256:8560ffafb13b9c45a92eb9eb77a2db32d5a1fa7903a1db17f7af58ee1058bb60"},
|
||||
{file = "pinecone_client-3.2.2-py3-none-any.whl", hash = "sha256:7e492fdda23c73726bc0cb94c689bb950d06fb94e82b701a0c610c2e830db327"},
|
||||
{file = "pinecone_client-3.2.2.tar.gz", hash = "sha256:887a12405f90ac11c396490f605fc479f31cf282361034d1ae0fccc02ac75bee"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -7100,6 +7100,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
|||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.23.6"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"},
|
||||
{file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0,<9"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "4.1.0"
|
||||
|
|
@ -10222,4 +10240,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.12"
|
||||
content-hash = "c259c7ed43f5f6c5e211780da445a10cf1863979b87fb2c6991b1090fbb1fb4a"
|
||||
content-hash = "b66acb0ed04e62c9f311828307ac1503bc7a19912753c217d4ea6237f474543a"
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ types-pywin32 = "^306.0.0.4"
|
|||
types-google-cloud-ndb = "^2.2.0.0"
|
||||
pytest-sugar = "^1.0.0"
|
||||
pytest-instafail = "^0.5.0"
|
||||
pytest-asyncio = "^0.23.0"
|
||||
respx = "^0.20.2"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def record_to_string(record: Record) -> str:
|
|||
Returns:
|
||||
str: The record as a string.
|
||||
"""
|
||||
return record.text
|
||||
return record.get_text()
|
||||
|
||||
|
||||
def document_to_string(document: Document) -> str:
|
||||
|
|
|
|||
|
|
@ -7,12 +7,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class AmazonBedrockEmeddingsComponent(CustomComponent):
|
||||
"""
|
||||
A custom component for implementing an Embeddings Model using Amazon Bedrock.
|
||||
"""
|
||||
|
||||
display_name: str = "Amazon Bedrock Embeddings"
|
||||
description: str = "Embeddings model from Amazon Bedrock."
|
||||
description: str = "Generate embeddings using Amazon Bedrock models."
|
||||
documentation = "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/bedrock"
|
||||
|
||||
def build_config(self):
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class AzureOpenAIEmbeddingsComponent(CustomComponent):
|
||||
display_name: str = "AzureOpenAIEmbeddings"
|
||||
description: str = "Embeddings model from Azure OpenAI."
|
||||
display_name: str = "Azure OpenAI Embeddings"
|
||||
description: str = "Generate embeddings using Azure OpenAI models."
|
||||
documentation: str = "https://python.langchain.com/docs/integrations/text_embedding/azureopenai"
|
||||
beta = False
|
||||
icon = "Azure"
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from langflow.custom import CustomComponent
|
|||
|
||||
|
||||
class CohereEmbeddingsComponent(CustomComponent):
|
||||
display_name = "CohereEmbeddings"
|
||||
description = "Cohere embedding models."
|
||||
display_name = "Cohere Embeddings"
|
||||
description = "Generate embeddings using Cohere models."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class HuggingFaceEmbeddingsComponent(CustomComponent):
|
||||
display_name = "HuggingFaceEmbeddings"
|
||||
description = "HuggingFace sentence_transformers embedding models."
|
||||
display_name = "Hugging Face Embeddings"
|
||||
description = "Generate embeddings using HuggingFace models."
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/sentence_transformers"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class HuggingFaceInferenceAPIEmbeddingsComponent(CustomComponent):
|
||||
display_name = "HuggingFaceInferenceAPIEmbeddings"
|
||||
description = "HuggingFace sentence_transformers embedding models, API version."
|
||||
display_name = "Hugging Face API Embeddings"
|
||||
description = "Generate embeddings using Hugging Face Inference API models."
|
||||
documentation = "https://github.com/huggingface/text-embeddings-inference"
|
||||
icon = "HuggingFace"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class OllamaEmbeddingsComponent(CustomComponent):
|
||||
"""
|
||||
A custom component for implementing an Embeddings Model using Ollama.
|
||||
"""
|
||||
|
||||
display_name: str = "Ollama Embeddings"
|
||||
description: str = "Embeddings model from Ollama."
|
||||
description: str = "Generate embeddings using Ollama models."
|
||||
documentation = "https://python.langchain.com/docs/integrations/text_embedding/ollama"
|
||||
|
||||
def build_config(self):
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class OpenAIEmbeddingsComponent(CustomComponent):
|
||||
display_name = "OpenAIEmbeddings"
|
||||
description = "OpenAI embedding models"
|
||||
display_name = "OpenAI Embeddings"
|
||||
description = "Generate embeddings using OpenAI models."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
|
||||
class VertexAIEmbeddingsComponent(CustomComponent):
|
||||
display_name = "VertexAIEmbeddings"
|
||||
description = "Google Cloud VertexAI embedding models."
|
||||
display_name = "VertexAI Embeddings"
|
||||
description = "Generate embeddings using Google Cloud VertexAI models."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from .CreateRecord import CreateRecordComponent
|
||||
from .CustomComponent import Component
|
||||
from .DocumentToRecord import DocumentToRecordComponent
|
||||
from .IDGenerator import UUIDGeneratorComponent
|
||||
from .MessageHistory import MessageHistoryComponent
|
||||
from .TextToRecord import TextToRecordComponent
|
||||
from .UpdateRecord import UpdateRecordComponent
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -12,6 +12,6 @@ __all__ = [
|
|||
"UUIDGeneratorComponent",
|
||||
"PythonFunctionComponent",
|
||||
"RecordsToTextComponent",
|
||||
"TextToRecordComponent",
|
||||
"CreateRecordComponent",
|
||||
"MessageHistoryComponent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
|
||||
class PromptComponent(CustomComponent):
|
||||
display_name: str = "Prompt"
|
||||
description: str = "Create prompt templates with dynamic variables. Prompts can help guide the behavior of a Language Model."
|
||||
description: str = "Create a prompt template with dynamic variables."
|
||||
icon = "terminal-square"
|
||||
|
||||
def build_config(self):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,17 @@ class AnthropicLLM(LCModelComponent):
|
|||
description: str = "Generate text using Anthropic Chat&Completion LLMs."
|
||||
icon = "Anthropic"
|
||||
|
||||
field_order = [
|
||||
"model",
|
||||
"anthropic_api_key",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"anthropic_api_url",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,19 @@ class AzureChatOpenAIComponent(LCModelComponent):
|
|||
beta = False
|
||||
icon = "Azure"
|
||||
|
||||
field_order = [
|
||||
"model",
|
||||
"azure_endpoint",
|
||||
"azure_deployment",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
"gpt-35-turbo",
|
||||
"gpt-35-turbo-16k",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,19 @@ class QianfanChatEndpointComponent(LCModelComponent):
|
|||
documentation: str = "https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint."
|
||||
icon = "BaiduQianfan"
|
||||
|
||||
field_order = [
|
||||
"model",
|
||||
"qianfan_ak",
|
||||
"qianfan_sk",
|
||||
"top_p",
|
||||
"temperature",
|
||||
"penalty_score",
|
||||
"endpoint",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
from langchain_community.chat_models.cohere import ChatCohere
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
|
|
@ -15,6 +14,15 @@ class CohereComponent(LCModelComponent):
|
|||
|
||||
icon = "Cohere"
|
||||
|
||||
field_order = [
|
||||
"cohere_api_key",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"cohere_api_key": {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,19 @@ class GoogleGenerativeAIComponent(LCModelComponent):
|
|||
display_name: str = "Google Generative AI"
|
||||
description: str = "Generate text using Google Generative AI."
|
||||
icon = "GoogleGenerativeAI"
|
||||
icon = "Google"
|
||||
|
||||
field_order = [
|
||||
"google_api_key",
|
||||
"model",
|
||||
"max_output_tokens",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"n",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,16 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
description: str = "Generate text using Hugging Face Inference APIs."
|
||||
icon = "HuggingFace"
|
||||
|
||||
field_order = [
|
||||
"endpoint_url",
|
||||
"task",
|
||||
"huggingfacehub_api_token",
|
||||
"model_kwargs",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"endpoint_url": {"display_name": "Endpoint URL", "password": True},
|
||||
|
|
|
|||
|
|
@ -17,6 +17,37 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
description = "Generate text using Ollama Local LLMs."
|
||||
icon = "Ollama"
|
||||
|
||||
field_order = [
|
||||
"base_url",
|
||||
"model",
|
||||
"temperature",
|
||||
"cache",
|
||||
"callback_manager",
|
||||
"callbacks",
|
||||
"format",
|
||||
"metadata",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"num_ctx",
|
||||
"num_gpu",
|
||||
"num_thread",
|
||||
"repeat_last_n",
|
||||
"repeat_penalty",
|
||||
"tfs_z",
|
||||
"timeout",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"verbose",
|
||||
"tags",
|
||||
"stop",
|
||||
"system",
|
||||
"template",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self) -> dict:
|
||||
return {
|
||||
"base_url": {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,18 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
description = "Generates text using OpenAI LLMs."
|
||||
icon = "OpenAI"
|
||||
|
||||
field_order = [
|
||||
"max_tokens",
|
||||
"model_kwargs",
|
||||
"model_name",
|
||||
"openai_api_base",
|
||||
"openai_api_key",
|
||||
"temperature",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"input_value": {"display_name": "Input"},
|
||||
|
|
|
|||
|
|
@ -11,6 +11,22 @@ class ChatVertexAIComponent(LCModelComponent):
|
|||
description = "Generate text using Vertex AI LLMs."
|
||||
icon = "VertexAI"
|
||||
|
||||
field_order = [
|
||||
"credentials",
|
||||
"project",
|
||||
"examples",
|
||||
"location",
|
||||
"max_output_tokens",
|
||||
"model_name",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"verbose",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"credentials": {
|
||||
|
|
|
|||
|
|
@ -9,10 +9,6 @@ from langflow.schema import Record
|
|||
|
||||
|
||||
class ChromaSearchComponent(LCVectorStoreComponent):
|
||||
"""
|
||||
A custom component for implementing a Vector Store using Chroma.
|
||||
"""
|
||||
|
||||
display_name: str = "Chroma Search"
|
||||
description: str = "Search a Chroma collection for similar documents."
|
||||
icon = "Chroma"
|
||||
|
|
|
|||
|
|
@ -9,10 +9,6 @@ from langflow.schema import Record
|
|||
|
||||
|
||||
class PGVectorSearchComponent(PGVectorComponent, LCVectorStoreComponent):
|
||||
"""
|
||||
A custom component for implementing a Vector Store using PostgreSQL.
|
||||
"""
|
||||
|
||||
display_name: str = "PGVector Search"
|
||||
description: str = "Search a PGVector Store for similar documents."
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector"
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class DocumentLoaderVertex(Vertex):
|
|||
# show how many documents are in the list?
|
||||
|
||||
if not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(record.text) for record in self._built_object if hasattr(record, "text")) / len(
|
||||
avg_length = sum(len(record.get_text()) for record in self._built_object if hasattr(record, "text")) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.display_name}({len(self._built_object)} records)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,10 @@ def extract_inner_type(return_type: str) -> str:
|
|||
|
||||
def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
Extracts the inner type from a type hint that is a list or a Optional.
|
||||
"""
|
||||
if return_type.__origin__ == list:
|
||||
return list(return_type.__args__)
|
||||
|
||||
return return_type
|
||||
|
||||
|
||||
|
|
@ -36,4 +35,12 @@ def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list:
|
|||
"""
|
||||
Extracts the inner type from a type hint that is a Union.
|
||||
"""
|
||||
if isinstance(return_type, list):
|
||||
return [
|
||||
_inner_arg
|
||||
for _type in return_type
|
||||
for _inner_arg in _type.__args__
|
||||
if _inner_arg not in set((Any, type(None), type(Any)))
|
||||
]
|
||||
|
||||
return list(return_type.__args__)
|
||||
|
|
|
|||
|
|
@ -34,10 +34,6 @@ class Component:
|
|||
if key == "user_id":
|
||||
setattr(self, "_user_id", value)
|
||||
else:
|
||||
if key == "code" and "from langflow import CustomComponent" in value:
|
||||
value = value.replace(
|
||||
"from langflow import CustomComponent", "from langflow.custom import CustomComponent"
|
||||
)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
|
|
|
|||
|
|
@ -249,7 +249,7 @@ class CustomComponent(Component):
|
|||
return ""
|
||||
markdown_string = "---\n"
|
||||
for record in records:
|
||||
markdown_string += f"- Text: {record.text}"
|
||||
markdown_string += f"- Text: {record.get_text()}"
|
||||
if include_data:
|
||||
markdown_string += f" Data: {record.data}"
|
||||
markdown_string += "\n"
|
||||
|
|
@ -318,9 +318,10 @@ class CustomComponent(Component):
|
|||
return_type = extract_inner_type_from_generic_alias(return_type)
|
||||
|
||||
# If the return type is not a Union, then we just return it as a list
|
||||
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
|
||||
inner_type = return_type[0] if isinstance(return_type, list) else return_type
|
||||
if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union:
|
||||
return return_type if isinstance(return_type, list) else [return_type]
|
||||
# If the return type is a Union, then we need to parse itx
|
||||
# If the return type is a Union, then we need to parse it
|
||||
return_type = extract_union_types_from_generic_alias(return_type)
|
||||
return return_type
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
|
@ -12,8 +13,9 @@ class Record(BaseModel):
|
|||
data (dict, optional): Additional data associated with the record.
|
||||
"""
|
||||
|
||||
text_key: Optional[str] = "text"
|
||||
data: dict = {}
|
||||
_default_value: str = ""
|
||||
default_value: Optional[str] = ""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_data(cls, values):
|
||||
|
|
@ -21,10 +23,22 @@ class Record(BaseModel):
|
|||
values["data"] = {}
|
||||
# Any other keyword should be added to the data dictionary
|
||||
for key in values:
|
||||
if key not in values["data"] and key != "data":
|
||||
if key not in values["data"] and key not in {"text_key", "data", "default_value"}:
|
||||
values["data"][key] = values[key]
|
||||
return values
|
||||
|
||||
def get_text(self):
|
||||
"""
|
||||
Retrieves the text value from the data dictionary.
|
||||
|
||||
If the text key is present in the data dictionary, the corresponding value is returned.
|
||||
Otherwise, the default value is returned.
|
||||
|
||||
Returns:
|
||||
The text value from the data dictionary or the default value.
|
||||
"""
|
||||
return self.data.get(self.text_key, self.default_value)
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "Record":
|
||||
"""
|
||||
|
|
@ -38,19 +52,27 @@ class Record(BaseModel):
|
|||
"""
|
||||
data = document.metadata
|
||||
data["text"] = document.page_content
|
||||
return cls(data=data)
|
||||
return cls(data=data, text_key="text")
|
||||
|
||||
def __add__(self, other: "Record") -> "Record":
|
||||
"""
|
||||
Concatenates the text of two records and combines their data.
|
||||
|
||||
Args:
|
||||
other (Record): The other record to concatenate with.
|
||||
|
||||
Returns:
|
||||
Record: The concatenated record.
|
||||
Combines the data of two records by attempting to add values for overlapping keys
|
||||
for all types that support the addition operation. Falls back to the value from 'other'
|
||||
record when addition is not supported.
|
||||
"""
|
||||
combined_data = {**self.data, **other.data}
|
||||
combined_data = self.data.copy()
|
||||
for key, value in other.data.items():
|
||||
# If the key exists in both records and both values support the addition operation
|
||||
if key in combined_data:
|
||||
try:
|
||||
combined_data[key] += value
|
||||
except TypeError:
|
||||
# Fallback: Use the value from 'other' record if addition is not supported
|
||||
combined_data[key] = value
|
||||
else:
|
||||
# If the key is not in the first record, simply add it
|
||||
combined_data[key] = value
|
||||
|
||||
return Record(data=combined_data)
|
||||
|
||||
def to_lc_document(self) -> Document:
|
||||
|
|
@ -60,17 +82,20 @@ class Record(BaseModel):
|
|||
Returns:
|
||||
Document: The converted Document.
|
||||
"""
|
||||
return Document(page_content=self.text, metadata=self.data)
|
||||
text = self.data.pop(self.text_key, self.default_value)
|
||||
return Document(page_content=text, metadata=self.data)
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""
|
||||
Allows attribute-like access to the data dictionary.
|
||||
"""
|
||||
try:
|
||||
if key == "data" or key.startswith("_"):
|
||||
if key.startswith("__"):
|
||||
return self.__getattribute__(key)
|
||||
if key in {"data", "text_key"} or key.startswith("_"):
|
||||
return super().__getattr__(key)
|
||||
|
||||
return self.data.get(key, self._default_value)
|
||||
return self.data.get(key, self.default_value)
|
||||
except KeyError:
|
||||
# Fallback to default behavior to raise AttributeError for undefined attributes
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
||||
|
|
@ -80,7 +105,7 @@ class Record(BaseModel):
|
|||
Allows attribute-like setting of values in the data dictionary,
|
||||
while still allowing direct assignment to class attributes.
|
||||
"""
|
||||
if key == "data" or key.startswith("_"):
|
||||
if key in {"data", "text_key"} or key.startswith("_"):
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
self.data[key] = value
|
||||
|
|
@ -89,7 +114,7 @@ class Record(BaseModel):
|
|||
"""
|
||||
Allows attribute-like deletion from the data dictionary.
|
||||
"""
|
||||
if key == "data" or key.startswith("_"):
|
||||
if key in {"data", "text_key"} or key.startswith("_"):
|
||||
super().__delattr__(key)
|
||||
else:
|
||||
del self.data[key]
|
||||
|
|
@ -98,12 +123,8 @@ class Record(BaseModel):
|
|||
"""
|
||||
Custom deepcopy implementation to handle copying of the Record object.
|
||||
"""
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
setattr(result, k, copy.deepcopy(v, memo))
|
||||
return result
|
||||
# Create a new Record object with a deep copy of the data dictionary
|
||||
return Record(data=copy.deepcopy(self.data, memo), text_key=self.text_key, default_value=self.default_value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
|
|
@ -114,7 +135,8 @@ class Record(BaseModel):
|
|||
# build the string considering all keys in the data dictionary
|
||||
prefix = "Record("
|
||||
suffix = ")"
|
||||
text = ", ".join([f"{k}={v}" for k, v in self.data.items()])
|
||||
text = f"text_key={self.text_key}, "
|
||||
text += ", ".join([f"{k}={v}" for k, v in self.data.items()])
|
||||
return prefix + text + suffix
|
||||
|
||||
# check which attributes the Record has by checking the keys in the data dictionary
|
||||
|
|
|
|||
139
tests/test_record.py
Normal file
139
tests/test_record.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.schema import Record
|
||||
|
||||
|
||||
def test_record_initialization():
|
||||
record = Record(text_key="msg", data={"msg": "Hello, World!", "extra": "value"})
|
||||
assert record.msg == "Hello, World!"
|
||||
assert record.extra == "value"
|
||||
|
||||
|
||||
def test_validate_data_with_extra_keys():
|
||||
record = Record(dummy_key="dummy", data={"key": "value"})
|
||||
assert record.data["dummy_key"] == "dummy"
|
||||
assert "dummy_key" in record.data
|
||||
assert record.key == "value"
|
||||
|
||||
|
||||
def test_conversion_to_document():
|
||||
record = Record(data={"text": "Sample text", "meta": "data"})
|
||||
document = record.to_lc_document()
|
||||
assert document.page_content == "Sample text"
|
||||
assert document.metadata == {"meta": "data"}
|
||||
|
||||
|
||||
def test_conversion_from_document():
|
||||
document = Document(page_content="Doc content", metadata={"meta": "info"})
|
||||
record = Record.from_document(document)
|
||||
assert record.text == "Doc content"
|
||||
assert record.meta == "info"
|
||||
|
||||
|
||||
def test_add_method_for_strings():
|
||||
record1 = Record(data={"text": "Hello"})
|
||||
record2 = Record(data={"text": " World"})
|
||||
combined = record1 + record2
|
||||
assert combined.text == "Hello World"
|
||||
|
||||
|
||||
def test_add_method_for_integers():
|
||||
record1 = Record(data={"number": 5})
|
||||
record2 = Record(data={"number": 10})
|
||||
combined = record1 + record2
|
||||
assert combined.number == 15
|
||||
|
||||
|
||||
def test_add_method_with_non_overlapping_keys():
|
||||
record1 = Record(data={"text": "Hello"})
|
||||
record2 = Record(data={"number": 10})
|
||||
combined = record1 + record2
|
||||
assert combined.text == "Hello"
|
||||
assert combined.number == 10
|
||||
|
||||
|
||||
def test_custom_attribute_get_set_del():
|
||||
record = Record()
|
||||
record.custom_attr = "custom_value"
|
||||
assert record.custom_attr == "custom_value"
|
||||
del record.custom_attr
|
||||
assert record.custom_attr == record.default_value
|
||||
|
||||
|
||||
def test_deep_copy():
|
||||
import copy
|
||||
|
||||
record1 = Record(data={"text": "Hello", "number": 10})
|
||||
record2 = copy.deepcopy(record1)
|
||||
assert record2.text == "Hello"
|
||||
assert record2.number == 10
|
||||
record2.text = "World"
|
||||
assert record1.text == "Hello" # Ensure original is unchanged
|
||||
|
||||
|
||||
def test_custom_attribute_setting_and_getting():
|
||||
record = Record()
|
||||
record.dynamic_attribute = "Dynamic Value"
|
||||
assert record.dynamic_attribute == "Dynamic Value"
|
||||
|
||||
|
||||
def test_str_and_dir_methods():
|
||||
record = Record(text_key="text", data={"text": "Test Text", "key": "value"})
|
||||
assert "Test Text" in str(record)
|
||||
assert "key" in dir(record)
|
||||
assert "data" in dir(record)
|
||||
|
||||
|
||||
def test_dir_includes_data_keys():
|
||||
record = Record(data={"text": "Hello", "new_attr": "value"})
|
||||
dir_output = dir(record)
|
||||
|
||||
# Check for standard attributes
|
||||
assert "data" in dir_output
|
||||
assert "text_key" in dir_output
|
||||
assert "__add__" in dir_output # Checking for a method
|
||||
|
||||
# Check for dynamic attributes from data
|
||||
assert "text" in dir_output
|
||||
assert "new_attr" in dir_output
|
||||
|
||||
# Optionally, verify that dynamically added attributes are listed
|
||||
record.dynamic_attr = "dynamic"
|
||||
assert "dynamic_attr" in dir_output or "dynamic_attr" in dir(record) # To account for the change
|
||||
|
||||
|
||||
def test_dir_reflects_attribute_deletion():
|
||||
record = Record(data={"removable": "I can be removed"})
|
||||
assert "removable" in dir(record)
|
||||
|
||||
# Delete the attribute and check again
|
||||
del record.removable
|
||||
assert "removable" not in dir(record)
|
||||
|
||||
|
||||
def test_get_text_with_text_key():
|
||||
data = {"text": "Hello, World!"}
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
def test_get_text_without_text_key():
|
||||
data = {"other_key": "Hello, World!"}
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "default"
|
||||
|
||||
|
||||
def test_get_text_with_empty_data():
|
||||
data = {}
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "default"
|
||||
|
||||
|
||||
def test_get_text_with_none_data():
|
||||
data = None
|
||||
schema = Record(data=data, text_key="text", default_value="default")
|
||||
result = schema.get_text()
|
||||
assert result == "default"
|
||||
Loading…
Add table
Add a link
Reference in a new issue