feat: get result from output if possible (#3338)
* refactor: Extract method to set params from normal edge in Vertex class. * feat: Add method to retrieve value from template dict in Vertex class. * feat: Add handling for cycle and contract edge targets in ComponentVertex build method. * refactor: Update result retrieval logic in ComponentVertex class. * refactor: Add condition to check flow_id before creating log transactions. * refactor: Add missing `Edge` import and cast `cast` in types.py for better typing.
This commit is contained in:
parent
450ebb723a
commit
2baee5fef1
5 changed files with 115 additions and 47 deletions
|
|
@ -102,6 +102,8 @@ class Graph:
|
|||
self.raw_graph_data: GraphData = {"nodes": [], "edges": []}
|
||||
self._is_cyclic: Optional[bool] = None
|
||||
self._cycles: Optional[List[tuple[str, str]]] = None
|
||||
self._call_order: List[str] = []
|
||||
self._snapshots: List[Dict[str, Any]] = []
|
||||
try:
|
||||
self.tracing_service: "TracingService" | None = get_tracing_service()
|
||||
except Exception as exc:
|
||||
|
|
@ -1134,6 +1136,14 @@ class Graph:
|
|||
return vertex
|
||||
raise ValueError(f"Vertex {vertex_id} is not a top level vertex or no root vertex found")
|
||||
|
||||
def get_next_in_queue(self):
|
||||
if not self._run_queue:
|
||||
return None
|
||||
return self._run_queue.popleft()
|
||||
|
||||
def extend_run_queue(self, vertices: List[str]):
|
||||
self._run_queue.extend(vertices)
|
||||
|
||||
async def astep(
|
||||
self,
|
||||
inputs: Optional["InputValueRequest"] = None,
|
||||
|
|
@ -1145,7 +1155,7 @@ class Graph:
|
|||
if not self._run_queue:
|
||||
asyncio.create_task(self.end_all_traces())
|
||||
return Finish()
|
||||
vertex_id = self._run_queue.popleft()
|
||||
vertex_id = self.get_next_in_queue()
|
||||
chat_service = get_chat_service()
|
||||
vertex_build_result = await self.build_vertex(
|
||||
vertex_id=vertex_id,
|
||||
|
|
@ -1161,13 +1171,31 @@ class Graph:
|
|||
)
|
||||
if self.stop_vertex and self.stop_vertex in next_runnable_vertices:
|
||||
next_runnable_vertices = [self.stop_vertex]
|
||||
self._run_queue.extend(next_runnable_vertices)
|
||||
self.extend_run_queue(next_runnable_vertices)
|
||||
self.reset_inactivated_vertices()
|
||||
self.reset_activated_vertices()
|
||||
|
||||
await chat_service.set_cache(str(self.flow_id or self._run_id), self)
|
||||
self._record_snapshot(vertex_id)
|
||||
return vertex_build_result
|
||||
|
||||
def get_snapshot(self):
|
||||
return copy.deepcopy(
|
||||
{
|
||||
"run_manager": self.run_manager.to_dict(),
|
||||
"run_queue": self._run_queue,
|
||||
"vertices_layers": self.vertices_layers,
|
||||
"first_layer": self.first_layer,
|
||||
"inactive_vertices": self.inactive_vertices,
|
||||
"activated_vertices": self.activated_vertices,
|
||||
}
|
||||
)
|
||||
|
||||
def _record_snapshot(self, vertex_id: str | None = None, start: bool = False):
|
||||
self._snapshots.append(self.get_snapshot())
|
||||
if vertex_id:
|
||||
self._call_order.append(vertex_id)
|
||||
|
||||
def step(
|
||||
self,
|
||||
inputs: Optional["InputValueRequest"] = None,
|
||||
|
|
@ -1347,14 +1375,14 @@ class Graph:
|
|||
return self
|
||||
|
||||
def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: List[str]) -> List[str]:
|
||||
next_runnable_vertices = []
|
||||
next_runnable_vertices = set()
|
||||
for v_id in vertex_successors_ids:
|
||||
if not self.is_vertex_runnable(v_id):
|
||||
next_runnable_vertices.extend(self.find_runnable_predecessors_for_successor(v_id))
|
||||
next_runnable_vertices.update(self.find_runnable_predecessors_for_successor(v_id))
|
||||
else:
|
||||
next_runnable_vertices.append(v_id)
|
||||
next_runnable_vertices.add(v_id)
|
||||
|
||||
return next_runnable_vertices
|
||||
return list(next_runnable_vertices)
|
||||
|
||||
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: "Vertex", cache: bool = True) -> List[str]:
|
||||
v_id = vertex.id
|
||||
|
|
@ -1592,6 +1620,7 @@ class Graph:
|
|||
self._first_layer = first_layer
|
||||
self._run_queue = deque(first_layer)
|
||||
self._prepared = True
|
||||
self._record_snapshot()
|
||||
return self
|
||||
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from langflow.utils.util import sync_to_async, unescape_string
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.custom import Component
|
||||
from langflow.graph.edge.base import CycleEdge
|
||||
from langflow.graph.edge.base import CycleEdge, Edge
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
||||
|
||||
|
|
@ -250,6 +250,12 @@ class Vertex:
|
|||
self.base_type = base_type
|
||||
break
|
||||
|
||||
def get_value_from_template_dict(self, key: str):
|
||||
template_dict = self.data.get("node", {}).get("template", {})
|
||||
if key not in template_dict:
|
||||
raise ValueError(f"Key {key} not found in template dict")
|
||||
return template_dict.get(key, {}).get("value")
|
||||
|
||||
def get_task(self):
|
||||
# using the task_id, get the task from celery
|
||||
# and return it
|
||||
|
|
@ -257,6 +263,31 @@ class Vertex:
|
|||
|
||||
return AsyncResult(self.task_id)
|
||||
|
||||
def _set_params_from_normal_edge(self, params: dict, edge: "Edge", template_dict: dict):
|
||||
param_key = edge.target_param
|
||||
|
||||
# If the param_key is in the template_dict and the edge.target_id is the current node
|
||||
# We check this to make sure params with the same name but different target_id
|
||||
# don't get overwritten
|
||||
if param_key in template_dict and edge.target_id == self.id:
|
||||
if template_dict[param_key].get("list"):
|
||||
if param_key not in params:
|
||||
params[param_key] = []
|
||||
params[param_key].append(self.graph.get_vertex(edge.source_id))
|
||||
elif edge.target_id == self.id:
|
||||
if isinstance(template_dict[param_key].get("value"), dict):
|
||||
# we don't know the key of the dict but we need to set the value
|
||||
# to the vertex that is the source of the edge
|
||||
param_dict = template_dict[param_key]["value"]
|
||||
if not param_dict or len(param_dict) != 1:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
else:
|
||||
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
|
||||
|
||||
else:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
return params
|
||||
|
||||
def _build_params(self):
|
||||
# sourcery skip: merge-list-append, remove-redundant-if
|
||||
# Some params are required, some are optional
|
||||
|
|
@ -287,30 +318,7 @@ class Vertex:
|
|||
for edge in self.edges:
|
||||
if not hasattr(edge, "target_param"):
|
||||
continue
|
||||
param_key = edge.target_param
|
||||
|
||||
# If the param_key is in the template_dict and the edge.target_id is the current node
|
||||
# We check this to make sure params with the same name but different target_id
|
||||
# don't get overwritten
|
||||
if param_key in template_dict and edge.target_id == self.id:
|
||||
if template_dict[param_key].get("list"):
|
||||
if param_key not in params:
|
||||
params[param_key] = []
|
||||
params[param_key].append(self.graph.get_vertex(edge.source_id))
|
||||
elif edge.target_id == self.id:
|
||||
if isinstance(template_dict[param_key].get("value"), dict):
|
||||
# we don't know the key of the dict but we need to set the value
|
||||
# to the vertex that is the source of the edge
|
||||
param_dict = template_dict[param_key]["value"]
|
||||
if not param_dict or len(param_dict) != 1:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
else:
|
||||
params[param_key] = {
|
||||
key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()
|
||||
}
|
||||
|
||||
else:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
params = self._set_params_from_normal_edge(params, edge, template_dict)
|
||||
|
||||
load_from_db_fields = []
|
||||
for field_name, field in template_dict.items():
|
||||
|
|
@ -598,11 +606,13 @@ class Vertex:
|
|||
"""
|
||||
flow_id = self.graph.flow_id
|
||||
if not self._built:
|
||||
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="error"))
|
||||
if flow_id:
|
||||
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="error"))
|
||||
raise ValueError(f"Component {self.display_name} has not been built yet")
|
||||
|
||||
result = self._built_result if self.use_result else self._built_object
|
||||
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="success"))
|
||||
if flow_id:
|
||||
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="success"))
|
||||
return result
|
||||
|
||||
async def _build_vertex_and_update_params(self, key, vertex: "Vertex"):
|
||||
|
|
|
|||
4
src/backend/base/langflow/graph/vertex/exceptions.py
Normal file
4
src/backend/base/langflow/graph/vertex/exceptions.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
class NoComponentInstance(Exception):
|
||||
def __init__(self, vertex_id: str):
|
||||
message = f"Vertex {vertex_id} does not have a component instance."
|
||||
super().__init__(message)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator, List
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator, List, cast
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
|
@ -9,6 +9,7 @@ from loguru import logger
|
|||
from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes, ResultData
|
||||
from langflow.graph.utils import UnbuiltObject, log_transaction, log_vertex_build, serialize_field
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.exceptions import NoComponentInstance
|
||||
from langflow.graph.vertex.schema import NodeData
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.schema import Data
|
||||
|
|
@ -43,7 +44,7 @@ class ComponentVertex(Vertex):
|
|||
|
||||
def get_output(self, name: str) -> Output:
|
||||
if self._custom_component is None:
|
||||
raise ValueError(f"Vertex {self.id} does not have a component instance.")
|
||||
raise NoComponentInstance(self.id)
|
||||
return self._custom_component.get_output(name)
|
||||
|
||||
def _built_object_repr(self):
|
||||
|
|
@ -92,10 +93,19 @@ class ComponentVertex(Vertex):
|
|||
Returns:
|
||||
The built result if use_result is True, else the built object.
|
||||
"""
|
||||
flow_id = self.graph.flow_id
|
||||
if not self._built:
|
||||
asyncio.create_task(
|
||||
log_transaction(source=self, target=requester, flow_id=str(self.graph.flow_id), status="error")
|
||||
)
|
||||
if flow_id:
|
||||
asyncio.create_task(
|
||||
log_transaction(source=self, target=requester, flow_id=str(flow_id), status="error")
|
||||
)
|
||||
for edge in self.get_edge_with_target(requester.id):
|
||||
# We need to check if the edge is a normal edge
|
||||
# or a contract edge
|
||||
|
||||
if edge.is_cycle and edge.target_param:
|
||||
return requester.get_value_from_template_dict(edge.target_param)
|
||||
|
||||
raise ValueError(f"Component {self.display_name} has not been built yet")
|
||||
|
||||
if requester is None:
|
||||
|
|
@ -103,10 +113,18 @@ class ComponentVertex(Vertex):
|
|||
|
||||
edges = self.get_edge_with_target(requester.id)
|
||||
result = UNDEFINED
|
||||
edge = None
|
||||
for edge in edges:
|
||||
if edge is not None and edge.source_handle.name in self.results:
|
||||
result = self.results[edge.source_handle.name]
|
||||
# Get the result from the output instead of the results dict
|
||||
try:
|
||||
output = self.get_output(edge.source_handle.name)
|
||||
|
||||
if output.value is UNDEFINED:
|
||||
result = self.results[edge.source_handle.name]
|
||||
else:
|
||||
result = cast(Any, output.value)
|
||||
except NoComponentInstance:
|
||||
result = self.results[edge.source_handle.name]
|
||||
break
|
||||
if result is UNDEFINED:
|
||||
if edge is None:
|
||||
|
|
@ -114,10 +132,9 @@ class ComponentVertex(Vertex):
|
|||
elif edge.source_handle.name not in self.results:
|
||||
raise ValueError(f"Result not found for {edge.source_handle.name}. Results: {self.results}")
|
||||
else:
|
||||
raise ValueError(f"Result not found for {edge.source_handle.name}")
|
||||
asyncio.create_task(
|
||||
log_transaction(source=self, target=requester, flow_id=str(self.graph.flow_id), status="success")
|
||||
)
|
||||
raise ValueError(f"Result not found for {edge.source_handle.name} in {edge}")
|
||||
if flow_id:
|
||||
asyncio.create_task(log_transaction(source=self, target=requester, flow_id=str(flow_id), status="success"))
|
||||
return result
|
||||
|
||||
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from collections import Counter, defaultdict
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
|
@ -14,6 +15,7 @@ from langflow.components.prompts.Prompt import PromptComponent
|
|||
from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.graph.schema import VertexBuildResult
|
||||
from langflow.schema.data import Data
|
||||
|
||||
|
||||
|
|
@ -96,7 +98,7 @@ def rag_graph():
|
|||
return graph
|
||||
|
||||
|
||||
def test_vector_store_rag(ingestion_graph, rag_graph):
|
||||
def test_vector_store_rag(ingestion_graph: Graph, rag_graph: Graph):
|
||||
assert ingestion_graph is not None
|
||||
ingestion_ids = [
|
||||
"file-123",
|
||||
|
|
@ -115,11 +117,17 @@ def test_vector_store_rag(ingestion_graph, rag_graph):
|
|||
"openai-embeddings-124",
|
||||
]
|
||||
for ids, graph, len_results in zip([ingestion_ids, rag_ids], [ingestion_graph, rag_graph], [5, 8]):
|
||||
results = []
|
||||
results: list[VertexBuildResult] = []
|
||||
ids_count = Counter(ids)
|
||||
results_id_count: dict[str, int] = defaultdict(int)
|
||||
for result in graph.start(config={"output": {"cache": True}}):
|
||||
results.append(result)
|
||||
if hasattr(result, "vertex"):
|
||||
results_id_count[result.vertex.id] += 1
|
||||
|
||||
assert len(results) == len_results
|
||||
assert (
|
||||
len(results) == len_results
|
||||
), f"Counts: {ids_count} != {results_id_count}, Diff: {set(ids_count.keys()) - set(results_id_count.keys())}"
|
||||
vids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert all(vid in ids for vid in vids), f"Diff: {set(vids) - set(ids)}"
|
||||
assert results[-1] == Finish()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue