fix: makes outputs be correctly retrieved from edge (#3392)

* feat: Add optional target handle name in get_result method.

* fix: Improve logic to consider target handle name in ComponentVertex.

Fixes #3380
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-16 16:11:59 -03:00 committed by GitHub
commit 16afd44295
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 8 deletions

View file

@ -568,7 +568,7 @@ class Vertex:
if not self._is_vertex(value):
self.params[key][sub_key] = value
else:
result = await value.get_result(self)
result = await value.get_result(self, target_handle_name=key)
self.params[key][sub_key] = result
def _is_vertex(self, value):
@ -583,7 +583,7 @@ class Vertex:
"""
return all(self._is_vertex(vertex) for vertex in value)
async def get_result(self, requester: "Vertex") -> Any:
async def get_result(self, requester: "Vertex", target_handle_name: Optional[str] = None) -> Any:
"""
Retrieves the result of the vertex.
@ -593,9 +593,9 @@ class Vertex:
The result of the vertex.
"""
async with self._lock:
return await self._get_result(requester)
return await self._get_result(requester, target_handle_name)
async def _get_result(self, requester: "Vertex") -> Any:
async def _get_result(self, requester: "Vertex", target_handle_name: Optional[str] = None) -> Any:
"""
Retrieves the result of the built component.
@ -620,7 +620,7 @@ class Vertex:
Builds a given vertex and updates the params dictionary accordingly.
"""
result = await vertex.get_result(self)
result = await vertex.get_result(self, target_handle_name=key)
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
@ -636,7 +636,7 @@ class Vertex:
"""
self.params[key] = []
for vertex in vertices:
result = await vertex.get_result(self)
result = await vertex.get_result(self, target_handle_name=key)
# Weird check to see if the params[key] is a list
# because sometimes it is a Data and breaks the code
if not isinstance(self.params[key], list):

View file

@ -84,7 +84,7 @@ class ComponentVertex(Vertex):
if edge.target_id == target_id:
yield edge
async def _get_result(self, requester: "Vertex") -> Any:
async def _get_result(self, requester: "Vertex", target_handle_name: str | None = None) -> Any:
"""
Retrieves the result of the built component.
@ -114,7 +114,11 @@ class ComponentVertex(Vertex):
edges = self.get_edge_with_target(requester.id)
result = UNDEFINED
for edge in edges:
if edge is not None and edge.source_handle.name in self.results:
if (
edge is not None
and edge.source_handle.name in self.results
and edge.target_handle.field_name == target_handle_name
):
# Get the result from the output instead of the results dict
try:
output = self.get_output(edge.source_handle.name)