Add Experimental Agent Component (#1705)

* Update langflow base prompts API utils and add ToolCallingAgentComponent

* Update return type annotations in AzureOpenAIModel.py and ChatLiteLLMModel.py

* Update langchainhub package version to 0.1.15

* Update langflow base prompts API utils and add ToolCallingAgentComponent

* Add AgentComponent to langflow experimental components

* Update prompt variable name to user_prompt in ToolCallingAgentComponent.py

* Update prompt variable name to system_message in AgentComponent.py

* Update system_message variable name in XMLAgentComponent and ToolCallingAgentComponent

* Update prompt variable name to user_prompt in ToolCallingAgentComponent.py
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-16 13:57:05 -03:00 committed by GitHub
commit 42e88b7a23
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 695 additions and 191 deletions

View file

@ -1,13 +1,19 @@
from typing import List, Optional, Union, cast
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from langflow.base.agents.utils import get_agents_list, records_to_messages
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, Text, Tool
from langflow.field_typing import Text, Tool
from langflow.schema.schema import Record
class LCAgentComponent(CustomComponent):
def get_agents_list(self):
return get_agents_list()
def build_config(self):
return {
"lc": {
@ -42,9 +48,8 @@ class LCAgentComponent(CustomComponent):
self,
agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor],
inputs: str,
input_variables: list[str],
tools: List[Tool],
memory: Optional[BaseMemory] = None,
message_history: Optional[List[Record]] = None,
handle_parsing_errors: bool = True,
output_key: str = "output",
) -> Text:
@ -55,13 +60,11 @@ class LCAgentComponent(CustomComponent):
agent=agent, # type: ignore
tools=tools,
verbose=True,
memory=memory,
handle_parsing_errors=handle_parsing_errors,
)
input_dict = {"input": inputs}
for var in input_variables:
if var not in ["agent_scratchpad", "input"]:
input_dict[var] = ""
input_dict: dict[str, str | list[BaseMessage]] = {"input": inputs}
if message_history:
input_dict["chat_history"] = records_to_messages(message_history)
result = await runnable.ainvoke(input_dict)
self.status = result
if output_key in result:

View file

@ -0,0 +1,23 @@
XML_AGENT_PROMPT = """You are a helpful assistant. Help the user answer any questions.
You have access to the following tools:
{tools}
In order to use a tool, you can use <tool></tool> and <tool_input></tool_input> tags. You will then get back a response in the form <observation></observation>
For example, if you have a tool called 'search' that could run a google search, in order to search for the weather in SF you would respond:
<tool>search</tool><tool_input>weather in SF</tool_input>
<observation>64 degrees</observation>
When you are done, respond with a final answer between <final_answer></final_answer>. For example:
<final_answer>The weather in SF is 64 degrees</final_answer>
Begin!
Previous Conversation:
{chat_history}
Question: {input}
{agent_scratchpad}"""

View file

@ -0,0 +1,143 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain.agents import (
create_json_chat_agent,
create_openai_tools_agent,
create_tool_calling_agent,
create_xml_agent,
)
from langchain.agents.xml.base import render_text_description
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from langflow.schema.schema import Record
from .default_prompts import XML_AGENT_PROMPT
class AgentSpec(BaseModel):
func: Callable[
[
BaseLanguageModel,
Sequence[BaseTool],
BasePromptTemplate | ChatPromptTemplate,
Optional[Callable[[List[BaseTool]], str]],
Optional[Union[bool, List[str]]],
],
Any,
]
prompt: Optional[Any] = None
fields: List[str]
hub_repo: Optional[str] = None
def records_to_messages(records: List[Record]) -> List[BaseMessage]:
"""
Convert a list of records to a list of messages.
Args:
records (List[Record]): The records to convert.
Returns:
List[Message]: The records as messages.
"""
return [record.to_lc_message() for record in records]
def validate_and_create_xml_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: BasePromptTemplate,
tools_renderer: Callable[[List[BaseTool]], str] = render_text_description,
*,
stop_sequence: Union[bool, List[str]] = True,
):
return create_xml_agent(
llm=llm,
tools=tools,
prompt=prompt,
tools_renderer=tools_renderer,
stop_sequence=stop_sequence,
)
def validate_and_create_openai_tools_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
tools_renderer: Callable[[List[BaseTool]], str] = render_text_description,
*,
stop_sequence: Union[bool, List[str]] = True,
):
return create_openai_tools_agent(
llm=llm,
tools=tools,
prompt=prompt,
)
def validate_and_create_tool_calling_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
tools_renderer: Callable[[List[BaseTool]], str] = render_text_description,
*,
stop_sequence: Union[bool, List[str]] = True,
):
return create_tool_calling_agent(
llm=llm,
tools=tools,
prompt=prompt,
)
def validate_and_create_json_chat_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
tools_renderer: Callable[[List[BaseTool]], str] = render_text_description,
*,
stop_sequence: Union[bool, List[str]] = True,
):
return create_json_chat_agent(
llm=llm,
tools=tools,
prompt=prompt,
tools_renderer=tools_renderer,
stop_sequence=stop_sequence,
)
AGENTS: Dict[str, AgentSpec] = {
"Tool Calling Agent": AgentSpec(
func=validate_and_create_tool_calling_agent,
prompt=None,
fields=["llm", "tools", "prompt"],
hub_repo=None,
),
"XML Agent": AgentSpec(
func=validate_and_create_xml_agent,
prompt=XML_AGENT_PROMPT, # Ensure XML_AGENT_PROMPT is properly defined and typed.
fields=["llm", "tools", "prompt", "tools_renderer", "stop_sequence"],
hub_repo="hwchase17/xml-agent-convo",
),
"OpenAI Tools Agent": AgentSpec(
func=validate_and_create_openai_tools_agent,
prompt=None,
fields=["llm", "tools", "prompt"],
hub_repo=None,
),
"JSON Chat Agent": AgentSpec(
func=validate_and_create_json_chat_agent,
prompt=None,
fields=["llm", "tools", "prompt", "tools_renderer", "stop_sequence"],
hub_repo="hwchase17/react-chat-json",
),
}
def get_agents_list():
return list(AGENTS.keys())

View file

@ -0,0 +1,64 @@
from typing import List, Optional
from langchain.agents.tool_calling_agent.base import create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langflow.base.agents.agent import LCAgentComponent
from langflow.field_typing import BaseLanguageModel, Text, Tool
from langflow.schema.schema import Record
class ToolCallingAgentComponent(LCAgentComponent):
display_name: str = "Tool Calling Agent"
description: str = "Agent that uses tools. Only models that are compatible with function calling are supported."
def build_config(self):
return {
"llm": {"display_name": "LLM"},
"tools": {"display_name": "Tools"},
"user_prompt": {
"display_name": "Prompt",
"multiline": True,
"info": "This prompt must contain 'input' key.",
},
"handle_parsing_errors": {
"display_name": "Handle Parsing Errors",
"info": "If True, the agent will handle parsing errors. If False, the agent will raise an error.",
"advanced": True,
},
"memory": {
"display_name": "Memory",
"info": "Memory to use for the agent.",
},
"input_value": {
"display_name": "Inputs",
"info": "Input text to pass to the agent.",
},
}
async def build(
self,
input_value: str,
llm: BaseLanguageModel,
tools: List[Tool],
user_prompt: str = "{input}",
message_history: Optional[List[Record]] = None,
system_message: str = "You are a helpful assistant",
handle_parsing_errors: bool = True,
) -> Text:
if "input" not in user_prompt:
raise ValueError("Prompt must contain 'input' key.")
messages = [
("system", system_message),
(
"placeholder",
"{chat_history}",
),
("human", user_prompt),
("placeholder", "{agent_scratchpad}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = create_tool_calling_agent(llm, tools, prompt)
result = await self.run_agent(agent, input_value, tools, message_history, handle_parsing_errors)
self.status = result
return result

View file

@ -1,10 +1,12 @@
from typing import List, Optional
from langchain.agents import create_xml_agent
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langflow.base.agents.agent import LCAgentComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text, Tool
from langflow.field_typing import BaseLanguageModel, Text, Tool
from langflow.schema.schema import Record
class XMLAgentComponent(LCAgentComponent):
@ -15,7 +17,7 @@ class XMLAgentComponent(LCAgentComponent):
return {
"llm": {"display_name": "LLM"},
"tools": {"display_name": "Tools"},
"prompt": {
"user_prompt": {
"display_name": "Prompt",
"multiline": True,
"info": "This prompt must contain 'tools' and 'agent_scratchpad' keys.",
@ -43,6 +45,11 @@ class XMLAgentComponent(LCAgentComponent):
Question: {input}
{agent_scratchpad}""",
},
"system_message": {
"display_name": "System Message",
"info": "System message to be passed to the LLM.",
"advanced": True,
},
"tool_template": {
"display_name": "Tool Template",
"info": "Template for rendering tools in the prompt. Tools have 'name' and 'description' keys.",
@ -53,9 +60,9 @@ class XMLAgentComponent(LCAgentComponent):
"info": "If True, the agent will handle parsing errors. If False, the agent will raise an error.",
"advanced": True,
},
"memory": {
"display_name": "Memory",
"info": "Memory to use for the agent.",
"message_history": {
"display_name": "Message History",
"info": "Message history to pass to the agent.",
},
"input_value": {
"display_name": "Inputs",
@ -68,12 +75,13 @@ class XMLAgentComponent(LCAgentComponent):
input_value: str,
llm: BaseLanguageModel,
tools: List[Tool],
prompt: str,
memory: Optional[BaseMemory] = None,
user_prompt: str = "{input}",
system_message: str = "You are a helpful assistant",
message_history: Optional[List[Record]] = None,
tool_template: str = "{name}: {description}",
handle_parsing_errors: bool = True,
) -> Text:
if "input" not in prompt:
if "input" not in user_prompt:
raise ValueError("Prompt must contain 'input' key.")
def render_tool_description(tools):
@ -81,9 +89,23 @@ class XMLAgentComponent(LCAgentComponent):
[tool_template.format(name=tool.name, description=tool.description, args=tool.args) for tool in tools]
)
prompt_template = PromptTemplate.from_template(prompt)
input_variables = prompt_template.input_variables
agent = create_xml_agent(llm, tools, prompt_template, tools_renderer=render_tool_description)
result = await self.run_agent(agent, input_value, input_variables, tools, memory, handle_parsing_errors)
messages = [
("system", system_message),
(
"placeholder",
"{chat_history}",
),
("human", user_prompt),
("placeholder", "{agent_scratchpad}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = create_xml_agent(llm, tools, prompt, tools_renderer=render_tool_description)
result = await self.run_agent(
agent=agent,
inputs=input_value,
tools=tools,
message_history=message_history,
handle_parsing_errors=handle_parsing_errors,
)
self.status = result
return result

View file

@ -0,0 +1,185 @@
from typing import Any, List, Optional, cast
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langflow.base.agents.agent import LCAgentComponent
from langflow.base.agents.utils import AGENTS, AgentSpec, get_agents_list
from langflow.field_typing import BaseLanguageModel, Text, Tool
from langflow.schema.dotdict import dotdict
from langflow.schema.schema import Record
class AgentComponent(LCAgentComponent):
display_name = "Agent"
description = "Run any LangChain agent using a simplified interface."
field_order = [
"agent_name",
"llm",
"tools",
"prompt",
"tool_template",
"handle_parsing_errors",
"memory",
"input_value",
]
def build_config(self):
return {
"agent_name": {
"display_name": "Agent",
"info": "The agent to use.",
"refresh_button": True,
"real_time_refresh": True,
"options": get_agents_list(),
},
"llm": {"display_name": "LLM"},
"tools": {"display_name": "Tools"},
"user_prompt": {
"display_name": "Prompt",
"multiline": True,
"info": "This prompt must contain 'tools' and 'agent_scratchpad' keys.",
},
"system_message": {
"display_name": "System Message",
"info": "System message to be passed to the LLM.",
"advanced": True,
},
"tool_template": {
"display_name": "Tool Template",
"info": "Template for rendering tools in the prompt. Tools have 'name' and 'description' keys.",
"advanced": True,
},
"handle_parsing_errors": {
"display_name": "Handle Parsing Errors",
"info": "If True, the agent will handle parsing errors. If False, the agent will raise an error.",
"advanced": True,
},
"message_history": {
"display_name": "Message History",
"info": "Message history to pass to the agent.",
},
"input_value": {
"display_name": "Input",
"info": "Input text to pass to the agent.",
},
"langchain_hub_api_key": {
"display_name": "LangChain Hub API Key",
"info": "API key to use for LangChain Hub. If provided, prompts will be fetched from LangChain Hub.",
"advanced": True,
},
}
def get_system_and_user_message_from_prompt(self, prompt: Any):
"""
Extracts the system message and user prompt from a given prompt object.
Args:
prompt (Any): The prompt object from which to extract the system message and user prompt.
Returns:
Tuple[Optional[str], Optional[str]]: A tuple containing the system message and user prompt.
If the prompt object does not have any messages, both values will be None.
"""
if hasattr(prompt, "messages"):
system_message = None
user_prompt = None
for message in prompt.messages:
if isinstance(message, SystemMessagePromptTemplate):
s_prompt = message.prompt
if isinstance(s_prompt, list):
s_template = " ".join([cast(str, s.template) for s in s_prompt if hasattr(s, "template")])
elif hasattr(s_prompt, "template"):
s_template = s_prompt.template
system_message = s_template
elif isinstance(message, HumanMessagePromptTemplate):
h_prompt = message.prompt
if isinstance(h_prompt, list):
h_template = " ".join([cast(str, h.template) for h in h_prompt if hasattr(h, "template")])
elif hasattr(h_prompt, "template"):
h_template = h_prompt.template
user_prompt = h_template
return system_message, user_prompt
return None, None
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: Text | None = None):
"""
Updates the build configuration based on the provided field value and field name.
Args:
build_config (dotdict): The build configuration to be updated.
field_value (Any): The value of the field being updated.
field_name (Text | None, optional): The name of the field being updated. Defaults to None.
Returns:
dotdict: The updated build configuration.
"""
if field_name == "agent":
build_config["agent"]["options"] = get_agents_list()
if field_value in AGENTS:
# if langchain_hub_api_key is provided, fetch the prompt from LangChain Hub
if build_config["langchain_hub_api_key"]["value"] and AGENTS[field_value].hub_repo:
from langchain import hub
hub_repo: str | None = AGENTS[field_value].hub_repo
if hub_repo:
hub_api_key: str = build_config["langchain_hub_api_key"]["value"]
prompt = hub.pull(hub_repo, api_key=hub_api_key)
system_message, user_prompt = self.get_system_and_user_message_from_prompt(prompt)
if system_message:
build_config["system_message"]["value"] = system_message
if user_prompt:
build_config["user_prompt"]["value"] = user_prompt
if AGENTS[field_value].prompt:
build_config["user_prompt"]["value"] = AGENTS[field_value].prompt
else:
build_config["user_prompt"]["value"] = "{input}"
fields = AGENTS[field_value].fields
for field in ["llm", "tools", "prompt", "tools_renderer"]:
if field not in fields:
build_config[field]["show"] = False
return build_config
async def build(
self,
agent_name: str,
input_value: str,
llm: BaseLanguageModel,
tools: List[Tool],
system_message: str = "You are a helpful assistant. Help the user answer any questions.",
user_prompt: str = "{input}",
message_history: Optional[List[Record]] = None,
tool_template: str = "{name}: {description}",
handle_parsing_errors: bool = True,
) -> Text:
agent_spec: Optional[AgentSpec] = AGENTS.get(agent_name)
if agent_spec is None:
raise ValueError(f"{agent_name} not found.")
def render_tool_description(tools):
return "\n".join(
[tool_template.format(name=tool.name, description=tool.description, args=tool.args) for tool in tools]
)
messages = [
("system", system_message),
(
"placeholder",
"{chat_history}",
),
("human", user_prompt),
("placeholder", "{agent_scratchpad}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent_func = agent_spec.func
agent = agent_func(llm, tools, prompt, render_tool_description, True)
result = await self.run_agent(
agent=agent,
inputs=input_value,
tools=tools,
message_history=message_history,
handle_parsing_errors=handle_parsing_errors,
)
self.status = result
return result

View file

@ -10,8 +10,10 @@ from .RunFlow import RunFlowComponent
from .RunnableExecutor import RunnableExecComponent
from .SQLExecutor import SQLExecutorComponent
from .SubFlow import SubFlowComponent
from .AgentComponent import AgentComponent
__all__ = [
"AgentComponent",
"ClearMessageHistoryComponent",
"ExtractKeyFromRecordComponent",
"FlowToolComponent",

View file

@ -105,7 +105,7 @@ class AzureChatOpenAIComponent(LCModelComponent):
system_message: Optional[str] = None,
max_tokens: Optional[int] = 1000,
stream: bool = False,
) -> BaseLanguageModel:
) -> Text:
if api_key:
secret_api_key = SecretStr(api_key)
else:

View file

@ -142,7 +142,7 @@ class ChatLiteLLMModelComponent(LCModelComponent):
max_retries: int = 6,
verbose: bool = False,
system_message: Optional[str] = None,
) -> BaseLanguageModel:
) -> Text:
try:
import litellm # type: ignore

View file

@ -4,6 +4,7 @@ from typing import Literal, Optional
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from pydantic import BaseModel, model_validator
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
class Record(BaseModel):
@ -101,6 +102,26 @@ class Record(BaseModel):
text = self.data.pop(self.text_key, self.default_value)
return Document(page_content=text, metadata=self.data)
def to_lc_message(self) -> BaseMessage:
"""
Converts the Record to a BaseMessage.
Returns:
BaseMessage: The converted BaseMessage.
"""
# The idea of this function is to be a helper to convert a Record to a BaseMessage
# It will use the "sender" key to determine if the message is Human or AI
# If the key is not present, it will default to AI
# But first we check if all required keys are present in the data dictionary
# they are: "text", "sender"
if not all(key in self.data for key in ["text", "sender"]):
raise ValueError(f"Missing required keys ('text', 'sender') in Record: {self.data}")
sender = self.data.get("sender", "Machine")
text = self.data.get("text", "")
if sender == "User":
return HumanMessage(content=text)
return AIMessage(content=text)
def __getattr__(self, key):
"""
Allows attribute-like access to the data dictionary.

View file

@ -1144,6 +1144,21 @@ langchain-core = ">=0.1.28,<0.2.0"
[package.extras]
extended-testing = ["lxml (>=5.1.0,<6.0.0)"]
[[package]]
name = "langchainhub"
version = "0.1.15"
description = "The LangChain Hub API client"
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchainhub-0.1.15-py3-none-any.whl", hash = "sha256:89a0951abd1db255e91c6d545d092a598fc255aa865d1ffc3ce8f93bbeae60e7"},
{file = "langchainhub-0.1.15.tar.gz", hash = "sha256:fa3ff81a31946860f84c119f1e2f6b7c7707e2bd7ed2394a7313b286d59f3bda"},
]
[package.dependencies]
requests = ">=2,<3"
types-requests = ">=2.31.0.2,<3.0.0.0"
[[package]]
name = "langsmith"
version = "0.1.47"
@ -2568,6 +2583,20 @@ rich = ">=10.11.0"
shellingham = ">=1.3.0"
typing-extensions = ">=3.7.4.3"
[[package]]
name = "types-requests"
version = "2.31.0.20240406"
description = "Typing stubs for requests"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-requests-2.31.0.20240406.tar.gz", hash = "sha256:4428df33c5503945c74b3f42e82b181e86ec7b724620419a2966e2de604ce1a1"},
{file = "types_requests-2.31.0.20240406-py3-none-any.whl", hash = "sha256:6216cdac377c6b9a040ac1c0404f7284bd13199c0e1bb235f4324627e8898cf5"},
]
[package.dependencies]
urllib3 = ">=2"
[[package]]
name = "typing-extensions"
version = "4.11.0"
@ -2861,4 +2890,4 @@ local = []
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.12"
content-hash = "11b7861d29ff2ca23defcb03faf670c409c639c7b4ee81455c0c1fea50ea54e9"
content-hash = "4baea3d7c34ad33205fbb00a9da20b4c696b9487a2d62dada9babf1d21ed2dba"

View file

@ -31,6 +31,7 @@ httpx = "*"
uvicorn = "^0.29.0"
gunicorn = "^21.2.0"
langchain = "~0.1.16"
langchainhub = "~0.1.15"
sqlmodel = "^0.0.16"
loguru = "^0.7.1"
rich = "^13.7.0"