refactor: update CustomComponent constructor (#3114)

* refactor: update code references to use _code instead of code

* refactor: add backwards compatible attributes to Component class

* refactor: update Component constructor to pass config params with underscore

Refactored the `Component` class in `component.py` to handle inputs and outputs. Added a new method `map_outputs` to map a list of outputs to the component. Also updated the `__init__` method to properly initialize the inputs, outputs, and other attributes. This change improves the flexibility and extensibility of the `Component` class.

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* refactor: change attribute to use underscore

* refactor: update CustomComponent initialization parameters

Refactored the `instantiate_class` function in `loading.py` to update the initialization parameters for the `CustomComponent` class. Changed the parameter names from `user_id`, `parameters`, `vertex`, and `tracing_service` to `_user_id`, `_parameters`, `_vertex`, and `_tracing_service` respectively. This change ensures consistency and improves code readability.

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* refactor: update BaseComponent to accept UUID for _user_id

Updated the `BaseComponent` class in `base_component.py` to accept a `UUID` type for the `_user_id` attribute. This change improves the type safety and ensures consistency with the usage of `_user_id` throughout the codebase.

* refactor: import nanoid with type annotation

The `nanoid` import in `component.py` has been updated to include a type annotation `# type: ignore`. This change ensures that the type checker ignores any errors related to the `nanoid` import.

* fix(custom_component.py): convert _user_id to string before passing to functions to ensure compatibility with function signatures

* refactor(utils.py): refactor code to use _user_id instead of user_id for consistency and clarity

perf(utils.py): optimize code by reusing cc_instance instead of calling get_component_instance multiple times

* [autofix.ci] apply automated fixes

* refactor: update BaseComponent to use get_template_config method

Refactored the `BaseComponent` class in `base_component.py` to use the `get_template_config` method instead of duplicating the code. This change improves code readability and reduces redundancy.

* refactor: update build_custom_component_template to use add_name instead of keep_name

Refactor the `build_custom_component_template` function in `utils.py` to use the `add_name` parameter instead of the deprecated `keep_name` parameter. This change ensures consistency with the updated method signature and improves code clarity.

* feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components (#3115)

* feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components

* refactor: extract method to get method return type in CustomComponent

* refactor: update _extract_return_type method in CustomComponent to accept Any type

The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types.

* refactor: add _template_config property to BaseComponent

Add a new `_template_config` property to the `BaseComponent` class in `base_component.py`. This property is used to store the template configuration for the custom component. If the `_template_config` property is empty, it is populated by calling the `build_template_config` method. This change improves the efficiency of accessing the template configuration and ensures that it is only built when needed.

* refactor: add type checking for Output types in add_types method

Improve type checking in the `add_types` method of the `Output` class in `base.py`. Check if the `type_` already exists in the `types` list before adding it. This change ensures that duplicate types are not added to the list.

* update starter projects

* refactor: optimize imports in base.py

Optimize imports in the `base.py` file by removing unused imports and organizing the remaining imports. This change improves code readability and reduces unnecessary clutter.

* fix(base.py): fix condition to check if self.types is not None before checking if type_ is in self.types

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-07-31 19:28:42 -03:00 committed by GitHub
commit 62191d92ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 180 additions and 101 deletions

View file

@ -51,8 +51,8 @@ class BaseCrewComponent(Component):
self,
) -> Callable:
def task_callback(task_output: TaskOutput):
if self.vertex:
vertex_id = self.vertex.id
if self._vertex:
vertex_id = self._vertex.id
else:
vertex_id = self.display_name or self.__class__.__name__
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")
@ -63,7 +63,7 @@ class BaseCrewComponent(Component):
self,
) -> Callable:
def step_callback(agent_output: Union[AgentFinish, List[Tuple[AgentAction, str]]]):
_id = self.vertex.id if self.vertex else self.display_name
_id = self._vertex.id if self._vertex else self.display_name
if isinstance(agent_output, AgentFinish):
messages = agent_output.messages
self.log(cast(dict, messages[0].to_json()), name=f"Finish (Agent: {_id})")

View file

