From 20d9e51208ffdbee344f26471275c5c543cb3041 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 26 Jan 2024 17:06:16 -0300 Subject: [PATCH] Change honor method to be asynchronous in ContractEdge --- src/backend/langflow/graph/edge/base.py | 6 ++--- src/backend/langflow/graph/graph/base.py | 28 ++++++++++++---------- src/backend/langflow/graph/vertex/types.py | 27 +++++++++++++++------ 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 74dd7c04d..5066f4214 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -105,7 +105,7 @@ class ContractEdge(Edge): self.is_fulfilled = False # Whether the contract has been fulfilled. self.result: Any = None - def honor(self, source: "Vertex", target: "Vertex") -> None: + async def honor(self, source: "Vertex", target: "Vertex") -> None: """ Fulfills the contract by setting the result of the source vertex to the target vertex's parameter. If the edge is runnable, the source vertex is run with the message text and the target vertex's @@ -117,7 +117,7 @@ class ContractEdge(Edge): return if not source._built: - source.build() + await source.build() if self.matched_type == "Text": self.result = source._built_result @@ -144,7 +144,7 @@ class ContractEdge(Edge): async def get_result(self, source: "Vertex", target: "Vertex"): # Fulfill the contract if it has not been fulfilled. if not self.is_fulfilled: - self.honor(source, target) + await self.honor(source, target) log_transaction(self, source, target, "success") # If the target vertex is a power component we log messages diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 17c402c51..288d77d0f 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -4,7 +4,7 @@ from typing import Dict, Generator, List, Type, Union from langchain.chains.base import Chain from loguru import logger -from langflow.graph.edge.base import ContractEdge, Edge +from langflow.graph.edge.base import ContractEdge from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.utils import process_flow from langflow.graph.vertex.base import Vertex @@ -33,6 +33,7 @@ class Graph: self._vertices = self._graph_data["nodes"] self._edges = self._graph_data["edges"] + self._build_graph() def __getstate__(self): @@ -111,7 +112,7 @@ class Graph: """Returns a vertex by id.""" return self.vertex_map.get(vertex_id) - def get_vertex_edges(self, vertex_id: str) -> List[Union[Edge, ContractEdge]]: + def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]: """Returns a list of edges for a given vertex.""" return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] @@ -210,18 +211,19 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class(self, vertex_type: str, vertex_base_type: str) -> Type[Vertex]: - """Returns the vertex class based on the vertex type.""" - if vertex_type in FILE_TOOLS: - return FileToolVertex - if vertex_base_type == "CustomComponent": - return lazy_load_vertex_dict.get_custom_component_vertex_type() + def _get_vertex_class(self, node_type: str, node_lc_type: str, node_id: str) -> Type[Vertex]: + """Returns the node class based on the node type.""" + node_name = node_id.split("-")[0] + if node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP: + return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_name] - if vertex_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: - return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_base_type] + if node_type in FILE_TOOLS: + return FileToolVertex + if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: + return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type] return ( - lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type] - if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP + lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type] + if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP else Vertex ) @@ -233,7 +235,7 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type) + VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 178d460bd..f829e5db9 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,6 +1,8 @@ import ast from typing import Callable, Dict, List, Optional, Union +from langchain_core.messages import AIMessage + from langflow.graph.utils import UnbuiltObject, flatten_list from langflow.graph.vertex.base import StatefulVertex, StatelessVertex from langflow.interface.utils import extract_input_variables_from_prompt @@ -318,17 +320,28 @@ class ChatVertex(StatelessVertex): if self.artifacts and "repr" in self.artifacts: return self.artifacts["repr"] or super()._built_object_repr() - def _run(self, *args, **kwargs): + async def _run(self, *args, **kwargs): if self.is_power_component: if self.vertex_type == "ChatOutput": + artifacts = None sender = self.params.get("sender", None) sender_name = self.params.get("sender_name", None) - self.artifacts = ChatOutputResponse( - message=str(self._built_object), - sender=sender, - sender_name=sender_name, - ).model_dump() + message = "" + if isinstance(self._built_object, AIMessage): + artifacts = ChatOutputResponse.from_message( + self._built_object, + sender=sender, + sender_name=sender_name, + ) + elif not isinstance(self._built_object, UnbuiltObject): + artifacts = ChatOutputResponse( + message=message, + sender=sender, + sender_name=sender_name, + ) + if artifacts: + self.artifacts = artifacts.model_dump() self._built_result = self._built_object else: - super()._run(*args, **kwargs) + await super()._run(*args, **kwargs)