fix: Update ToolCallingAgentComponent

This commit is contained in:
Rodrigo 2024-06-23 00:51:13 -03:00
commit 2f7d24fb11

View file

@ -1,38 +1,40 @@
from typing import List, cast
from typing import List
from langchain.agents import AgentExecutor, BaseSingleActionAgent
from langchain.agents.tool_calling_agent.base import create_tool_calling_agent
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langflow.custom import Component
from langflow.io import BoolInput, HandleInput, Output, TextInput
from langflow.schema import Data
from langchain.agents import AgentExecutor
from langchain_core.messages import BaseMessage
from langflow.schema.message import Message
from langflow.custom import Component
from langflow.io import HandleInput, TextInput, BoolInput, Output
from langflow.schema import Data
class ToolCallingAgentComponent(Component):
display_name: str = "Tool Calling Agent"
description: str = "Agent that uses tools. Only models that are compatible with function calling are supported."
icon = "Agent"
description: str = (
"Agent that uses tools. Only models that are compatible with function calling are supported."
)
icon = "LangChain"
inputs = [
HandleInput(
name="llm",
display_name="LLM",
input_types=["LanguageModel"],
TextInput(
name="system_prompt",
display_name="System Prompt",
info="System prompt for the agent.",
value="You are a helpful assistant",
),
HandleInput(
name="tools",
display_name="Tools",
input_types=["Tool"],
is_list=True,
TextInput(
name="input_value",
display_name="Inputs",
info="Input text to pass to the agent.",
),
TextInput(
name="user_prompt",
display_name="Prompt",
info="This prompt must contain 'input' key.",
value="{input}",
advanced=True,
),
BoolInput(
name="handle_parsing_errors",
@ -47,10 +49,16 @@ class ToolCallingAgentComponent(Component):
input_types=["Data"],
info="Memory to use for the agent.",
),
TextInput(
name="input_value",
display_name="Inputs",
info="Input text to pass to the agent.",
HandleInput(
name="tools",
display_name="Tools",
input_types=["Tool"],
is_list=True,
),
HandleInput(
name="llm",
display_name="LLM",
input_types=["LanguageModel"],
),
]
@ -62,7 +70,7 @@ class ToolCallingAgentComponent(Component):
if "input" not in self.user_prompt:
raise ValueError("Prompt must contain 'input' key.")
messages = [
("system", "You are a helpful assistant"),
("system", self.system_prompt),
(
"placeholder",
"{chat_history}",
@ -74,7 +82,7 @@ class ToolCallingAgentComponent(Component):
agent = create_tool_calling_agent(self.llm, self.tools, prompt)
runnable = AgentExecutor.from_agent_and_tools(
agent=cast(BaseSingleActionAgent, agent),
agent=agent,
tools=self.tools,
verbose=True,
handle_parsing_errors=self.handle_parsing_errors,
@ -92,16 +100,9 @@ class ToolCallingAgentComponent(Component):
return Message(text=result_string)
def convert_chat_history(self, chat_history: List[Message | Data]) -> List[BaseMessage]:
def convert_chat_history(self, chat_history: List[Data]) -> List[Dict[str, str]]:
messages = []
for item in chat_history:
if isinstance(item, (Message, Data)):
messages.append(item.to_lc_message())
elif isinstance(item, str):
messages.append(HumanMessage(content=item))
elif isinstance(item, dict) and "sender" in item and "text" in item:
messages.append(HumanMessage(content=item["text"], sender=item["sender"]))
else:
raise ValueError(f"Invalid item ({type(item)}) in chat history: {item}")
role = "user" if item.sender == "User" else "assistant"
messages.append({"role": role, "content": item.text})
return messages