refactor: vertex instantiation (#2703)
* style: handle whitespaces around colons * refactor: split a func into two * refactor: update code to use newly created funcs * refactor: merge code of one func into another * refactor: rename func * refactor: extract code segment to parent func * refactor: extract code segment to parent func * refactor: rename func * refactor: rename object * refactor: extract code segment into a new func * feat: add condition to determine how the vertex is built * fix: modify component initialization call
This commit is contained in:
parent
0ac67a9e89
commit
e318694366
2 changed files with 44 additions and 41 deletions
|
|
@ -432,7 +432,18 @@ class Vertex:
|
|||
"""
|
||||
logger.debug(f"Building {self.display_name}")
|
||||
await self._build_each_vertex_in_params_dict(user_id)
|
||||
await self._get_and_instantiate_class(user_id, fallback_to_env_vars)
|
||||
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for vertex {self.display_name} not found")
|
||||
|
||||
if not self._custom_component:
|
||||
custom_component, custom_params = await loading.instantiate_class(user_id=user_id, vertex=self)
|
||||
else:
|
||||
custom_component = self._custom_component
|
||||
custom_params = loading.get_params(self.params)
|
||||
|
||||
await self._build_results(custom_component, custom_params, fallback_to_env_vars)
|
||||
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
|
@ -617,7 +628,7 @@ class Vertex:
|
|||
logger.exception(e)
|
||||
raise ValueError(
|
||||
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}"
|
||||
f"Error building Component {self.display_name}:\n\n{str(e)}"
|
||||
f"Error building Component {self.display_name}: \n\n{str(e)}"
|
||||
) from e
|
||||
|
||||
def _handle_func(self, key, result):
|
||||
|
|
@ -642,25 +653,23 @@ class Vertex:
|
|||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
async def _get_and_instantiate_class(self, user_id=None, fallback_to_env_vars=False):
|
||||
"""
|
||||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for vertex {self.display_name} not found")
|
||||
async def _build_results(self, custom_component, custom_params, fallback_to_env_vars=False):
|
||||
try:
|
||||
result = await loading.instantiate_class(
|
||||
user_id=user_id,
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
result = await loading.get_instance_results(
|
||||
custom_component=custom_component,
|
||||
custom_params=custom_params,
|
||||
vertex=self,
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
base_type=self.base_type,
|
||||
)
|
||||
|
||||
self.outputs_logs = build_output_logs(self, result)
|
||||
|
||||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
tb = traceback.format_exc()
|
||||
logger.exception(exc)
|
||||
raise ComponentBuildException(f"Error building Component {self.display_name}:\n\n{exc}", tb) from exc
|
||||
raise ComponentBuildException(f"Error building Component {self.display_name}: \n\n{exc}", tb) from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple["Component", Any, dict]):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,68 +15,60 @@ from langflow.services.deps import get_tracing_service
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.tracing.service import TracingService
|
||||
|
||||
|
||||
async def instantiate_class(
|
||||
vertex: "Vertex",
|
||||
fallback_to_env_vars,
|
||||
user_id=None,
|
||||
) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
|
||||
vertex_type = vertex.vertex_type
|
||||
base_type = vertex.base_type
|
||||
params = vertex.params
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
logger.debug(f"Instantiating {vertex_type} of type {base_type}")
|
||||
|
||||
if not base_type:
|
||||
raise ValueError("No base type provided for vertex")
|
||||
|
||||
custom_component, build_results, artifacts = await build_component_and_get_results(
|
||||
params=params,
|
||||
vertex=vertex,
|
||||
custom_params = get_params(vertex.params)
|
||||
code = custom_params.pop("code")
|
||||
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(code)
|
||||
custom_component: "CustomComponent" | "Component" = class_object.initialize(
|
||||
user_id=user_id,
|
||||
parameters=custom_params,
|
||||
vertex=vertex,
|
||||
tracing_service=get_tracing_service(),
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
base_type=base_type,
|
||||
)
|
||||
return custom_component, build_results, artifacts
|
||||
return custom_component, custom_params
|
||||
|
||||
|
||||
async def build_component_and_get_results(
|
||||
params: dict,
|
||||
async def get_instance_results(
|
||||
custom_component,
|
||||
custom_params: dict,
|
||||
vertex: "Vertex",
|
||||
user_id: str,
|
||||
tracing_service: "TracingService",
|
||||
fallback_to_env_vars: bool = False,
|
||||
base_type: str = "component",
|
||||
):
|
||||
params_copy = params.copy()
|
||||
# Remove code from params
|
||||
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component: "CustomComponent" | "Component" = class_object.initialize(
|
||||
user_id=user_id,
|
||||
parameters=params_copy,
|
||||
vertex=vertex,
|
||||
tracing_service=tracing_service,
|
||||
custom_params = update_params_with_load_from_db_fields(
|
||||
custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars
|
||||
)
|
||||
params_copy = update_params_with_load_from_db_fields(
|
||||
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars
|
||||
)
|
||||
custom_component.set_parameters(params_copy)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
if base_type == "custom_components" and isinstance(custom_component, CustomComponent):
|
||||
return await build_custom_component(params=params_copy, custom_component=custom_component)
|
||||
return await build_custom_component(params=custom_params, custom_component=custom_component)
|
||||
elif base_type == "component" and isinstance(custom_component, Component):
|
||||
return await build_component(custom_component=custom_component)
|
||||
return await build_component(params=custom_params, custom_component=custom_component)
|
||||
else:
|
||||
raise ValueError(f"Base type {base_type} not found.")
|
||||
|
||||
|
||||
def get_params(vertex_params):
|
||||
params = vertex_params
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
return params.copy()
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
"""Convert certain params to sets"""
|
||||
if "allowed_special" in params:
|
||||
|
|
@ -147,9 +139,11 @@ def update_params_with_load_from_db_fields(
|
|||
|
||||
|
||||
async def build_component(
|
||||
params: dict,
|
||||
custom_component: "Component",
|
||||
):
|
||||
# Now set the params as attributes of the custom_component
|
||||
custom_component.set_attributes(params)
|
||||
build_results, artifacts = await custom_component.build_results()
|
||||
|
||||
return custom_component, build_results, artifacts
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue