refactor: enhance CustomComponent class and updates tests (#3201)

* fix: update CustomComponent to use properties for user_id and flow_id

Refactored user_id and flow_id in CustomComponent to use properties for better encapsulation and code clarity

* refactor: update CustomComponent initialization and remove unused imports

Refactored the CustomComponent class to streamline initialization and removed unnecessary import for BaseCallbackHandler

* refactor: update build_custom_component_template to use cc_instance for field order to improve consistency and clarity

* refactor: update user_id parameter in FlowToolComponent to use self.user_id for consistency

* refactor: remove unused _tree attribute and clean up imports in CustomComponent for better code clarity

* refactor: rename CustomComponent to Component for consistency in directory_reader.py import and usage

* refactor: enhance timestamp handling in _timestamp_to_str for better validation and error reporting in message.py

* refactor: preserve async get_file_content_dicts method for backwards compatibility in message.py

* refactor: update function_entrypoint_name to _function_entrypoint_name for consistency in test_custom_component.py

* feat: add client fixture for improved test structure in test_data_components.py

* feat: add unit tests for PromptComponent including template processing and custom fields in test_prompt_component.py

* feat: add dev dependencies for improved testing and development tools in pyproject.toml
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-05 16:55:29 -03:00 committed by GitHub
commit f706b05438
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1823 additions and 89 deletions

View file

@ -489,7 +489,6 @@ async def build_vertex(
artifacts = vertex_build_result.artifacts
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False)
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices)
result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
if isinstance(exc, ComponentBuildException):

View file

@ -85,7 +85,7 @@ class FlowToolComponent(LCToolComponent):
return_direct=self.return_direct,
inputs=inputs,
flow_id=str(flow_data.id),
user_id=str(self._user_id),
user_id=str(self.user_id),
)
description_repr = repr(tool.description).strip("'")
args_str = "\n".join([f"- {arg_name}: {arg_data['description']}" for arg_name, arg_data in tool.args.items()])

View file

@ -1,7 +1,7 @@
import operator
import warnings
from typing import Any, ClassVar, Optional
from uuid import UUID
import warnings
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException

View file

