diff --git a/src/backend/base/langflow/base/agents/context.py b/src/backend/base/langflow/base/agents/context.py new file mode 100644 index 000000000..8e4961ecc --- /dev/null +++ b/src/backend/base/langflow/base/agents/context.py @@ -0,0 +1,109 @@ +from datetime import datetime, timezone +from typing import Any + +from langchain_core.language_models import BaseLanguageModel, BaseLLM +from langchain_core.language_models.chat_models import BaseChatModel +from pydantic import BaseModel, Field, field_validator, model_serializer + +from langflow.field_typing import LanguageModel +from langflow.schema.data import Data + + +class AgentContext(BaseModel): + tools: dict[str, Any] + llm: Any + context: str = "" + iteration: int = 0 + max_iterations: int = 5 + thought: str = "" + last_action: Any = None + last_action_result: Any = None + final_answer: Any = "" + context_history: list[tuple[str, str, str]] = Field(default_factory=list) + + @model_serializer(mode="plain") + def serialize_agent_context(self): + serliazed_llm = self.llm.to_json() if hasattr(self.llm, "to_json") else str(self.llm) + serliazed_tools = {k: v.to_json() if hasattr(v, "to_json") else str(v) for k, v in self.tools.items()} + return { + "tools": serliazed_tools, + "llm": serliazed_llm, + "context": self.context, + "iteration": self.iteration, + "max_iterations": self.max_iterations, + "thought": self.thought, + "last_action": self.last_action.to_json() + if hasattr(self.last_action, "to_json") + else str(self.last_action), + "action_result": self.last_action_result.to_json() + if hasattr(self.last_action_result, "to_json") + else str(self.last_action_result), + "final_answer": self.final_answer, + "context_history": self.context_history, + } + + @field_validator("llm", mode="before") + @classmethod + def validate_llm(cls, v) -> LanguageModel: + if not isinstance(v, BaseLLM | BaseChatModel | BaseLanguageModel): + msg = "llm must be an instance of LanguageModel" + raise TypeError(msg) + return v + + def to_data_repr(self): + data_objs = [] + for name, val, time_str in self.context_history: + content = val.content if hasattr(val, "content") else val + data_objs.append(Data(name=name, value=content, timestamp=time_str)) + + sorted_data_objs = sorted(data_objs, key=lambda x: datetime.fromisoformat(x.timestamp), reverse=True) + + sorted_data_objs.append( + Data( + name="Formatted Context", + value=self.get_full_context(), + ) + ) + return sorted_data_objs + + def _build_tools_context(self): + tool_context = "" + for tool_name, tool_obj in self.tools.items(): + tool_context += f"{tool_name}: {tool_obj.description}\n" + return tool_context + + def _build_init_context(self): + return f""" +{self.context} + +""" + + def model_post_init(self, _context: Any) -> None: + if hasattr(self.llm, "bind_tools"): + self.llm = self.llm.bind_tools(self.tools.values()) + if self.context: + self.update_context("Initial Context", self.context) + + def update_context(self, key: str, value: str): + self.context_history.insert(0, (key, value, datetime.now(tz=timezone.utc).astimezone().isoformat())) + + def _serialize_context_history_tuple(self, context_history_tuple: tuple[str, str, str]) -> str: + name, value, _ = context_history_tuple + if hasattr(value, "content"): + value = value.content + elif hasattr(value, "log"): + value = value.log + return f"{name}: {value}" + + def get_full_context(self) -> str: + context_history_reversed = self.context_history[::-1] + context_formatted = "\n".join( + [ + self._serialize_context_history_tuple(context_history_tuple) + for context_history_tuple in context_history_reversed + ] + ) + return f""" +Context: +{context_formatted} +""" diff --git a/src/backend/base/langflow/components/embeddings/util/aiml.py b/src/backend/base/langflow/base/embeddings/aiml_embeddings.py similarity index 94% rename from src/backend/base/langflow/components/embeddings/util/aiml.py rename to src/backend/base/langflow/base/embeddings/aiml_embeddings.py index 38273d518..694e05c03 100644 --- a/src/backend/base/langflow/components/embeddings/util/aiml.py +++ b/src/backend/base/langflow/base/embeddings/aiml_embeddings.py @@ -30,7 +30,7 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings): try: result_data = future.result() if len(result_data["data"]) != 1: - msg = "Expected one embedding" + msg = f"Expected one embedding, got {len(result_data['data'])}" raise ValueError(msg) embeddings[index] = result_data["data"][0]["embedding"] except ( @@ -38,6 +38,7 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings): httpx.RequestError, json.JSONDecodeError, KeyError, + ValueError, ): logger.exception("Error occurred") raise diff --git a/src/backend/base/langflow/components/embeddings/aiml.py b/src/backend/base/langflow/components/embeddings/aiml.py index a5c4e2835..fcdab9aae 100644 --- a/src/backend/base/langflow/components/embeddings/aiml.py +++ b/src/backend/base/langflow/components/embeddings/aiml.py @@ -1,6 +1,6 @@ +from langflow.base.embeddings.aiml_embeddings import AIMLEmbeddingsImpl from langflow.base.embeddings.model import LCEmbeddingsModel from langflow.base.models.aiml_constants import AIML_EMBEDDING_MODELS -from langflow.components.embeddings.util import AIMLEmbeddingsImpl from langflow.field_typing import Embeddings from langflow.inputs.inputs import DropdownInput from langflow.io import SecretStrInput diff --git a/src/backend/base/langflow/components/embeddings/util/__init__.py b/src/backend/base/langflow/components/embeddings/util/__init__.py deleted file mode 100644 index 5f630621e..000000000 --- a/src/backend/base/langflow/components/embeddings/util/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import warnings - -from langchain_core._api.deprecation import LangChainDeprecationWarning - -with warnings.catch_warnings(): - warnings.simplefilter("ignore", LangChainDeprecationWarning) - from .aiml import AIMLEmbeddingsImpl - - -__all__ = ["AIMLEmbeddingsImpl"] diff --git a/src/backend/base/langflow/components/prototypes/conditional_router.py b/src/backend/base/langflow/components/prototypes/conditional_router.py index dd10c231b..ced08e3bd 100644 --- a/src/backend/base/langflow/components/prototypes/conditional_router.py +++ b/src/backend/base/langflow/components/prototypes/conditional_router.py @@ -1,5 +1,5 @@ from langflow.custom import Component -from langflow.io import BoolInput, DropdownInput, MessageInput, MessageTextInput, Output +from langflow.io import BoolInput, DropdownInput, IntInput, MessageInput, MessageTextInput, Output from langflow.schema.message import Message @@ -9,6 +9,10 @@ class ConditionalRouterComponent(Component): icon = "equal" name = "ConditionalRouter" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__iteration_updated = False + inputs = [ MessageTextInput( name="input_text", @@ -40,6 +44,20 @@ class ConditionalRouterComponent(Component): display_name="Message", info="The message to pass through either route.", ), + IntInput( + name="max_iterations", + display_name="Max Iterations", + info="The maximum number of iterations for the conditional router.", + value=10, + ), + DropdownInput( + name="default_route", + display_name="Default Route", + options=["true_result", "false_result"], + info="The default route to take when max iterations are reached.", + value="false_result", + advanced=True, + ), ] outputs = [ @@ -47,6 +65,9 @@ class ConditionalRouterComponent(Component): Output(display_name="False Route", name="false_result", method="false_response"), ] + def _pre_run_setup(self): + self.__iteration_updated = False + def evaluate_condition(self, input_text: str, match_text: str, operator: str, *, case_sensitive: bool) -> bool: if not case_sensitive: input_text = input_text.lower() @@ -64,15 +85,25 @@ class ConditionalRouterComponent(Component): return input_text.endswith(match_text) return False + def iterate_and_stop_once(self, route_to_stop: str): + if not self.__iteration_updated: + self.update_ctx({f"{self._id}_iteration": self.ctx.get(f"{self._id}_iteration", 0) + 1}) + self.__iteration_updated = True + if self.ctx.get(f"{self._id}_iteration", 0) >= self.max_iterations and route_to_stop == self.default_route: + # We need to stop the other route + route_to_stop = "true_result" if route_to_stop == "false_result" else "false_result" + self.stop(route_to_stop) + def true_response(self) -> Message: result = self.evaluate_condition( self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive ) if result: self.status = self.message + self.iterate_and_stop_once("false_result") return self.message - self.stop("true_result") - return None # type: ignore[return-value] + self.iterate_and_stop_once("true_result") + return self.message def false_response(self) -> Message: result = self.evaluate_condition( @@ -80,6 +111,7 @@ class ConditionalRouterComponent(Component): ) if not result: self.status = self.message + self.iterate_and_stop_once("true_result") return self.message - self.stop("false_result") - return None # type: ignore[return-value] + self.iterate_and_stop_once("false_result") + return self.message diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 16bbe7b58..e8ed9a45d 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -66,6 +66,7 @@ class Component(CustomComponent): _output_logs: dict[str, list[Log]] = {} _current_output: str = "" _metadata: dict = {} + _ctx: dict = {} def __init__(self, **kwargs) -> None: # if key starts with _ it is a config @@ -111,6 +112,53 @@ class Component(CustomComponent): self.set_class_code() self._set_output_required_inputs() + @property + def ctx(self): + if not hasattr(self, "graph") or self.graph is None: + msg = "Graph not found. Please build the graph first." + raise ValueError(msg) + return self.graph.context + + def add_to_ctx(self, key: str, value: Any, *, overwrite: bool = False) -> None: + """Add a key-value pair to the context. + + Args: + key (str): The key to add. + value (Any): The value to associate with the key. + overwrite (bool, optional): Whether to overwrite the existing value. Defaults to False. + + Raises: + ValueError: If the graph is not built. + """ + if not hasattr(self, "graph") or self.graph is None: + msg = "Graph not found. Please build the graph first." + raise ValueError(msg) + if key in self.graph.context and not overwrite: + msg = f"Key {key} already exists in context. Set overwrite=True to overwrite." + raise ValueError(msg) + self.graph.context.update({key: value}) + + def update_ctx(self, value_dict: dict[str, Any]) -> None: + """Update the context with a dictionary of values. + + Args: + value_dict (dict[str, Any]): The dictionary of values to update. + + Raises: + ValueError: If the graph is not built. + """ + if not hasattr(self, "graph") or self.graph is None: + msg = "Graph not found. Please build the graph first." + raise ValueError(msg) + if not isinstance(value_dict, dict): + msg = "Value dict must be a dictionary" + raise TypeError(msg) + + self.graph.context.update(value_dict) + + def _pre_run_setup(self): + pass + def set_event_manager(self, event_manager: EventManager | None = None) -> None: self._event_manager = event_manager @@ -768,7 +816,8 @@ class Component(CustomComponent): async def _build_results(self) -> tuple[dict, dict]: _results = {} _artifacts = {} - + if hasattr(self, "_pre_run_setup"): + self._pre_run_setup() if hasattr(self, "outputs"): if any(getattr(_input, "tool_mode", False) for _input in self.inputs): self._append_tool_to_outputs_map() diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index af06751d2..6384ce400 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -4,6 +4,8 @@ import asyncio import contextlib import copy import json +import queue +import threading import uuid from collections import defaultdict, deque from collections.abc import Generator, Iterable @@ -12,7 +14,6 @@ from functools import partial from itertools import chain from typing import TYPE_CHECKING, Any, cast -import nest_asyncio from loguru import logger from langflow.exceptions.component import ComponentBuildError @@ -26,7 +27,6 @@ from langflow.graph.graph.utils import ( find_all_cycle_edges, find_cycle_vertices, find_start_component_id, - has_cycle, process_flow, should_continue, sort_up_to_vertex, @@ -36,6 +36,7 @@ from langflow.graph.vertex.base import Vertex, VertexStates from langflow.graph.vertex.schema import NodeData, NodeTypeEnum from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex from langflow.logging.logger import LogConfig, configure +from langflow.schema.dotdict import dotdict from langflow.schema.schema import INPUT_FIELD_NAME, InputType from langflow.services.cache.utils import CacheMiss from langflow.services.deps import get_chat_service, get_tracing_service @@ -63,6 +64,7 @@ class Graph: description: str | None = None, user_id: str | None = None, log_config: LogConfig | None = None, + context: dict[str, Any] | None = None, ) -> None: """Initializes a new instance of the Graph class. @@ -74,9 +76,11 @@ class Graph: description: The graph description. user_id: The user ID. log_config: The log configuration. + context: Additional context for the graph. Defaults to None. """ if log_config: configure(**log_config) + self._start = start self._state_model = None self._end = end @@ -107,6 +111,7 @@ class Graph: self.state_manager = GraphStateManager() self._vertices: list[NodeData] = [] self._edges: list[EdgeData] = [] + self.top_level_vertices: list[str] = [] self.vertex_map: dict[str, Vertex] = {} self.predecessor_map: dict[str, list[str]] = defaultdict(list) @@ -123,6 +128,11 @@ class Graph: self._call_order: list[str] = [] self._snapshots: list[dict[str, Any]] = [] self._end_trace_tasks: set[asyncio.Task] = set() + + if context and not isinstance(context, dict): + msg = "Context must be a dictionary" + raise TypeError(msg) + self._context = dotdict(context or {}) try: self.tracing_service: TracingService | None = get_tracing_service() except Exception: # noqa: BLE001 @@ -135,6 +145,21 @@ class Graph: msg = "You must provide both input and output components" raise ValueError(msg) + @property + def context(self) -> dotdict: + if isinstance(self._context, dotdict): + return self._context + return dotdict(self._context) + + @context.setter + def context(self, value: dict[str, Any]): + if not isinstance(value, dict): + msg = "Context must be a dictionary" + raise TypeError(msg) + if isinstance(value, dict): + value = dotdict(value) + self._context = value + @property def session_id(self): return self._session_id @@ -217,6 +242,8 @@ class Graph: for vertex in self._vertices: if vertex_id := vertex.get("id"): self.top_level_vertices.append(vertex_id) + if vertex_id in self.cycle_vertices: + self.run_manager.add_to_cycle_vertices(vertex_id) self._graph_data = process_flow(self.raw_graph_data) self._vertices = self._graph_data["nodes"] @@ -360,26 +387,81 @@ class Graph: config: StartConfigDict | None = None, event_manager: EventManager | None = None, ) -> Generator: + """Starts the graph execution synchronously by creating a new event loop in a separate thread. + + Args: + inputs: Optional list of input dictionaries + max_iterations: Optional maximum number of iterations + config: Optional configuration dictionary + event_manager: Optional event manager + + Returns: + Generator yielding results from graph execution + """ if self.is_cyclic and max_iterations is None: msg = "You must specify a max_iterations if the graph is cyclic" raise ValueError(msg) + if config is not None: self.__apply_config(config) - # ! Change this ASAP - nest_asyncio.apply() - loop = asyncio.get_event_loop() - async_gen = self.async_start(inputs, max_iterations, event_manager) - async_gen_task = asyncio.ensure_future(anext(async_gen)) - while True: + # Create a queue for passing results and errors between threads + result_queue: queue.Queue[VertexBuildResult | Exception | None] = queue.Queue() + + # Function to run async code in separate thread + def run_async_code(): + # Create new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: - result = loop.run_until_complete(async_gen_task) - yield result - if isinstance(result, Finish): - return - async_gen_task = asyncio.ensure_future(anext(async_gen)) - except StopAsyncIteration: + # Run the async generator + async_gen = self.async_start(inputs, max_iterations, event_manager) + + while True: + try: + # Get next result from async generator + result = loop.run_until_complete(anext(async_gen)) + result_queue.put(result) + + if isinstance(result, Finish): + break + + except StopAsyncIteration: + break + except ValueError as e: + # Put the exception in the queue + result_queue.put(e) + break + + finally: + # Ensure all pending tasks are completed + pending = asyncio.all_tasks(loop) + if pending: + # Create a future to gather all pending tasks + cleanup_future = asyncio.gather(*pending, return_exceptions=True) + loop.run_until_complete(cleanup_future) + + # Close the loop + loop.close() + # Signal completion + result_queue.put(None) + + # Start thread for async execution + thread = threading.Thread(target=run_async_code) + thread.start() + + # Yield results from queue + while True: + result = result_queue.get() + if result is None: break + if isinstance(result, Exception): + raise result + yield result + + # Wait for thread to complete + thread.join() def _add_edge(self, edge: EdgeData) -> None: self.add_edge(edge) @@ -533,12 +615,7 @@ class Graph: bool: True if the graph has any cycles, False otherwise. """ if self._is_cyclic is None: - vertices = [vertex.id for vertex in self.vertices] - try: - edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges] - except KeyError: - edges = [(e["source"], e["target"]) for e in self._edges] - self._is_cyclic = has_cycle(vertices, edges) + self._is_cyclic = bool(self.cycle_vertices) return self._is_cyclic @property @@ -1136,6 +1213,9 @@ class Graph: self._build_vertex_params() self._instantiate_components_in_vertices() self._set_cache_to_vertices_in_cycle() + for vertex in self.vertices: + if vertex.id in self.cycle_vertices: + self.run_manager.add_to_cycle_vertices(vertex.id) def _get_edges_as_list_of_tuples(self) -> list[tuple[str, str]]: """Returns the edges of the graph as a list of tuples.""" @@ -1455,7 +1535,7 @@ class Graph: else: next_runnable_vertices.add(v_id) - return list(next_runnable_vertices) + return sorted(next_runnable_vertices) async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: Vertex, *, cache: bool = True) -> list[str]: v_id = vertex.id @@ -1717,6 +1797,8 @@ class Graph: for vertex_id in first_layer: self.run_manager.add_to_vertices_being_run(vertex_id) + if vertex_id in self.cycle_vertices: + self.run_manager.add_to_cycle_vertices(vertex_id) self._first_layer = sorted(first_layer) self._run_queue = deque(self._first_layer) self._prepared = True @@ -1993,7 +2075,7 @@ class Graph: for successor_id in self.run_manager.run_map.get(vertex_id, []): runnable_vertices.extend(self.find_runnable_predecessors_for_successor(successor_id)) - return runnable_vertices + return sorted(runnable_vertices) def find_runnable_predecessors_for_successor(self, vertex_id: str) -> list[str]: runnable_vertices = [] diff --git a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py index fece73cd6..9abc227e9 100644 --- a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py +++ b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py @@ -2,11 +2,12 @@ from collections import defaultdict class RunnableVerticesManager: - def __init__(self) -> None: + def __init__(self): self.run_map: dict[str, list[str]] = defaultdict(list) # Tracks successors of each vertex self.run_predecessors: dict[str, set[str]] = defaultdict(set) # Tracks predecessors for each vertex self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run self.vertices_being_run: set[str] = set() # Set of vertices that are currently running + self.cycle_vertices: set[str] = set() # Set of vertices that are in a cycle def to_dict(self) -> dict: return { @@ -55,7 +56,7 @@ class RunnableVerticesManager: return False if vertex_id not in self.vertices_to_run: return False - return self.are_all_predecessors_fulfilled(vertex_id) + return self.are_all_predecessors_fulfilled(vertex_id) or vertex_id in self.cycle_vertices def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool: return not any(self.run_predecessors.get(vertex_id, [])) @@ -89,3 +90,6 @@ class RunnableVerticesManager: def add_to_vertices_being_run(self, v_id) -> None: self.vertices_being_run.add(v_id) + + def add_to_cycle_vertices(self, v_id): + self.cycle_vertices.add(v_id) diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index cbcdfcf2a..7e53e8424 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -871,4 +871,4 @@ class Vertex: if not self.custom_component or not self.custom_component.outputs: return # Apply the function to each output - [func(output) for output in self.custom_component.outputs] + [func(output) for output in self.custom_component._outputs_map.values()] diff --git a/src/backend/base/langflow/services/tracing/langsmith.py b/src/backend/base/langflow/services/tracing/langsmith.py index 4c5712449..1137c2ed1 100644 --- a/src/backend/base/langflow/services/tracing/langsmith.py +++ b/src/backend/base/langflow/services/tracing/langsmith.py @@ -7,7 +7,6 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from loguru import logger -from typing_extensions import override from langflow.schema.data import Data from langflow.services.tracing.base import BaseTracer @@ -63,17 +62,16 @@ class LangSmithTracer(BaseTracer): os.environ["LANGCHAIN_TRACING_V2"] = "true" return True - @override def add_trace( self, - trace_id: str, + trace_id: str, # noqa: ARG002 trace_name: str, trace_type: str, inputs: dict[str, Any], metadata: dict[str, Any] | None = None, - vertex: Vertex | None = None, + vertex: Vertex | None = None, # noqa: ARG002 ) -> None: - if not self._ready: + if not self._ready or not self._run_tree: return processed_inputs = {} if inputs: @@ -117,16 +115,15 @@ class LangSmithTracer(BaseTracer): value = str(value) return value - @override def end_trace( self, - trace_id: str, + trace_id: str, # noqa: ARG002 trace_name: str, outputs: dict[str, Any] | None = None, error: Exception | None = None, logs: Sequence[Log | dict] = (), - ) -> None: - if not self._ready: + ): + if not self._ready or trace_name not in self._children: return child = self._children[trace_name] raw_outputs = {} @@ -159,7 +156,7 @@ class LangSmithTracer(BaseTracer): error: Exception | None = None, metadata: dict[str, Any] | None = None, ) -> None: - if not self._ready: + if not self._ready or not self._run_tree: return self._run_tree.add_metadata({"inputs": inputs}) if metadata: diff --git a/src/backend/tests/unit/graph/graph/test_cycles.py b/src/backend/tests/unit/graph/graph/test_cycles.py index 56b3edcb7..467aa5096 100644 --- a/src/backend/tests/unit/graph/graph/test_cycles.py +++ b/src/backend/tests/unit/graph/graph/test_cycles.py @@ -2,6 +2,7 @@ import os import pytest from langflow.components.inputs import ChatInput +from langflow.components.inputs.text import TextInputComponent from langflow.components.models import OpenAIModelComponent from langflow.components.outputs import ChatOutput, TextOutputComponent from langflow.components.prompts import PromptComponent @@ -31,7 +32,7 @@ class Concatenate(Component): @pytest.mark.skip(reason="Temporarily disabled") def test_cycle_in_graph(): chat_input = ChatInput(_id="chat_input") - router = ConditionalRouterComponent(_id="router") + router = ConditionalRouterComponent(_id="router", default_route="true_result") chat_input.set(input_value=router.false_response) concat_component = Concatenate(_id="concatenate") concat_component.set(text=chat_input.message_response) @@ -59,7 +60,6 @@ def test_cycle_in_graph(): snapshots.append(graph._snapshot()) results.append(result) results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")] - assert results_ids[-2:] == ["text_output", "chat_output"] assert len(results_ids) > len(graph.vertices), snapshots # Check that chat_output and text_output are the last vertices in the results assert results_ids == [ @@ -127,7 +127,9 @@ def test_that_outputs_cache_is_set_to_false_in_cycle(): graph = Graph(chat_input, chat_output) cycle_vertices = find_cycle_vertices(graph._get_edges_as_list_of_tuples()) - cycle_outputs_lists = [graph.vertex_map[vertex_id].custom_component.outputs for vertex_id in cycle_vertices] + cycle_outputs_lists = [ + graph.vertex_map[vertex_id].custom_component._outputs_map.values() for vertex_id in cycle_vertices + ] cycle_outputs = [output for outputs in cycle_outputs_lists for output in outputs] for output in cycle_outputs: assert output.cache is False @@ -206,3 +208,119 @@ def test_updated_graph_with_prompts(): # Extract the vertex IDs for analysis results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")] assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}" + + +@pytest.mark.api_key_required +def test_updated_graph_with_max_iterations(): + # Chat input initialization + chat_input = ChatInput(_id="chat_input").set(input_value="bacon") + + # First prompt: Guessing game with hints + prompt_component_1 = PromptComponent(_id="prompt_component_1").set( + template="Try to guess a word. I will give you hints if you get it wrong.\n" + "Hint: {hint}\n" + "Last try: {last_try}\n" + "Answer:", + ) + + # First OpenAI LLM component (Processes the guessing prompt) + openai_component_1 = OpenAIModelComponent(_id="openai_1").set( + input_value=prompt_component_1.build_prompt, api_key=os.getenv("OPENAI_API_KEY") + ) + + # Conditional router based on agent response + router = ConditionalRouterComponent(_id="router").set( + input_text=openai_component_1.text_response, + match_text=chat_input.message_response, + operator="contains", + message=openai_component_1.text_response, + ) + + # Second prompt: After the last try, provide a new hint + prompt_component_2 = PromptComponent(_id="prompt_component_2") + prompt_component_2.set( + template="Given the following word and the following last try. Give the guesser a new hint.\n" + "Last try: {last_try}\n" + "Word: {word}\n" + "Hint:", + word=chat_input.message_response, + last_try=router.false_response, + ) + + # Second OpenAI component (handles the router's response) + openai_component_2 = OpenAIModelComponent(_id="openai_2") + openai_component_2.set(input_value=prompt_component_2.build_prompt, api_key=os.getenv("OPENAI_API_KEY")) + + prompt_component_1.set(hint=openai_component_2.text_response, last_try=router.false_response) + + # chat output for the final OpenAI response + chat_output_1 = ChatOutput(_id="chat_output_1") + chat_output_1.set(input_value=router.true_response) + + # Build the graph without concatenate + graph = Graph(chat_input, chat_output_1) + + # Assertions for graph cyclicity and correctness + assert graph.is_cyclic is True, "Graph should contain cycles." + + # Run and validate the execution of the graph + results = [] + max_iterations = 20 + snapshots = [graph.get_snapshot()] + + for result in graph.start(max_iterations=max_iterations, config={"output": {"cache": False}}): + snapshots.append(graph.get_snapshot()) + results.append(result) + + assert len(snapshots) > 2, "Graph should have more than one snapshot" + # Extract the vertex IDs for analysis + results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")] + assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}" + + +def test_conditional_router_max_iterations(): + # Chat input initialization + text_input = TextInputComponent(_id="text_input") + + # Conditional router setup with a condition that will never match + router = ConditionalRouterComponent(_id="router").set( + input_text=text_input.text_response, + match_text="bacon", + operator="equals", + message="This message should not be routed to true_result", + max_iterations=5, + default_route="true_result", + ) + + # Chat output for the true route + text_input.set(input_value=router.false_response) + + # Chat output for the false route + chat_output_false = ChatOutput(_id="chat_output_false") + chat_output_false.set(input_value=router.true_response) + + # Build the graph + graph = Graph(text_input, chat_output_false) + + # Assertions for graph cyclicity and correctness + assert graph.is_cyclic is True, "Graph should contain cycles." + + # Run and validate the execution of the graph + results = [] + snapshots = [graph.get_snapshot()] + previous_iteration = graph.context.get("router_iteration", 0) + for result in graph.start(max_iterations=20, config={"output": {"cache": False}}): + snapshots.append(graph.get_snapshot()) + results.append(result) + if hasattr(result, "vertex") and result.vertex.id == "router": + current_iteration = graph.context.get("router_iteration", 0) + assert current_iteration == previous_iteration + 1, "Iteration should increment by 1" + previous_iteration = current_iteration + + # Check if the max_iterations logic is working + router_id = router._id.lower() + assert graph.context.get(f"{router_id}_iteration", 0) == 5, "Router should stop after max_iterations" + + # Extract the vertex IDs for analysis + results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")] + assert "chat_output_false" in results_ids, f"Expected outputs not in results: {results_ids}"