feat: Add set_attributes method to CustomComponent

This commit adds the set_attributes method to the CustomComponent class in the custom_component.py file. The set_attributes method allows for setting attributes of the CustomComponent instance based on a dictionary of key-value pairs. This change enhances the flexibility and configurability of the CustomComponent class.
This commit is contained in:
ogabrielluiz 2024-05-30 23:05:41 -03:00
commit b0432c3802
3 changed files with 36 additions and 4 deletions

View file

@ -100,6 +100,12 @@ class CustomComponent(Component):
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
return build_config
def set_attributes(self, params: dict):
for key, value in params.items():
if key in self.__dict__:
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
setattr(self, key, value)
def update_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")

View file

@ -73,6 +73,7 @@ class Vertex:
self.parent_is_top_level = False
self.layer = None
self.result: Optional[ResultData] = None
self.results: Dict[str, Any] = {}
try:
self.is_interface_component = self.vertex_type in InterfaceComponentTypes
except ValueError:
@ -82,6 +83,9 @@ class Vertex:
self.build_times: List[float] = []
self.state = VertexStates.ACTIVE
def add_result(self, name: str, result: Any):
self.results[name] = result
def update_graph_state(self, key, new_state, append: bool):
if append:
self.graph.append_state(key, new_state, caller=self.id)
@ -196,7 +200,7 @@ class Vertex:
def _parse_data(self) -> None:
self.data = self._data["data"]
self.output = self.data["node"]["base_classes"]
self.outputs = self.data["node"]["outputs"]
self.display_name = self.data["node"].get("display_name", self.id.split("-")[0])
self.description = self.data["node"].get("description", "")
@ -224,7 +228,8 @@ class Vertex:
template_dict = self.data["node"]["template"]
self.vertex_type = (
self.data["type"]
if "Tool" not in self.output or template_dict["_type"].islower()
if "Tool" not in [type_ for out in self.outputs for type_ in out["types"]]
or template_dict["_type"].islower()
else template_dict["_type"]
)

View file

@ -1,7 +1,7 @@
import inspect
import json
import os
from typing import TYPE_CHECKING, Any, Type
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Type
import orjson
from loguru import logger
@ -94,7 +94,9 @@ def update_params_with_load_from_db_fields(
return params
async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_vars: bool = False):
async def instantiate_custom_component(
params: dict, user_id: str, vertex: "Vertex", fallback_to_env_vars: bool = False
):
params_copy = params.copy()
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
custom_component: "CustomComponent" = class_object(
@ -107,12 +109,31 @@ async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars
)
# Now set the params as attributes of the custom_component
custom_component.set_attributes(params_copy)
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
params_copy["retriever"] = params_copy["retriever"].as_retriever()
# Determine if the build method is asynchronous
is_async = inspect.iscoroutinefunction(custom_component.build)
# New feature: the component has a list of outputs and we have
# to check the vertex.edges to see which is connected (coulb be multiple)
# and then we'll get the output which has the name of the method we should call.
# the methods don't require any params because they are already set in the custom_component
# so we can just call them
if hasattr(custom_component, "outputs"):
for output in custom_component.outputs:
if output.name in vertex.edges:
method: Callable | Awaitable = getattr(custom_component, output.method)
result = method()
# If the method is asynchronous, we need to await it
if inspect.iscoroutinefunction(method):
result = await result
vertex.add_result(output.name, result)
if is_async:
# Await the build method directly if it's async
build_result = await custom_component.build(**params_copy)