@ -1,7 +1,7 @@
import json
import warnings
from abc import abstractmethod
from typing import Optional, Union, List
from typing import List, Optional, Union
from langchain_core.language_models.llms import LLM
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
@ -10,7 +10,7 @@ from langflow.base.constants import STREAM_INFO_TEXT
from langflow.custom import Component
from langflow.field_typing import LanguageModel
from langflow.inputs import MessageInput, MessageTextInput
from langflow.inputs.inputs import InputTypes, BoolInput
from langflow.inputs.inputs import BoolInput, InputTypes
from langflow.schema.message import Message
from langflow.template.field.base import Output
@ -164,7 +164,7 @@ class LCModelComponent(Component):
inputs: Union[list, dict] = messages or {}
try:
runnable = runnable.with_config( # type: ignore
{"run_name": self.display_name, "project_name": self.tracing_service.project_name} # type: ignore
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
)
if stream:
return runnable.stream(inputs) # type: ignore

View file

@ -23,6 +23,6 @@ class ListenComponent(CustomComponent):
return state
def _set_successors_ids(self):
self.vertex.is_state = True
successors = self.vertex.graph.successor_map.get(self.vertex.id, [])
return successors + self.vertex.graph.activated_vertices
self._vertex.is_state = True
successors = self._vertex.graph.successor_map.get(self._vertex.id, [])
return successors + self._vertex.graph.activated_vertices

View file

@ -43,6 +43,6 @@ class NotifyComponent(CustomComponent):
return data
def _set_successors_ids(self):
self.vertex.is_state = True
successors = self.vertex.graph.successor_map.get(self.vertex.id, [])
return successors + self.vertex.graph.activated_vertices
self._vertex.is_state = True
successors = self._vertex.graph.successor_map.get(self._vertex.id, [])
return successors + self._vertex.graph.activated_vertices

View file

