From c8d04af53386b90ac63fe0367f176fabe4a481a8 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 15 Jul 2024 15:37:10 -0300 Subject: [PATCH] refactor(CustomComponent): make initialization separate from constructor (#2704) --- .../custom/custom_component/component.py | 4 +++- .../custom/custom_component/custom_component.py | 17 ++++++++++++++++- .../langflow/interface/initialize/loading.py | 7 +++---- src/backend/base/langflow/utils/validate.py | 8 ++++---- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index c34891984..aeddc6093 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -40,14 +40,16 @@ class Component(CustomComponent): code_class_base_inheritance: ClassVar[str] = "Component" def __init__(self, **data): - super().__init__(**data) self._inputs: dict[str, InputTypes] = {} self._results: dict[str, Any] = {} self._attributes: dict[str, Any] = {} + self._parameters: dict[str, Any] = {} + super().__init__(**data) if not hasattr(self, "trace_type"): self.trace_type = "chain" if self.inputs is not None: self.map_inputs(self.inputs) + self.set_attributes(self._parameters) def __getattr__(self, name: str) -> Any: if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]: diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index eca85c1c8..a3b4c57c7 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -14,7 +14,7 @@ from langflow.schema.artifact import get_artifact_type from langflow.schema.dotdict import dotdict from langflow.schema.log import LoggableType from langflow.schema.schema import OutputLog -from langflow.services.deps import get_storage_service, get_variable_service, session_scope +from langflow.services.deps import get_storage_service, get_tracing_service, get_variable_service, session_scope from langflow.services.storage.service import StorageService from langflow.services.tracing.schema import Log from langflow.template.utils import update_frontend_node_with_template_values @@ -88,6 +88,21 @@ class CustomComponent(BaseComponent): _logs: List[Log] = [] tracing_service: Optional["TracingService"] = None + def set_attributes(self, parameters: dict): + pass + + def set_parameters(self, parameters: dict): + self._parameters = parameters + self.set_attributes(self._parameters) + + @classmethod + def initialize(cls, **kwargs): + user_id = kwargs.pop("user_id", None) + vertex = kwargs.pop("vertex", None) + tracing_service = kwargs.pop("tracing_service", get_tracing_service()) + params_copy = kwargs.copy() + return cls(user_id=user_id, _parameters=params_copy, vertex=vertex, tracing_service=tracing_service) + @property def trace_name(self): return f"{self.display_name} ({self.vertex.id})" diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index 3e6a26962..ec67649a2 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -57,7 +57,7 @@ async def build_component_and_get_results( params_copy = params.copy() # Remove code from params class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(params_copy.pop("code")) - custom_component: "CustomComponent" | "Component" = class_object( + custom_component: "CustomComponent" | "Component" = class_object.initialize( user_id=user_id, parameters=params_copy, vertex=vertex, @@ -66,12 +66,13 @@ async def build_component_and_get_results( params_copy = update_params_with_load_from_db_fields( custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars ) + custom_component.set_parameters(params_copy) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) if base_type == "custom_components" and isinstance(custom_component, CustomComponent): return await build_custom_component(params=params_copy, custom_component=custom_component) elif base_type == "component" and isinstance(custom_component, Component): - return await build_component(params=params_copy, custom_component=custom_component) + return await build_component(custom_component=custom_component) else: raise ValueError(f"Base type {base_type} not found.") @@ -146,11 +147,9 @@ def update_params_with_load_from_db_fields( async def build_component( - params: dict, custom_component: "Component", ): # Now set the params as attributes of the custom_component - custom_component.set_attributes(params) build_results, artifacts = await custom_component.build_results() return custom_component, build_results, artifacts diff --git a/src/backend/base/langflow/utils/validate.py b/src/backend/base/langflow/utils/validate.py index d3edef561..30358eb48 100644 --- a/src/backend/base/langflow/utils/validate.py +++ b/src/backend/base/langflow/utils/validate.py @@ -250,17 +250,17 @@ def build_class_constructor(compiled_class, exec_globals, class_name): exec_globals[class_name] = locals()[class_name] # Return a function that imports necessary modules and creates an instance of the target class - def build_custom_class(*args, **kwargs): + def build_custom_class(): for module_name, module in exec_globals.items(): if isinstance(module, type(importlib)): globals()[module_name] = module - instance = exec_globals[class_name](*args, **kwargs) + exec_globals[class_name] - return instance + return exec_globals[class_name] build_custom_class.__globals__.update(exec_globals) - return build_custom_class + return build_custom_class() def get_default_imports(code_string):