@ -24,11 +24,12 @@ from langflow.type_extraction.type_extraction import (
from langflow.utils import validate
if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackHandler
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.services.storage.service import StorageService
from langflow.services.tracing.service import TracingService
from langchain.callbacks.base import BaseCallbackHandler
class CustomComponent(BaseComponent):
@ -74,7 +75,7 @@ class CustomComponent(BaseComponent):
"""The build parameters of the component. Defaults to None."""
_vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
_code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
@ -85,6 +86,20 @@ class CustomComponent(BaseComponent):
_logs: List[Log] = []
_output_logs: dict[str, Log] = {}
_tracing_service: Optional["TracingService"] = None
_tree: Optional[dict] = None
def __init__(self, **data):
"""
Initializes a new instance of the CustomComponent class.
Args:
**data: Additional keyword arguments to initialize the custom component.
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self._logs = []
self._results = {}
self._artifacts = {}
super().__init__(**data)
def set_attributes(self, parameters: dict):
pass
@ -133,19 +148,6 @@ class CustomComponent(BaseComponent):
except Exception as e:
raise ValueError(f"Error getting state: {e}")
_tree: Optional[dict] = None
def __init__(self, **data):
"""
Initializes a new instance of the CustomComponent class.
Args:
**data: Additional keyword arguments to initialize the custom component.
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self._logs = []
super().__init__(**data)
@staticmethod
def resolve_path(path: str) -> str:
"""Resolves the path to an absolute path."""
@ -169,6 +171,20 @@ class CustomComponent(BaseComponent):
def graph(self):
return self._vertex.graph
@property
def user_id(self):
if hasattr(self, "_user_id"):
return self._user_id
return self.graph.user_id
@property
def flow_id(self):
return self.graph.flow_id
@property
def flow_name(self):
return self.graph.flow_name
def _get_field_order(self):
return self.field_order or list(self.field_config.keys())
@ -305,7 +321,7 @@ class CustomComponent(BaseComponent):
Returns:
list: The arguments of the function entrypoint.
"""
build_method = self.get_method(self.function_entrypoint_name)
build_method = self.get_method(self._function_entrypoint_name)
if not build_method:
return []
@ -346,9 +362,9 @@ class CustomComponent(BaseComponent):
Returns:
List[Any]: The return type of the function entrypoint.
"""
return self.get_method_return_type(self.function_entrypoint_name)
return self.get_method_return_type(self._function_entrypoint_name)
def _extract_return_type(self, return_type: Any):
def _extract_return_type(self, return_type: Any) -> List[Any]:
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,
@ -374,8 +390,8 @@ class CustomComponent(BaseComponent):
if not self._code:
return ""
base_name = self.code_class_base_inheritance
method_name = self.function_entrypoint_name
base_name = self._code_class_base_inheritance
method_name = self._function_entrypoint_name
classes = []
for item in self.tree.get("classes", []):
@ -412,12 +428,12 @@ class CustomComponent(BaseComponent):
"""
def get_variable(name: str, field: str):
if hasattr(self, "_user_id") and not self._user_id:
if hasattr(self, "_user_id") and not self.user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
variable_service = get_variable_service() # Get service instance
# Retrieve and decrypt the variable by name for the current user
with session_scope() as session:
user_id = self._user_id or ""
user_id = self.user_id or ""
return variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
return get_variable
@ -432,12 +448,12 @@ class CustomComponent(BaseComponent):
Returns:
List[str]: The names of the variables for the current user.
"""
if hasattr(self, "_user_id") and not self._user_id:
if hasattr(self, "_user_id") and not self.user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
variable_service = get_variable_service()
with session_scope() as session:
return variable_service.list_variables(user_id=self._user_id, session=session)
return variable_service.list_variables(user_id=self.user_id, session=session)
def index(self, value: int = 0):
"""
@ -462,10 +478,10 @@ class CustomComponent(BaseComponent):
Returns:
Callable: The function associated with the custom component.
"""
return validate.create_function(self._code, self.function_entrypoint_name)
return validate.create_function(self._code, self._function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
if not self._user_id:
if not self.user_id:
raise ValueError("Session is invalid")
return await load_flow(user_id=str(self._user_id), flow_id=flow_id, tweaks=tweaks)
@ -487,7 +503,7 @@ class CustomComponent(BaseComponent):
)
def list_flows(self) -> List[Data]:
if not self._user_id:
if not self.user_id:
raise ValueError("Session is invalid")
try:
return list_flows(user_id=str(self._user_id))

View file

@ -6,7 +6,7 @@ from pathlib import Path
from loguru import logger
from langflow.custom import CustomComponent
from langflow.custom import Component
class CustomComponentPathValueError(ValueError):
@ -373,7 +373,7 @@ class DirectoryReader:
"""
Get the output types from the code.
"""
custom_component = CustomComponent(_code=code)
custom_component = Component(_code=code)
types_list = custom_component.get_function_entrypoint_return_type
# Get the name of types classes

View file

@ -377,8 +377,8 @@ def build_custom_component_template_from_inputs(
frontend_node.validate_component()
# ! This should be removed when we have a better way to handle this
frontend_node.set_base_classes_from_outputs()
reorder_fields(frontend_node, custom_component._get_field_order())
cc_instance = get_component_instance(custom_component, user_id=user_id)
reorder_fields(frontend_node, cc_instance._get_field_order())
return frontend_node.to_dict(keep_name=False), cc_instance

View file

@ -24,9 +24,14 @@ from langflow.utils.constants import (
def _timestamp_to_str(timestamp: datetime | str) -> str:
if isinstance(timestamp, datetime):
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
return timestamp
if isinstance(timestamp, str):
# Just check if the string is a valid datetime
try:
datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
return timestamp
except ValueError:
raise ValueError(f"Invalid timestamp: {timestamp}")
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
class Message(Data):
@ -163,6 +168,7 @@ class Message(Data):
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
# Keep this async method for backwards compatibility
async def get_file_content_dicts(self):
content_dicts = []
files = await get_file_paths(self.files)

File diff suppressed because it is too large Load diff

View file

@ -83,8 +83,38 @@ local = ["llama-cpp-python", "sentence-transformers", "ctransformers"]
all = ["deploy", "local"]
[tool.poetry.group.dev.dependencies]
types-redis = "^4.6.0.5"
ipykernel = "^6.29.0"
mypy = "^1.11.0"
ruff = "^0.4.5"
httpx = "*"
pytest = "^8.2.0"
types-requests = "^2.32.0"
requests = "^2.32.0"
pytest-cov = "^5.0.0"
pandas-stubs = "^2.1.4.231227"
types-pillow = "^10.2.0.20240213"
types-pyyaml = "^6.0.12.8"
types-python-jose = "^3.3.4.8"
types-passlib = "^1.7.7.13"
locust = "^2.23.1"
pytest-mock = "^3.14.0"
pytest-xdist = "^3.6.0"
types-pywin32 = "^306.0.0.4"
types-google-cloud-ndb = "^2.2.0.0"
pytest-sugar = "^1.0.0"
respx = "^0.21.1"
pytest-instafail = "^0.5.0"
pytest-asyncio = "^0.23.0"
pytest-profiling = "^1.7.0"
pre-commit = "^3.7.0"
vulture = "^2.11"
dictdiffer = "^0.9.0"
pytest-split = "^0.9.0"
devtools = "^0.12.2"
[tool.pytest.ini_options]
minversion = "6.0"

View file

@ -0,0 +1,19 @@
import pytest
from langflow.components.prompts.Prompt import PromptComponent # type: ignore
@pytest.fixture
def client():
pass
class TestPromptComponent:
def test_post_code_processing(self):
component = PromptComponent(template="Hello {name}!", name="John")
frontend_node = component.to_frontend_node()
node_data = frontend_node["data"]["node"]
assert node_data["template"]["template"]["value"] == "Hello {name}!"
assert "name" in node_data["custom_fields"]["template"]
assert "name" in node_data["template"]
assert node_data["template"]["name"]["value"] == "John"

View file

@ -72,16 +72,16 @@ def test_component_init():
"""
Test the initialization of the Component class.
"""
component = BaseComponent(_code=code_default, function_entrypoint_name="build")
component = BaseComponent(_code=code_default, _function_entrypoint_name="build")
assert component._code == code_default
assert component.function_entrypoint_name == "build"
assert component._function_entrypoint_name == "build"
def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
component = BaseComponent(_code=code_default, function_entrypoint_name="build")
component = BaseComponent(_code=code_default, _function_entrypoint_name="build")
tree = component.get_code_tree(component._code)
assert "imports" in tree
@ -91,7 +91,7 @@ def test_component_code_null_error():
Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = BaseComponent(_code="", function_entrypoint_name="")
component = BaseComponent(_code="", _function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
@ -102,16 +102,16 @@ def test_custom_component_init():
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(_code=code_default, function_entrypoint_name=function_entrypoint_name)
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name=function_entrypoint_name)
assert custom_component._code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
assert custom_component._function_entrypoint_name == function_entrypoint_name
def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(_code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
config = custom_component.build_template_config()
assert isinstance(config, dict)
@ -120,7 +120,7 @@ def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(_code="def build(): pass", function_entrypoint_name="build")
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
my_function = custom_component.get_function()
assert isinstance(my_function, types.FunctionType)
@ -195,7 +195,7 @@ def test_component_get_function_valid():
Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = BaseComponent(_code="def build(): pass", function_entrypoint_name="build")
component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build")
my_function = component.get_function()
assert callable(my_function)
@ -205,7 +205,7 @@ def test_custom_component_get_function_entrypoint_args():
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(_code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 3
assert args[0]["name"] == "self"
@ -219,7 +219,7 @@ def test_custom_component_get_function_entrypoint_return_type():
property of the CustomComponent class.
"""
custom_component = CustomComponent(_code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
@ -228,7 +228,7 @@ def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(_code=code_default, function_entrypoint_name="build")
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
@ -238,7 +238,7 @@ def test_custom_component_get_function_valid():
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(_code="def build(): pass", function_entrypoint_name="build")
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
my_function = custom_component.get_function
assert callable(my_function)
@ -352,7 +352,7 @@ def test_component_get_code_tree_syntax_error():
Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = BaseComponent(_code="import os as", function_entrypoint_name="build")
component = BaseComponent(_code="import os as", _function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component._code)
@ -362,7 +362,7 @@ def test_custom_component_class_template_validation_no_code():
Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
custom_component = CustomComponent(_code=None, function_entrypoint_name="build")
custom_component = CustomComponent(_code=None, _function_entrypoint_name="build")
with pytest.raises(TypeError):
custom_component.get_function()
@ -372,7 +372,7 @@ def test_custom_component_get_code_tree_syntax_error():
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(_code="import os as", function_entrypoint_name="build")
custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component._code)
@ -387,7 +387,7 @@ class MyMainClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(_code=my_code, function_entrypoint_name="build")
custom_component = CustomComponent(_code=my_code, _function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 0
@ -402,7 +402,7 @@ class MyClass(CustomComponent):
def build():
pass"""
custom_component = CustomComponent(_code=my_code, function_entrypoint_name="build")
custom_component = CustomComponent(_code=my_code, _function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == []
@ -416,7 +416,7 @@ def test_custom_component_get_main_class_name_no_main_class():
def build():
pass"""
custom_component = CustomComponent(_code=my_code, function_entrypoint_name="build")
custom_component = CustomComponent(_code=my_code, _function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == ""
@ -426,7 +426,7 @@ def test_custom_component_build_not_implemented():
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(_code="def build(): pass", function_entrypoint_name="build")
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
with pytest.raises(NotImplementedError):
custom_component.build()

View file

@ -12,6 +12,11 @@ from httpx import Response
from langflow.components import data
@pytest.fixture
def client():
pass
@pytest.fixture
def api_request():
# This fixture provides an instance of APIRequest for each test case