@ -1,6 +1,7 @@
import operator
import warnings
from typing import Any, ClassVar, Optional
from uuid import UUID
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
@ -27,7 +28,8 @@ class BaseComponent:
"""The code of the component. Defaults to None."""
_function_entrypoint_name: str = "build"
field_config: dict = {}
_user_id: Optional[str]
_user_id: Optional[str | UUID] = None
_template_config: dict = {}
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
@ -38,7 +40,7 @@ class BaseComponent:
setattr(self, key, value)
def __setattr__(self, key, value):
if key == "_user_id" and hasattr(self, "_user_id"):
if key == "_user_id" and hasattr(self, "_user_id") and getattr(self, "_user_id") is not None:
warnings.warn("user_id is immutable and cannot be changed.")
super().__setattr__(key, value)
@ -65,6 +67,25 @@ class BaseComponent:
return validate.create_function(self._code, self._function_entrypoint_name)
@staticmethod
def get_template_config(component):
"""
Gets the template configuration for the custom component itself.
"""
template_config = {}
for attribute, func in ATTR_FUNC_MAPPING.items():
if hasattr(component, attribute):
value = getattr(component, attribute)
if value is not None:
template_config[attribute] = func(value=value)
for key in template_config.copy():
if key not in ATTR_FUNC_MAPPING.keys():
template_config.pop(key, None)
return template_config
def build_template_config(self) -> dict:
"""
Builds the template configuration for the custom component.
@ -77,18 +98,7 @@ class BaseComponent:
cc_class = eval_custom_component_code(self._code)
component_instance = cc_class()
template_config = {}
for attribute, func in ATTR_FUNC_MAPPING.items():
if hasattr(component_instance, attribute):
value = getattr(component_instance, attribute)
if value is not None:
template_config[attribute] = func(value=value)
for key in template_config.copy():
if key not in ATTR_FUNC_MAPPING.keys():
template_config.pop(key, None)
template_config = self.get_template_config(component_instance)
return template_config
def build(self, *args: Any, **kwargs: Any) -> Any:

View file

@ -1,10 +1,12 @@
import inspect
from typing import Any, Callable, ClassVar, List, Optional, Union
from typing import Any, Callable, ClassVar, List, Optional, Union, get_type_hints
from uuid import UUID
import nanoid # type: ignore
import yaml
from pydantic import BaseModel
from langflow.helpers.custom import format_type
from langflow.inputs.inputs import InputTypes
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
@ -14,6 +16,8 @@ from langflow.template.field.base import UNDEFINED, Output
from .custom_component import CustomComponent
BACKWARDS_COMPATIBLE_ATTRIBUTES = ["user_id", "vertex", "tracing_service"]
class Component(CustomComponent):
inputs: List[InputTypes] = []
@ -21,24 +25,45 @@ class Component(CustomComponent):
code_class_base_inheritance: ClassVar[str] = "Component"
_output_logs: dict[str, Log] = {}
def __init__(self, **data):
def __init__(self, **kwargs):
# if key starts with _ it is a config
# else it is an input
inputs = {}
config = {}
for key, value in kwargs.items():
if key.startswith("_"):
config[key] = value
else:
inputs[key] = value
self._inputs: dict[str, InputTypes] = {}
self._outputs: dict[str, Output] = {}
self._results: dict[str, Any] = {}
self._attributes: dict[str, Any] = {}
self._parameters: dict[str, Any] = {}
self._parameters = inputs or {}
self._components: list[Component] = []
self.set_attributes(self._parameters)
self._output_logs = {}
super().__init__(**data)
config = config or {}
if "_id" not in config:
config |= {"_id": f"{self.__class__.__name__}-{nanoid.generate(size=5)}"}
super().__init__(**config)
if hasattr(self, "_trace_type"):
self.trace_type = self._trace_type
if not hasattr(self, "trace_type"):
self.trace_type = "chain"
if self.inputs is not None:
self.map_inputs(self.inputs)
self.set_attributes(self._parameters)
if self.outputs is not None:
self.map_outputs(self.outputs)
self._set_output_types()
def __getattr__(self, name: str) -> Any:
if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]:
return self.__dict__["_attributes"][name]
if "_inputs" in self.__dict__ and name in self.__dict__["_inputs"]:
return self.__dict__["_inputs"][name].value
if name in BACKWARDS_COMPATIBLE_ATTRIBUTES:
return self.__dict__[f"_{name}"]
raise AttributeError(f"{name} not found in {self.__class__.__name__}")
def map_inputs(self, inputs: List[InputTypes]):
@ -48,10 +73,47 @@ class Component(CustomComponent):
raise ValueError("Input name cannot be None.")
self._inputs[input_.name] = input_
def map_outputs(self, outputs: List[Output]):
"""
Maps the given list of outputs to the component.
Args:
outputs (List[Output]): The list of outputs to be mapped.
Raises:
ValueError: If the output name is None.
Returns:
None
"""
self.outputs = outputs
for output in outputs:
if output.name is None:
raise ValueError("Output name cannot be None.")
self._outputs[output.name] = output
def validate(self, params: dict):
self._validate_inputs(params)
self._validate_outputs()
def _set_output_types(self):
for output in self.outputs:
return_types = self._get_method_return_type(output.method)
output.add_types(return_types)
output.set_selected()
def _get_method_return_type(self, method_name: str) -> List[str]:
method = getattr(self, method_name)
return_type = get_type_hints(method)["return"]
extracted_return_types = self._extract_return_type(return_type)
return [format_type(extracted_return_type) for extracted_return_type in extracted_return_types]
def _get_output_by_method(self, method: Callable):
# method is a callable and output.method is a string
# we need to find the output that has the same method
output = next((output for output in self.outputs if output.method == method.__name__), None)
if output is None:
method_name = method.__name__ if hasattr(method, "__name__") else str(method)
raise ValueError(f"Output with method {method_name} not found")
return output
def _validate_outputs(self):
# Raise Error if some rule isn't met
pass
@ -106,9 +168,9 @@ class Component(CustomComponent):
async def _build_with_tracing(self):
inputs = self.get_trace_as_inputs()
metadata = self.get_trace_as_metadata()
async with self.tracing_service.trace_context(self, self.trace_name, inputs, metadata):
async with self._tracing_service.trace_context(self, self.trace_name, inputs, metadata):
_results, _artifacts = await self._build_results()
self.tracing_service.set_outputs(self.trace_name, _results)
self._tracing_service.set_outputs(self.trace_name, _results)
return _results, _artifacts
@ -116,7 +178,7 @@ class Component(CustomComponent):
return await self._build_results()
async def build_results(self):
if self.tracing_service:
if self._tracing_service:
return await self._build_with_tracing()
return await self._build_without_tracing()
@ -124,11 +186,11 @@ class Component(CustomComponent):
_results = {}
_artifacts = {}
if hasattr(self, "outputs"):
self._set_outputs(self.vertex.outputs)
self._set_outputs(self._vertex.outputs)
for output in self.outputs:
# Build the output if it's connected to some other vertex
# or if it's not connected to any vertex
if not self.vertex.outgoing_edges or output.name in self.vertex.edges_source_names:
if not self._vertex.outgoing_edges or output.name in self._vertex.edges_source_names:
if output.method is None:
raise ValueError(f"Output {output.name} does not have a method defined.")
method: Callable = getattr(self, output.method)
@ -142,9 +204,9 @@ class Component(CustomComponent):
if (
isinstance(result, Message)
and result.flow_id is None
and self.vertex.graph.flow_id is not None
and self._vertex.graph.flow_id is not None
):
result.set_flow_id(self.vertex.graph.flow_id)
result.set_flow_id(self._vertex.graph.flow_id)
_results[output.name] = result
output.value = result
custom_repr = self.custom_repr()
@ -176,8 +238,8 @@ class Component(CustomComponent):
self._logs = []
self._artifacts = _artifacts
self._results = _results
if self.tracing_service:
self.tracing_service.set_outputs(self.trace_name, _results)
if self._tracing_service:
self._tracing_service.set_outputs(self.trace_name, _results)
return _results, _artifacts
def custom_repr(self):

View file

@ -1,6 +1,5 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
from uuid import UUID
import yaml
from cachetools import TTLCache
@ -14,7 +13,7 @@ from langflow.schema.artifact import get_artifact_type
from langflow.schema.dotdict import dotdict
from langflow.schema.log import LoggableType
from langflow.schema.schema import OutputValue
from langflow.services.deps import get_storage_service, get_tracing_service, get_variable_service, session_scope
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.services.tracing.schema import Log
from langflow.template.utils import update_frontend_node_with_template_values
@ -72,20 +71,19 @@ class CustomComponent(BaseComponent):
"""The default frozen state of the component. Defaults to False."""
build_parameters: Optional[dict] = None
"""The build parameters of the component. Defaults to None."""
vertex: Optional["Vertex"] = None
_vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
user_id: Optional[Union[UUID, str]] = None
status: Optional[Any] = None
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_data: Optional[List[Data]] = None
_outputs: List[OutputValue] = []
_logs: List[Log] = []
_output_logs: dict[str, Log] = {}
tracing_service: Optional["TracingService"] = None
_tracing_service: Optional["TracingService"] = None
def set_attributes(self, parameters: dict):
pass
@ -94,51 +92,43 @@ class CustomComponent(BaseComponent):
self._parameters = parameters
self.set_attributes(self._parameters)
@classmethod
def initialize(cls, **kwargs):
user_id = kwargs.pop("user_id", None)
vertex = kwargs.pop("vertex", None)
tracing_service = kwargs.pop("tracing_service", get_tracing_service())
params_copy = kwargs.copy()
return cls(user_id=user_id, _parameters=params_copy, vertex=vertex, tracing_service=tracing_service)
@property
def trace_name(self):
return f"{self.display_name} ({self.vertex.id})"
return f"{self.display_name} ({self._vertex.id})"
def update_state(self, name: str, value: Any):
if not self.vertex:
if not self._vertex:
raise ValueError("Vertex is not set")
try:
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
self._vertex.graph.update_state(name=name, record=value, caller=self._vertex.id)
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def stop(self, output_name: str | None = None):
if not output_name and self.vertex and len(self.vertex.outputs) == 1:
output_name = self.vertex.outputs[0]["name"]
if not output_name and self._vertex and len(self._vertex.outputs) == 1:
output_name = self._vertex.outputs[0]["name"]
elif not output_name:
raise ValueError("You must specify an output name to call stop")
if not self.vertex:
if not self._vertex:
raise ValueError("Vertex is not set")
try:
self.graph.mark_branch(vertex_id=self.vertex.id, output_name=output_name, state="INACTIVE")
self.graph.mark_branch(vertex_id=self._vertex.id, output_name=output_name, state="INACTIVE")
except Exception as e:
raise ValueError(f"Error stopping {self.display_name}: {e}")
def append_state(self, name: str, value: Any):
if not self.vertex:
if not self._vertex:
raise ValueError("Vertex is not set")
try:
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
self._vertex.graph.append_state(name=name, record=value, caller=self._vertex.id)
except Exception as e:
raise ValueError(f"Error appending state: {e}")
def get_state(self, name: str):
if not self.vertex:
if not self._vertex:
raise ValueError("Vertex is not set")
try:
return self.vertex.graph.get_state(name=name)
return self._vertex.graph.get_state(name=name)
except Exception as e:
raise ValueError(f"Error getting state: {e}")
@ -176,7 +166,7 @@ class CustomComponent(BaseComponent):
@property
def graph(self):
return self.vertex.graph
return self._vertex.graph
def _get_field_order(self):
return self.field_order or list(self.field_config.keys())
@ -277,6 +267,14 @@ class CustomComponent(BaseComponent):
return data_objects
def get_method_return_type(self, method_name: str):
build_method = self.get_method(method_name)
if not build_method or not build_method.get("has_return"):
return []
return_type = build_method["return_type"]
return self._extract_return_type(return_type)
def create_references_from_data(self, data: List[Data], include_data: bool = False) -> str:
"""
Create references from a list of data.
@ -349,12 +347,7 @@ class CustomComponent(BaseComponent):
"""
return self.get_method_return_type(self.function_entrypoint_name)
def get_method_return_type(self, method_name: str):
build_method = self.get_method(method_name)
if not build_method or not build_method.get("has_return"):
return []
return_type = build_method["return_type"]
def _extract_return_type(self, return_type: Any):
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,
@ -401,7 +394,9 @@ class CustomComponent(BaseComponent):
Returns:
dict: The template configuration for the custom component.
"""
return self.build_template_config()
if not self._template_config:
self._template_config = self.build_template_config()
return self._template_config
@property
def variables(self):
@ -471,7 +466,7 @@ class CustomComponent(BaseComponent):
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
if not self._user_id:
raise ValueError("Session is invalid")
return await load_flow(user_id=self._user_id, flow_id=flow_id, tweaks=tweaks)
return await load_flow(user_id=str(self._user_id), flow_id=flow_id, tweaks=tweaks)
async def run_flow(
self,
@ -487,14 +482,14 @@ class CustomComponent(BaseComponent):
flow_id=flow_id,
flow_name=flow_name,
tweaks=tweaks,
user_id=self._user_id,
user_id=str(self._user_id),
)
def list_flows(self) -> List[Data]:
if not self._user_id:
raise ValueError("Session is invalid")
try:
return list_flows(user_id=self._user_id)
return list_flows(user_id=str(self._user_id))
except Exception as e:
raise ValueError(f"Error listing flows: {e}")
@ -522,8 +517,8 @@ class CustomComponent(BaseComponent):
name = f"Log {len(self._logs) + 1}"
log = Log(message=message, type=get_artifact_type(message), name=name)
self._logs.append(log)
if self.tracing_service and self.vertex:
self.tracing_service.add_log(trace_name=self.trace_name, log=log)
if self._tracing_service and self._vertex:
self._tracing_service.add_log(trace_name=self.trace_name, log=log)
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
"""

View file

@ -283,7 +283,7 @@ def get_component_instance(custom_component: CustomComponent, user_id: Optional[
) from exc
try:
custom_instance = custom_class(user_id=user_id)
custom_instance = custom_class(_user_id=user_id)
return custom_instance
except Exception as exc:
logger.error(f"Error while instantiating custom component: {str(exc)}")
@ -317,7 +317,7 @@ def run_build_config(
) from exc
try:
custom_instance = custom_class(user_id=user_id)
custom_instance = custom_class(_user_id=user_id)
build_config: Dict = custom_instance.build_config()
for field_name, field in build_config.copy().items():
@ -361,14 +361,15 @@ def build_custom_component_template_from_inputs(
custom_component: Union[Component, CustomComponent], user_id: Optional[Union[str, UUID]] = None
):
# The List of Inputs fills the role of the build_config and the entrypoint_args
field_config = custom_component.template_config
cc_instance = get_component_instance(custom_component, user_id=user_id)
field_config = cc_instance.get_template_config(cc_instance)
frontend_node = ComponentFrontendNode.from_inputs(**field_config)
frontend_node = add_code_field(frontend_node, custom_component._code, field_config.get("code", {}))
# But we now need to calculate the return_type of the methods in the outputs
for output in frontend_node.outputs:
if output.types:
continue
return_types = custom_component.get_method_return_type(output.method)
return_types = cc_instance.get_method_return_type(output.method)
return_types = [format_type(return_type) for return_type in return_types]
output.add_types(return_types)
output.set_selected()
@ -376,8 +377,8 @@ def build_custom_component_template_from_inputs(
frontend_node.validate_component()
# ! This should be removed when we have a better way to handle this
frontend_node.set_base_classes_from_outputs()
reorder_fields(frontend_node, custom_component._get_field_order())
cc_instance = get_component_instance(custom_component, user_id=user_id)
reorder_fields(frontend_node, cc_instance._get_field_order())
return frontend_node.to_dict(add_name=False), cc_instance

View file

@ -4037,7 +4037,8 @@
"name": "api_run_model",
"selected": "Data",
"types": [
"Data"
"Data",
"list"
],
"value": "__UNDEFINED__"
},
@ -4048,7 +4049,8 @@
"name": "api_build_tool",
"selected": "Tool",
"types": [
"Tool"
"Tool",
"Sequence"
],
"value": "__UNDEFINED__"
}

View file

@ -2615,7 +2615,8 @@
"name": "api_run_model",
"selected": "Data",
"types": [
"Data"
"Data",
"list"
],
"value": "__UNDEFINED__"
},
@ -2626,7 +2627,8 @@
"name": "api_build_tool",
"selected": "Tool",
"types": [
"Tool"
"Tool",
"Sequence"
],
"value": "__UNDEFINED__"
}

View file

@ -2953,7 +2953,8 @@
"name": "api_run_model",
"selected": "Data",
"types": [
"Data"
"Data",
"list"
],
"value": "__UNDEFINED__"
},
@ -2964,7 +2965,8 @@
"name": "api_build_tool",
"selected": "Tool",
"types": [
"Tool"
"Tool",
"Sequence"
],
"value": "__UNDEFINED__"
}

View file

@ -33,11 +33,11 @@ async def instantiate_class(
custom_params = get_params(vertex.params)
code = custom_params.pop("code")
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(code)
custom_component: "CustomComponent" | "Component" = class_object.initialize(
user_id=user_id,
parameters=custom_params,
vertex=vertex,
tracing_service=get_tracing_service(),
custom_component: "CustomComponent" | "Component" = class_object(
_user_id=user_id,
_parameters=custom_params,
_vertex=vertex,
_tracing_service=get_tracing_service(),
)
return custom_component, custom_params
@ -186,9 +186,9 @@ async def build_custom_component(params: dict, custom_component: "CustomComponen
raw = post_process_raw(raw, artifact_type)
artifact = {"repr": custom_repr, "raw": raw, "type": artifact_type}
if custom_component.vertex is not None:
custom_component._artifacts = {custom_component.vertex.outputs[0].get("name"): artifact}
custom_component._results = {custom_component.vertex.outputs[0].get("name"): build_result}
if custom_component._vertex is not None:
custom_component._artifacts = {custom_component._vertex.outputs[0].get("name"): artifact}
custom_component._results = {custom_component._vertex.outputs[0].get("name"): build_result}
return custom_component, build_result, artifact
raise ValueError("Custom component does not have a vertex")

View file

@ -194,8 +194,8 @@ class TracingService(Service):
metadata: Optional[Dict[str, Any]] = None,
):
trace_id = trace_name
if component.vertex:
trace_id = component.vertex.id
if component._vertex:
trace_id = component._vertex.id
trace_type = component.trace_type
self._start_traces(
trace_id,
@ -203,7 +203,7 @@ class TracingService(Service):
trace_type,
self._cleanup_inputs(inputs),
metadata,
component.vertex,
component._vertex,
)
try:
yield self

View file

@ -1,5 +1,8 @@
from enum import Enum
from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias # type: ignore
from typing import GenericAlias # type: ignore
from typing import _GenericAlias # type: ignore
from typing import _UnionGenericAlias # type: ignore
from typing import Any, Callable, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
@ -182,6 +185,8 @@ class Output(BaseModel):
def add_types(self, _type: list[Any]):
for type_ in _type:
if self.types and type_ in self.types:
continue
if self.types is None:
self.types = []
self.types.append(type_)