feat: Enhance backend with context management, error handling, and refactored code (#4286)
* Add cycle detection and management for graph vertices in run manager * Refactor: Move AIMLEmbeddingsImpl to a new module path * Add AIMLEmbeddingsImpl class for document and query embeddings using AIML API * Add agent components for action routing, decision-making, execution, and context management - Introduced `AgentActionRouter` to route agent flow based on action type. - Added `DecideActionComponent` for determining actions from context and prompts. - Implemented `ExecuteActionComponent` to execute actions using available tools. - Created `GenerateThoughtComponent` for generating thoughts based on context. - Developed `ProvideFinalAnswerComponent` to generate final answers from context. - Built `AgentContextBuilder` for constructing `AgentContext` instances. - Added `ObserveResultComponent` to process and observe action results. - Implemented `CheckTerminationComponent` to determine if the agent should continue or terminate. * Add AgentContext class for managing agent state and context serialization - Introduced `AgentContext` class in `context.py` to handle agent state, including tools, language model, and context history. - Implemented serialization methods for converting agent context to JSON-compatible format. - Added validation for language model instances to ensure compatibility. - Provided methods for updating and retrieving full context, including context history management. * Add new agent components to the langflow module's init file * Update `apply_on_outputs` to use `_outputs_map` in vertex base class * Add _pre_run_setup method to custom component for pre-execution setup * Handle non-list action types in decide_action method * Enhance AgentActionRouter with iteration control and context routing logic * Fix incorrect variable usage in tool call result message formatting * Add AgentActionRouter to module exports in agents package * Refactor cycle detection logic in graph base class * Add test for complex agent flow with cyclic graph validation * Enhance readiness checks in tracing service methods * Add context management to Graph class with dotdict support * Add context management methods to custom component class - Introduced a `_ctx` attribute to store context data. - Added `ctx` property to access the graph's context, raising an error if the graph is not built. - Implemented `add_to_ctx` method to add key-value pairs to the context with an optional overwrite flag. - Implemented `update_ctx` method to update the context with a dictionary of values, ensuring the graph is built and the input is a dictionary. * Add customizable Agent component with input/output handling and action routing * Handle non-list 'tools' attribute in 'build_context' method * Convert `get_response` method to asynchronous and update graph processing to use async iteration. * Add async test for Agent component in graph cycle tests * Refactor Agent Flow JSON: Simplify input types and update agent component structure - Removed "BaseTool" from input types for "ToolCallingAgent" to streamline tool handling. - Updated agent component to a more modular structure with new prompts and input configurations. - Replaced deprecated methods and fields with updated implementations for improved functionality. - Adjusted metadata and configuration settings for better clarity and usability. * [autofix.ci] apply automated fixes * Add Agent import to init, improve error handling, and clean up imports - Added `Agent` import to `__init__.py` for better module accessibility. - Improved error handling in `aiml_embeddings.py` by raising a `ValueError` when the expected embedding count is not met. - Cleaned up redundant imports in `test_cycles.py` to enhance code readability. * Refactor agent component imports for improved modularity and organization * Remove agent components and update `__init__.py` exports * Add iteration control and default route options to ConditionalRouter component * Refactor graph tests to include new components and update iteration logic - Replaced complex agent flow with a simplified guessing game using OpenAI components and conditional routing. - Introduced `TextInputComponent` and updated `ChatInput` initialization. - Added new test `test_conditional_router_max_iterations` to validate conditional routing with max iterations. - Updated graph cyclicity assertions and snapshot checks for improved test coverage. - Removed deprecated agent components and related logic. * Refactor conditional router to return message consistently and use iterate_and_stop_once method * Add return type annotations to methods in langsmith.py * Remove unnecessary `@override` decorator and add `# noqa: ARG002` comments for unused arguments * Move ChatInput import inside flow_component fixture in conftest.py * Update test to use _outputs_map for cycle outputs retrieval * Refactor `iterate_and_stop_once` to remove redundant `_id` variable usage * Add default route to ConditionalRouterComponent in cycle test * Implement synchronous graph execution using threading and queues - Removed `nest_asyncio` dependency and replaced it with a new threading-based approach for synchronous graph execution. - Introduced a `queue.Queue` to handle results and exceptions between threads. - Added a new thread to run asynchronous code, ensuring proper event loop management and task completion. - Updated methods to return sorted lists of runnable vertices for consistency. * Update import path for ModelConstants in test_model_constants.py * [autofix.ci] apply automated fixes * fix: add property decorator --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
parent
b898d8a652
commit
8681c56cdc
11 changed files with 438 additions and 56 deletions
109
src/backend/base/langflow/base/agents/context.py
Normal file
109
src/backend/base/langflow/base/agents/context.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from pydantic import BaseModel, Field, field_validator, model_serializer
|
||||
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.schema.data import Data
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
tools: dict[str, Any]
|
||||
llm: Any
|
||||
context: str = ""
|
||||
iteration: int = 0
|
||||
max_iterations: int = 5
|
||||
thought: str = ""
|
||||
last_action: Any = None
|
||||
last_action_result: Any = None
|
||||
final_answer: Any = ""
|
||||
context_history: list[tuple[str, str, str]] = Field(default_factory=list)
|
||||
|
||||
@model_serializer(mode="plain")
|
||||
def serialize_agent_context(self):
|
||||
serliazed_llm = self.llm.to_json() if hasattr(self.llm, "to_json") else str(self.llm)
|
||||
serliazed_tools = {k: v.to_json() if hasattr(v, "to_json") else str(v) for k, v in self.tools.items()}
|
||||
return {
|
||||
"tools": serliazed_tools,
|
||||
"llm": serliazed_llm,
|
||||
"context": self.context,
|
||||
"iteration": self.iteration,
|
||||
"max_iterations": self.max_iterations,
|
||||
"thought": self.thought,
|
||||
"last_action": self.last_action.to_json()
|
||||
if hasattr(self.last_action, "to_json")
|
||||
else str(self.last_action),
|
||||
"action_result": self.last_action_result.to_json()
|
||||
if hasattr(self.last_action_result, "to_json")
|
||||
else str(self.last_action_result),
|
||||
"final_answer": self.final_answer,
|
||||
"context_history": self.context_history,
|
||||
}
|
||||
|
||||
@field_validator("llm", mode="before")
|
||||
@classmethod
|
||||
def validate_llm(cls, v) -> LanguageModel:
|
||||
if not isinstance(v, BaseLLM | BaseChatModel | BaseLanguageModel):
|
||||
msg = "llm must be an instance of LanguageModel"
|
||||
raise TypeError(msg)
|
||||
return v
|
||||
|
||||
def to_data_repr(self):
|
||||
data_objs = []
|
||||
for name, val, time_str in self.context_history:
|
||||
content = val.content if hasattr(val, "content") else val
|
||||
data_objs.append(Data(name=name, value=content, timestamp=time_str))
|
||||
|
||||
sorted_data_objs = sorted(data_objs, key=lambda x: datetime.fromisoformat(x.timestamp), reverse=True)
|
||||
|
||||
sorted_data_objs.append(
|
||||
Data(
|
||||
name="Formatted Context",
|
||||
value=self.get_full_context(),
|
||||
)
|
||||
)
|
||||
return sorted_data_objs
|
||||
|
||||
def _build_tools_context(self):
|
||||
tool_context = ""
|
||||
for tool_name, tool_obj in self.tools.items():
|
||||
tool_context += f"{tool_name}: {tool_obj.description}\n"
|
||||
return tool_context
|
||||
|
||||
def _build_init_context(self):
|
||||
return f"""
|
||||
{self.context}
|
||||
|
||||
"""
|
||||
|
||||
def model_post_init(self, _context: Any) -> None:
|
||||
if hasattr(self.llm, "bind_tools"):
|
||||
self.llm = self.llm.bind_tools(self.tools.values())
|
||||
if self.context:
|
||||
self.update_context("Initial Context", self.context)
|
||||
|
||||
def update_context(self, key: str, value: str):
|
||||
self.context_history.insert(0, (key, value, datetime.now(tz=timezone.utc).astimezone().isoformat()))
|
||||
|
||||
def _serialize_context_history_tuple(self, context_history_tuple: tuple[str, str, str]) -> str:
|
||||
name, value, _ = context_history_tuple
|
||||
if hasattr(value, "content"):
|
||||
value = value.content
|
||||
elif hasattr(value, "log"):
|
||||
value = value.log
|
||||
return f"{name}: {value}"
|
||||
|
||||
def get_full_context(self) -> str:
|
||||
context_history_reversed = self.context_history[::-1]
|
||||
context_formatted = "\n".join(
|
||||
[
|
||||
self._serialize_context_history_tuple(context_history_tuple)
|
||||
for context_history_tuple in context_history_reversed
|
||||
]
|
||||
)
|
||||
return f"""
|
||||
Context:
|
||||
{context_formatted}
|
||||
"""
|
||||
|
|
@ -30,7 +30,7 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings):
|
|||
try:
|
||||
result_data = future.result()
|
||||
if len(result_data["data"]) != 1:
|
||||
msg = "Expected one embedding"
|
||||
msg = f"Expected one embedding, got {len(result_data['data'])}"
|
||||
raise ValueError(msg)
|
||||
embeddings[index] = result_data["data"][0]["embedding"]
|
||||
except (
|
||||
|
|
@ -38,6 +38,7 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings):
|
|||
httpx.RequestError,
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
ValueError,
|
||||
):
|
||||
logger.exception("Error occurred")
|
||||
raise
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from langflow.base.embeddings.aiml_embeddings import AIMLEmbeddingsImpl
|
||||
from langflow.base.embeddings.model import LCEmbeddingsModel
|
||||
from langflow.base.models.aiml_constants import AIML_EMBEDDING_MODELS
|
||||
from langflow.components.embeddings.util import AIMLEmbeddingsImpl
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.inputs.inputs import DropdownInput
|
||||
from langflow.io import SecretStrInput
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
import warnings
|
||||
|
||||
from langchain_core._api.deprecation import LangChainDeprecationWarning
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", LangChainDeprecationWarning)
|
||||
from .aiml import AIMLEmbeddingsImpl
|
||||
|
||||
|
||||
__all__ = ["AIMLEmbeddingsImpl"]
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.custom import Component
|
||||
from langflow.io import BoolInput, DropdownInput, MessageInput, MessageTextInput, Output
|
||||
from langflow.io import BoolInput, DropdownInput, IntInput, MessageInput, MessageTextInput, Output
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
|
|
@ -9,6 +9,10 @@ class ConditionalRouterComponent(Component):
|
|||
icon = "equal"
|
||||
name = "ConditionalRouter"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__iteration_updated = False
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="input_text",
|
||||
|
|
@ -40,6 +44,20 @@ class ConditionalRouterComponent(Component):
|
|||
display_name="Message",
|
||||
info="The message to pass through either route.",
|
||||
),
|
||||
IntInput(
|
||||
name="max_iterations",
|
||||
display_name="Max Iterations",
|
||||
info="The maximum number of iterations for the conditional router.",
|
||||
value=10,
|
||||
),
|
||||
DropdownInput(
|
||||
name="default_route",
|
||||
display_name="Default Route",
|
||||
options=["true_result", "false_result"],
|
||||
info="The default route to take when max iterations are reached.",
|
||||
value="false_result",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
|
|
@ -47,6 +65,9 @@ class ConditionalRouterComponent(Component):
|
|||
Output(display_name="False Route", name="false_result", method="false_response"),
|
||||
]
|
||||
|
||||
def _pre_run_setup(self):
|
||||
self.__iteration_updated = False
|
||||
|
||||
def evaluate_condition(self, input_text: str, match_text: str, operator: str, *, case_sensitive: bool) -> bool:
|
||||
if not case_sensitive:
|
||||
input_text = input_text.lower()
|
||||
|
|
@ -64,15 +85,25 @@ class ConditionalRouterComponent(Component):
|
|||
return input_text.endswith(match_text)
|
||||
return False
|
||||
|
||||
def iterate_and_stop_once(self, route_to_stop: str):
|
||||
if not self.__iteration_updated:
|
||||
self.update_ctx({f"{self._id}_iteration": self.ctx.get(f"{self._id}_iteration", 0) + 1})
|
||||
self.__iteration_updated = True
|
||||
if self.ctx.get(f"{self._id}_iteration", 0) >= self.max_iterations and route_to_stop == self.default_route:
|
||||
# We need to stop the other route
|
||||
route_to_stop = "true_result" if route_to_stop == "false_result" else "false_result"
|
||||
self.stop(route_to_stop)
|
||||
|
||||
def true_response(self) -> Message:
|
||||
result = self.evaluate_condition(
|
||||
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
|
||||
)
|
||||
if result:
|
||||
self.status = self.message
|
||||
self.iterate_and_stop_once("false_result")
|
||||
return self.message
|
||||
self.stop("true_result")
|
||||
return None # type: ignore[return-value]
|
||||
self.iterate_and_stop_once("true_result")
|
||||
return self.message
|
||||
|
||||
def false_response(self) -> Message:
|
||||
result = self.evaluate_condition(
|
||||
|
|
@ -80,6 +111,7 @@ class ConditionalRouterComponent(Component):
|
|||
)
|
||||
if not result:
|
||||
self.status = self.message
|
||||
self.iterate_and_stop_once("true_result")
|
||||
return self.message
|
||||
self.stop("false_result")
|
||||
return None # type: ignore[return-value]
|
||||
self.iterate_and_stop_once("false_result")
|
||||
return self.message
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class Component(CustomComponent):
|
|||
_output_logs: dict[str, list[Log]] = {}
|
||||
_current_output: str = ""
|
||||
_metadata: dict = {}
|
||||
_ctx: dict = {}
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# if key starts with _ it is a config
|
||||
|
|
@ -111,6 +112,53 @@ class Component(CustomComponent):
|
|||
self.set_class_code()
|
||||
self._set_output_required_inputs()
|
||||
|
||||
@property
|
||||
def ctx(self):
|
||||
if not hasattr(self, "graph") or self.graph is None:
|
||||
msg = "Graph not found. Please build the graph first."
|
||||
raise ValueError(msg)
|
||||
return self.graph.context
|
||||
|
||||
def add_to_ctx(self, key: str, value: Any, *, overwrite: bool = False) -> None:
|
||||
"""Add a key-value pair to the context.
|
||||
|
||||
Args:
|
||||
key (str): The key to add.
|
||||
value (Any): The value to associate with the key.
|
||||
overwrite (bool, optional): Whether to overwrite the existing value. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph is not built.
|
||||
"""
|
||||
if not hasattr(self, "graph") or self.graph is None:
|
||||
msg = "Graph not found. Please build the graph first."
|
||||
raise ValueError(msg)
|
||||
if key in self.graph.context and not overwrite:
|
||||
msg = f"Key {key} already exists in context. Set overwrite=True to overwrite."
|
||||
raise ValueError(msg)
|
||||
self.graph.context.update({key: value})
|
||||
|
||||
def update_ctx(self, value_dict: dict[str, Any]) -> None:
|
||||
"""Update the context with a dictionary of values.
|
||||
|
||||
Args:
|
||||
value_dict (dict[str, Any]): The dictionary of values to update.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph is not built.
|
||||
"""
|
||||
if not hasattr(self, "graph") or self.graph is None:
|
||||
msg = "Graph not found. Please build the graph first."
|
||||
raise ValueError(msg)
|
||||
if not isinstance(value_dict, dict):
|
||||
msg = "Value dict must be a dictionary"
|
||||
raise TypeError(msg)
|
||||
|
||||
self.graph.context.update(value_dict)
|
||||
|
||||
def _pre_run_setup(self):
|
||||
pass
|
||||
|
||||
def set_event_manager(self, event_manager: EventManager | None = None) -> None:
|
||||
self._event_manager = event_manager
|
||||
|
||||
|
|
@ -768,7 +816,8 @@ class Component(CustomComponent):
|
|||
async def _build_results(self) -> tuple[dict, dict]:
|
||||
_results = {}
|
||||
_artifacts = {}
|
||||
|
||||
if hasattr(self, "_pre_run_setup"):
|
||||
self._pre_run_setup()
|
||||
if hasattr(self, "outputs"):
|
||||
if any(getattr(_input, "tool_mode", False) for _input in self.inputs):
|
||||
self._append_tool_to_outputs_map()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import asyncio
|
|||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Generator, Iterable
|
||||
|
|
@ -12,7 +14,6 @@ from functools import partial
|
|||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import nest_asyncio
|
||||
from loguru import logger
|
||||
|
||||
from langflow.exceptions.component import ComponentBuildError
|
||||
|
|
@ -26,7 +27,6 @@ from langflow.graph.graph.utils import (
|
|||
find_all_cycle_edges,
|
||||
find_cycle_vertices,
|
||||
find_start_component_id,
|
||||
has_cycle,
|
||||
process_flow,
|
||||
should_continue,
|
||||
sort_up_to_vertex,
|
||||
|
|
@ -36,6 +36,7 @@ from langflow.graph.vertex.base import Vertex, VertexStates
|
|||
from langflow.graph.vertex.schema import NodeData, NodeTypeEnum
|
||||
from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex
|
||||
from langflow.logging.logger import LogConfig, configure
|
||||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
from langflow.services.deps import get_chat_service, get_tracing_service
|
||||
|
|
@ -63,6 +64,7 @@ class Graph:
|
|||
description: str | None = None,
|
||||
user_id: str | None = None,
|
||||
log_config: LogConfig | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initializes a new instance of the Graph class.
|
||||
|
||||
|
|
@ -74,9 +76,11 @@ class Graph:
|
|||
description: The graph description.
|
||||
user_id: The user ID.
|
||||
log_config: The log configuration.
|
||||
context: Additional context for the graph. Defaults to None.
|
||||
"""
|
||||
if log_config:
|
||||
configure(**log_config)
|
||||
|
||||
self._start = start
|
||||
self._state_model = None
|
||||
self._end = end
|
||||
|
|
@ -107,6 +111,7 @@ class Graph:
|
|||
self.state_manager = GraphStateManager()
|
||||
self._vertices: list[NodeData] = []
|
||||
self._edges: list[EdgeData] = []
|
||||
|
||||
self.top_level_vertices: list[str] = []
|
||||
self.vertex_map: dict[str, Vertex] = {}
|
||||
self.predecessor_map: dict[str, list[str]] = defaultdict(list)
|
||||
|
|
@ -123,6 +128,11 @@ class Graph:
|
|||
self._call_order: list[str] = []
|
||||
self._snapshots: list[dict[str, Any]] = []
|
||||
self._end_trace_tasks: set[asyncio.Task] = set()
|
||||
|
||||
if context and not isinstance(context, dict):
|
||||
msg = "Context must be a dictionary"
|
||||
raise TypeError(msg)
|
||||
self._context = dotdict(context or {})
|
||||
try:
|
||||
self.tracing_service: TracingService | None = get_tracing_service()
|
||||
except Exception: # noqa: BLE001
|
||||
|
|
@ -135,6 +145,21 @@ class Graph:
|
|||
msg = "You must provide both input and output components"
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def context(self) -> dotdict:
|
||||
if isinstance(self._context, dotdict):
|
||||
return self._context
|
||||
return dotdict(self._context)
|
||||
|
||||
@context.setter
|
||||
def context(self, value: dict[str, Any]):
|
||||
if not isinstance(value, dict):
|
||||
msg = "Context must be a dictionary"
|
||||
raise TypeError(msg)
|
||||
if isinstance(value, dict):
|
||||
value = dotdict(value)
|
||||
self._context = value
|
||||
|
||||
@property
|
||||
def session_id(self):
|
||||
return self._session_id
|
||||
|
|
@ -217,6 +242,8 @@ class Graph:
|
|||
for vertex in self._vertices:
|
||||
if vertex_id := vertex.get("id"):
|
||||
self.top_level_vertices.append(vertex_id)
|
||||
if vertex_id in self.cycle_vertices:
|
||||
self.run_manager.add_to_cycle_vertices(vertex_id)
|
||||
self._graph_data = process_flow(self.raw_graph_data)
|
||||
|
||||
self._vertices = self._graph_data["nodes"]
|
||||
|
|
@ -360,26 +387,81 @@ class Graph:
|
|||
config: StartConfigDict | None = None,
|
||||
event_manager: EventManager | None = None,
|
||||
) -> Generator:
|
||||
"""Starts the graph execution synchronously by creating a new event loop in a separate thread.
|
||||
|
||||
Args:
|
||||
inputs: Optional list of input dictionaries
|
||||
max_iterations: Optional maximum number of iterations
|
||||
config: Optional configuration dictionary
|
||||
event_manager: Optional event manager
|
||||
|
||||
Returns:
|
||||
Generator yielding results from graph execution
|
||||
"""
|
||||
if self.is_cyclic and max_iterations is None:
|
||||
msg = "You must specify a max_iterations if the graph is cyclic"
|
||||
raise ValueError(msg)
|
||||
|
||||
if config is not None:
|
||||
self.__apply_config(config)
|
||||
# ! Change this ASAP
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
async_gen = self.async_start(inputs, max_iterations, event_manager)
|
||||
async_gen_task = asyncio.ensure_future(anext(async_gen))
|
||||
|
||||
while True:
|
||||
# Create a queue for passing results and errors between threads
|
||||
result_queue: queue.Queue[VertexBuildResult | Exception | None] = queue.Queue()
|
||||
|
||||
# Function to run async code in separate thread
|
||||
def run_async_code():
|
||||
# Create new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(async_gen_task)
|
||||
yield result
|
||||
if isinstance(result, Finish):
|
||||
return
|
||||
async_gen_task = asyncio.ensure_future(anext(async_gen))
|
||||
except StopAsyncIteration:
|
||||
# Run the async generator
|
||||
async_gen = self.async_start(inputs, max_iterations, event_manager)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get next result from async generator
|
||||
result = loop.run_until_complete(anext(async_gen))
|
||||
result_queue.put(result)
|
||||
|
||||
if isinstance(result, Finish):
|
||||
break
|
||||
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except ValueError as e:
|
||||
# Put the exception in the queue
|
||||
result_queue.put(e)
|
||||
break
|
||||
|
||||
finally:
|
||||
# Ensure all pending tasks are completed
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
# Create a future to gather all pending tasks
|
||||
cleanup_future = asyncio.gather(*pending, return_exceptions=True)
|
||||
loop.run_until_complete(cleanup_future)
|
||||
|
||||
# Close the loop
|
||||
loop.close()
|
||||
# Signal completion
|
||||
result_queue.put(None)
|
||||
|
||||
# Start thread for async execution
|
||||
thread = threading.Thread(target=run_async_code)
|
||||
thread.start()
|
||||
|
||||
# Yield results from queue
|
||||
while True:
|
||||
result = result_queue.get()
|
||||
if result is None:
|
||||
break
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
yield result
|
||||
|
||||
# Wait for thread to complete
|
||||
thread.join()
|
||||
|
||||
def _add_edge(self, edge: EdgeData) -> None:
|
||||
self.add_edge(edge)
|
||||
|
|
@ -533,12 +615,7 @@ class Graph:
|
|||
bool: True if the graph has any cycles, False otherwise.
|
||||
"""
|
||||
if self._is_cyclic is None:
|
||||
vertices = [vertex.id for vertex in self.vertices]
|
||||
try:
|
||||
edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges]
|
||||
except KeyError:
|
||||
edges = [(e["source"], e["target"]) for e in self._edges]
|
||||
self._is_cyclic = has_cycle(vertices, edges)
|
||||
self._is_cyclic = bool(self.cycle_vertices)
|
||||
return self._is_cyclic
|
||||
|
||||
@property
|
||||
|
|
@ -1136,6 +1213,9 @@ class Graph:
|
|||
self._build_vertex_params()
|
||||
self._instantiate_components_in_vertices()
|
||||
self._set_cache_to_vertices_in_cycle()
|
||||
for vertex in self.vertices:
|
||||
if vertex.id in self.cycle_vertices:
|
||||
self.run_manager.add_to_cycle_vertices(vertex.id)
|
||||
|
||||
def _get_edges_as_list_of_tuples(self) -> list[tuple[str, str]]:
|
||||
"""Returns the edges of the graph as a list of tuples."""
|
||||
|
|
@ -1455,7 +1535,7 @@ class Graph:
|
|||
else:
|
||||
next_runnable_vertices.add(v_id)
|
||||
|
||||
return list(next_runnable_vertices)
|
||||
return sorted(next_runnable_vertices)
|
||||
|
||||
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: Vertex, *, cache: bool = True) -> list[str]:
|
||||
v_id = vertex.id
|
||||
|
|
@ -1717,6 +1797,8 @@ class Graph:
|
|||
|
||||
for vertex_id in first_layer:
|
||||
self.run_manager.add_to_vertices_being_run(vertex_id)
|
||||
if vertex_id in self.cycle_vertices:
|
||||
self.run_manager.add_to_cycle_vertices(vertex_id)
|
||||
self._first_layer = sorted(first_layer)
|
||||
self._run_queue = deque(self._first_layer)
|
||||
self._prepared = True
|
||||
|
|
@ -1993,7 +2075,7 @@ class Graph:
|
|||
for successor_id in self.run_manager.run_map.get(vertex_id, []):
|
||||
runnable_vertices.extend(self.find_runnable_predecessors_for_successor(successor_id))
|
||||
|
||||
return runnable_vertices
|
||||
return sorted(runnable_vertices)
|
||||
|
||||
def find_runnable_predecessors_for_successor(self, vertex_id: str) -> list[str]:
|
||||
runnable_vertices = []
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@ from collections import defaultdict
|
|||
|
||||
|
||||
class RunnableVerticesManager:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.run_map: dict[str, list[str]] = defaultdict(list) # Tracks successors of each vertex
|
||||
self.run_predecessors: dict[str, set[str]] = defaultdict(set) # Tracks predecessors for each vertex
|
||||
self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run
|
||||
self.vertices_being_run: set[str] = set() # Set of vertices that are currently running
|
||||
self.cycle_vertices: set[str] = set() # Set of vertices that are in a cycle
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -55,7 +56,7 @@ class RunnableVerticesManager:
|
|||
return False
|
||||
if vertex_id not in self.vertices_to_run:
|
||||
return False
|
||||
return self.are_all_predecessors_fulfilled(vertex_id)
|
||||
return self.are_all_predecessors_fulfilled(vertex_id) or vertex_id in self.cycle_vertices
|
||||
|
||||
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
|
||||
return not any(self.run_predecessors.get(vertex_id, []))
|
||||
|
|
@ -89,3 +90,6 @@ class RunnableVerticesManager:
|
|||
|
||||
def add_to_vertices_being_run(self, v_id) -> None:
|
||||
self.vertices_being_run.add(v_id)
|
||||
|
||||
def add_to_cycle_vertices(self, v_id):
|
||||
self.cycle_vertices.add(v_id)
|
||||
|
|
|
|||
|
|
@ -871,4 +871,4 @@ class Vertex:
|
|||
if not self.custom_component or not self.custom_component.outputs:
|
||||
return
|
||||
# Apply the function to each output
|
||||
[func(output) for output in self.custom_component.outputs]
|
||||
[func(output) for output in self.custom_component._outputs_map.values()]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from datetime import datetime, timezone
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.schema.data import Data
|
||||
from langflow.services.tracing.base import BaseTracer
|
||||
|
|
@ -63,17 +62,16 @@ class LangSmithTracer(BaseTracer):
|
|||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
return True
|
||||
|
||||
@override
|
||||
def add_trace(
|
||||
self,
|
||||
trace_id: str,
|
||||
trace_id: str, # noqa: ARG002
|
||||
trace_name: str,
|
||||
trace_type: str,
|
||||
inputs: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
vertex: Vertex | None = None,
|
||||
vertex: Vertex | None = None, # noqa: ARG002
|
||||
) -> None:
|
||||
if not self._ready:
|
||||
if not self._ready or not self._run_tree:
|
||||
return
|
||||
processed_inputs = {}
|
||||
if inputs:
|
||||
|
|
@ -117,16 +115,15 @@ class LangSmithTracer(BaseTracer):
|
|||
value = str(value)
|
||||
return value
|
||||
|
||||
@override
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: str,
|
||||
trace_id: str, # noqa: ARG002
|
||||
trace_name: str,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
error: Exception | None = None,
|
||||
logs: Sequence[Log | dict] = (),
|
||||
) -> None:
|
||||
if not self._ready:
|
||||
):
|
||||
if not self._ready or trace_name not in self._children:
|
||||
return
|
||||
child = self._children[trace_name]
|
||||
raw_outputs = {}
|
||||
|
|
@ -159,7 +156,7 @@ class LangSmithTracer(BaseTracer):
|
|||
error: Exception | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if not self._ready:
|
||||
if not self._ready or not self._run_tree:
|
||||
return
|
||||
self._run_tree.add_metadata({"inputs": inputs})
|
||||
if metadata:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
|
||||
import pytest
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.inputs.text import TextInputComponent
|
||||
from langflow.components.models import OpenAIModelComponent
|
||||
from langflow.components.outputs import ChatOutput, TextOutputComponent
|
||||
from langflow.components.prompts import PromptComponent
|
||||
|
|
@ -31,7 +32,7 @@ class Concatenate(Component):
|
|||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_cycle_in_graph():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
router = ConditionalRouterComponent(_id="router")
|
||||
router = ConditionalRouterComponent(_id="router", default_route="true_result")
|
||||
chat_input.set(input_value=router.false_response)
|
||||
concat_component = Concatenate(_id="concatenate")
|
||||
concat_component.set(text=chat_input.message_response)
|
||||
|
|
@ -59,7 +60,6 @@ def test_cycle_in_graph():
|
|||
snapshots.append(graph._snapshot())
|
||||
results.append(result)
|
||||
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert results_ids[-2:] == ["text_output", "chat_output"]
|
||||
assert len(results_ids) > len(graph.vertices), snapshots
|
||||
# Check that chat_output and text_output are the last vertices in the results
|
||||
assert results_ids == [
|
||||
|
|
@ -127,7 +127,9 @@ def test_that_outputs_cache_is_set_to_false_in_cycle():
|
|||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
cycle_vertices = find_cycle_vertices(graph._get_edges_as_list_of_tuples())
|
||||
cycle_outputs_lists = [graph.vertex_map[vertex_id].custom_component.outputs for vertex_id in cycle_vertices]
|
||||
cycle_outputs_lists = [
|
||||
graph.vertex_map[vertex_id].custom_component._outputs_map.values() for vertex_id in cycle_vertices
|
||||
]
|
||||
cycle_outputs = [output for outputs in cycle_outputs_lists for output in outputs]
|
||||
for output in cycle_outputs:
|
||||
assert output.cache is False
|
||||
|
|
@ -206,3 +208,119 @@ def test_updated_graph_with_prompts():
|
|||
# Extract the vertex IDs for analysis
|
||||
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}"
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
def test_updated_graph_with_max_iterations():
|
||||
# Chat input initialization
|
||||
chat_input = ChatInput(_id="chat_input").set(input_value="bacon")
|
||||
|
||||
# First prompt: Guessing game with hints
|
||||
prompt_component_1 = PromptComponent(_id="prompt_component_1").set(
|
||||
template="Try to guess a word. I will give you hints if you get it wrong.\n"
|
||||
"Hint: {hint}\n"
|
||||
"Last try: {last_try}\n"
|
||||
"Answer:",
|
||||
)
|
||||
|
||||
# First OpenAI LLM component (Processes the guessing prompt)
|
||||
openai_component_1 = OpenAIModelComponent(_id="openai_1").set(
|
||||
input_value=prompt_component_1.build_prompt, api_key=os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
# Conditional router based on agent response
|
||||
router = ConditionalRouterComponent(_id="router").set(
|
||||
input_text=openai_component_1.text_response,
|
||||
match_text=chat_input.message_response,
|
||||
operator="contains",
|
||||
message=openai_component_1.text_response,
|
||||
)
|
||||
|
||||
# Second prompt: After the last try, provide a new hint
|
||||
prompt_component_2 = PromptComponent(_id="prompt_component_2")
|
||||
prompt_component_2.set(
|
||||
template="Given the following word and the following last try. Give the guesser a new hint.\n"
|
||||
"Last try: {last_try}\n"
|
||||
"Word: {word}\n"
|
||||
"Hint:",
|
||||
word=chat_input.message_response,
|
||||
last_try=router.false_response,
|
||||
)
|
||||
|
||||
# Second OpenAI component (handles the router's response)
|
||||
openai_component_2 = OpenAIModelComponent(_id="openai_2")
|
||||
openai_component_2.set(input_value=prompt_component_2.build_prompt, api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
prompt_component_1.set(hint=openai_component_2.text_response, last_try=router.false_response)
|
||||
|
||||
# chat output for the final OpenAI response
|
||||
chat_output_1 = ChatOutput(_id="chat_output_1")
|
||||
chat_output_1.set(input_value=router.true_response)
|
||||
|
||||
# Build the graph without concatenate
|
||||
graph = Graph(chat_input, chat_output_1)
|
||||
|
||||
# Assertions for graph cyclicity and correctness
|
||||
assert graph.is_cyclic is True, "Graph should contain cycles."
|
||||
|
||||
# Run and validate the execution of the graph
|
||||
results = []
|
||||
max_iterations = 20
|
||||
snapshots = [graph.get_snapshot()]
|
||||
|
||||
for result in graph.start(max_iterations=max_iterations, config={"output": {"cache": False}}):
|
||||
snapshots.append(graph.get_snapshot())
|
||||
results.append(result)
|
||||
|
||||
assert len(snapshots) > 2, "Graph should have more than one snapshot"
|
||||
# Extract the vertex IDs for analysis
|
||||
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}"
|
||||
|
||||
|
||||
def test_conditional_router_max_iterations():
|
||||
# Chat input initialization
|
||||
text_input = TextInputComponent(_id="text_input")
|
||||
|
||||
# Conditional router setup with a condition that will never match
|
||||
router = ConditionalRouterComponent(_id="router").set(
|
||||
input_text=text_input.text_response,
|
||||
match_text="bacon",
|
||||
operator="equals",
|
||||
message="This message should not be routed to true_result",
|
||||
max_iterations=5,
|
||||
default_route="true_result",
|
||||
)
|
||||
|
||||
# Chat output for the true route
|
||||
text_input.set(input_value=router.false_response)
|
||||
|
||||
# Chat output for the false route
|
||||
chat_output_false = ChatOutput(_id="chat_output_false")
|
||||
chat_output_false.set(input_value=router.true_response)
|
||||
|
||||
# Build the graph
|
||||
graph = Graph(text_input, chat_output_false)
|
||||
|
||||
# Assertions for graph cyclicity and correctness
|
||||
assert graph.is_cyclic is True, "Graph should contain cycles."
|
||||
|
||||
# Run and validate the execution of the graph
|
||||
results = []
|
||||
snapshots = [graph.get_snapshot()]
|
||||
previous_iteration = graph.context.get("router_iteration", 0)
|
||||
for result in graph.start(max_iterations=20, config={"output": {"cache": False}}):
|
||||
snapshots.append(graph.get_snapshot())
|
||||
results.append(result)
|
||||
if hasattr(result, "vertex") and result.vertex.id == "router":
|
||||
current_iteration = graph.context.get("router_iteration", 0)
|
||||
assert current_iteration == previous_iteration + 1, "Iteration should increment by 1"
|
||||
previous_iteration = current_iteration
|
||||
|
||||
# Check if the max_iterations logic is working
|
||||
router_id = router._id.lower()
|
||||
assert graph.context.get(f"{router_id}_iteration", 0) == 5, "Router should stop after max_iterations"
|
||||
|
||||
# Extract the vertex IDs for analysis
|
||||
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert "chat_output_false" in results_ids, f"Expected outputs not in results: {results_ids}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue