fix: add tests to cycles in Graph and improve error handling (#3628)

* Add cycle detection and handling in graph edge building process

- Introduced `cycles` property to detect cycles in the graph.
- Modified `_build_edges` and `build_edge` methods to differentiate between `CycleEdge` and `Edge`.
- Updated imports and type hints to support new functionality.

* Add cycle detection and handling in graph processing

- Introduced `is_cyclic` property to check for cycles in the graph.
- Added `_snapshot` method for capturing the current state of the graph.
- Modified `layered_topological_sort` to handle cyclic graphs by starting from a specified start component.
- Updated imports and type hints for better code clarity and functionality.

* Refactor tests and components for improved caching and data handling

- Updated `test_vector_store_rag.py` to use `set_on_output` with `cache=True` and simplified assertions.
- Enhanced `test_memory_chatbot.py` with additional assertions for graph structure and caching.
- Simplified `to_data` method in `base.py` to directly return `_data` without JSON serialization.

* Add unit tests for detecting cycles in graph

- Introduce `test_cycle_in_graph` to verify cyclic behavior in the graph.
- Add `test_cycle_in_graph_max_iterations` to ensure max iterations limit is respected.
- Implement `Concatenate` component for testing purposes.

* Disable output cache in graph tests to allow loops to work

* Refactor: Update VertexStates enum values to uppercase and optimize imports in base.py

* Refactor type hints and improve error handling in `Vertex` class

- Replace `ValueError` with `NoComponentInstance` exception for missing component instances.
- Add `target_handle_name` parameter to `_get_result` method for better result retrieval.
- Refactor type hints to use `collections.abc` for `AsyncIterator`, `Generator`, and `Iterator`.
- Update type hints for `extract_messages_from_artifacts` and `successors_ids` methods to use generic `dict` and `list`.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-09-02 08:45:37 -03:00 committed by GitHub
commit bc6e918f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 222 additions and 55 deletions

View file

@ -4,24 +4,31 @@ import json
import uuid
import warnings
from collections import defaultdict, deque
from collections.abc import Generator, Iterable
from datetime import datetime, timezone
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional, cast
import nest_asyncio
from loguru import logger
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.edge.base import CycleEdge
from langflow.graph.edge.base import CycleEdge, Edge
from langflow.graph.edge.schema import EdgeData
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, StartConfigDict, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, should_continue, sort_up_to_vertex
from langflow.graph.graph.utils import (
find_all_cycle_edges,
find_start_component_id,
has_cycle,
process_flow,
should_continue,
sort_up_to_vertex,
)
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData
@ -279,6 +286,15 @@ class Graph:
raise ValueError("Max iterations reached")
def _snapshot(self):
return {
"_run_queue": self._run_queue.copy(),
"_first_layer": self._first_layer.copy(),
"vertices_layers": copy.deepcopy(self.vertices_layers),
"vertices_to_run": copy.deepcopy(self.vertices_to_run),
"run_manager": copy.deepcopy(self.run_manager.to_dict()),
}
def __apply_config(self, config: StartConfigDict):
for vertex in self.vertices:
if vertex._custom_component is None:
@ -460,6 +476,23 @@ class Graph:
raise ValueError("Graph not prepared. Call prepare() first.")
return self._first_layer
@property
def is_cyclic(self):
"""
Check if the graph has any cycles.
Returns:
bool: True if the graph has any cycles, False otherwise.
"""
if self._is_cyclic is None:
vertices = [vertex.id for vertex in self.vertices]
try:
edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges]
except KeyError:
edges = [(e["source"], e["target"]) for e in self._edges]
self._is_cyclic = has_cycle(vertices, edges)
return self._is_cyclic
@property
def run_id(self):
"""
@ -1380,7 +1413,7 @@ class Graph:
def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: list[str]) -> list[str]:
next_runnable_vertices = set()
for v_id in vertex_successors_ids:
for v_id in sorted(vertex_successors_ids):
if not self.is_vertex_runnable(v_id):
next_runnable_vertices.update(self.find_runnable_predecessors_for_successor(v_id))
else:
@ -1536,21 +1569,31 @@ class Graph:
neighbors[neighbor] += 1
return neighbors
@property
def cycles(self):
if self._cycles is None:
if self._start is None:
self._cycles = []
else:
entry_vertex = self._start._id
edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges]
self._cycles = find_all_cycle_edges(entry_vertex, edges)
return self._cycles
def _build_edges(self) -> list[CycleEdge]:
"""Builds the edges of the graph."""
# Edge takes two vertices as arguments, so we need to build the vertices first
# and then build the edges
# if we can't find a vertex, we raise an error
edges: set[CycleEdge] = set()
edges: set[CycleEdge | Edge] = set()
for edge in self._edges:
new_edge = self.build_edge(edge)
edges.add(new_edge)
if self.vertices and not edges:
warnings.warn("Graph has vertices but no edges")
return list(edges)
return list(cast(Iterable[CycleEdge], edges))
def build_edge(self, edge: EdgeData) -> CycleEdge:
def build_edge(self, edge: EdgeData) -> CycleEdge | Edge:
source = self.get_vertex(edge["source"])
target = self.get_vertex(edge["target"])
@ -1558,7 +1601,10 @@ class Graph:
raise ValueError(f"Source vertex {edge['source']} not found")
if target is None:
raise ValueError(f"Target vertex {edge['target']} not found")
new_edge = CycleEdge(source, target, edge)
if (source.id, target.id) in self.cycles:
new_edge: CycleEdge | Edge = CycleEdge(source, target, edge)
else:
new_edge = Edge(source, target, edge)
return new_edge
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> type["Vertex"]:
@ -1608,7 +1654,6 @@ class Graph:
if stop_component_id and start_component_id:
raise ValueError("You can only provide one of stop_component_id or start_component_id")
self.validate_stream()
self.edges = self._build_edges()
if stop_component_id or start_component_id:
try:
@ -1658,12 +1703,25 @@ class Graph:
"""Performs a layered topological sort of the vertices in the graph."""
vertices_ids = {vertex.id for vertex in vertices}
# Queue for vertices with no incoming edges
queue = deque(
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
)
in_degree_map = self.in_degree_map.copy()
if self.is_cyclic and all(in_degree_map.values()):
# This means we have a cycle because all vertex have in_degree_map > 0
# because of this we set the queue to start on the ._start if it exists
if self._start is not None:
queue = deque([self._start._id])
else:
# Find the chat input component
chat_input = find_start_component_id(vertices_ids)
if chat_input is None:
raise ValueError("No input component found and no start component provided")
queue = deque([chat_input])
else:
queue = deque(
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
)
layers: list[list[str]] = []
visited = set(queue)
@ -1684,13 +1742,13 @@ class Graph:
if neighbor not in vertices_ids:
continue
self.in_degree_map[neighbor] -= 1 # 'remove' edge
if self.in_degree_map[neighbor] == 0 and neighbor not in visited:
in_degree_map[neighbor] -= 1 # 'remove' edge
if in_degree_map[neighbor] == 0 and neighbor not in visited:
queue.append(neighbor)
# if > 0 it might mean not all predecessors have added to the queue
# so we should process the neighbors predecessors
elif self.in_degree_map[neighbor] > 0:
elif in_degree_map[neighbor] > 0:
for predecessor in self.predecessor_map[neighbor]:
if predecessor not in queue and predecessor not in visited:
queue.append(predecessor)

View file

@ -4,10 +4,9 @@ import inspect
import os
import traceback
import types
import json
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
import pandas as pd
from loguru import logger
@ -37,9 +36,9 @@ if TYPE_CHECKING:
class VertexStates(str, Enum):
"""Vertex are related to it being active, inactive, or in an error state."""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
ACTIVE = "ACTIVE"
INACTIVE = "INACTIVE"
ERROR = "ERROR"
class Vertex:
@ -105,12 +104,7 @@ class Vertex:
self._custom_component._set_input_value(name, value)
def to_data(self):
try:
data = json.loads(json.dumps(self._data, default=str))
except TypeError:
data = self._data
return data
return self._data
def add_component_instance(self, component_instance: "Component"):
component_instance.set_vertex(self)

View file

@ -102,8 +102,6 @@ class ComponentVertex(Vertex):
)
for edge in self.get_edge_with_target(requester.id):
# We need to check if the edge is a normal edge
# or a contract edge
if edge.is_cycle and edge.target_param:
return requester.get_value_from_template_dict(edge.target_param)

View file

@ -0,0 +1,111 @@
import pytest
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.outputs.TextOutput import TextOutputComponent
from langflow.components.prototypes.ConditionalRouter import ConditionalRouterComponent
from langflow.custom.custom_component.component import Component
from langflow.graph.graph.base import Graph
from langflow.io import MessageTextInput, Output
from langflow.schema.message import Message
@pytest.fixture
def client():
pass
class Concatenate(Component):
display_name = "Concatenate"
description = "Concatenates two strings"
inputs = [
MessageTextInput(name="text", display_name="Text", required=True),
]
outputs = [
Output(display_name="Text", name="some_text", method="concatenate"),
]
def concatenate(self) -> Message:
return Message(text=f"{self.text}{self.text}" or "test")
def test_cycle_in_graph():
chat_input = ChatInput(_id="chat_input")
router = ConditionalRouterComponent(_id="router")
chat_input.set(input_value=router.false_response)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
router.set(
input_text=chat_input.message_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=text_output.text_response)
graph = Graph(chat_input, chat_output)
assert graph.is_cyclic is True
# Run queue should contain chat_input and not router
assert "chat_input" in graph._run_queue
assert "router" not in graph._run_queue
results = []
max_iterations = 20
snapshots = [graph._snapshot()]
for result in graph.start(max_iterations=max_iterations, config={"output": {"cache": False}}):
snapshots.append(graph._snapshot())
results.append(result)
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert results_ids[-2:] == ["text_output", "chat_output"]
assert len(results_ids) > len(graph.vertices), snapshots
# Check that chat_output and text_output are the last vertices in the results
assert results_ids == [
"chat_input",
"concatenate",
"router",
"chat_input",
"concatenate",
"router",
"chat_input",
"concatenate",
"router",
"chat_input",
"concatenate",
"router",
"text_output",
"chat_output",
], f"Results: {results_ids}"
def test_cycle_in_graph_max_iterations():
chat_input = ChatInput(_id="chat_input")
router = ConditionalRouterComponent(_id="router")
chat_input.set(input_value=router.false_response)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
router.set(
input_text=chat_input.message_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=text_output.text_response)
graph = Graph(chat_input, chat_output)
assert graph.is_cyclic is True
# Run queue should contain chat_input and not router
assert "chat_input" in graph._run_queue
assert "router" not in graph._run_queue
results = []
with pytest.raises(ValueError, match="Max iterations reached"):
for result in graph.start(max_iterations=2, config={"output": {"cache": False}}):
results.append(result)

View file

@ -30,23 +30,35 @@ AI: """
openai_component.set(
input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key"
)
openai_component.get_output("text_output").value = "Mock response"
openai_component.set_on_output(name="text_output", value="Mock response", cache=True)
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=openai_component.text_response)
graph = Graph(chat_input, chat_output)
assert graph.in_degree_map == {"chat_output": 1, "prompt": 2, "openai": 1, "chat_input": 0, "chat_memory": 0}
return graph
def test_memory_chatbot(memory_chatbot_graph):
# Now we run step by step
expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"])
assert memory_chatbot_graph.in_degree_map == {
"chat_output": 1,
"prompt": 2,
"openai": 1,
"chat_input": 0,
"chat_memory": 0,
}
assert memory_chatbot_graph.vertices_layers == [["prompt"], ["openai"], ["chat_output"]]
assert memory_chatbot_graph.first_layer == ["chat_input", "chat_memory"]
for step in expected_order:
result = memory_chatbot_graph.step()
if isinstance(result, Finish):
break
assert step == result.vertex.id
assert step == result.vertex.id, (memory_chatbot_graph.in_degree_map, memory_chatbot_graph.vertices_layers)
def test_memory_chatbot_dump_structure(memory_chatbot_graph: Graph):

View file

@ -1,5 +1,4 @@
import copy
from collections import Counter, defaultdict
from textwrap import dedent
import pytest
@ -15,7 +14,6 @@ from langflow.components.prompts.Prompt import PromptComponent
from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.graph.schema import VertexBuildResult
from langflow.schema.data import Data
@ -29,7 +27,7 @@ def ingestion_graph():
# Ingestion Graph
file_component = FileComponent(_id="file-123")
file_component.set(path="test.txt")
file_component.set_on_output("data", value=Data(text="This is a test file."))
file_component.set_on_output(name="data", value=Data(text="This is a test file."), cache=True)
text_splitter = SplitTextComponent(_id="text-splitter-123")
text_splitter.set(data_inputs=file_component.load_file)
openai_embeddings = OpenAIEmbeddingsComponent(_id="openai-embeddings-123")
@ -43,9 +41,10 @@ def ingestion_graph():
api_endpoint="https://astra.example.com",
token="token",
)
vector_store.set_on_output("vector_store", value="mock_vector_store")
vector_store.set_on_output("base_retriever", value="mock_retriever")
vector_store.set_on_output("search_results", value=[Data(text="This is a test file.")])
vector_store.set_on_output(name="vector_store", value="mock_vector_store", cache=True)
vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True)
vector_store.set_on_output(name="search_results", value=[Data(text="This is a test file.")], cache=True)
ingestion_graph = Graph(file_component, vector_store)
return ingestion_graph
@ -65,14 +64,15 @@ def rag_graph():
)
# Mock search_documents
rag_vector_store.set_on_output(
"search_results",
name="search_results",
value=[
Data(data={"text": "Hello, world!"}),
Data(data={"text": "Goodbye, world!"}),
],
cache=True,
)
rag_vector_store.set_on_output("base_retriever", value="mock_retriever")
rag_vector_store.set_on_output("vector_store", value="mock_vector_store")
rag_vector_store.set_on_output(name="vector_store", value="mock_vector_store", cache=True)
rag_vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True)
parse_data = ParseDataComponent(_id="parse-data-123")
parse_data.set(data=rag_vector_store.search_documents)
prompt_component = PromptComponent(_id="prompt-123")
@ -88,7 +88,7 @@ def rag_graph():
openai_component = OpenAIModelComponent(_id="openai-123")
openai_component.set(api_key="sk-123", openai_api_base="https://api.openai.com/v1")
openai_component.set_on_output("text_output", value="Hello, world!")
openai_component.set_on_output(name="text_output", value="Hello, world!", cache=True)
openai_component.set(input_value=prompt_component.build_prompt)
chat_output = ChatOutput(_id="chatoutput-123")
@ -98,7 +98,7 @@ def rag_graph():
return graph
def test_vector_store_rag(ingestion_graph: Graph, rag_graph: Graph):
def test_vector_store_rag(ingestion_graph, rag_graph):
assert ingestion_graph is not None
ingestion_ids = [
"file-123",
@ -117,17 +117,11 @@ def test_vector_store_rag(ingestion_graph: Graph, rag_graph: Graph):
"openai-embeddings-124",
]
for ids, graph, len_results in zip([ingestion_ids, rag_ids], [ingestion_graph, rag_graph], [5, 8]):
results: list[VertexBuildResult] = []
ids_count = Counter(ids)
results_id_count: dict[str, int] = defaultdict(int)
for result in graph.start(config={"output": {"cache": True}}):
results = []
for result in graph.start():
results.append(result)
if hasattr(result, "vertex"):
results_id_count[result.vertex.id] += 1
assert (
len(results) == len_results
), f"Counts: {ids_count} != {results_id_count}, Diff: {set(ids_count.keys()) - set(results_id_count.keys())}"
assert len(results) == len_results
vids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert all(vid in ids for vid in vids), f"Diff: {set(vids) - set(ids)}"
assert results[-1] == Finish()