feat: get result from output if possible (#3338)

* refactor: Extract method to set params from normal edge in Vertex class.

* feat: Add method to retrieve value from template dict in Vertex class.

* feat: Add handling for cycle and contract edge targets in ComponentVertex build method.

* refactor: Update result retrieval logic in ComponentVertex class.

* refactor: Add condition to check flow_id before creating log transactions.

* refactor: Add missing `Edge` import and cast `cast` in types.py for better typing.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-15 13:44:06 -03:00 committed by GitHub
commit 2baee5fef1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 115 additions and 47 deletions

View file

@ -102,6 +102,8 @@ class Graph:
self.raw_graph_data: GraphData = {"nodes": [], "edges": []}
self._is_cyclic: Optional[bool] = None
self._cycles: Optional[List[tuple[str, str]]] = None
self._call_order: List[str] = []
self._snapshots: List[Dict[str, Any]] = []
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
@ -1134,6 +1136,14 @@ class Graph:
return vertex
raise ValueError(f"Vertex {vertex_id} is not a top level vertex or no root vertex found")
def get_next_in_queue(self):
if not self._run_queue:
return None
return self._run_queue.popleft()
def extend_run_queue(self, vertices: List[str]):
self._run_queue.extend(vertices)
async def astep(
self,
inputs: Optional["InputValueRequest"] = None,
@ -1145,7 +1155,7 @@ class Graph:
if not self._run_queue:
asyncio.create_task(self.end_all_traces())
return Finish()
vertex_id = self._run_queue.popleft()
vertex_id = self.get_next_in_queue()
chat_service = get_chat_service()
vertex_build_result = await self.build_vertex(
vertex_id=vertex_id,
@ -1161,13 +1171,31 @@ class Graph:
)
if self.stop_vertex and self.stop_vertex in next_runnable_vertices:
next_runnable_vertices = [self.stop_vertex]
self._run_queue.extend(next_runnable_vertices)
self.extend_run_queue(next_runnable_vertices)
self.reset_inactivated_vertices()
self.reset_activated_vertices()
await chat_service.set_cache(str(self.flow_id or self._run_id), self)
self._record_snapshot(vertex_id)
return vertex_build_result
def get_snapshot(self):
return copy.deepcopy(
{
"run_manager": self.run_manager.to_dict(),
"run_queue": self._run_queue,
"vertices_layers": self.vertices_layers,
"first_layer": self.first_layer,
"inactive_vertices": self.inactive_vertices,
"activated_vertices": self.activated_vertices,
}
)
def _record_snapshot(self, vertex_id: str | None = None, start: bool = False):
self._snapshots.append(self.get_snapshot())
if vertex_id:
self._call_order.append(vertex_id)
def step(
self,
inputs: Optional["InputValueRequest"] = None,
@ -1347,14 +1375,14 @@ class Graph:
return self
def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: List[str]) -> List[str]:
next_runnable_vertices = []
next_runnable_vertices = set()
for v_id in vertex_successors_ids:
if not self.is_vertex_runnable(v_id):
next_runnable_vertices.extend(self.find_runnable_predecessors_for_successor(v_id))
next_runnable_vertices.update(self.find_runnable_predecessors_for_successor(v_id))
else:
next_runnable_vertices.append(v_id)
next_runnable_vertices.add(v_id)
return next_runnable_vertices
return list(next_runnable_vertices)
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: "Vertex", cache: bool = True) -> List[str]:
v_id = vertex.id
@ -1592,6 +1620,7 @@ class Graph:
self._first_layer = first_layer
self._run_queue = deque(first_layer)
self._prepared = True
self._record_snapshot()
return self
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:

View file

@ -28,7 +28,7 @@ from langflow.utils.util import sync_to_async, unescape_string
if TYPE_CHECKING:
from langflow.custom import Component
from langflow.graph.edge.base import CycleEdge
from langflow.graph.edge.base import CycleEdge, Edge
from langflow.graph.graph.base import Graph
@ -250,6 +250,12 @@ class Vertex:
self.base_type = base_type
break
def get_value_from_template_dict(self, key: str):
template_dict = self.data.get("node", {}).get("template", {})
if key not in template_dict:
raise ValueError(f"Key {key} not found in template dict")
return template_dict.get(key, {}).get("value")
def get_task(self):
# using the task_id, get the task from celery
# and return it
@ -257,6 +263,31 @@ class Vertex:
return AsyncResult(self.task_id)
def _set_params_from_normal_edge(self, params: dict, edge: "Edge", template_dict: dict):
param_key = edge.target_param
# If the param_key is in the template_dict and the edge.target_id is the current node
# We check this to make sure params with the same name but different target_id
# don't get overwritten
if param_key in template_dict and edge.target_id == self.id:
if template_dict[param_key].get("list"):
if param_key not in params:
params[param_key] = []
params[param_key].append(self.graph.get_vertex(edge.source_id))
elif edge.target_id == self.id:
if isinstance(template_dict[param_key].get("value"), dict):
# we don't know the key of the dict but we need to set the value
# to the vertex that is the source of the edge
param_dict = template_dict[param_key]["value"]
if not param_dict or len(param_dict) != 1:
params[param_key] = self.graph.get_vertex(edge.source_id)
else:
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
return params
def _build_params(self):
# sourcery skip: merge-list-append, remove-redundant-if
# Some params are required, some are optional
@ -287,30 +318,7 @@ class Vertex:
for edge in self.edges:
if not hasattr(edge, "target_param"):
continue
param_key = edge.target_param
# If the param_key is in the template_dict and the edge.target_id is the current node
# We check this to make sure params with the same name but different target_id
# don't get overwritten
if param_key in template_dict and edge.target_id == self.id:
if template_dict[param_key].get("list"):
if param_key not in params:
params[param_key] = []
params[param_key].append(self.graph.get_vertex(edge.source_id))
elif edge.target_id == self.id:
if isinstance(template_dict[param_key].get("value"), dict):
# we don't know the key of the dict but we need to set the value
# to the vertex that is the source of the edge
param_dict = template_dict[param_key]["value"]
if not param_dict or len(param_dict) != 1:
params[param_key] = self.graph.get_vertex(edge.source_id)
else:
params[param_key] = {
key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()
}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
params = self._set_params_from_normal_edge(params, edge, template_dict)
load_from_db_fields = []
for field_name, field in template_dict.items():
@ -598,11 +606,13 @@ class Vertex:
"""
flow_id = self.graph.flow_id
if not self._built:
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="error"))
if flow_id:
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="error"))
raise ValueError(f"Component {self.display_name} has not been built yet")
result = self._built_result if self.use_result else self._built_object
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="success"))
if flow_id:
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="success"))
return result
async def _build_vertex_and_update_params(self, key, vertex: "Vertex"):

View file

@ -0,0 +1,4 @@
class NoComponentInstance(Exception):
def __init__(self, vertex_id: str):
message = f"Vertex {vertex_id} does not have a component instance."
super().__init__(message)

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator, List
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator, List, cast
import yaml
from langchain_core.messages import AIMessage, AIMessageChunk
@ -9,6 +9,7 @@ from loguru import logger
from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes, ResultData
from langflow.graph.utils import UnbuiltObject, log_transaction, log_vertex_build, serialize_field
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.exceptions import NoComponentInstance
from langflow.graph.vertex.schema import NodeData
from langflow.inputs.inputs import InputTypes
from langflow.schema import Data
@ -43,7 +44,7 @@ class ComponentVertex(Vertex):
def get_output(self, name: str) -> Output:
if self._custom_component is None:
raise ValueError(f"Vertex {self.id} does not have a component instance.")
raise NoComponentInstance(self.id)
return self._custom_component.get_output(name)
def _built_object_repr(self):
@ -92,10 +93,19 @@ class ComponentVertex(Vertex):
Returns:
The built result if use_result is True, else the built object.
"""
flow_id = self.graph.flow_id
if not self._built:
asyncio.create_task(
log_transaction(source=self, target=requester, flow_id=str(self.graph.flow_id), status="error")
)
if flow_id:
asyncio.create_task(
log_transaction(source=self, target=requester, flow_id=str(flow_id), status="error")
)
for edge in self.get_edge_with_target(requester.id):
# We need to check if the edge is a normal edge
# or a contract edge
if edge.is_cycle and edge.target_param:
return requester.get_value_from_template_dict(edge.target_param)
raise ValueError(f"Component {self.display_name} has not been built yet")
if requester is None:
@ -103,10 +113,18 @@ class ComponentVertex(Vertex):
edges = self.get_edge_with_target(requester.id)
result = UNDEFINED
edge = None
for edge in edges:
if edge is not None and edge.source_handle.name in self.results:
result = self.results[edge.source_handle.name]
# Get the result from the output instead of the results dict
try:
output = self.get_output(edge.source_handle.name)
if output.value is UNDEFINED:
result = self.results[edge.source_handle.name]
else:
result = cast(Any, output.value)
except NoComponentInstance:
result = self.results[edge.source_handle.name]
break
if result is UNDEFINED:
if edge is None:
@ -114,10 +132,9 @@ class ComponentVertex(Vertex):
elif edge.source_handle.name not in self.results:
raise ValueError(f"Result not found for {edge.source_handle.name}. Results: {self.results}")
else:
raise ValueError(f"Result not found for {edge.source_handle.name}")
asyncio.create_task(
log_transaction(source=self, target=requester, flow_id=str(self.graph.flow_id), status="success")
)
raise ValueError(f"Result not found for {edge.source_handle.name} in {edge}")
if flow_id:
asyncio.create_task(log_transaction(source=self, target=requester, flow_id=str(flow_id), status="success"))
return result
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:

View file

@ -1,4 +1,5 @@
import copy
from collections import Counter, defaultdict
from textwrap import dedent
import pytest
@ -14,6 +15,7 @@ from langflow.components.prompts.Prompt import PromptComponent
from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.graph.schema import VertexBuildResult
from langflow.schema.data import Data
@ -96,7 +98,7 @@ def rag_graph():
return graph
def test_vector_store_rag(ingestion_graph, rag_graph):
def test_vector_store_rag(ingestion_graph: Graph, rag_graph: Graph):
assert ingestion_graph is not None
ingestion_ids = [
"file-123",
@ -115,11 +117,17 @@ def test_vector_store_rag(ingestion_graph, rag_graph):
"openai-embeddings-124",
]
for ids, graph, len_results in zip([ingestion_ids, rag_ids], [ingestion_graph, rag_graph], [5, 8]):
results = []
results: list[VertexBuildResult] = []
ids_count = Counter(ids)
results_id_count: dict[str, int] = defaultdict(int)
for result in graph.start(config={"output": {"cache": True}}):
results.append(result)
if hasattr(result, "vertex"):
results_id_count[result.vertex.id] += 1
assert len(results) == len_results
assert (
len(results) == len_results
), f"Counts: {ids_count} != {results_id_count}, Diff: {set(ids_count.keys()) - set(results_id_count.keys())}"
vids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert all(vid in ids for vid in vids), f"Diff: {set(vids) - set(ids)}"
assert results[-1] == Finish()