refactor(CustomComponent): make initialization separate from constructor (#2704)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-07-15 15:37:10 -03:00 committed by GitHub
commit c8d04af533
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 26 additions and 10 deletions

View file

@ -40,14 +40,16 @@ class Component(CustomComponent):
code_class_base_inheritance: ClassVar[str] = "Component"
def __init__(self, **data):
super().__init__(**data)
self._inputs: dict[str, InputTypes] = {}
self._results: dict[str, Any] = {}
self._attributes: dict[str, Any] = {}
self._parameters: dict[str, Any] = {}
super().__init__(**data)
if not hasattr(self, "trace_type"):
self.trace_type = "chain"
if self.inputs is not None:
self.map_inputs(self.inputs)
self.set_attributes(self._parameters)
def __getattr__(self, name: str) -> Any:
if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]:

View file

@ -14,7 +14,7 @@ from langflow.schema.artifact import get_artifact_type
from langflow.schema.dotdict import dotdict
from langflow.schema.log import LoggableType
from langflow.schema.schema import OutputLog
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.deps import get_storage_service, get_tracing_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.services.tracing.schema import Log
from langflow.template.utils import update_frontend_node_with_template_values
@ -88,6 +88,21 @@ class CustomComponent(BaseComponent):
_logs: List[Log] = []
tracing_service: Optional["TracingService"] = None
def set_attributes(self, parameters: dict):
pass
def set_parameters(self, parameters: dict):
self._parameters = parameters
self.set_attributes(self._parameters)
@classmethod
def initialize(cls, **kwargs):
user_id = kwargs.pop("user_id", None)
vertex = kwargs.pop("vertex", None)
tracing_service = kwargs.pop("tracing_service", get_tracing_service())
params_copy = kwargs.copy()
return cls(user_id=user_id, _parameters=params_copy, vertex=vertex, tracing_service=tracing_service)
@property
def trace_name(self):
return f"{self.display_name} ({self.vertex.id})"

View file

@ -57,7 +57,7 @@ async def build_component_and_get_results(
params_copy = params.copy()
# Remove code from params
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(params_copy.pop("code"))
custom_component: "CustomComponent" | "Component" = class_object(
custom_component: "CustomComponent" | "Component" = class_object.initialize(
user_id=user_id,
parameters=params_copy,
vertex=vertex,
@ -66,12 +66,13 @@ async def build_component_and_get_results(
params_copy = update_params_with_load_from_db_fields(
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars
)
custom_component.set_parameters(params_copy)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
if base_type == "custom_components" and isinstance(custom_component, CustomComponent):
return await build_custom_component(params=params_copy, custom_component=custom_component)
elif base_type == "component" and isinstance(custom_component, Component):
return await build_component(params=params_copy, custom_component=custom_component)
return await build_component(custom_component=custom_component)
else:
raise ValueError(f"Base type {base_type} not found.")
@ -146,11 +147,9 @@ def update_params_with_load_from_db_fields(
async def build_component(
params: dict,
custom_component: "Component",
):
# Now set the params as attributes of the custom_component
custom_component.set_attributes(params)
build_results, artifacts = await custom_component.build_results()
return custom_component, build_results, artifacts

View file

@ -250,17 +250,17 @@ def build_class_constructor(compiled_class, exec_globals, class_name):
exec_globals[class_name] = locals()[class_name]
# Return a function that imports necessary modules and creates an instance of the target class
def build_custom_class(*args, **kwargs):
def build_custom_class():
for module_name, module in exec_globals.items():
if isinstance(module, type(importlib)):
globals()[module_name] = module
instance = exec_globals[class_name](*args, **kwargs)
exec_globals[class_name]
return instance
return exec_globals[class_name]
build_custom_class.__globals__.update(exec_globals)
return build_custom_class
return build_custom_class()
def get_default_imports(code_string):