ref: refactors the agents around and fixes a few bugs (#2847)

This commit is contained in:
Jordan Frazier 2024-07-22 11:06:37 -07:00 committed by GitHub
commit 9dadfd45b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 76 additions and 52 deletions

View file

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional, Union, cast
from typing import List, Union, cast
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
from langchain.agents.agent import RunnableAgent
@ -9,10 +9,9 @@ from langchain_core.runnables import Runnable
from langflow.base.agents.callback import AgentAsyncHandler
from langflow.base.agents.utils import data_to_messages
from langflow.custom import Component
from langflow.field_typing import Text, Tool
from langflow.field_typing import Text
from langflow.inputs.inputs import DataInput, InputTypes
from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.template import Output
@ -39,6 +38,7 @@ class LCAgentComponent(Component):
value=15,
advanced=True,
),
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
]
outputs = [
@ -46,15 +46,16 @@ class LCAgentComponent(Component):
Output(display_name="Response", name="response", method="message_response"),
]
@abstractmethod
def build_agent(self) -> AgentExecutor:
"""Create the agent."""
pass
async def message_response(self) -> Message:
"""Run the agent and return the response."""
agent = self.build_agent()
result = await self.run_agent(
agent=agent,
inputs=self.input_value,
tools=self.tools,
message_history=self.chat_history,
handle_parsing_errors=self.handle_parsing_errors,
)
result = await self.run_agent(agent=agent)
if isinstance(result, list):
result = "\n".join([result_dict["text"] for result_dict in result])
message = Message(text=result, sender="Machine")
@ -87,27 +88,11 @@ class LCAgentComponent(Component):
}
return {**base, "agent_executor_kwargs": agent_kwargs}
async def run_agent(
self,
agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor],
inputs: str,
tools: List[Tool],
message_history: Optional[List[Data]] = None,
handle_parsing_errors: bool = True,
) -> Text:
if isinstance(agent, AgentExecutor):
runnable = agent
else:
runnable = AgentExecutor.from_agent_and_tools(
agent=agent, # type: ignore
tools=tools,
verbose=True,
handle_parsing_errors=handle_parsing_errors,
)
input_dict: dict[str, str | list[BaseMessage]] = {"input": inputs}
if message_history:
input_dict["chat_history"] = data_to_messages(message_history)
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
async def run_agent(self, agent: AgentExecutor) -> Text:
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
self.status = result
if "output" not in result:
raise ValueError("Output key not found in result. Tried 'output'.")
@ -123,24 +108,41 @@ class LCToolsAgentComponent(LCAgentComponent):
input_types=["Tool", "BaseTool"],
is_list=True,
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel", "ToolEnabledLanguageModel"],
required=True,
),
DataInput(name="chat_history", display_name="Chat History", is_list=True),
]
def build_agent(self) -> AgentExecutor:
agent = self.creat_agent_runnable()
agent = self.create_agent_runnable()
return AgentExecutor.from_agent_and_tools(
agent=RunnableAgent(runnable=agent, input_keys_arg=["input"], return_keys_arg=["output"]),
tools=self.tools,
**self.get_agent_kwargs(flatten=True),
)
async def run_agent(
self,
agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor],
) -> Text:
if isinstance(agent, AgentExecutor):
runnable = agent
else:
runnable = AgentExecutor.from_agent_and_tools(
agent=agent, # type: ignore
tools=self.tools,
handle_parsing_errors=self.handle_parsing_errors,
verbose=self.verbose,
max_iterations=self.max_iterations,
)
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
self.status = result
if "output" not in result:
raise ValueError("Output key not found in result. Tried 'output'.")
return cast(str, result.get("output"))
@abstractmethod
def creat_agent_runnable(self) -> Runnable:
def create_agent_runnable(self) -> Runnable:
"""Create the agent."""
pass

View file

@ -27,6 +27,7 @@ class BaseCrewComponent(Component):
name="function_calling_llm",
display_name="Function Calling LLM",
input_types=["LanguageModel"],
info="Turns the ReAct CrewAI agent into a function-calling agent",
required=False,
advanced=True,
),

View file

@ -12,8 +12,8 @@ class CSVAgentComponent(LCAgentComponent):
name = "CSVAgent"
inputs = LCAgentComponent._base_inputs + [
FileInput(name="path", display_name="File Path", file_types=["csv"], required=True),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
FileInput(name="path", display_name="File Path", file_types=["csv"], required=True),
DropdownInput(
name="agent_type",
display_name="Agent Type",

View file

@ -9,6 +9,8 @@ class HierarchicalCrewComponent(BaseCrewComponent):
description: str = (
"Represents a group of agents, defining how they should collaborate and the tasks they should perform."
)
documentation: str = "https://docs.crewai.com/how-to/Hierarchical/"
icon = "CrewAI"
inputs = BaseCrewComponent._base_inputs + [
HandleInput(name="agents", display_name="Agents", input_types=["Agent"], is_list=True),

View file

@ -16,8 +16,8 @@ class JsonAgentComponent(LCAgentComponent):
name = "JsonAgent"
inputs = LCAgentComponent._base_inputs + [
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
]
def build_agent(self) -> AgentExecutor:

View file

@ -3,6 +3,7 @@ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMess
from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.inputs import MultilineInput
from langflow.inputs.inputs import HandleInput
class OpenAIToolsAgentComponent(LCToolsAgentComponent):
@ -13,6 +14,12 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent):
name = "OpenAIToolsAgent"
inputs = LCToolsAgentComponent._base_inputs + [
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel", "ToolEnabledLanguageModel"],
required=True,
),
MultilineInput(
name="system_prompt",
display_name="System Prompt",
@ -24,7 +31,7 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent):
),
]
def creat_agent_runnable(self):
def create_agent_runnable(self):
if "input" not in self.user_prompt:
raise ValueError("Prompt must contain 'input' key.")
messages = [

View file

@ -17,8 +17,8 @@ class OpenAPIAgentComponent(LCAgentComponent):
name = "OpenAPIAgent"
inputs = LCAgentComponent._base_inputs + [
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
BoolInput(name="allow_dangerous_requests", display_name="Allow Dangerous Requests", value=False, required=True),
]
@ -37,6 +37,9 @@ class OpenAPIAgentComponent(LCAgentComponent):
)
agent_args = self.get_agent_kwargs()
# This is bit weird - generally other create_*_agent functions have max_iterations in the
# `agent_executor_kwargs`, but openai has this parameter passed directly.
agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"]
del agent_args["agent_executor_kwargs"]["max_iterations"]
return create_openapi_agent(llm=self.llm, toolkit=toolkit, **agent_args)

View file

@ -13,8 +13,15 @@ class SQLAgentComponent(LCAgentComponent):
name = "SQLAgent"
inputs = LCAgentComponent._base_inputs + [
MessageTextInput(name="database_uri", display_name="Database URI", required=True),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
MessageTextInput(name="database_uri", display_name="Database URI", required=True),
HandleInput(
name="extra_tools",
display_name="Extra Tools",
input_types=["Tool", "BaseTool"],
is_list=True,
advanced=True,
),
]
def build_agent(self) -> AgentExecutor:
@ -23,4 +30,4 @@ class SQLAgentComponent(LCAgentComponent):
agent_args = self.get_agent_kwargs()
agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"]
del agent_args["agent_executor_kwargs"]["max_iterations"]
return create_sql_agent(llm=self.llm, toolkit=toolkit, **agent_args)
return create_sql_agent(llm=self.llm, toolkit=toolkit, extra_tools=self.extra_tools or [], **agent_args)

View file

@ -7,10 +7,8 @@ from langflow.schema.message import Message
class SequentialCrewComponent(BaseCrewComponent):
display_name: str = "Sequential Crew"
description: str = (
"Represents a group of agents, defining how they should collaborate and the tasks they should perform."
)
documentation: str = "https://docs.crewai.com/how-to/LLM-Connections/"
description: str = "Represents a group of agents with tasks that are executed sequentially."
documentation: str = "https://docs.crewai.com/how-to/Sequential/"
icon = "CrewAI"
inputs = BaseCrewComponent._base_inputs + [

View file

@ -2,6 +2,7 @@ from langchain.agents import create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate
from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.inputs import MultilineInput
from langflow.inputs.inputs import HandleInput
class ToolCallingAgentComponent(LCToolsAgentComponent):
@ -12,6 +13,7 @@ class ToolCallingAgentComponent(LCToolsAgentComponent):
name = "ToolCallingAgent"
inputs = LCToolsAgentComponent._base_inputs + [
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
MultilineInput(
name="system_prompt",
display_name="System Prompt",
@ -23,7 +25,7 @@ class ToolCallingAgentComponent(LCToolsAgentComponent):
),
]
def creat_agent_runnable(self):
def create_agent_runnable(self):
if "input" not in self.user_prompt:
raise ValueError("Prompt must contain 'input' key.")
messages = [

View file

@ -3,6 +3,7 @@ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMess
from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.inputs import MultilineInput
from langflow.inputs.inputs import HandleInput
class XMLAgentComponent(LCToolsAgentComponent):
@ -13,6 +14,7 @@ class XMLAgentComponent(LCToolsAgentComponent):
name = "XMLAgent"
inputs = LCToolsAgentComponent._base_inputs + [
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
MultilineInput(
name="user_prompt",
display_name="Prompt",
@ -44,7 +46,7 @@ Question: {input}
),
]
def creat_agent_runnable(self):
def create_agent_runnable(self):
messages = [
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt))
]