refactor(CustomComponent): make initialization separate from constructor (#2704)
This commit is contained in:
parent
d93382e90a
commit
c8d04af533
4 changed files with 26 additions and 10 deletions
|
|
@ -40,14 +40,16 @@ class Component(CustomComponent):
|
|||
code_class_base_inheritance: ClassVar[str] = "Component"
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._inputs: dict[str, InputTypes] = {}
|
||||
self._results: dict[str, Any] = {}
|
||||
self._attributes: dict[str, Any] = {}
|
||||
self._parameters: dict[str, Any] = {}
|
||||
super().__init__(**data)
|
||||
if not hasattr(self, "trace_type"):
|
||||
self.trace_type = "chain"
|
||||
if self.inputs is not None:
|
||||
self.map_inputs(self.inputs)
|
||||
self.set_attributes(self._parameters)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from langflow.schema.artifact import get_artifact_type
|
|||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.schema.log import LoggableType
|
||||
from langflow.schema.schema import OutputLog
|
||||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.deps import get_storage_service, get_tracing_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.services.tracing.schema import Log
|
||||
from langflow.template.utils import update_frontend_node_with_template_values
|
||||
|
|
@ -88,6 +88,21 @@ class CustomComponent(BaseComponent):
|
|||
_logs: List[Log] = []
|
||||
tracing_service: Optional["TracingService"] = None
|
||||
|
||||
def set_attributes(self, parameters: dict):
|
||||
pass
|
||||
|
||||
def set_parameters(self, parameters: dict):
|
||||
self._parameters = parameters
|
||||
self.set_attributes(self._parameters)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, **kwargs):
|
||||
user_id = kwargs.pop("user_id", None)
|
||||
vertex = kwargs.pop("vertex", None)
|
||||
tracing_service = kwargs.pop("tracing_service", get_tracing_service())
|
||||
params_copy = kwargs.copy()
|
||||
return cls(user_id=user_id, _parameters=params_copy, vertex=vertex, tracing_service=tracing_service)
|
||||
|
||||
@property
|
||||
def trace_name(self):
|
||||
return f"{self.display_name} ({self.vertex.id})"
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def build_component_and_get_results(
|
|||
params_copy = params.copy()
|
||||
# Remove code from params
|
||||
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component: "CustomComponent" | "Component" = class_object(
|
||||
custom_component: "CustomComponent" | "Component" = class_object.initialize(
|
||||
user_id=user_id,
|
||||
parameters=params_copy,
|
||||
vertex=vertex,
|
||||
|
|
@ -66,12 +66,13 @@ async def build_component_and_get_results(
|
|||
params_copy = update_params_with_load_from_db_fields(
|
||||
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars
|
||||
)
|
||||
custom_component.set_parameters(params_copy)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
if base_type == "custom_components" and isinstance(custom_component, CustomComponent):
|
||||
return await build_custom_component(params=params_copy, custom_component=custom_component)
|
||||
elif base_type == "component" and isinstance(custom_component, Component):
|
||||
return await build_component(params=params_copy, custom_component=custom_component)
|
||||
return await build_component(custom_component=custom_component)
|
||||
else:
|
||||
raise ValueError(f"Base type {base_type} not found.")
|
||||
|
||||
|
|
@ -146,11 +147,9 @@ def update_params_with_load_from_db_fields(
|
|||
|
||||
|
||||
async def build_component(
|
||||
params: dict,
|
||||
custom_component: "Component",
|
||||
):
|
||||
# Now set the params as attributes of the custom_component
|
||||
custom_component.set_attributes(params)
|
||||
build_results, artifacts = await custom_component.build_results()
|
||||
|
||||
return custom_component, build_results, artifacts
|
||||
|
|
|
|||
|
|
@ -250,17 +250,17 @@ def build_class_constructor(compiled_class, exec_globals, class_name):
|
|||
exec_globals[class_name] = locals()[class_name]
|
||||
|
||||
# Return a function that imports necessary modules and creates an instance of the target class
|
||||
def build_custom_class(*args, **kwargs):
|
||||
def build_custom_class():
|
||||
for module_name, module in exec_globals.items():
|
||||
if isinstance(module, type(importlib)):
|
||||
globals()[module_name] = module
|
||||
|
||||
instance = exec_globals[class_name](*args, **kwargs)
|
||||
exec_globals[class_name]
|
||||
|
||||
return instance
|
||||
return exec_globals[class_name]
|
||||
|
||||
build_custom_class.__globals__.update(exec_globals)
|
||||
return build_custom_class
|
||||
return build_custom_class()
|
||||
|
||||
|
||||
def get_default_imports(code_string):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue