refactor: update inputs in ToolCallingAgentComponent and add astream_events setup (#4240)
* Refactor `ToolCallingAgentComponent` to enhance input handling and update descriptions * Add async event processing for agent execution in agent.py - Introduced `process_agent_events` method to handle asynchronous event processing for agent execution. - Added individual async handlers for different event types: `handle_chain_start`, `handle_chain_end`, `handle_chat_model_stream`, `handle_tool_start`, and `handle_tool_end`. - Updated agent execution to use `astream_events` for improved event handling. - Enhanced logging for better traceability of agent and tool execution events. * Refactor agent component to use MessageTextInput and improve prompt handling * Refactor tool calling agent component: update prompt handling and input naming conventions * Remove input key validation from create_agent_runnable method in tool_calling.py * Enhance agent logging by removing redundant checks and adding scratchpad logging * Refactor `process_agent_events` to a standalone function with enhanced logging capabilities * Refactor log callback messages to improve emoji placement consistency * Use `jsonable_encoder` to encode agent scratchpad messages before logging * style: simplify double string on the same line --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
parent
b39ee0ae2d
commit
6c3c32851b
2 changed files with 105 additions and 24 deletions
|
|
@ -1,6 +1,8 @@
|
|||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
from langchain_core.runnables import Runnable
|
||||
|
|
@ -12,6 +14,7 @@ from langflow.field_typing import Text
|
|||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.log import LogFunctionType
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template import Output
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI
|
||||
|
|
@ -109,7 +112,26 @@ class LCAgentComponent(Component):
|
|||
msg = "Output key not found in result. Tried 'output'."
|
||||
raise ValueError(msg)
|
||||
|
||||
return cast(str, result.get("output"))
|
||||
return cast(str, result)
|
||||
|
||||
async def handle_chain_start(self, event: dict[str, Any]) -> None:
|
||||
if event["name"] == "Agent":
|
||||
self.log(f"Starting agent: {event['name']} with input: {event['data'].get('input')}")
|
||||
|
||||
async def handle_chain_end(self, event: dict[str, Any]) -> None:
|
||||
if event["name"] == "Agent":
|
||||
self.log(f"Done agent: {event['name']} with output: {event['data'].get('output', {}).get('output', '')}")
|
||||
|
||||
async def handle_tool_start(self, event: dict[str, Any]) -> None:
|
||||
self.log(f"Starting tool: {event['name']} with inputs: {event['data'].get('input')}")
|
||||
|
||||
async def handle_tool_end(self, event: dict[str, Any]) -> None:
|
||||
self.log(f"Done tool: {event['name']}")
|
||||
self.log(f"Tool output was: {event['data'].get('output')}")
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_runnable(self) -> Runnable:
|
||||
"""Create the agent."""
|
||||
|
||||
|
||||
class LCToolsAgentComponent(LCAgentComponent):
|
||||
|
|
@ -146,16 +168,76 @@ class LCToolsAgentComponent(LCAgentComponent):
|
|||
if self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
|
||||
result = runnable.invoke(
|
||||
input_dict, config={"callbacks": [AgentAsyncHandler(self.log), *self.get_langchain_callbacks()]}
|
||||
result = await process_agent_events(
|
||||
runnable.astream_events(
|
||||
input_dict,
|
||||
config={"callbacks": [AgentAsyncHandler(self.log), *self.get_langchain_callbacks()]},
|
||||
version="v2",
|
||||
),
|
||||
self.log,
|
||||
)
|
||||
self.status = result
|
||||
if "output" not in result:
|
||||
msg = "Output key not found in result. Tried 'output'."
|
||||
raise ValueError(msg)
|
||||
|
||||
return cast(str, result.get("output"))
|
||||
self.status = result
|
||||
return cast(str, result)
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_runnable(self) -> Runnable:
|
||||
"""Create the agent."""
|
||||
|
||||
|
||||
# Add this function near the top of the file, after the imports
|
||||
|
||||
|
||||
async def process_agent_events(agent_executor: AsyncIterator[dict[str, Any]], log_callback: LogFunctionType) -> str:
|
||||
"""Process agent events and return the final output.
|
||||
|
||||
Args:
|
||||
agent_executor: An async iterator of agent events
|
||||
log_callback: A callable function for logging messages
|
||||
|
||||
Returns:
|
||||
str: The final output from the agent
|
||||
"""
|
||||
final_output = ""
|
||||
async for event in agent_executor:
|
||||
match event["event"]:
|
||||
case "on_chain_start":
|
||||
if event["data"].get("input"):
|
||||
log_callback(f"Agent initiated with input: {event['data'].get('input')}", name="🚀 Agent Start")
|
||||
|
||||
case "on_chain_end":
|
||||
data_output = event["data"].get("output", {})
|
||||
if data_output and "output" in data_output:
|
||||
final_output = data_output["output"]
|
||||
log_callback(f"{final_output}", name="✅ Agent End")
|
||||
elif data_output and "agent_scratchpad" in data_output and data_output["agent_scratchpad"]:
|
||||
agent_scratchpad_messages = data_output["agent_scratchpad"]
|
||||
json_encoded_messages = jsonable_encoder(agent_scratchpad_messages)
|
||||
log_callback(json_encoded_messages, name="🔍 Agent Scratchpad")
|
||||
|
||||
case "on_tool_start":
|
||||
log_callback(
|
||||
f"Initiating tool: '{event['name']}' with inputs: {event['data'].get('input')}",
|
||||
name="🔧 Tool Start",
|
||||
)
|
||||
|
||||
case "on_tool_end":
|
||||
log_callback(f"Tool '{event['name']}' execution completed", name="🏁 Tool End")
|
||||
log_callback(f"{event['data'].get('output')}", name="📊 Tool Output")
|
||||
|
||||
case "on_tool_error":
|
||||
tool_name = event.get("name", "Unknown tool")
|
||||
error_message = event["data"].get("error", "Unknown error")
|
||||
log_callback(f"Tool '{tool_name}' failed with error: {error_message}", name="❌ Tool Error")
|
||||
|
||||
if "stack_trace" in event["data"]:
|
||||
log_callback(f"{event['data']['stack_trace']}", name="🔍 Tool Error")
|
||||
|
||||
if "recovery_attempt" in event["data"]:
|
||||
log_callback(f"{event['data']['recovery_attempt']}", name="🔄 Tool Error")
|
||||
|
||||
case _:
|
||||
# Handle any other event types or ignore them
|
||||
pass
|
||||
|
||||
return final_output
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from langchain.agents import create_tool_calling_agent
|
||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.inputs import MultilineInput
|
||||
from langflow.inputs import MessageTextInput
|
||||
from langflow.inputs.inputs import DataInput, HandleInput
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class ToolCallingAgentComponent(LCToolsAgentComponent):
|
||||
display_name: str = "Tool Calling Agent"
|
||||
description: str = "Agent that uses tools"
|
||||
description: str = "An agent designed to utilize various tools seamlessly within workflows."
|
||||
icon = "LangChain"
|
||||
beta = True
|
||||
name = "ToolCallingAgent"
|
||||
|
|
@ -17,34 +17,33 @@ class ToolCallingAgentComponent(LCToolsAgentComponent):
|
|||
inputs = [
|
||||
*LCToolsAgentComponent._base_inputs,
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
MultilineInput(
|
||||
MessageTextInput(
|
||||
name="system_prompt",
|
||||
display_name="System Prompt",
|
||||
info="System prompt for the agent.",
|
||||
value="You are a helpful assistant",
|
||||
info="Initial instructions and context provided to guide the agent's behavior.",
|
||||
value="You are a helpful assistant that can use tools to answer questions and perform tasks.",
|
||||
),
|
||||
MultilineInput(
|
||||
name="user_prompt", display_name="Prompt", info="This prompt must contain 'input' key.", value="{input}"
|
||||
MessageTextInput(
|
||||
name="input_value",
|
||||
display_name="Input",
|
||||
info="The input provided by the user for the agent to process.",
|
||||
),
|
||||
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
|
||||
DataInput(name="chat_history", display_name="Chat Memory", is_list=True, advanced=True),
|
||||
]
|
||||
|
||||
def get_chat_history_data(self) -> list[Data] | None:
|
||||
return self.chat_history
|
||||
|
||||
def create_agent_runnable(self):
|
||||
if "input" not in self.user_prompt:
|
||||
msg = "Prompt must contain 'input' key."
|
||||
raise ValueError(msg)
|
||||
messages = [
|
||||
("system", self.system_prompt),
|
||||
("placeholder", "{chat_history}"),
|
||||
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)),
|
||||
("human", self.input_value),
|
||||
("placeholder", "{agent_scratchpad}"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
try:
|
||||
return create_tool_calling_agent(self.llm, self.tools, prompt)
|
||||
return create_tool_calling_agent(self.llm, self.tools or [], prompt)
|
||||
except NotImplementedError as e:
|
||||
message = f"{self.display_name} does not support tool calling." "Please try using a compatible model."
|
||||
message = f"{self.display_name} does not support tool calling. Please try using a compatible model."
|
||||
raise NotImplementedError(message) from e
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue