fix: makes outputs be correctly retrieved from edge (#3392)
* feat: Add optional target handle name in get_result method. * fix: Improve logic to consider target handle name in ComponentVertex. Fixes #3380
This commit is contained in:
parent
c00e687ec1
commit
16afd44295
2 changed files with 12 additions and 8 deletions
|
|
@ -568,7 +568,7 @@ class Vertex:
|
|||
if not self._is_vertex(value):
|
||||
self.params[key][sub_key] = value
|
||||
else:
|
||||
result = await value.get_result(self)
|
||||
result = await value.get_result(self, target_handle_name=key)
|
||||
self.params[key][sub_key] = result
|
||||
|
||||
def _is_vertex(self, value):
|
||||
|
|
@ -583,7 +583,7 @@ class Vertex:
|
|||
"""
|
||||
return all(self._is_vertex(vertex) for vertex in value)
|
||||
|
||||
async def get_result(self, requester: "Vertex") -> Any:
|
||||
async def get_result(self, requester: "Vertex", target_handle_name: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Retrieves the result of the vertex.
|
||||
|
||||
|
|
@ -593,9 +593,9 @@ class Vertex:
|
|||
The result of the vertex.
|
||||
"""
|
||||
async with self._lock:
|
||||
return await self._get_result(requester)
|
||||
return await self._get_result(requester, target_handle_name)
|
||||
|
||||
async def _get_result(self, requester: "Vertex") -> Any:
|
||||
async def _get_result(self, requester: "Vertex", target_handle_name: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Retrieves the result of the built component.
|
||||
|
||||
|
|
@ -620,7 +620,7 @@ class Vertex:
|
|||
Builds a given vertex and updates the params dictionary accordingly.
|
||||
"""
|
||||
|
||||
result = await vertex.get_result(self)
|
||||
result = await vertex.get_result(self, target_handle_name=key)
|
||||
self._handle_func(key, result)
|
||||
if isinstance(result, list):
|
||||
self._extend_params_list_with_result(key, result)
|
||||
|
|
@ -636,7 +636,7 @@ class Vertex:
|
|||
"""
|
||||
self.params[key] = []
|
||||
for vertex in vertices:
|
||||
result = await vertex.get_result(self)
|
||||
result = await vertex.get_result(self, target_handle_name=key)
|
||||
# Weird check to see if the params[key] is a list
|
||||
# because sometimes it is a Data and breaks the code
|
||||
if not isinstance(self.params[key], list):
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class ComponentVertex(Vertex):
|
|||
if edge.target_id == target_id:
|
||||
yield edge
|
||||
|
||||
async def _get_result(self, requester: "Vertex") -> Any:
|
||||
async def _get_result(self, requester: "Vertex", target_handle_name: str | None = None) -> Any:
|
||||
"""
|
||||
Retrieves the result of the built component.
|
||||
|
||||
|
|
@ -114,7 +114,11 @@ class ComponentVertex(Vertex):
|
|||
edges = self.get_edge_with_target(requester.id)
|
||||
result = UNDEFINED
|
||||
for edge in edges:
|
||||
if edge is not None and edge.source_handle.name in self.results:
|
||||
if (
|
||||
edge is not None
|
||||
and edge.source_handle.name in self.results
|
||||
and edge.target_handle.field_name == target_handle_name
|
||||
):
|
||||
# Get the result from the output instead of the results dict
|
||||
try:
|
||||
output = self.get_output(edge.source_handle.name)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue