feat: Add set_attributes method to CustomComponent
This commit adds the set_attributes method to the CustomComponent class in the custom_component.py file. The set_attributes method allows for setting attributes of the CustomComponent instance based on a dictionary of key-value pairs. This change enhances the flexibility and configurability of the CustomComponent class.
This commit is contained in:
parent
98a09174d9
commit
b0432c3802
3 changed files with 36 additions and 4 deletions
|
|
@ -100,6 +100,12 @@ class CustomComponent(Component):
|
|||
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
|
||||
return build_config
|
||||
|
||||
def set_attributes(self, params: dict):
|
||||
for key, value in params.items():
|
||||
if key in self.__dict__:
|
||||
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
|
||||
setattr(self, key, value)
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class Vertex:
|
|||
self.parent_is_top_level = False
|
||||
self.layer = None
|
||||
self.result: Optional[ResultData] = None
|
||||
self.results: Dict[str, Any] = {}
|
||||
try:
|
||||
self.is_interface_component = self.vertex_type in InterfaceComponentTypes
|
||||
except ValueError:
|
||||
|
|
@ -82,6 +83,9 @@ class Vertex:
|
|||
self.build_times: List[float] = []
|
||||
self.state = VertexStates.ACTIVE
|
||||
|
||||
def add_result(self, name: str, result: Any):
|
||||
self.results[name] = result
|
||||
|
||||
def update_graph_state(self, key, new_state, append: bool):
|
||||
if append:
|
||||
self.graph.append_state(key, new_state, caller=self.id)
|
||||
|
|
@ -196,7 +200,7 @@ class Vertex:
|
|||
|
||||
def _parse_data(self) -> None:
|
||||
self.data = self._data["data"]
|
||||
self.output = self.data["node"]["base_classes"]
|
||||
self.outputs = self.data["node"]["outputs"]
|
||||
self.display_name = self.data["node"].get("display_name", self.id.split("-")[0])
|
||||
|
||||
self.description = self.data["node"].get("description", "")
|
||||
|
|
@ -224,7 +228,8 @@ class Vertex:
|
|||
template_dict = self.data["node"]["template"]
|
||||
self.vertex_type = (
|
||||
self.data["type"]
|
||||
if "Tool" not in self.output or template_dict["_type"].islower()
|
||||
if "Tool" not in [type_ for out in self.outputs for type_ in out["types"]]
|
||||
or template_dict["_type"].islower()
|
||||
else template_dict["_type"]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Type
|
||||
|
||||
import orjson
|
||||
from loguru import logger
|
||||
|
|
@ -94,7 +94,9 @@ def update_params_with_load_from_db_fields(
|
|||
return params
|
||||
|
||||
|
||||
async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_vars: bool = False):
|
||||
async def instantiate_custom_component(
|
||||
params: dict, user_id: str, vertex: "Vertex", fallback_to_env_vars: bool = False
|
||||
):
|
||||
params_copy = params.copy()
|
||||
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component: "CustomComponent" = class_object(
|
||||
|
|
@ -107,12 +109,31 @@ async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_
|
|||
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars
|
||||
)
|
||||
|
||||
# Now set the params as attributes of the custom_component
|
||||
custom_component.set_attributes(params_copy)
|
||||
|
||||
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
|
||||
params_copy["retriever"] = params_copy["retriever"].as_retriever()
|
||||
|
||||
# Determine if the build method is asynchronous
|
||||
is_async = inspect.iscoroutinefunction(custom_component.build)
|
||||
|
||||
# New feature: the component has a list of outputs and we have
|
||||
# to check the vertex.edges to see which is connected (coulb be multiple)
|
||||
# and then we'll get the output which has the name of the method we should call.
|
||||
# the methods don't require any params because they are already set in the custom_component
|
||||
# so we can just call them
|
||||
|
||||
if hasattr(custom_component, "outputs"):
|
||||
for output in custom_component.outputs:
|
||||
if output.name in vertex.edges:
|
||||
method: Callable | Awaitable = getattr(custom_component, output.method)
|
||||
result = method()
|
||||
# If the method is asynchronous, we need to await it
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await result
|
||||
vertex.add_result(output.name, result)
|
||||
|
||||
if is_async:
|
||||
# Await the build method directly if it's async
|
||||
build_result = await custom_component.build(**params_copy)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue