refactor: update CustomComponent constructor (#3114)
* refactor: update code references to use _code instead of code * refactor: add backwards compatible attributes to Component class * refactor: update Component constructor to pass config params with underscore Refactored the `Component` class in `component.py` to handle inputs and outputs. Added a new method `map_outputs` to map a list of outputs to the component. Also updated the `__init__` method to properly initialize the inputs, outputs, and other attributes. This change improves the flexibility and extensibility of the `Component` class. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: change attribute to use underscore * refactor: update CustomComponent initialization parameters Refactored the `instantiate_class` function in `loading.py` to update the initialization parameters for the `CustomComponent` class. Changed the parameter names from `user_id`, `parameters`, `vertex`, and `tracing_service` to `_user_id`, `_parameters`, `_vertex`, and `_tracing_service` respectively. This change ensures consistency and improves code readability. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: update BaseComponent to accept UUID for _user_id Updated the `BaseComponent` class in `base_component.py` to accept a `UUID` type for the `_user_id` attribute. This change improves the type safety and ensures consistency with the usage of `_user_id` throughout the codebase. * refactor: import nanoid with type annotation The `nanoid` import in `component.py` has been updated to include a type annotation `# type: ignore`. This change ensures that the type checker ignores any errors related to the `nanoid` import. * fix(custom_component.py): convert _user_id to string before passing to functions to ensure compatibility with function signatures * refactor(utils.py): refactor code to use _user_id instead of user_id for consistency and clarity perf(utils.py): optimize code by reusing cc_instance instead of calling get_component_instance multiple times * [autofix.ci] apply automated fixes * refactor: update BaseComponent to use get_template_config method Refactored the `BaseComponent` class in `base_component.py` to use the `get_template_config` method instead of duplicating the code. This change improves code readability and reduces redundancy. * refactor: update build_custom_component_template to use add_name instead of keep_name Refactor the `build_custom_component_template` function in `utils.py` to use the `add_name` parameter instead of the deprecated `keep_name` parameter. This change ensures consistency with the updated method signature and improves code clarity. * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components (#3115) * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components * refactor: extract method to get method return type in CustomComponent * refactor: update _extract_return_type method in CustomComponent to accept Any type The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types. * refactor: add _template_config property to BaseComponent Add a new `_template_config` property to the `BaseComponent` class in `base_component.py`. This property is used to store the template configuration for the custom component. If the `_template_config` property is empty, it is populated by calling the `build_template_config` method. This change improves the efficiency of accessing the template configuration and ensures that it is only built when needed. * refactor: add type checking for Output types in add_types method Improve type checking in the `add_types` method of the `Output` class in `base.py`. Check if the `type_` already exists in the `types` list before adding it. This change ensures that duplicate types are not added to the list. * update starter projects * refactor: optimize imports in base.py Optimize imports in the `base.py` file by removing unused imports and organizing the remaining imports. This change improves code readability and reduces unnecessary clutter. * fix(base.py): fix condition to check if self.types is not None before checking if type_ is in self.types --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1c3ee13a85
commit
62191d92ae
14 changed files with 180 additions and 101 deletions
|
|
@ -51,8 +51,8 @@ class BaseCrewComponent(Component):
|
|||
self,
|
||||
) -> Callable:
|
||||
def task_callback(task_output: TaskOutput):
|
||||
if self.vertex:
|
||||
vertex_id = self.vertex.id
|
||||
if self._vertex:
|
||||
vertex_id = self._vertex.id
|
||||
else:
|
||||
vertex_id = self.display_name or self.__class__.__name__
|
||||
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")
|
||||
|
|
@ -63,7 +63,7 @@ class BaseCrewComponent(Component):
|
|||
self,
|
||||
) -> Callable:
|
||||
def step_callback(agent_output: Union[AgentFinish, List[Tuple[AgentAction, str]]]):
|
||||
_id = self.vertex.id if self.vertex else self.display_name
|
||||
_id = self._vertex.id if self._vertex else self.display_name
|
||||
if isinstance(agent_output, AgentFinish):
|
||||
messages = agent_output.messages
|
||||
self.log(cast(dict, messages[0].to_json()), name=f"Finish (Agent: {_id})")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union, List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
|
@ -10,7 +10,7 @@ from langflow.base.constants import STREAM_INFO_TEXT
|
|||
from langflow.custom import Component
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs import MessageInput, MessageTextInput
|
||||
from langflow.inputs.inputs import InputTypes, BoolInput
|
||||
from langflow.inputs.inputs import BoolInput, InputTypes
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template.field.base import Output
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ class LCModelComponent(Component):
|
|||
inputs: Union[list, dict] = messages or {}
|
||||
try:
|
||||
runnable = runnable.with_config( # type: ignore
|
||||
{"run_name": self.display_name, "project_name": self.tracing_service.project_name} # type: ignore
|
||||
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
|
||||
)
|
||||
if stream:
|
||||
return runnable.stream(inputs) # type: ignore
|
||||
|
|
|
|||
|
|
@ -23,6 +23,6 @@ class ListenComponent(CustomComponent):
|
|||
return state
|
||||
|
||||
def _set_successors_ids(self):
|
||||
self.vertex.is_state = True
|
||||
successors = self.vertex.graph.successor_map.get(self.vertex.id, [])
|
||||
return successors + self.vertex.graph.activated_vertices
|
||||
self._vertex.is_state = True
|
||||
successors = self._vertex.graph.successor_map.get(self._vertex.id, [])
|
||||
return successors + self._vertex.graph.activated_vertices
|
||||
|
|
|
|||
|
|
@ -43,6 +43,6 @@ class NotifyComponent(CustomComponent):
|
|||
return data
|
||||
|
||||
def _set_successors_ids(self):
|
||||
self.vertex.is_state = True
|
||||
successors = self.vertex.graph.successor_map.get(self.vertex.id, [])
|
||||
return successors + self.vertex.graph.activated_vertices
|
||||
self._vertex.is_state = True
|
||||
successors = self._vertex.graph.successor_map.get(self._vertex.id, [])
|
||||
return successors + self._vertex.graph.activated_vertices
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import operator
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -27,7 +28,8 @@ class BaseComponent:
|
|||
"""The code of the component. Defaults to None."""
|
||||
_function_entrypoint_name: str = "build"
|
||||
field_config: dict = {}
|
||||
_user_id: Optional[str]
|
||||
_user_id: Optional[str | UUID] = None
|
||||
_template_config: dict = {}
|
||||
|
||||
def __init__(self, **data):
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
|
|
@ -38,7 +40,7 @@ class BaseComponent:
|
|||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "_user_id" and hasattr(self, "_user_id"):
|
||||
if key == "_user_id" and hasattr(self, "_user_id") and getattr(self, "_user_id") is not None:
|
||||
warnings.warn("user_id is immutable and cannot be changed.")
|
||||
super().__setattr__(key, value)
|
||||
|
||||
|
|
@ -65,6 +67,25 @@ class BaseComponent:
|
|||
|
||||
return validate.create_function(self._code, self._function_entrypoint_name)
|
||||
|
||||
@staticmethod
|
||||
def get_template_config(component):
|
||||
"""
|
||||
Gets the template configuration for the custom component itself.
|
||||
"""
|
||||
template_config = {}
|
||||
|
||||
for attribute, func in ATTR_FUNC_MAPPING.items():
|
||||
if hasattr(component, attribute):
|
||||
value = getattr(component, attribute)
|
||||
if value is not None:
|
||||
template_config[attribute] = func(value=value)
|
||||
|
||||
for key in template_config.copy():
|
||||
if key not in ATTR_FUNC_MAPPING.keys():
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
||||
def build_template_config(self) -> dict:
|
||||
"""
|
||||
Builds the template configuration for the custom component.
|
||||
|
|
@ -77,18 +98,7 @@ class BaseComponent:
|
|||
|
||||
cc_class = eval_custom_component_code(self._code)
|
||||
component_instance = cc_class()
|
||||
template_config = {}
|
||||
|
||||
for attribute, func in ATTR_FUNC_MAPPING.items():
|
||||
if hasattr(component_instance, attribute):
|
||||
value = getattr(component_instance, attribute)
|
||||
if value is not None:
|
||||
template_config[attribute] = func(value=value)
|
||||
|
||||
for key in template_config.copy():
|
||||
if key not in ATTR_FUNC_MAPPING.keys():
|
||||
template_config.pop(key, None)
|
||||
|
||||
template_config = self.get_template_config(component_instance)
|
||||
return template_config
|
||||
|
||||
def build(self, *args: Any, **kwargs: Any) -> Any:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import inspect
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Union
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Union, get_type_hints
|
||||
from uuid import UUID
|
||||
|
||||
import nanoid # type: ignore
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.helpers.custom import format_type
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
from langflow.schema.data import Data
|
||||
|
|
@ -14,6 +16,8 @@ from langflow.template.field.base import UNDEFINED, Output
|
|||
|
||||
from .custom_component import CustomComponent
|
||||
|
||||
BACKWARDS_COMPATIBLE_ATTRIBUTES = ["user_id", "vertex", "tracing_service"]
|
||||
|
||||
|
||||
class Component(CustomComponent):
|
||||
inputs: List[InputTypes] = []
|
||||
|
|
@ -21,24 +25,45 @@ class Component(CustomComponent):
|
|||
code_class_base_inheritance: ClassVar[str] = "Component"
|
||||
_output_logs: dict[str, Log] = {}
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **kwargs):
|
||||
# if key starts with _ it is a config
|
||||
# else it is an input
|
||||
inputs = {}
|
||||
config = {}
|
||||
for key, value in kwargs.items():
|
||||
if key.startswith("_"):
|
||||
config[key] = value
|
||||
else:
|
||||
inputs[key] = value
|
||||
self._inputs: dict[str, InputTypes] = {}
|
||||
self._outputs: dict[str, Output] = {}
|
||||
self._results: dict[str, Any] = {}
|
||||
self._attributes: dict[str, Any] = {}
|
||||
self._parameters: dict[str, Any] = {}
|
||||
self._parameters = inputs or {}
|
||||
self._components: list[Component] = []
|
||||
self.set_attributes(self._parameters)
|
||||
self._output_logs = {}
|
||||
super().__init__(**data)
|
||||
config = config or {}
|
||||
if "_id" not in config:
|
||||
config |= {"_id": f"{self.__class__.__name__}-{nanoid.generate(size=5)}"}
|
||||
super().__init__(**config)
|
||||
if hasattr(self, "_trace_type"):
|
||||
self.trace_type = self._trace_type
|
||||
if not hasattr(self, "trace_type"):
|
||||
self.trace_type = "chain"
|
||||
if self.inputs is not None:
|
||||
self.map_inputs(self.inputs)
|
||||
self.set_attributes(self._parameters)
|
||||
if self.outputs is not None:
|
||||
self.map_outputs(self.outputs)
|
||||
self._set_output_types()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]:
|
||||
return self.__dict__["_attributes"][name]
|
||||
if "_inputs" in self.__dict__ and name in self.__dict__["_inputs"]:
|
||||
return self.__dict__["_inputs"][name].value
|
||||
if name in BACKWARDS_COMPATIBLE_ATTRIBUTES:
|
||||
return self.__dict__[f"_{name}"]
|
||||
raise AttributeError(f"{name} not found in {self.__class__.__name__}")
|
||||
|
||||
def map_inputs(self, inputs: List[InputTypes]):
|
||||
|
|
@ -48,10 +73,47 @@ class Component(CustomComponent):
|
|||
raise ValueError("Input name cannot be None.")
|
||||
self._inputs[input_.name] = input_
|
||||
|
||||
def map_outputs(self, outputs: List[Output]):
|
||||
"""
|
||||
Maps the given list of outputs to the component.
|
||||
Args:
|
||||
outputs (List[Output]): The list of outputs to be mapped.
|
||||
Raises:
|
||||
ValueError: If the output name is None.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.outputs = outputs
|
||||
for output in outputs:
|
||||
if output.name is None:
|
||||
raise ValueError("Output name cannot be None.")
|
||||
self._outputs[output.name] = output
|
||||
|
||||
def validate(self, params: dict):
|
||||
self._validate_inputs(params)
|
||||
self._validate_outputs()
|
||||
|
||||
def _set_output_types(self):
|
||||
for output in self.outputs:
|
||||
return_types = self._get_method_return_type(output.method)
|
||||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
|
||||
def _get_method_return_type(self, method_name: str) -> List[str]:
|
||||
method = getattr(self, method_name)
|
||||
return_type = get_type_hints(method)["return"]
|
||||
extracted_return_types = self._extract_return_type(return_type)
|
||||
return [format_type(extracted_return_type) for extracted_return_type in extracted_return_types]
|
||||
|
||||
def _get_output_by_method(self, method: Callable):
|
||||
# method is a callable and output.method is a string
|
||||
# we need to find the output that has the same method
|
||||
output = next((output for output in self.outputs if output.method == method.__name__), None)
|
||||
if output is None:
|
||||
method_name = method.__name__ if hasattr(method, "__name__") else str(method)
|
||||
raise ValueError(f"Output with method {method_name} not found")
|
||||
return output
|
||||
|
||||
def _validate_outputs(self):
|
||||
# Raise Error if some rule isn't met
|
||||
pass
|
||||
|
|
@ -106,9 +168,9 @@ class Component(CustomComponent):
|
|||
async def _build_with_tracing(self):
|
||||
inputs = self.get_trace_as_inputs()
|
||||
metadata = self.get_trace_as_metadata()
|
||||
async with self.tracing_service.trace_context(self, self.trace_name, inputs, metadata):
|
||||
async with self._tracing_service.trace_context(self, self.trace_name, inputs, metadata):
|
||||
_results, _artifacts = await self._build_results()
|
||||
self.tracing_service.set_outputs(self.trace_name, _results)
|
||||
self._tracing_service.set_outputs(self.trace_name, _results)
|
||||
|
||||
return _results, _artifacts
|
||||
|
||||
|
|
@ -116,7 +178,7 @@ class Component(CustomComponent):
|
|||
return await self._build_results()
|
||||
|
||||
async def build_results(self):
|
||||
if self.tracing_service:
|
||||
if self._tracing_service:
|
||||
return await self._build_with_tracing()
|
||||
return await self._build_without_tracing()
|
||||
|
||||
|
|
@ -124,11 +186,11 @@ class Component(CustomComponent):
|
|||
_results = {}
|
||||
_artifacts = {}
|
||||
if hasattr(self, "outputs"):
|
||||
self._set_outputs(self.vertex.outputs)
|
||||
self._set_outputs(self._vertex.outputs)
|
||||
for output in self.outputs:
|
||||
# Build the output if it's connected to some other vertex
|
||||
# or if it's not connected to any vertex
|
||||
if not self.vertex.outgoing_edges or output.name in self.vertex.edges_source_names:
|
||||
if not self._vertex.outgoing_edges or output.name in self._vertex.edges_source_names:
|
||||
if output.method is None:
|
||||
raise ValueError(f"Output {output.name} does not have a method defined.")
|
||||
method: Callable = getattr(self, output.method)
|
||||
|
|
@ -142,9 +204,9 @@ class Component(CustomComponent):
|
|||
if (
|
||||
isinstance(result, Message)
|
||||
and result.flow_id is None
|
||||
and self.vertex.graph.flow_id is not None
|
||||
and self._vertex.graph.flow_id is not None
|
||||
):
|
||||
result.set_flow_id(self.vertex.graph.flow_id)
|
||||
result.set_flow_id(self._vertex.graph.flow_id)
|
||||
_results[output.name] = result
|
||||
output.value = result
|
||||
custom_repr = self.custom_repr()
|
||||
|
|
@ -176,8 +238,8 @@ class Component(CustomComponent):
|
|||
self._logs = []
|
||||
self._artifacts = _artifacts
|
||||
self._results = _results
|
||||
if self.tracing_service:
|
||||
self.tracing_service.set_outputs(self.trace_name, _results)
|
||||
if self._tracing_service:
|
||||
self._tracing_service.set_outputs(self.trace_name, _results)
|
||||
return _results, _artifacts
|
||||
|
||||
def custom_repr(self):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
from cachetools import TTLCache
|
||||
|
|
@ -14,7 +13,7 @@ from langflow.schema.artifact import get_artifact_type
|
|||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.schema.log import LoggableType
|
||||
from langflow.schema.schema import OutputValue
|
||||
from langflow.services.deps import get_storage_service, get_tracing_service, get_variable_service, session_scope
|
||||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.services.tracing.schema import Log
|
||||
from langflow.template.utils import update_frontend_node_with_template_values
|
||||
|
|
@ -72,20 +71,19 @@ class CustomComponent(BaseComponent):
|
|||
"""The default frozen state of the component. Defaults to False."""
|
||||
build_parameters: Optional[dict] = None
|
||||
"""The build parameters of the component. Defaults to None."""
|
||||
vertex: Optional["Vertex"] = None
|
||||
_vertex: Optional["Vertex"] = None
|
||||
"""The edge target parameter of the component. Defaults to None."""
|
||||
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
|
||||
function_entrypoint_name: ClassVar[str] = "build"
|
||||
function: Optional[Callable] = None
|
||||
repr_value: Optional[Any] = ""
|
||||
user_id: Optional[Union[UUID, str]] = None
|
||||
status: Optional[Any] = None
|
||||
"""The status of the component. This is displayed on the frontend. Defaults to None."""
|
||||
_flows_data: Optional[List[Data]] = None
|
||||
_outputs: List[OutputValue] = []
|
||||
_logs: List[Log] = []
|
||||
_output_logs: dict[str, Log] = {}
|
||||
tracing_service: Optional["TracingService"] = None
|
||||
_tracing_service: Optional["TracingService"] = None
|
||||
|
||||
def set_attributes(self, parameters: dict):
|
||||
pass
|
||||
|
|
@ -94,51 +92,43 @@ class CustomComponent(BaseComponent):
|
|||
self._parameters = parameters
|
||||
self.set_attributes(self._parameters)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, **kwargs):
|
||||
user_id = kwargs.pop("user_id", None)
|
||||
vertex = kwargs.pop("vertex", None)
|
||||
tracing_service = kwargs.pop("tracing_service", get_tracing_service())
|
||||
params_copy = kwargs.copy()
|
||||
return cls(user_id=user_id, _parameters=params_copy, vertex=vertex, tracing_service=tracing_service)
|
||||
|
||||
@property
|
||||
def trace_name(self):
|
||||
return f"{self.display_name} ({self.vertex.id})"
|
||||
return f"{self.display_name} ({self._vertex.id})"
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
if not self._vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
|
||||
self._vertex.graph.update_state(name=name, record=value, caller=self._vertex.id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error updating state: {e}")
|
||||
|
||||
def stop(self, output_name: str | None = None):
|
||||
if not output_name and self.vertex and len(self.vertex.outputs) == 1:
|
||||
output_name = self.vertex.outputs[0]["name"]
|
||||
if not output_name and self._vertex and len(self._vertex.outputs) == 1:
|
||||
output_name = self._vertex.outputs[0]["name"]
|
||||
elif not output_name:
|
||||
raise ValueError("You must specify an output name to call stop")
|
||||
if not self.vertex:
|
||||
if not self._vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
self.graph.mark_branch(vertex_id=self.vertex.id, output_name=output_name, state="INACTIVE")
|
||||
self.graph.mark_branch(vertex_id=self._vertex.id, output_name=output_name, state="INACTIVE")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error stopping {self.display_name}: {e}")
|
||||
|
||||
def append_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
if not self._vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
|
||||
self._vertex.graph.append_state(name=name, record=value, caller=self._vertex.id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error appending state: {e}")
|
||||
|
||||
def get_state(self, name: str):
|
||||
if not self.vertex:
|
||||
if not self._vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
return self.vertex.graph.get_state(name=name)
|
||||
return self._vertex.graph.get_state(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting state: {e}")
|
||||
|
||||
|
|
@ -176,7 +166,7 @@ class CustomComponent(BaseComponent):
|
|||
|
||||
@property
|
||||
def graph(self):
|
||||
return self.vertex.graph
|
||||
return self._vertex.graph
|
||||
|
||||
def _get_field_order(self):
|
||||
return self.field_order or list(self.field_config.keys())
|
||||
|
|
@ -277,6 +267,14 @@ class CustomComponent(BaseComponent):
|
|||
|
||||
return data_objects
|
||||
|
||||
def get_method_return_type(self, method_name: str):
|
||||
build_method = self.get_method(method_name)
|
||||
if not build_method or not build_method.get("has_return"):
|
||||
return []
|
||||
return_type = build_method["return_type"]
|
||||
|
||||
return self._extract_return_type(return_type)
|
||||
|
||||
def create_references_from_data(self, data: List[Data], include_data: bool = False) -> str:
|
||||
"""
|
||||
Create references from a list of data.
|
||||
|
|
@ -349,12 +347,7 @@ class CustomComponent(BaseComponent):
|
|||
"""
|
||||
return self.get_method_return_type(self.function_entrypoint_name)
|
||||
|
||||
def get_method_return_type(self, method_name: str):
|
||||
build_method = self.get_method(method_name)
|
||||
if not build_method or not build_method.get("has_return"):
|
||||
return []
|
||||
return_type = build_method["return_type"]
|
||||
|
||||
def _extract_return_type(self, return_type: Any):
|
||||
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
|
||||
list,
|
||||
List,
|
||||
|
|
@ -401,7 +394,9 @@ class CustomComponent(BaseComponent):
|
|||
Returns:
|
||||
dict: The template configuration for the custom component.
|
||||
"""
|
||||
return self.build_template_config()
|
||||
if not self._template_config:
|
||||
self._template_config = self.build_template_config()
|
||||
return self._template_config
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
|
|
@ -471,7 +466,7 @@ class CustomComponent(BaseComponent):
|
|||
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
|
||||
if not self._user_id:
|
||||
raise ValueError("Session is invalid")
|
||||
return await load_flow(user_id=self._user_id, flow_id=flow_id, tweaks=tweaks)
|
||||
return await load_flow(user_id=str(self._user_id), flow_id=flow_id, tweaks=tweaks)
|
||||
|
||||
async def run_flow(
|
||||
self,
|
||||
|
|
@ -487,14 +482,14 @@ class CustomComponent(BaseComponent):
|
|||
flow_id=flow_id,
|
||||
flow_name=flow_name,
|
||||
tweaks=tweaks,
|
||||
user_id=self._user_id,
|
||||
user_id=str(self._user_id),
|
||||
)
|
||||
|
||||
def list_flows(self) -> List[Data]:
|
||||
if not self._user_id:
|
||||
raise ValueError("Session is invalid")
|
||||
try:
|
||||
return list_flows(user_id=self._user_id)
|
||||
return list_flows(user_id=str(self._user_id))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error listing flows: {e}")
|
||||
|
||||
|
|
@ -522,8 +517,8 @@ class CustomComponent(BaseComponent):
|
|||
name = f"Log {len(self._logs) + 1}"
|
||||
log = Log(message=message, type=get_artifact_type(message), name=name)
|
||||
self._logs.append(log)
|
||||
if self.tracing_service and self.vertex:
|
||||
self.tracing_service.add_log(trace_name=self.trace_name, log=log)
|
||||
if self._tracing_service and self._vertex:
|
||||
self._tracing_service.add_log(trace_name=self.trace_name, log=log)
|
||||
|
||||
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ def get_component_instance(custom_component: CustomComponent, user_id: Optional[
|
|||
) from exc
|
||||
|
||||
try:
|
||||
custom_instance = custom_class(user_id=user_id)
|
||||
custom_instance = custom_class(_user_id=user_id)
|
||||
return custom_instance
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while instantiating custom component: {str(exc)}")
|
||||
|
|
@ -317,7 +317,7 @@ def run_build_config(
|
|||
) from exc
|
||||
|
||||
try:
|
||||
custom_instance = custom_class(user_id=user_id)
|
||||
custom_instance = custom_class(_user_id=user_id)
|
||||
build_config: Dict = custom_instance.build_config()
|
||||
|
||||
for field_name, field in build_config.copy().items():
|
||||
|
|
@ -361,14 +361,15 @@ def build_custom_component_template_from_inputs(
|
|||
custom_component: Union[Component, CustomComponent], user_id: Optional[Union[str, UUID]] = None
|
||||
):
|
||||
# The List of Inputs fills the role of the build_config and the entrypoint_args
|
||||
field_config = custom_component.template_config
|
||||
cc_instance = get_component_instance(custom_component, user_id=user_id)
|
||||
field_config = cc_instance.get_template_config(cc_instance)
|
||||
frontend_node = ComponentFrontendNode.from_inputs(**field_config)
|
||||
frontend_node = add_code_field(frontend_node, custom_component._code, field_config.get("code", {}))
|
||||
# But we now need to calculate the return_type of the methods in the outputs
|
||||
for output in frontend_node.outputs:
|
||||
if output.types:
|
||||
continue
|
||||
return_types = custom_component.get_method_return_type(output.method)
|
||||
return_types = cc_instance.get_method_return_type(output.method)
|
||||
return_types = [format_type(return_type) for return_type in return_types]
|
||||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
|
|
@ -376,8 +377,8 @@ def build_custom_component_template_from_inputs(
|
|||
frontend_node.validate_component()
|
||||
# ! This should be removed when we have a better way to handle this
|
||||
frontend_node.set_base_classes_from_outputs()
|
||||
reorder_fields(frontend_node, custom_component._get_field_order())
|
||||
cc_instance = get_component_instance(custom_component, user_id=user_id)
|
||||
reorder_fields(frontend_node, cc_instance._get_field_order())
|
||||
|
||||
return frontend_node.to_dict(add_name=False), cc_instance
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4037,7 +4037,8 @@
|
|||
"name": "api_run_model",
|
||||
"selected": "Data",
|
||||
"types": [
|
||||
"Data"
|
||||
"Data",
|
||||
"list"
|
||||
],
|
||||
"value": "__UNDEFINED__"
|
||||
},
|
||||
|
|
@ -4048,7 +4049,8 @@
|
|||
"name": "api_build_tool",
|
||||
"selected": "Tool",
|
||||
"types": [
|
||||
"Tool"
|
||||
"Tool",
|
||||
"Sequence"
|
||||
],
|
||||
"value": "__UNDEFINED__"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2615,7 +2615,8 @@
|
|||
"name": "api_run_model",
|
||||
"selected": "Data",
|
||||
"types": [
|
||||
"Data"
|
||||
"Data",
|
||||
"list"
|
||||
],
|
||||
"value": "__UNDEFINED__"
|
||||
},
|
||||
|
|
@ -2626,7 +2627,8 @@
|
|||
"name": "api_build_tool",
|
||||
"selected": "Tool",
|
||||
"types": [
|
||||
"Tool"
|
||||
"Tool",
|
||||
"Sequence"
|
||||
],
|
||||
"value": "__UNDEFINED__"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2953,7 +2953,8 @@
|
|||
"name": "api_run_model",
|
||||
"selected": "Data",
|
||||
"types": [
|
||||
"Data"
|
||||
"Data",
|
||||
"list"
|
||||
],
|
||||
"value": "__UNDEFINED__"
|
||||
},
|
||||
|
|
@ -2964,7 +2965,8 @@
|
|||
"name": "api_build_tool",
|
||||
"selected": "Tool",
|
||||
"types": [
|
||||
"Tool"
|
||||
"Tool",
|
||||
"Sequence"
|
||||
],
|
||||
"value": "__UNDEFINED__"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,11 +33,11 @@ async def instantiate_class(
|
|||
custom_params = get_params(vertex.params)
|
||||
code = custom_params.pop("code")
|
||||
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(code)
|
||||
custom_component: "CustomComponent" | "Component" = class_object.initialize(
|
||||
user_id=user_id,
|
||||
parameters=custom_params,
|
||||
vertex=vertex,
|
||||
tracing_service=get_tracing_service(),
|
||||
custom_component: "CustomComponent" | "Component" = class_object(
|
||||
_user_id=user_id,
|
||||
_parameters=custom_params,
|
||||
_vertex=vertex,
|
||||
_tracing_service=get_tracing_service(),
|
||||
)
|
||||
return custom_component, custom_params
|
||||
|
||||
|
|
@ -186,9 +186,9 @@ async def build_custom_component(params: dict, custom_component: "CustomComponen
|
|||
raw = post_process_raw(raw, artifact_type)
|
||||
artifact = {"repr": custom_repr, "raw": raw, "type": artifact_type}
|
||||
|
||||
if custom_component.vertex is not None:
|
||||
custom_component._artifacts = {custom_component.vertex.outputs[0].get("name"): artifact}
|
||||
custom_component._results = {custom_component.vertex.outputs[0].get("name"): build_result}
|
||||
if custom_component._vertex is not None:
|
||||
custom_component._artifacts = {custom_component._vertex.outputs[0].get("name"): artifact}
|
||||
custom_component._results = {custom_component._vertex.outputs[0].get("name"): build_result}
|
||||
return custom_component, build_result, artifact
|
||||
|
||||
raise ValueError("Custom component does not have a vertex")
|
||||
|
|
|
|||
|
|
@ -194,8 +194,8 @@ class TracingService(Service):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
trace_id = trace_name
|
||||
if component.vertex:
|
||||
trace_id = component.vertex.id
|
||||
if component._vertex:
|
||||
trace_id = component._vertex.id
|
||||
trace_type = component.trace_type
|
||||
self._start_traces(
|
||||
trace_id,
|
||||
|
|
@ -203,7 +203,7 @@ class TracingService(Service):
|
|||
trace_type,
|
||||
self._cleanup_inputs(inputs),
|
||||
metadata,
|
||||
component.vertex,
|
||||
component._vertex,
|
||||
)
|
||||
try:
|
||||
yield self
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias # type: ignore
|
||||
from typing import GenericAlias # type: ignore
|
||||
from typing import _GenericAlias # type: ignore
|
||||
from typing import _UnionGenericAlias # type: ignore
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
|
||||
|
||||
|
|
@ -182,6 +185,8 @@ class Output(BaseModel):
|
|||
|
||||
def add_types(self, _type: list[Any]):
|
||||
for type_ in _type:
|
||||
if self.types and type_ in self.types:
|
||||
continue
|
||||
if self.types is None:
|
||||
self.types = []
|
||||
self.types.append(type_)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue