From 33129c74669f67856e1a65729b200528d6ce69a6 Mon Sep 17 00:00:00 2001 From: Cezar Vasconcelos Date: Fri, 21 Jun 2024 01:54:04 +0000 Subject: [PATCH] refactor: adjust ToolCallingAgent to use new IO and ChatHistory, added helper functions --- .../components/agents/ToolCallingAgent.py | 120 ++++++++++++------ 1 file changed, 78 insertions(+), 42 deletions(-) diff --git a/src/backend/base/langflow/components/agents/ToolCallingAgent.py b/src/backend/base/langflow/components/agents/ToolCallingAgent.py index 3783607dc..a2c0746c8 100644 --- a/src/backend/base/langflow/components/agents/ToolCallingAgent.py +++ b/src/backend/base/langflow/components/agents/ToolCallingAgent.py @@ -2,63 +2,99 @@ from typing import List, Optional from langchain.agents.tool_calling_agent.base import create_tool_calling_agent from langchain_core.prompts import ChatPromptTemplate - -from langflow.base.agents.agent import LCAgentComponent +from langchain.agents import AgentExecutor +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage +from langflow.schema.message import Message +from langflow.custom import Component +from langflow.io import HandleInput, TextInput, BoolInput, Output from langflow.field_typing import LanguageModel, Text, Tool from langflow.schema import Data -class ToolCallingAgentComponent(LCAgentComponent): +class ToolCallingAgentComponent(Component): display_name: str = "Tool Calling Agent" description: str = "Agent that uses tools. Only models that are compatible with function calling are supported." + icon = "Agent" - def build_config(self): - return { - "llm": {"display_name": "LLM"}, - "tools": {"display_name": "Tools"}, - "user_prompt": { - "display_name": "Prompt", - "multiline": True, - "info": "This prompt must contain 'input' key.", - }, - "handle_parsing_errors": { - "display_name": "Handle Parsing Errors", - "info": "If True, the agent will handle parsing errors. If False, the agent will raise an error.", - "advanced": True, - }, - "memory": { - "display_name": "Memory", - "info": "Memory to use for the agent.", - }, - "input_value": { - "display_name": "Inputs", - "info": "Input text to pass to the agent.", - }, - } + inputs = [ + HandleInput( + name="llm", + display_name="LLM", + input_types=["LanguageModel"], + ), + HandleInput( + name="tools", + display_name="Tools", + input_types=["Tool"], + is_list=True, + ), + TextInput( + name="user_prompt", + display_name="Prompt", + info="This prompt must contain 'input' key.", + value="{input}", + ), + BoolInput( + name="handle_parsing_errors", + display_name="Handle Parsing Errors", + info="If True, the agent will handle parsing errors. If False, the agent will raise an error.", + advanced=True, + value=True, + ), + HandleInput( + name="memory", + display_name="Memory", + input_types=["Data"], + info="Memory to use for the agent.", + ), + TextInput( + name="input_value", + display_name="Inputs", + info="Input text to pass to the agent.", + ), + ] - async def build( - self, - input_value: str, - llm: LanguageModel, - tools: List[Tool], - user_prompt: str = "{input}", - message_history: Optional[List[Data]] = None, - system_message: str = "You are a helpful assistant", - handle_parsing_errors: bool = True, - ) -> Text: - if "input" not in user_prompt: + outputs = [ + Output(display_name="Text", name="text_output", method="run_agent"), + ] + + async def run_agent(self) -> Message: + if "input" not in self.user_prompt: raise ValueError("Prompt must contain 'input' key.") messages = [ - ("system", system_message), + ("system", "You are a helpful assistant"), ( "placeholder", "{chat_history}", ), - ("human", user_prompt), + ("human", self.user_prompt), ("placeholder", "{agent_scratchpad}"), ] prompt = ChatPromptTemplate.from_messages(messages) - agent = create_tool_calling_agent(llm, tools, prompt) - result = await self.run_agent(agent, input_value, tools, message_history, handle_parsing_errors) + agent = create_tool_calling_agent(self.llm, self.tools, prompt) + + runnable = AgentExecutor.from_agent_and_tools( + agent=agent, + tools=self.tools, + verbose=True, + handle_parsing_errors=self.handle_parsing_errors, + ) + input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value} + if hasattr(self, "memory") and self.memory: + input_dict["chat_history"] = self.convert_chat_history(self.memory) + result = await runnable.ainvoke(input_dict) self.status = result - return result + + if "output" not in result: + raise ValueError("Output key not found in result. Tried 'output'.") + + result_string = result["output"] + + return Message(text=result_string) + + def convert_chat_history(self, chat_history: List[Data]) -> List[Dict[str, str]]: + messages = [] + for item in chat_history: + role = "user" if item.sender == "User" else "assistant" + messages.append({"role": role, "content": item.text}) + return messages