From b0432c380238b89dce709368417832b77e7b01d5 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Thu, 30 May 2024 23:05:41 -0300 Subject: [PATCH] feat: Add set_attributes method to CustomComponent This commit adds the set_attributes method to the CustomComponent class in the custom_component.py file. The set_attributes method allows for setting attributes of the CustomComponent instance based on a dictionary of key-value pairs. This change enhances the flexibility and configurability of the CustomComponent class. --- .../custom_component/custom_component.py | 6 +++++ .../base/langflow/graph/vertex/base.py | 9 +++++-- .../langflow/interface/initialize/loading.py | 25 +++++++++++++++++-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 535bf0cf3..d23a0b71f 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -100,6 +100,12 @@ class CustomComponent(Component): build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs} return build_config + def set_attributes(self, params: dict): + for key, value in params.items(): + if key in self.__dict__: + raise ValueError(f"Key {key} already exists in {self.__class__.__name__}") + setattr(self, key, value) + def update_state(self, name: str, value: Any): if not self.vertex: raise ValueError("Vertex is not set") diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index 77b4e7192..cae5ac976 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -73,6 +73,7 @@ class Vertex: self.parent_is_top_level = False self.layer = None self.result: Optional[ResultData] = None + self.results: Dict[str, Any] = {} try: self.is_interface_component = self.vertex_type in InterfaceComponentTypes except ValueError: @@ -82,6 +83,9 @@ class Vertex: self.build_times: List[float] = [] self.state = VertexStates.ACTIVE + def add_result(self, name: str, result: Any): + self.results[name] = result + def update_graph_state(self, key, new_state, append: bool): if append: self.graph.append_state(key, new_state, caller=self.id) @@ -196,7 +200,7 @@ class Vertex: def _parse_data(self) -> None: self.data = self._data["data"] - self.output = self.data["node"]["base_classes"] + self.outputs = self.data["node"]["outputs"] self.display_name = self.data["node"].get("display_name", self.id.split("-")[0]) self.description = self.data["node"].get("description", "") @@ -224,7 +228,8 @@ class Vertex: template_dict = self.data["node"]["template"] self.vertex_type = ( self.data["type"] - if "Tool" not in self.output or template_dict["_type"].islower() + if "Tool" not in [type_ for out in self.outputs for type_ in out["types"]] + or template_dict["_type"].islower() else template_dict["_type"] ) diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index 03de827b3..098ff8a33 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -1,7 +1,7 @@ import inspect import json import os -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Type import orjson from loguru import logger @@ -94,7 +94,9 @@ def update_params_with_load_from_db_fields( return params -async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_vars: bool = False): +async def instantiate_custom_component( + params: dict, user_id: str, vertex: "Vertex", fallback_to_env_vars: bool = False +): params_copy = params.copy() class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code")) custom_component: "CustomComponent" = class_object( @@ -107,12 +109,31 @@ async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_ custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars ) + # Now set the params as attributes of the custom_component + custom_component.set_attributes(params_copy) + if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"): params_copy["retriever"] = params_copy["retriever"].as_retriever() # Determine if the build method is asynchronous is_async = inspect.iscoroutinefunction(custom_component.build) + # New feature: the component has a list of outputs and we have + # to check the vertex.edges to see which is connected (coulb be multiple) + # and then we'll get the output which has the name of the method we should call. + # the methods don't require any params because they are already set in the custom_component + # so we can just call them + + if hasattr(custom_component, "outputs"): + for output in custom_component.outputs: + if output.name in vertex.edges: + method: Callable | Awaitable = getattr(custom_component, output.method) + result = method() + # If the method is asynchronous, we need to await it + if inspect.iscoroutinefunction(method): + result = await result + vertex.add_result(output.name, result) + if is_async: # Await the build method directly if it's async build_result = await custom_component.build(**params_copy)