From 04de488edee64d4eb2784ab9622049aca2215563 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 27 Feb 2024 11:37:44 -0300 Subject: [PATCH] Refactor vertex class and update build process --- src/backend/langflow/graph/vertex/base.py | 59 ++++++++++++++++++---- src/backend/langflow/graph/vertex/types.py | 15 +----- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 874d4bd21..e87389c7b 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -6,7 +6,12 @@ from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional from loguru import logger -from langflow.graph.schema import InterfaceComponentTypes +from langflow.graph.schema import ( + INPUT_COMPONENTS, + OUTPUT_COMPONENTS, + InterfaceComponentTypes, + ResultData, +) from langflow.graph.utils import UnbuiltObject, UnbuiltResult from langflow.graph.vertex.utils import generate_result from langflow.interface.initialize import loading @@ -16,7 +21,6 @@ from langflow.utils.constants import DIRECT_TYPES from langflow.utils.util import sync_to_async if TYPE_CHECKING: - from langflow.api.v1.schemas import ResultData from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.base import Graph @@ -40,11 +44,19 @@ class Vertex: ) -> None: # is_external means that the Vertex send or receives data from # an external source (e.g the chat) + + self.id: str = data["id"] + self.is_input = any( + input_component_name in self.id for input_component_name in INPUT_COMPONENTS + ) + self.is_output = any( + output_component_name in self.id + for output_component_name in OUTPUT_COMPONENTS + ) self._custom_component = None self.has_external_input = False self.has_external_output = False self.graph = graph - self.id: str = data["id"] self._data = data self.base_type: Optional[str] = base_type self._parse_data() @@ -61,7 +73,7 @@ class Vertex: self.parent_is_top_level = False self.layer = None self.should_run = True - self.result: Optional["ResultData"] = None + self.result: Optional[ResultData] = None try: self.is_interface_component = InterfaceComponentTypes(self.vertex_type) except ValueError: @@ -116,7 +128,7 @@ class Vertex: ) return edge_results - def set_result(self, result: "ResultData") -> None: + def set_result(self, result: ResultData) -> None: self.result = result def get_built_result(self): @@ -203,6 +215,8 @@ class Vertex: self.display_name = self.data["node"]["display_name"] self.pinned = self.data["node"].get("pinned", False) self.selected_output_type = self.data["node"].get("selected_output_type") + self.is_input = self.data["node"].get("is_input") or self.is_input + self.is_output = self.data["node"].get("is_output") or self.is_output template_dicts = { key: value for key, value in self.data["node"]["template"].items() @@ -359,6 +373,21 @@ class Vertex: self.params = params self._raw_params = params.copy() + def update_raw_params(self, new_params: Dict[str, str]): + """ + Update the raw parameters of the vertex with the given new parameters. + + Args: + new_params (Dict[str, Any]): The new parameters to update. + + Raises: + ValueError: If any key in new_params is not found in self._raw_params. + """ + for key in new_params: + if key not in self._raw_params: + raise ValueError(f"Key {key} not found in raw params") + self._raw_params.update(new_params) + async def _build(self, user_id=None): """ Initiate the build process. @@ -370,6 +399,18 @@ class Vertex: self._built = True + def _finalize_build(self): + result_dict = self.get_built_result() + # We need to set the artifacts to pass information + # to the frontend + self.set_artifacts() + artifacts = self.artifacts + result_dict = ResultData( + results=result_dict, + artifacts=artifacts, + ) + self.set_result(result_dict) + async def _run( self, user_id: str, @@ -501,17 +542,13 @@ class Vertex: if self.base_type is None: raise ValueError(f"Base type for node {self.display_name} not found") try: - outgoing_edges = self.graph.get_vertex_edges( - self.id, is_source=True, is_target=False - ) result = await loading.instantiate_class( node_type=self.vertex_type, base_type=self.base_type, params=self.params, user_id=user_id, - outgoing_edges=outgoing_edges, - selected_output_type=self.selected_output_type, + vertex=self, ) self._update_built_object_and_artifacts(result) except Exception as exc: @@ -584,6 +621,8 @@ class Vertex: step(user_id=user_id, **kwargs) self.steps_ran.append(step) + self._finalize_build() + return await self.get_requester_result(requester) async def get_requester_result(self, requester: Optional["Vertex"]): diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index b7746b6a1..45bcb9ccd 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,10 +1,10 @@ import ast import json -from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union +from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional, + Union) import yaml from langchain_core.messages import AIMessage -from loguru import logger from langflow.graph.utils import UnbuiltObject, flatten_list from langflow.graph.vertex.base import StatefulVertex, StatelessVertex @@ -344,17 +344,6 @@ class ChatVertex(StatelessVertex): def build_stream_url(self): return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream" - async def _build(self, user_id=None): - """ - Initiate the build process. - """ - logger.debug(f"Building {self.vertex_type}") - await self._build_each_node_in_params_dict(user_id) - await self._get_and_instantiate_class(user_id) - self._validate_built_object() - - self._built = True - def _built_object_repr(self): if self.task_id and self.is_task: if task := self.get_task():