feat: Enhance backend with context management, error handling, and refactored code (#4286)

* Add cycle detection and management for graph vertices in run manager

* Refactor: Move AIMLEmbeddingsImpl to a new module path

* Add AIMLEmbeddingsImpl class for document and query embeddings using AIML API

* Add agent components for action routing, decision-making, execution, and context management

- Introduced `AgentActionRouter` to route agent flow based on action type.
- Added `DecideActionComponent` for determining actions from context and prompts.
- Implemented `ExecuteActionComponent` to execute actions using available tools.
- Created `GenerateThoughtComponent` for generating thoughts based on context.
- Developed `ProvideFinalAnswerComponent` to generate final answers from context.
- Built `AgentContextBuilder` for constructing `AgentContext` instances.
- Added `ObserveResultComponent` to process and observe action results.
- Implemented `CheckTerminationComponent` to determine if the agent should continue or terminate.

* Add AgentContext class for managing agent state and context serialization

- Introduced `AgentContext` class in `context.py` to handle agent state, including tools, language model, and context history.
- Implemented serialization methods for converting agent context to JSON-compatible format.
- Added validation for language model instances to ensure compatibility.
- Provided methods for updating and retrieving full context, including context history management.

* Add new agent components to the langflow module's init file

* Update `apply_on_outputs` to use `_outputs_map` in vertex base class

* Add _pre_run_setup method to custom component for pre-execution setup

* Handle non-list action types in decide_action method

* Enhance AgentActionRouter with iteration control and context routing logic

* Fix incorrect variable usage in tool call result message formatting

* Add AgentActionRouter to module exports in agents package

* Refactor cycle detection logic in graph base class

* Add test for complex agent flow with cyclic graph validation

* Enhance readiness checks in tracing service methods

* Add context management to Graph class with dotdict support

* Add context management methods to custom component class

- Introduced a `_ctx` attribute to store context data.
- Added `ctx` property to access the graph's context, raising an error if the graph is not built.
- Implemented `add_to_ctx` method to add key-value pairs to the context with an optional overwrite flag.
- Implemented `update_ctx` method to update the context with a dictionary of values, ensuring the graph is built and the input is a dictionary.

* Add customizable Agent component with input/output handling and action routing

* Handle non-list 'tools' attribute in 'build_context' method

* Convert `get_response` method to asynchronous and update graph processing to use async iteration.

* Add async test for Agent component in graph cycle tests

* Refactor Agent Flow JSON: Simplify input types and update agent component structure

- Removed "BaseTool" from input types for "ToolCallingAgent" to streamline tool handling.
- Updated agent component to a more modular structure with new prompts and input configurations.
- Replaced deprecated methods and fields with updated implementations for improved functionality.
- Adjusted metadata and configuration settings for better clarity and usability.

* [autofix.ci] apply automated fixes

* Add Agent import to init, improve error handling, and clean up imports

- Added `Agent` import to `__init__.py` for better module accessibility.
- Improved error handling in `aiml_embeddings.py` by raising a `ValueError` when the expected embedding count is not met.
- Cleaned up redundant imports in `test_cycles.py` to enhance code readability.

* Refactor agent component imports for improved modularity and organization

* Remove agent components and update `__init__.py` exports

* Add iteration control and default route options to ConditionalRouter component

* Refactor graph tests to include new components and update iteration logic

- Replaced complex agent flow with a simplified guessing game using OpenAI components and conditional routing.
- Introduced `TextInputComponent` and updated `ChatInput` initialization.
- Added new test `test_conditional_router_max_iterations` to validate conditional routing with max iterations.
- Updated graph cyclicity assertions and snapshot checks for improved test coverage.
- Removed deprecated agent components and related logic.

* Refactor conditional router to return message consistently and use iterate_and_stop_once method

* Add return type annotations to methods in langsmith.py

* Remove unnecessary `@override` decorator and add `# noqa: ARG002` comments for unused arguments

* Move ChatInput import inside flow_component fixture in conftest.py

* Update test to use _outputs_map for cycle outputs retrieval

* Refactor `iterate_and_stop_once` to remove redundant `_id` variable usage

* Add default route to ConditionalRouterComponent in cycle test

* Implement synchronous graph execution using threading and queues

- Removed `nest_asyncio` dependency and replaced it with a new threading-based approach for synchronous graph execution.
- Introduced a `queue.Queue` to handle results and exceptions between threads.
- Added a new thread to run asynchronous code, ensuring proper event loop management and task completion.
- Updated methods to return sorted lists of runnable vertices for consistency.

* Update import path for ModelConstants in test_model_constants.py

* [autofix.ci] apply automated fixes

* fix: add property decorator

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-11-08 17:28:30 -03:00 committed by GitHub
commit 8681c56cdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 438 additions and 56 deletions

View file

@ -0,0 +1,109 @@
from datetime import datetime, timezone
from typing import Any
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic import BaseModel, Field, field_validator, model_serializer
from langflow.field_typing import LanguageModel
from langflow.schema.data import Data
class AgentContext(BaseModel):
tools: dict[str, Any]
llm: Any
context: str = ""
iteration: int = 0
max_iterations: int = 5
thought: str = ""
last_action: Any = None
last_action_result: Any = None
final_answer: Any = ""
context_history: list[tuple[str, str, str]] = Field(default_factory=list)
@model_serializer(mode="plain")
def serialize_agent_context(self):
serliazed_llm = self.llm.to_json() if hasattr(self.llm, "to_json") else str(self.llm)
serliazed_tools = {k: v.to_json() if hasattr(v, "to_json") else str(v) for k, v in self.tools.items()}
return {
"tools": serliazed_tools,
"llm": serliazed_llm,
"context": self.context,
"iteration": self.iteration,
"max_iterations": self.max_iterations,
"thought": self.thought,
"last_action": self.last_action.to_json()
if hasattr(self.last_action, "to_json")
else str(self.last_action),
"action_result": self.last_action_result.to_json()
if hasattr(self.last_action_result, "to_json")
else str(self.last_action_result),
"final_answer": self.final_answer,
"context_history": self.context_history,
}
@field_validator("llm", mode="before")
@classmethod
def validate_llm(cls, v) -> LanguageModel:
if not isinstance(v, BaseLLM | BaseChatModel | BaseLanguageModel):
msg = "llm must be an instance of LanguageModel"
raise TypeError(msg)
return v
def to_data_repr(self):
data_objs = []
for name, val, time_str in self.context_history:
content = val.content if hasattr(val, "content") else val
data_objs.append(Data(name=name, value=content, timestamp=time_str))
sorted_data_objs = sorted(data_objs, key=lambda x: datetime.fromisoformat(x.timestamp), reverse=True)
sorted_data_objs.append(
Data(
name="Formatted Context",
value=self.get_full_context(),
)
)
return sorted_data_objs
def _build_tools_context(self):
tool_context = ""
for tool_name, tool_obj in self.tools.items():
tool_context += f"{tool_name}: {tool_obj.description}\n"
return tool_context
def _build_init_context(self):
return f"""
{self.context}
"""
def model_post_init(self, _context: Any) -> None:
if hasattr(self.llm, "bind_tools"):
self.llm = self.llm.bind_tools(self.tools.values())
if self.context:
self.update_context("Initial Context", self.context)
def update_context(self, key: str, value: str):
self.context_history.insert(0, (key, value, datetime.now(tz=timezone.utc).astimezone().isoformat()))
def _serialize_context_history_tuple(self, context_history_tuple: tuple[str, str, str]) -> str:
name, value, _ = context_history_tuple
if hasattr(value, "content"):
value = value.content
elif hasattr(value, "log"):
value = value.log
return f"{name}: {value}"
def get_full_context(self) -> str:
context_history_reversed = self.context_history[::-1]
context_formatted = "\n".join(
[
self._serialize_context_history_tuple(context_history_tuple)
for context_history_tuple in context_history_reversed
]
)
return f"""
Context:
{context_formatted}
"""

View file

@ -30,7 +30,7 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings):
try:
result_data = future.result()
if len(result_data["data"]) != 1:
msg = "Expected one embedding"
msg = f"Expected one embedding, got {len(result_data['data'])}"
raise ValueError(msg)
embeddings[index] = result_data["data"][0]["embedding"]
except (
@ -38,6 +38,7 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings):
httpx.RequestError,
json.JSONDecodeError,
KeyError,
ValueError,
):
logger.exception("Error occurred")
raise

View file

@ -1,6 +1,6 @@
from langflow.base.embeddings.aiml_embeddings import AIMLEmbeddingsImpl
from langflow.base.embeddings.model import LCEmbeddingsModel
from langflow.base.models.aiml_constants import AIML_EMBEDDING_MODELS
from langflow.components.embeddings.util import AIMLEmbeddingsImpl
from langflow.field_typing import Embeddings
from langflow.inputs.inputs import DropdownInput
from langflow.io import SecretStrInput

View file

@ -1,10 +0,0 @@
import warnings
from langchain_core._api.deprecation import LangChainDeprecationWarning
with warnings.catch_warnings():
warnings.simplefilter("ignore", LangChainDeprecationWarning)
from .aiml import AIMLEmbeddingsImpl
__all__ = ["AIMLEmbeddingsImpl"]

View file

@ -1,5 +1,5 @@
from langflow.custom import Component
from langflow.io import BoolInput, DropdownInput, MessageInput, MessageTextInput, Output
from langflow.io import BoolInput, DropdownInput, IntInput, MessageInput, MessageTextInput, Output
from langflow.schema.message import Message
@ -9,6 +9,10 @@ class ConditionalRouterComponent(Component):
icon = "equal"
name = "ConditionalRouter"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__iteration_updated = False
inputs = [
MessageTextInput(
name="input_text",
@ -40,6 +44,20 @@ class ConditionalRouterComponent(Component):
display_name="Message",
info="The message to pass through either route.",
),
IntInput(
name="max_iterations",
display_name="Max Iterations",
info="The maximum number of iterations for the conditional router.",
value=10,
),
DropdownInput(
name="default_route",
display_name="Default Route",
options=["true_result", "false_result"],
info="The default route to take when max iterations are reached.",
value="false_result",
advanced=True,
),
]
outputs = [
@ -47,6 +65,9 @@ class ConditionalRouterComponent(Component):
Output(display_name="False Route", name="false_result", method="false_response"),
]
def _pre_run_setup(self):
self.__iteration_updated = False
def evaluate_condition(self, input_text: str, match_text: str, operator: str, *, case_sensitive: bool) -> bool:
if not case_sensitive:
input_text = input_text.lower()
@ -64,15 +85,25 @@ class ConditionalRouterComponent(Component):
return input_text.endswith(match_text)
return False
def iterate_and_stop_once(self, route_to_stop: str):
if not self.__iteration_updated:
self.update_ctx({f"{self._id}_iteration": self.ctx.get(f"{self._id}_iteration", 0) + 1})
self.__iteration_updated = True
if self.ctx.get(f"{self._id}_iteration", 0) >= self.max_iterations and route_to_stop == self.default_route:
# We need to stop the other route
route_to_stop = "true_result" if route_to_stop == "false_result" else "false_result"
self.stop(route_to_stop)
def true_response(self) -> Message:
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if result:
self.status = self.message
self.iterate_and_stop_once("false_result")
return self.message
self.stop("true_result")
return None # type: ignore[return-value]
self.iterate_and_stop_once("true_result")
return self.message
def false_response(self) -> Message:
result = self.evaluate_condition(
@ -80,6 +111,7 @@ class ConditionalRouterComponent(Component):
)
if not result:
self.status = self.message
self.iterate_and_stop_once("true_result")
return self.message
self.stop("false_result")
return None # type: ignore[return-value]
self.iterate_and_stop_once("false_result")
return self.message

View file

@ -66,6 +66,7 @@ class Component(CustomComponent):
_output_logs: dict[str, list[Log]] = {}
_current_output: str = ""
_metadata: dict = {}
_ctx: dict = {}
def __init__(self, **kwargs) -> None:
# if key starts with _ it is a config
@ -111,6 +112,53 @@ class Component(CustomComponent):
self.set_class_code()
self._set_output_required_inputs()
@property
def ctx(self):
if not hasattr(self, "graph") or self.graph is None:
msg = "Graph not found. Please build the graph first."
raise ValueError(msg)
return self.graph.context
def add_to_ctx(self, key: str, value: Any, *, overwrite: bool = False) -> None:
"""Add a key-value pair to the context.
Args:
key (str): The key to add.
value (Any): The value to associate with the key.
overwrite (bool, optional): Whether to overwrite the existing value. Defaults to False.
Raises:
ValueError: If the graph is not built.
"""
if not hasattr(self, "graph") or self.graph is None:
msg = "Graph not found. Please build the graph first."
raise ValueError(msg)
if key in self.graph.context and not overwrite:
msg = f"Key {key} already exists in context. Set overwrite=True to overwrite."
raise ValueError(msg)
self.graph.context.update({key: value})
def update_ctx(self, value_dict: dict[str, Any]) -> None:
"""Update the context with a dictionary of values.
Args:
value_dict (dict[str, Any]): The dictionary of values to update.
Raises:
ValueError: If the graph is not built.
"""
if not hasattr(self, "graph") or self.graph is None:
msg = "Graph not found. Please build the graph first."
raise ValueError(msg)
if not isinstance(value_dict, dict):
msg = "Value dict must be a dictionary"
raise TypeError(msg)
self.graph.context.update(value_dict)
def _pre_run_setup(self):
pass
def set_event_manager(self, event_manager: EventManager | None = None) -> None:
self._event_manager = event_manager
@ -768,7 +816,8 @@ class Component(CustomComponent):
async def _build_results(self) -> tuple[dict, dict]:
_results = {}
_artifacts = {}
if hasattr(self, "_pre_run_setup"):
self._pre_run_setup()
if hasattr(self, "outputs"):
if any(getattr(_input, "tool_mode", False) for _input in self.inputs):
self._append_tool_to_outputs_map()

