From 9dadfd45b16f2ae8e97575f67e7350b4720d6ca6 Mon Sep 17 00:00:00 2001 From: Jordan Frazier <122494242+jordanrfrazier@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:06:37 -0700 Subject: [PATCH] ref: refactors the agents around and fixes a few bugs (#2847) --- .../base/langflow/base/agents/agent.py | 82 ++++++++++--------- .../base/langflow/base/agents/crewai/crew.py | 1 + .../langflow/components/agents/CSVAgent.py | 2 +- .../components/agents/HierarchicalCrew.py | 2 + .../langflow/components/agents/JsonAgent.py | 2 +- .../components/agents/OpenAIToolsAgent.py | 9 +- .../components/agents/OpenAPIAgent.py | 5 +- .../langflow/components/agents/SQLAgent.py | 11 ++- .../components/agents/SequentialCrew.py | 6 +- .../components/agents/ToolCallingAgent.py | 4 +- .../langflow/components/agents/XMLAgent.py | 4 +- 11 files changed, 76 insertions(+), 52 deletions(-) diff --git a/src/backend/base/langflow/base/agents/agent.py b/src/backend/base/langflow/base/agents/agent.py index eeb48d60c..69fc0f4c9 100644 --- a/src/backend/base/langflow/base/agents/agent.py +++ b/src/backend/base/langflow/base/agents/agent.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional, Union, cast +from typing import List, Union, cast from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent from langchain.agents.agent import RunnableAgent @@ -9,10 +9,9 @@ from langchain_core.runnables import Runnable from langflow.base.agents.callback import AgentAsyncHandler from langflow.base.agents.utils import data_to_messages from langflow.custom import Component -from langflow.field_typing import Text, Tool +from langflow.field_typing import Text from langflow.inputs.inputs import DataInput, InputTypes from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput -from langflow.schema import Data from langflow.schema.message import Message from langflow.template import Output @@ -39,6 +38,7 @@ class LCAgentComponent(Component): value=15, advanced=True, ), + DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True), ] outputs = [ @@ -46,15 +46,16 @@ class LCAgentComponent(Component): Output(display_name="Response", name="response", method="message_response"), ] + @abstractmethod + def build_agent(self) -> AgentExecutor: + """Create the agent.""" + pass + async def message_response(self) -> Message: + """Run the agent and return the response.""" agent = self.build_agent() - result = await self.run_agent( - agent=agent, - inputs=self.input_value, - tools=self.tools, - message_history=self.chat_history, - handle_parsing_errors=self.handle_parsing_errors, - ) + result = await self.run_agent(agent=agent) + if isinstance(result, list): result = "\n".join([result_dict["text"] for result_dict in result]) message = Message(text=result, sender="Machine") @@ -87,27 +88,11 @@ class LCAgentComponent(Component): } return {**base, "agent_executor_kwargs": agent_kwargs} - async def run_agent( - self, - agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor], - inputs: str, - tools: List[Tool], - message_history: Optional[List[Data]] = None, - handle_parsing_errors: bool = True, - ) -> Text: - if isinstance(agent, AgentExecutor): - runnable = agent - else: - runnable = AgentExecutor.from_agent_and_tools( - agent=agent, # type: ignore - tools=tools, - verbose=True, - handle_parsing_errors=handle_parsing_errors, - ) - input_dict: dict[str, str | list[BaseMessage]] = {"input": inputs} - if message_history: - input_dict["chat_history"] = data_to_messages(message_history) - result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]}) + async def run_agent(self, agent: AgentExecutor) -> Text: + input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value} + if self.chat_history: + input_dict["chat_history"] = data_to_messages(self.chat_history) + result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]}) self.status = result if "output" not in result: raise ValueError("Output key not found in result. Tried 'output'.") @@ -123,24 +108,41 @@ class LCToolsAgentComponent(LCAgentComponent): input_types=["Tool", "BaseTool"], is_list=True, ), - HandleInput( - name="llm", - display_name="Language Model", - input_types=["LanguageModel", "ToolEnabledLanguageModel"], - required=True, - ), - DataInput(name="chat_history", display_name="Chat History", is_list=True), ] def build_agent(self) -> AgentExecutor: - agent = self.creat_agent_runnable() + agent = self.create_agent_runnable() return AgentExecutor.from_agent_and_tools( agent=RunnableAgent(runnable=agent, input_keys_arg=["input"], return_keys_arg=["output"]), tools=self.tools, **self.get_agent_kwargs(flatten=True), ) + async def run_agent( + self, + agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor], + ) -> Text: + if isinstance(agent, AgentExecutor): + runnable = agent + else: + runnable = AgentExecutor.from_agent_and_tools( + agent=agent, # type: ignore + tools=self.tools, + handle_parsing_errors=self.handle_parsing_errors, + verbose=self.verbose, + max_iterations=self.max_iterations, + ) + input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value} + if self.chat_history: + input_dict["chat_history"] = data_to_messages(self.chat_history) + result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]}) + self.status = result + if "output" not in result: + raise ValueError("Output key not found in result. Tried 'output'.") + + return cast(str, result.get("output")) + @abstractmethod - def creat_agent_runnable(self) -> Runnable: + def create_agent_runnable(self) -> Runnable: """Create the agent.""" pass diff --git a/src/backend/base/langflow/base/agents/crewai/crew.py b/src/backend/base/langflow/base/agents/crewai/crew.py index 70137518c..4f4a7a567 100644 --- a/src/backend/base/langflow/base/agents/crewai/crew.py +++ b/src/backend/base/langflow/base/agents/crewai/crew.py @@ -27,6 +27,7 @@ class BaseCrewComponent(Component): name="function_calling_llm", display_name="Function Calling LLM", input_types=["LanguageModel"], + info="Turns the ReAct CrewAI agent into a function-calling agent", required=False, advanced=True, ), diff --git a/src/backend/base/langflow/components/agents/CSVAgent.py b/src/backend/base/langflow/components/agents/CSVAgent.py index c88685521..f28691ac3 100644 --- a/src/backend/base/langflow/components/agents/CSVAgent.py +++ b/src/backend/base/langflow/components/agents/CSVAgent.py @@ -12,8 +12,8 @@ class CSVAgentComponent(LCAgentComponent): name = "CSVAgent" inputs = LCAgentComponent._base_inputs + [ - FileInput(name="path", display_name="File Path", file_types=["csv"], required=True), HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + FileInput(name="path", display_name="File Path", file_types=["csv"], required=True), DropdownInput( name="agent_type", display_name="Agent Type", diff --git a/src/backend/base/langflow/components/agents/HierarchicalCrew.py b/src/backend/base/langflow/components/agents/HierarchicalCrew.py index b5160b2e5..a69a12554 100644 --- a/src/backend/base/langflow/components/agents/HierarchicalCrew.py +++ b/src/backend/base/langflow/components/agents/HierarchicalCrew.py @@ -9,6 +9,8 @@ class HierarchicalCrewComponent(BaseCrewComponent): description: str = ( "Represents a group of agents, defining how they should collaborate and the tasks they should perform." ) + documentation: str = "https://docs.crewai.com/how-to/Hierarchical/" + icon = "CrewAI" inputs = BaseCrewComponent._base_inputs + [ HandleInput(name="agents", display_name="Agents", input_types=["Agent"], is_list=True), diff --git a/src/backend/base/langflow/components/agents/JsonAgent.py b/src/backend/base/langflow/components/agents/JsonAgent.py index f79d2dba5..56bdacf14 100644 --- a/src/backend/base/langflow/components/agents/JsonAgent.py +++ b/src/backend/base/langflow/components/agents/JsonAgent.py @@ -16,8 +16,8 @@ class JsonAgentComponent(LCAgentComponent): name = "JsonAgent" inputs = LCAgentComponent._base_inputs + [ - FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True), HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True), ] def build_agent(self) -> AgentExecutor: diff --git a/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py b/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py index 45eafc430..0676821da 100644 --- a/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py +++ b/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py @@ -3,6 +3,7 @@ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMess from langflow.base.agents.agent import LCToolsAgentComponent from langflow.inputs import MultilineInput +from langflow.inputs.inputs import HandleInput class OpenAIToolsAgentComponent(LCToolsAgentComponent): @@ -13,6 +14,12 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent): name = "OpenAIToolsAgent" inputs = LCToolsAgentComponent._base_inputs + [ + HandleInput( + name="llm", + display_name="Language Model", + input_types=["LanguageModel", "ToolEnabledLanguageModel"], + required=True, + ), MultilineInput( name="system_prompt", display_name="System Prompt", @@ -24,7 +31,7 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent): ), ] - def creat_agent_runnable(self): + def create_agent_runnable(self): if "input" not in self.user_prompt: raise ValueError("Prompt must contain 'input' key.") messages = [ diff --git a/src/backend/base/langflow/components/agents/OpenAPIAgent.py b/src/backend/base/langflow/components/agents/OpenAPIAgent.py index f453ada49..798057fdb 100644 --- a/src/backend/base/langflow/components/agents/OpenAPIAgent.py +++ b/src/backend/base/langflow/components/agents/OpenAPIAgent.py @@ -17,8 +17,8 @@ class OpenAPIAgentComponent(LCAgentComponent): name = "OpenAPIAgent" inputs = LCAgentComponent._base_inputs + [ - FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True), HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True), BoolInput(name="allow_dangerous_requests", display_name="Allow Dangerous Requests", value=False, required=True), ] @@ -37,6 +37,9 @@ class OpenAPIAgentComponent(LCAgentComponent): ) agent_args = self.get_agent_kwargs() + + # This is bit weird - generally other create_*_agent functions have max_iterations in the + # `agent_executor_kwargs`, but openai has this parameter passed directly. agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"] del agent_args["agent_executor_kwargs"]["max_iterations"] return create_openapi_agent(llm=self.llm, toolkit=toolkit, **agent_args) diff --git a/src/backend/base/langflow/components/agents/SQLAgent.py b/src/backend/base/langflow/components/agents/SQLAgent.py index 9db94deff..6653fbdfa 100644 --- a/src/backend/base/langflow/components/agents/SQLAgent.py +++ b/src/backend/base/langflow/components/agents/SQLAgent.py @@ -13,8 +13,15 @@ class SQLAgentComponent(LCAgentComponent): name = "SQLAgent" inputs = LCAgentComponent._base_inputs + [ - MessageTextInput(name="database_uri", display_name="Database URI", required=True), HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + MessageTextInput(name="database_uri", display_name="Database URI", required=True), + HandleInput( + name="extra_tools", + display_name="Extra Tools", + input_types=["Tool", "BaseTool"], + is_list=True, + advanced=True, + ), ] def build_agent(self) -> AgentExecutor: @@ -23,4 +30,4 @@ class SQLAgentComponent(LCAgentComponent): agent_args = self.get_agent_kwargs() agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"] del agent_args["agent_executor_kwargs"]["max_iterations"] - return create_sql_agent(llm=self.llm, toolkit=toolkit, **agent_args) + return create_sql_agent(llm=self.llm, toolkit=toolkit, extra_tools=self.extra_tools or [], **agent_args) diff --git a/src/backend/base/langflow/components/agents/SequentialCrew.py b/src/backend/base/langflow/components/agents/SequentialCrew.py index 600d66883..858be7e82 100644 --- a/src/backend/base/langflow/components/agents/SequentialCrew.py +++ b/src/backend/base/langflow/components/agents/SequentialCrew.py @@ -7,10 +7,8 @@ from langflow.schema.message import Message class SequentialCrewComponent(BaseCrewComponent): display_name: str = "Sequential Crew" - description: str = ( - "Represents a group of agents, defining how they should collaborate and the tasks they should perform." - ) - documentation: str = "https://docs.crewai.com/how-to/LLM-Connections/" + description: str = "Represents a group of agents with tasks that are executed sequentially." + documentation: str = "https://docs.crewai.com/how-to/Sequential/" icon = "CrewAI" inputs = BaseCrewComponent._base_inputs + [ diff --git a/src/backend/base/langflow/components/agents/ToolCallingAgent.py b/src/backend/base/langflow/components/agents/ToolCallingAgent.py index f77d4ac69..aa1c4dcd7 100644 --- a/src/backend/base/langflow/components/agents/ToolCallingAgent.py +++ b/src/backend/base/langflow/components/agents/ToolCallingAgent.py @@ -2,6 +2,7 @@ from langchain.agents import create_tool_calling_agent from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate from langflow.base.agents.agent import LCToolsAgentComponent from langflow.inputs import MultilineInput +from langflow.inputs.inputs import HandleInput class ToolCallingAgentComponent(LCToolsAgentComponent): @@ -12,6 +13,7 @@ class ToolCallingAgentComponent(LCToolsAgentComponent): name = "ToolCallingAgent" inputs = LCToolsAgentComponent._base_inputs + [ + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), MultilineInput( name="system_prompt", display_name="System Prompt", @@ -23,7 +25,7 @@ class ToolCallingAgentComponent(LCToolsAgentComponent): ), ] - def creat_agent_runnable(self): + def create_agent_runnable(self): if "input" not in self.user_prompt: raise ValueError("Prompt must contain 'input' key.") messages = [ diff --git a/src/backend/base/langflow/components/agents/XMLAgent.py b/src/backend/base/langflow/components/agents/XMLAgent.py index cc8a50e49..98464b627 100644 --- a/src/backend/base/langflow/components/agents/XMLAgent.py +++ b/src/backend/base/langflow/components/agents/XMLAgent.py @@ -3,6 +3,7 @@ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMess from langflow.base.agents.agent import LCToolsAgentComponent from langflow.inputs import MultilineInput +from langflow.inputs.inputs import HandleInput class XMLAgentComponent(LCToolsAgentComponent): @@ -13,6 +14,7 @@ class XMLAgentComponent(LCToolsAgentComponent): name = "XMLAgent" inputs = LCToolsAgentComponent._base_inputs + [ + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), MultilineInput( name="user_prompt", display_name="Prompt", @@ -44,7 +46,7 @@ Question: {input} ), ] - def creat_agent_runnable(self): + def create_agent_runnable(self): messages = [ HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)) ]