From decb741e1319bd0497b210494748254c163a3121 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 18 Jun 2024 20:47:58 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20(types.py):=20Fix=20get=5Fedge?= =?UTF-8?q?=5Fwith=5Ftarget=20method=20to=20return=20a=20Generator=20inste?= =?UTF-8?q?ad=20of=20a=20single=20ContractEdge=20instance?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🐛 (types.py): Fix _get_result method to handle multiple edges with the same target_id 📝 (types.py): Add comments to clarify the purpose of setting the result in the vertex of origin --- .../base/langflow/graph/vertex/types.py | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/src/backend/base/langflow/graph/vertex/types.py b/src/backend/base/langflow/graph/vertex/types.py index d6de9158f..4145a69e5 100644 --- a/src/backend/base/langflow/graph/vertex/types.py +++ b/src/backend/base/langflow/graph/vertex/types.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator, List import yaml from langchain_core.messages import AIMessage, AIMessageChunk @@ -12,6 +12,7 @@ from langflow.schema import Data from langflow.schema.artifact import ArtifactType from langflow.schema.schema import INPUT_FIELD_NAME from langflow.services.monitor.utils import log_transaction, log_vertex_build +from langflow.template.field.base import UNDEFINED from langflow.utils.schemas import ChatOutputResponse, DataOutputResponse from langflow.utils.util import unescape_string @@ -54,7 +55,7 @@ class ComponentVertex(Vertex): for key, value in self._built_object.items(): self.add_result(key, value) - def get_edge_with_target(self, target_id: str) -> "ContractEdge": + def get_edge_with_target(self, target_id: str) -> Generator["ContractEdge", None, None]: """ Get the edge with the target id. @@ -66,8 +67,7 @@ class ComponentVertex(Vertex): """ for edge in self.edges: if edge.target_id == target_id: - return edge - return None + yield edge async def _get_result(self, requester: "Vertex") -> Any: """ @@ -85,13 +85,20 @@ class ComponentVertex(Vertex): if requester is None: raise ValueError("Requester Vertex is None") - edge = self.get_edge_with_target(requester.id) - if edge is None: - raise ValueError(f"Edge not found between {self.display_name} and {requester.display_name}") - if edge.source_handle.name not in self.results: - raise ValueError(f"Result not found for {edge.source_handle.name}. Results: {self.results}") - result = self.results[edge.source_handle.name] - + edges = self.get_edge_with_target(requester.id) + result = UNDEFINED + edge = None + for edge in edges: + if edge is not None and edge.source_handle.name in self.results: + result = self.results[edge.source_handle.name] + break + if result is UNDEFINED: + if edge is None: + raise ValueError(f"Edge not found between {self.display_name} and {requester.display_name}") + elif edge.source_handle.name not in self.results: + raise ValueError(f"Result not found for {edge.source_handle.name}. Results: {self.results}") + else: + raise ValueError(f"Result not found for {edge.source_handle.name}") log_transaction(source=self, target=requester, flow_id=self.graph.flow_id, status="success") return result @@ -341,6 +348,13 @@ class InterfaceVertex(ComponentVertex): # and remove the stream_url self._finalize_build() logger.debug(f"Streamed message: {complete_message}") + # Set the result in the vertex of origin + edges = self.get_edge_with_target(self.id) + for edge in edges: + origin_vertex = self.graph.get_vertex(edge.source_id) + for key, value in origin_vertex.results.items(): + if isinstance(value, (AsyncIterator, Iterator)): + origin_vertex.results[key] = complete_message await log_vertex_build( flow_id=self.graph.flow_id,