ref: refactors the agents around and fixes a few bugs (#2847)
This commit is contained in:
parent
a4d6f4ff5b
commit
9dadfd45b1
11 changed files with 76 additions and 52 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import List, Union, cast
|
||||
|
||||
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
|
|
@ -9,10 +9,9 @@ from langchain_core.runnables import Runnable
|
|||
from langflow.base.agents.callback import AgentAsyncHandler
|
||||
from langflow.base.agents.utils import data_to_messages
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Text, Tool
|
||||
from langflow.field_typing import Text
|
||||
from langflow.inputs.inputs import DataInput, InputTypes
|
||||
from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template import Output
|
||||
|
||||
|
|
@ -39,6 +38,7 @@ class LCAgentComponent(Component):
|
|||
value=15,
|
||||
advanced=True,
|
||||
),
|
||||
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
|
|
@ -46,15 +46,16 @@ class LCAgentComponent(Component):
|
|||
Output(display_name="Response", name="response", method="message_response"),
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def build_agent(self) -> AgentExecutor:
|
||||
"""Create the agent."""
|
||||
pass
|
||||
|
||||
async def message_response(self) -> Message:
|
||||
"""Run the agent and return the response."""
|
||||
agent = self.build_agent()
|
||||
result = await self.run_agent(
|
||||
agent=agent,
|
||||
inputs=self.input_value,
|
||||
tools=self.tools,
|
||||
message_history=self.chat_history,
|
||||
handle_parsing_errors=self.handle_parsing_errors,
|
||||
)
|
||||
result = await self.run_agent(agent=agent)
|
||||
|
||||
if isinstance(result, list):
|
||||
result = "\n".join([result_dict["text"] for result_dict in result])
|
||||
message = Message(text=result, sender="Machine")
|
||||
|
|
@ -87,27 +88,11 @@ class LCAgentComponent(Component):
|
|||
}
|
||||
return {**base, "agent_executor_kwargs": agent_kwargs}
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor],
|
||||
inputs: str,
|
||||
tools: List[Tool],
|
||||
message_history: Optional[List[Data]] = None,
|
||||
handle_parsing_errors: bool = True,
|
||||
) -> Text:
|
||||
if isinstance(agent, AgentExecutor):
|
||||
runnable = agent
|
||||
else:
|
||||
runnable = AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, # type: ignore
|
||||
tools=tools,
|
||||
verbose=True,
|
||||
handle_parsing_errors=handle_parsing_errors,
|
||||
)
|
||||
input_dict: dict[str, str | list[BaseMessage]] = {"input": inputs}
|
||||
if message_history:
|
||||
input_dict["chat_history"] = data_to_messages(message_history)
|
||||
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
|
||||
async def run_agent(self, agent: AgentExecutor) -> Text:
|
||||
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
|
||||
if self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
|
||||
self.status = result
|
||||
if "output" not in result:
|
||||
raise ValueError("Output key not found in result. Tried 'output'.")
|
||||
|
|
@ -123,24 +108,41 @@ class LCToolsAgentComponent(LCAgentComponent):
|
|||
input_types=["Tool", "BaseTool"],
|
||||
is_list=True,
|
||||
),
|
||||
HandleInput(
|
||||
name="llm",
|
||||
display_name="Language Model",
|
||||
input_types=["LanguageModel", "ToolEnabledLanguageModel"],
|
||||
required=True,
|
||||
),
|
||||
DataInput(name="chat_history", display_name="Chat History", is_list=True),
|
||||
]
|
||||
|
||||
def build_agent(self) -> AgentExecutor:
|
||||
agent = self.creat_agent_runnable()
|
||||
agent = self.create_agent_runnable()
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=RunnableAgent(runnable=agent, input_keys_arg=["input"], return_keys_arg=["output"]),
|
||||
tools=self.tools,
|
||||
**self.get_agent_kwargs(flatten=True),
|
||||
)
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor],
|
||||
) -> Text:
|
||||
if isinstance(agent, AgentExecutor):
|
||||
runnable = agent
|
||||
else:
|
||||
runnable = AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, # type: ignore
|
||||
tools=self.tools,
|
||||
handle_parsing_errors=self.handle_parsing_errors,
|
||||
verbose=self.verbose,
|
||||
max_iterations=self.max_iterations,
|
||||
)
|
||||
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
|
||||
if self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
result = await runnable.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
|
||||
self.status = result
|
||||
if "output" not in result:
|
||||
raise ValueError("Output key not found in result. Tried 'output'.")
|
||||
|
||||
return cast(str, result.get("output"))
|
||||
|
||||
@abstractmethod
|
||||
def creat_agent_runnable(self) -> Runnable:
|
||||
def create_agent_runnable(self) -> Runnable:
|
||||
"""Create the agent."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class BaseCrewComponent(Component):
|
|||
name="function_calling_llm",
|
||||
display_name="Function Calling LLM",
|
||||
input_types=["LanguageModel"],
|
||||
info="Turns the ReAct CrewAI agent into a function-calling agent",
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ class CSVAgentComponent(LCAgentComponent):
|
|||
name = "CSVAgent"
|
||||
|
||||
inputs = LCAgentComponent._base_inputs + [
|
||||
FileInput(name="path", display_name="File Path", file_types=["csv"], required=True),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
FileInput(name="path", display_name="File Path", file_types=["csv"], required=True),
|
||||
DropdownInput(
|
||||
name="agent_type",
|
||||
display_name="Agent Type",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ class HierarchicalCrewComponent(BaseCrewComponent):
|
|||
description: str = (
|
||||
"Represents a group of agents, defining how they should collaborate and the tasks they should perform."
|
||||
)
|
||||
documentation: str = "https://docs.crewai.com/how-to/Hierarchical/"
|
||||
icon = "CrewAI"
|
||||
|
||||
inputs = BaseCrewComponent._base_inputs + [
|
||||
HandleInput(name="agents", display_name="Agents", input_types=["Agent"], is_list=True),
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ class JsonAgentComponent(LCAgentComponent):
|
|||
name = "JsonAgent"
|
||||
|
||||
inputs = LCAgentComponent._base_inputs + [
|
||||
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
|
||||
]
|
||||
|
||||
def build_agent(self) -> AgentExecutor:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMess
|
|||
|
||||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.inputs import MultilineInput
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
|
||||
|
||||
class OpenAIToolsAgentComponent(LCToolsAgentComponent):
|
||||
|
|
@ -13,6 +14,12 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent):
|
|||
name = "OpenAIToolsAgent"
|
||||
|
||||
inputs = LCToolsAgentComponent._base_inputs + [
|
||||
HandleInput(
|
||||
name="llm",
|
||||
display_name="Language Model",
|
||||
input_types=["LanguageModel", "ToolEnabledLanguageModel"],
|
||||
required=True,
|
||||
),
|
||||
MultilineInput(
|
||||
name="system_prompt",
|
||||
display_name="System Prompt",
|
||||
|
|
@ -24,7 +31,7 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent):
|
|||
),
|
||||
]
|
||||
|
||||
def creat_agent_runnable(self):
|
||||
def create_agent_runnable(self):
|
||||
if "input" not in self.user_prompt:
|
||||
raise ValueError("Prompt must contain 'input' key.")
|
||||
messages = [
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ class OpenAPIAgentComponent(LCAgentComponent):
|
|||
name = "OpenAPIAgent"
|
||||
|
||||
inputs = LCAgentComponent._base_inputs + [
|
||||
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
FileInput(name="path", display_name="File Path", file_types=["json", "yaml", "yml"], required=True),
|
||||
BoolInput(name="allow_dangerous_requests", display_name="Allow Dangerous Requests", value=False, required=True),
|
||||
]
|
||||
|
||||
|
|
@ -37,6 +37,9 @@ class OpenAPIAgentComponent(LCAgentComponent):
|
|||
)
|
||||
|
||||
agent_args = self.get_agent_kwargs()
|
||||
|
||||
# This is bit weird - generally other create_*_agent functions have max_iterations in the
|
||||
# `agent_executor_kwargs`, but openai has this parameter passed directly.
|
||||
agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"]
|
||||
del agent_args["agent_executor_kwargs"]["max_iterations"]
|
||||
return create_openapi_agent(llm=self.llm, toolkit=toolkit, **agent_args)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,15 @@ class SQLAgentComponent(LCAgentComponent):
|
|||
name = "SQLAgent"
|
||||
|
||||
inputs = LCAgentComponent._base_inputs + [
|
||||
MessageTextInput(name="database_uri", display_name="Database URI", required=True),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
MessageTextInput(name="database_uri", display_name="Database URI", required=True),
|
||||
HandleInput(
|
||||
name="extra_tools",
|
||||
display_name="Extra Tools",
|
||||
input_types=["Tool", "BaseTool"],
|
||||
is_list=True,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def build_agent(self) -> AgentExecutor:
|
||||
|
|
@ -23,4 +30,4 @@ class SQLAgentComponent(LCAgentComponent):
|
|||
agent_args = self.get_agent_kwargs()
|
||||
agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"]
|
||||
del agent_args["agent_executor_kwargs"]["max_iterations"]
|
||||
return create_sql_agent(llm=self.llm, toolkit=toolkit, **agent_args)
|
||||
return create_sql_agent(llm=self.llm, toolkit=toolkit, extra_tools=self.extra_tools or [], **agent_args)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@ from langflow.schema.message import Message
|
|||
|
||||
class SequentialCrewComponent(BaseCrewComponent):
|
||||
display_name: str = "Sequential Crew"
|
||||
description: str = (
|
||||
"Represents a group of agents, defining how they should collaborate and the tasks they should perform."
|
||||
)
|
||||
documentation: str = "https://docs.crewai.com/how-to/LLM-Connections/"
|
||||
description: str = "Represents a group of agents with tasks that are executed sequentially."
|
||||
documentation: str = "https://docs.crewai.com/how-to/Sequential/"
|
||||
icon = "CrewAI"
|
||||
|
||||
inputs = BaseCrewComponent._base_inputs + [
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from langchain.agents import create_tool_calling_agent
|
|||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate
|
||||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.inputs import MultilineInput
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
|
||||
|
||||
class ToolCallingAgentComponent(LCToolsAgentComponent):
|
||||
|
|
@ -12,6 +13,7 @@ class ToolCallingAgentComponent(LCToolsAgentComponent):
|
|||
name = "ToolCallingAgent"
|
||||
|
||||
inputs = LCToolsAgentComponent._base_inputs + [
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
MultilineInput(
|
||||
name="system_prompt",
|
||||
display_name="System Prompt",
|
||||
|
|
@ -23,7 +25,7 @@ class ToolCallingAgentComponent(LCToolsAgentComponent):
|
|||
),
|
||||
]
|
||||
|
||||
def creat_agent_runnable(self):
|
||||
def create_agent_runnable(self):
|
||||
if "input" not in self.user_prompt:
|
||||
raise ValueError("Prompt must contain 'input' key.")
|
||||
messages = [
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMess
|
|||
|
||||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.inputs import MultilineInput
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
|
||||
|
||||
class XMLAgentComponent(LCToolsAgentComponent):
|
||||
|
|
@ -13,6 +14,7 @@ class XMLAgentComponent(LCToolsAgentComponent):
|
|||
name = "XMLAgent"
|
||||
|
||||
inputs = LCToolsAgentComponent._base_inputs + [
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
MultilineInput(
|
||||
name="user_prompt",
|
||||
display_name="Prompt",
|
||||
|
|
@ -44,7 +46,7 @@ Question: {input}
|
|||
),
|
||||
]
|
||||
|
||||
def creat_agent_runnable(self):
|
||||
def create_agent_runnable(self):
|
||||
messages = [
|
||||
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt))
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue