refactor: vertex instantiation (#2703)

* style: handle whitespaces around colons

* refactor: split a func into two

* refactor: update code to use newly created funcs

* refactor: merge code of one func into another

* refactor: rename func

* refactor: extract code segment to parent func

* refactor: extract code segment to parent func

* refactor: rename func

* refactor: rename object

* refactor: extract code segment into a new func

* feat: add condition to determine how the vertex is built

* fix: modify component initialization call
This commit is contained in:
Ítalo Johnny 2024-07-24 15:28:53 -03:00 committed by GitHub
commit e318694366
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 44 additions and 41 deletions

View file

@ -432,7 +432,18 @@ class Vertex:
"""
logger.debug(f"Building {self.display_name}")
await self._build_each_vertex_in_params_dict(user_id)
await self._get_and_instantiate_class(user_id, fallback_to_env_vars)
if self.base_type is None:
raise ValueError(f"Base type for vertex {self.display_name} not found")
if not self._custom_component:
custom_component, custom_params = await loading.instantiate_class(user_id=user_id, vertex=self)
else:
custom_component = self._custom_component
custom_params = loading.get_params(self.params)
await self._build_results(custom_component, custom_params, fallback_to_env_vars)
self._validate_built_object()
self._built = True
@ -617,7 +628,7 @@ class Vertex:
logger.exception(e)
raise ValueError(
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}"
f"Error building Component {self.display_name}:\n\n{str(e)}"
f"Error building Component {self.display_name}: \n\n{str(e)}"
) from e
def _handle_func(self, key, result):
@ -642,25 +653,23 @@ class Vertex:
if isinstance(self.params[key], list):
self.params[key].extend(result)
async def _get_and_instantiate_class(self, user_id=None, fallback_to_env_vars=False):
"""
Gets the class from a dictionary and instantiates it with the params.
"""
if self.base_type is None:
raise ValueError(f"Base type for vertex {self.display_name} not found")
async def _build_results(self, custom_component, custom_params, fallback_to_env_vars=False):
try:
result = await loading.instantiate_class(
user_id=user_id,
fallback_to_env_vars=fallback_to_env_vars,
result = await loading.get_instance_results(
custom_component=custom_component,
custom_params=custom_params,
vertex=self,
fallback_to_env_vars=fallback_to_env_vars,
base_type=self.base_type,
)
self.outputs_logs = build_output_logs(self, result)
self._update_built_object_and_artifacts(result)
except Exception as exc:
tb = traceback.format_exc()
logger.exception(exc)
raise ComponentBuildException(f"Error building Component {self.display_name}:\n\n{exc}", tb) from exc
raise ComponentBuildException(f"Error building Component {self.display_name}: \n\n{exc}", tb) from exc
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple["Component", Any, dict]):
"""

View file

@ -15,68 +15,60 @@ from langflow.services.deps import get_tracing_service
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
from langflow.services.tracing.service import TracingService
async def instantiate_class(
vertex: "Vertex",
fallback_to_env_vars,
user_id=None,
) -> Any:
"""Instantiate class from module type and key, and params"""
vertex_type = vertex.vertex_type
base_type = vertex.base_type
params = vertex.params
params = convert_params_to_sets(params)
params = convert_kwargs(params)
logger.debug(f"Instantiating {vertex_type} of type {base_type}")
if not base_type:
raise ValueError("No base type provided for vertex")
custom_component, build_results, artifacts = await build_component_and_get_results(
params=params,
vertex=vertex,
custom_params = get_params(vertex.params)
code = custom_params.pop("code")
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(code)
custom_component: "CustomComponent" | "Component" = class_object.initialize(
user_id=user_id,
parameters=custom_params,
vertex=vertex,
tracing_service=get_tracing_service(),
fallback_to_env_vars=fallback_to_env_vars,
base_type=base_type,
)
return custom_component, build_results, artifacts
return custom_component, custom_params
async def build_component_and_get_results(
params: dict,
async def get_instance_results(
custom_component,
custom_params: dict,
vertex: "Vertex",
user_id: str,
tracing_service: "TracingService",
fallback_to_env_vars: bool = False,
base_type: str = "component",
):
params_copy = params.copy()
# Remove code from params
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(params_copy.pop("code"))
custom_component: "CustomComponent" | "Component" = class_object.initialize(
user_id=user_id,
parameters=params_copy,
vertex=vertex,
tracing_service=tracing_service,
custom_params = update_params_with_load_from_db_fields(
custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars
)
params_copy = update_params_with_load_from_db_fields(
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars
)
custom_component.set_parameters(params_copy)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
if base_type == "custom_components" and isinstance(custom_component, CustomComponent):
return await build_custom_component(params=params_copy, custom_component=custom_component)
return await build_custom_component(params=custom_params, custom_component=custom_component)
elif base_type == "component" and isinstance(custom_component, Component):
return await build_component(custom_component=custom_component)
return await build_component(params=custom_params, custom_component=custom_component)
else:
raise ValueError(f"Base type {base_type} not found.")
def get_params(vertex_params):
params = vertex_params
params = convert_params_to_sets(params)
params = convert_kwargs(params)
return params.copy()
def convert_params_to_sets(params):
"""Convert certain params to sets"""
if "allowed_special" in params:
@ -147,9 +139,11 @@ def update_params_with_load_from_db_fields(
async def build_component(
params: dict,
custom_component: "Component",
):
# Now set the params as attributes of the custom_component
custom_component.set_attributes(params)
build_results, artifacts = await custom_component.build_results()
return custom_component, build_results, artifacts