View file

@ -4,6 +4,8 @@ import asyncio
import contextlib
import copy
import json
import queue
import threading
import uuid
from collections import defaultdict, deque
from collections.abc import Generator, Iterable
@ -12,7 +14,6 @@ from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, cast
import nest_asyncio
from loguru import logger
from langflow.exceptions.component import ComponentBuildError
@ -26,7 +27,6 @@ from langflow.graph.graph.utils import (
find_all_cycle_edges,
find_cycle_vertices,
find_start_component_id,
has_cycle,
process_flow,
should_continue,
sort_up_to_vertex,
@ -36,6 +36,7 @@ from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData, NodeTypeEnum
from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex
from langflow.logging.logger import LogConfig, configure
from langflow.schema.dotdict import dotdict
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
from langflow.services.cache.utils import CacheMiss
from langflow.services.deps import get_chat_service, get_tracing_service
@ -63,6 +64,7 @@ class Graph:
description: str | None = None,
user_id: str | None = None,
log_config: LogConfig | None = None,
context: dict[str, Any] | None = None,
) -> None:
"""Initializes a new instance of the Graph class.
@ -74,9 +76,11 @@ class Graph:
description: The graph description.
user_id: The user ID.
log_config: The log configuration.
context: Additional context for the graph. Defaults to None.
"""
if log_config:
configure(**log_config)
self._start = start
self._state_model = None
self._end = end
@ -107,6 +111,7 @@ class Graph:
self.state_manager = GraphStateManager()
self._vertices: list[NodeData] = []
self._edges: list[EdgeData] = []
self.top_level_vertices: list[str] = []
self.vertex_map: dict[str, Vertex] = {}
self.predecessor_map: dict[str, list[str]] = defaultdict(list)
@ -123,6 +128,11 @@ class Graph:
self._call_order: list[str] = []
self._snapshots: list[dict[str, Any]] = []
self._end_trace_tasks: set[asyncio.Task] = set()
if context and not isinstance(context, dict):
msg = "Context must be a dictionary"
raise TypeError(msg)
self._context = dotdict(context or {})
try:
self.tracing_service: TracingService | None = get_tracing_service()
except Exception: # noqa: BLE001
@ -135,6 +145,21 @@ class Graph:
msg = "You must provide both input and output components"
raise ValueError(msg)
@property
def context(self) -> dotdict:
if isinstance(self._context, dotdict):
return self._context
return dotdict(self._context)
@context.setter
def context(self, value: dict[str, Any]):
if not isinstance(value, dict):
msg = "Context must be a dictionary"
raise TypeError(msg)
if isinstance(value, dict):
value = dotdict(value)
self._context = value
@property
def session_id(self):
return self._session_id
@ -217,6 +242,8 @@ class Graph:
for vertex in self._vertices:
if vertex_id := vertex.get("id"):
self.top_level_vertices.append(vertex_id)
if vertex_id in self.cycle_vertices:
self.run_manager.add_to_cycle_vertices(vertex_id)
self._graph_data = process_flow(self.raw_graph_data)
self._vertices = self._graph_data["nodes"]
@ -360,26 +387,81 @@ class Graph:
config: StartConfigDict | None = None,
event_manager: EventManager | None = None,
) -> Generator:
"""Starts the graph execution synchronously by creating a new event loop in a separate thread.
Args:
inputs: Optional list of input dictionaries
max_iterations: Optional maximum number of iterations
config: Optional configuration dictionary
event_manager: Optional event manager
Returns:
Generator yielding results from graph execution
"""
if self.is_cyclic and max_iterations is None:
msg = "You must specify a max_iterations if the graph is cyclic"
raise ValueError(msg)
if config is not None:
self.__apply_config(config)
# ! Change this ASAP
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async_gen = self.async_start(inputs, max_iterations, event_manager)
async_gen_task = asyncio.ensure_future(anext(async_gen))
while True:
# Create a queue for passing results and errors between threads
result_queue: queue.Queue[VertexBuildResult | Exception | None] = queue.Queue()
# Function to run async code in separate thread
def run_async_code():
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(async_gen_task)
yield result
if isinstance(result, Finish):
return
async_gen_task = asyncio.ensure_future(anext(async_gen))
except StopAsyncIteration:
# Run the async generator
async_gen = self.async_start(inputs, max_iterations, event_manager)
while True:
try:
# Get next result from async generator
result = loop.run_until_complete(anext(async_gen))
result_queue.put(result)
if isinstance(result, Finish):
break
except StopAsyncIteration:
break
except ValueError as e:
# Put the exception in the queue
result_queue.put(e)
break
finally:
# Ensure all pending tasks are completed
pending = asyncio.all_tasks(loop)
if pending:
# Create a future to gather all pending tasks
cleanup_future = asyncio.gather(*pending, return_exceptions=True)
loop.run_until_complete(cleanup_future)
# Close the loop
loop.close()
# Signal completion
result_queue.put(None)
# Start thread for async execution
thread = threading.Thread(target=run_async_code)
thread.start()
# Yield results from queue
while True:
result = result_queue.get()
if result is None:
break
if isinstance(result, Exception):
raise result
yield result
# Wait for thread to complete
thread.join()
def _add_edge(self, edge: EdgeData) -> None:
self.add_edge(edge)
@ -533,12 +615,7 @@ class Graph:
bool: True if the graph has any cycles, False otherwise.
"""
if self._is_cyclic is None:
vertices = [vertex.id for vertex in self.vertices]
try:
edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges]
except KeyError:
edges = [(e["source"], e["target"]) for e in self._edges]
self._is_cyclic = has_cycle(vertices, edges)
self._is_cyclic = bool(self.cycle_vertices)
return self._is_cyclic
@property
@ -1136,6 +1213,9 @@ class Graph:
self._build_vertex_params()
self._instantiate_components_in_vertices()
self._set_cache_to_vertices_in_cycle()
for vertex in self.vertices:
if vertex.id in self.cycle_vertices:
self.run_manager.add_to_cycle_vertices(vertex.id)
def _get_edges_as_list_of_tuples(self) -> list[tuple[str, str]]:
"""Returns the edges of the graph as a list of tuples."""
@ -1455,7 +1535,7 @@ class Graph:
else:
next_runnable_vertices.add(v_id)
return list(next_runnable_vertices)
return sorted(next_runnable_vertices)
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: Vertex, *, cache: bool = True) -> list[str]:
v_id = vertex.id
@ -1717,6 +1797,8 @@ class Graph:
for vertex_id in first_layer:
self.run_manager.add_to_vertices_being_run(vertex_id)
if vertex_id in self.cycle_vertices:
self.run_manager.add_to_cycle_vertices(vertex_id)
self._first_layer = sorted(first_layer)
self._run_queue = deque(self._first_layer)
self._prepared = True
@ -1993,7 +2075,7 @@ class Graph:
for successor_id in self.run_manager.run_map.get(vertex_id, []):
runnable_vertices.extend(self.find_runnable_predecessors_for_successor(successor_id))
return runnable_vertices
return sorted(runnable_vertices)
def find_runnable_predecessors_for_successor(self, vertex_id: str) -> list[str]:
runnable_vertices = []

View file

@ -2,11 +2,12 @@ from collections import defaultdict
class RunnableVerticesManager:
def __init__(self) -> None:
def __init__(self):
self.run_map: dict[str, list[str]] = defaultdict(list) # Tracks successors of each vertex
self.run_predecessors: dict[str, set[str]] = defaultdict(set) # Tracks predecessors for each vertex
self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run
self.vertices_being_run: set[str] = set() # Set of vertices that are currently running
self.cycle_vertices: set[str] = set() # Set of vertices that are in a cycle
def to_dict(self) -> dict:
return {
@ -55,7 +56,7 @@ class RunnableVerticesManager:
return False
if vertex_id not in self.vertices_to_run:
return False
return self.are_all_predecessors_fulfilled(vertex_id)
return self.are_all_predecessors_fulfilled(vertex_id) or vertex_id in self.cycle_vertices
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
return not any(self.run_predecessors.get(vertex_id, []))
@ -89,3 +90,6 @@ class RunnableVerticesManager:
def add_to_vertices_being_run(self, v_id) -> None:
self.vertices_being_run.add(v_id)
def add_to_cycle_vertices(self, v_id):
self.cycle_vertices.add(v_id)

View file

@ -871,4 +871,4 @@ class Vertex:
if not self.custom_component or not self.custom_component.outputs:
return
# Apply the function to each output
[func(output) for output in self.custom_component.outputs]
[func(output) for output in self.custom_component._outputs_map.values()]

View file

@ -7,7 +7,6 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from loguru import logger
from typing_extensions import override
from langflow.schema.data import Data
from langflow.services.tracing.base import BaseTracer
@ -63,17 +62,16 @@ class LangSmithTracer(BaseTracer):
os.environ["LANGCHAIN_TRACING_V2"] = "true"
return True
@override
def add_trace(
self,
trace_id: str,
trace_id: str, # noqa: ARG002
trace_name: str,
trace_type: str,
inputs: dict[str, Any],
metadata: dict[str, Any] | None = None,
vertex: Vertex | None = None,
vertex: Vertex | None = None, # noqa: ARG002
) -> None:
if not self._ready:
if not self._ready or not self._run_tree:
return
processed_inputs = {}
if inputs:
@ -117,16 +115,15 @@ class LangSmithTracer(BaseTracer):
value = str(value)
return value
@override
def end_trace(
self,
trace_id: str,
trace_id: str, # noqa: ARG002
trace_name: str,
outputs: dict[str, Any] | None = None,
error: Exception | None = None,
logs: Sequence[Log | dict] = (),
) -> None:
if not self._ready:
):
if not self._ready or trace_name not in self._children:
return
child = self._children[trace_name]
raw_outputs = {}
@ -159,7 +156,7 @@ class LangSmithTracer(BaseTracer):
error: Exception | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
if not self._ready:
if not self._ready or not self._run_tree:
return
self._run_tree.add_metadata({"inputs": inputs})
if metadata:

View file

@ -2,6 +2,7 @@ import os
import pytest
from langflow.components.inputs import ChatInput
from langflow.components.inputs.text import TextInputComponent
from langflow.components.models import OpenAIModelComponent
from langflow.components.outputs import ChatOutput, TextOutputComponent
from langflow.components.prompts import PromptComponent
@ -31,7 +32,7 @@ class Concatenate(Component):
@pytest.mark.skip(reason="Temporarily disabled")
def test_cycle_in_graph():
chat_input = ChatInput(_id="chat_input")
router = ConditionalRouterComponent(_id="router")
router = ConditionalRouterComponent(_id="router", default_route="true_result")
chat_input.set(input_value=router.false_response)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
@ -59,7 +60,6 @@ def test_cycle_in_graph():
snapshots.append(graph._snapshot())
results.append(result)
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert results_ids[-2:] == ["text_output", "chat_output"]
assert len(results_ids) > len(graph.vertices), snapshots
# Check that chat_output and text_output are the last vertices in the results
assert results_ids == [
@ -127,7 +127,9 @@ def test_that_outputs_cache_is_set_to_false_in_cycle():
graph = Graph(chat_input, chat_output)
cycle_vertices = find_cycle_vertices(graph._get_edges_as_list_of_tuples())
cycle_outputs_lists = [graph.vertex_map[vertex_id].custom_component.outputs for vertex_id in cycle_vertices]
cycle_outputs_lists = [
graph.vertex_map[vertex_id].custom_component._outputs_map.values() for vertex_id in cycle_vertices
]
cycle_outputs = [output for outputs in cycle_outputs_lists for output in outputs]
for output in cycle_outputs:
assert output.cache is False
@ -206,3 +208,119 @@ def test_updated_graph_with_prompts():
# Extract the vertex IDs for analysis
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}"
@pytest.mark.api_key_required
def test_updated_graph_with_max_iterations():
# Chat input initialization
chat_input = ChatInput(_id="chat_input").set(input_value="bacon")
# First prompt: Guessing game with hints
prompt_component_1 = PromptComponent(_id="prompt_component_1").set(
template="Try to guess a word. I will give you hints if you get it wrong.\n"
"Hint: {hint}\n"
"Last try: {last_try}\n"
"Answer:",
)
# First OpenAI LLM component (Processes the guessing prompt)
openai_component_1 = OpenAIModelComponent(_id="openai_1").set(
input_value=prompt_component_1.build_prompt, api_key=os.getenv("OPENAI_API_KEY")
)
# Conditional router based on agent response
router = ConditionalRouterComponent(_id="router").set(
input_text=openai_component_1.text_response,
match_text=chat_input.message_response,
operator="contains",
message=openai_component_1.text_response,
)
# Second prompt: After the last try, provide a new hint
prompt_component_2 = PromptComponent(_id="prompt_component_2")
prompt_component_2.set(
template="Given the following word and the following last try. Give the guesser a new hint.\n"
"Last try: {last_try}\n"
"Word: {word}\n"
"Hint:",
word=chat_input.message_response,
last_try=router.false_response,
)
# Second OpenAI component (handles the router's response)
openai_component_2 = OpenAIModelComponent(_id="openai_2")
openai_component_2.set(input_value=prompt_component_2.build_prompt, api_key=os.getenv("OPENAI_API_KEY"))
prompt_component_1.set(hint=openai_component_2.text_response, last_try=router.false_response)
# chat output for the final OpenAI response
chat_output_1 = ChatOutput(_id="chat_output_1")
chat_output_1.set(input_value=router.true_response)
# Build the graph without concatenate
graph = Graph(chat_input, chat_output_1)
# Assertions for graph cyclicity and correctness
assert graph.is_cyclic is True, "Graph should contain cycles."
# Run and validate the execution of the graph
results = []
max_iterations = 20
snapshots = [graph.get_snapshot()]
for result in graph.start(max_iterations=max_iterations, config={"output": {"cache": False}}):
snapshots.append(graph.get_snapshot())
results.append(result)
assert len(snapshots) > 2, "Graph should have more than one snapshot"
# Extract the vertex IDs for analysis
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}"
def test_conditional_router_max_iterations():
# Chat input initialization
text_input = TextInputComponent(_id="text_input")
# Conditional router setup with a condition that will never match
router = ConditionalRouterComponent(_id="router").set(
input_text=text_input.text_response,
match_text="bacon",
operator="equals",
message="This message should not be routed to true_result",
max_iterations=5,
default_route="true_result",
)
# Chat output for the true route
text_input.set(input_value=router.false_response)
# Chat output for the false route
chat_output_false = ChatOutput(_id="chat_output_false")
chat_output_false.set(input_value=router.true_response)
# Build the graph
graph = Graph(text_input, chat_output_false)
# Assertions for graph cyclicity and correctness
assert graph.is_cyclic is True, "Graph should contain cycles."
# Run and validate the execution of the graph
results = []
snapshots = [graph.get_snapshot()]
previous_iteration = graph.context.get("router_iteration", 0)
for result in graph.start(max_iterations=20, config={"output": {"cache": False}}):
snapshots.append(graph.get_snapshot())
results.append(result)
if hasattr(result, "vertex") and result.vertex.id == "router":
current_iteration = graph.context.get("router_iteration", 0)
assert current_iteration == previous_iteration + 1, "Iteration should increment by 1"
previous_iteration = current_iteration
# Check if the max_iterations logic is working
router_id = router._id.lower()
assert graph.context.get(f"{router_id}_iteration", 0) == 5, "Router should stop after max_iterations"
# Extract the vertex IDs for analysis
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert "chat_output_false" in results_ids, f"Expected outputs not in results: {results_ids}"