From 22c3c83d6a98187e8b2c598b6da697d0771c9ec2 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Thu, 30 Mar 2023 08:48:41 -0300 Subject: [PATCH] feat: jsonagent --- src/backend/langflow/custom/customs.py | 1 + .../langflow/interface/custom_lists.py | 1 + .../langflow/interface/custom_types.py | 34 +++++++++++++++++++ src/backend/langflow/interface/listing.py | 4 +-- src/backend/langflow/interface/signature.py | 4 ++- src/backend/langflow/template/nodes.py | 32 +++++++++++++++++ src/backend/langflow/utils/util.py | 8 ++++- 7 files changed, 80 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index 877a06387..c856996a3 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -4,6 +4,7 @@ from langflow.template import nodes CUSTOM_NODES = { "prompts": {**nodes.ZeroShotPromptNode().to_dict()}, "tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()}, + "agents": {**nodes.JsonAgentNode().to_dict()}, } diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 2cf57cc39..9875a8b9d 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -59,6 +59,7 @@ toolkit_type_to_cls_dict: dict[str, Any] = { if not toolkit_name.islower() } + wrapper_type_to_cls_dict: dict[str, Any] = { wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper] } diff --git a/src/backend/langflow/interface/custom_types.py b/src/backend/langflow/interface/custom_types.py index 05d77fd1d..7943f99f0 100644 --- a/src/backend/langflow/interface/custom_types.py +++ b/src/backend/langflow/interface/custom_types.py @@ -1,6 +1,12 @@ from typing import Callable, Optional +from langchain import LLMChain, PromptTemplate +from langchain.agents import AgentExecutor, ZeroShotAgent from langflow.utils import validate from pydantic import BaseModel, validator +from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX +from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit +from langchain.schema import BaseLanguageModel class Function(BaseModel): @@ -33,3 +39,31 @@ class PythonFunction(Function): """Python function""" code: str + + +class JsonAgent(BaseModel): + """Json agent""" + + toolkit: JsonToolkit + llm: BaseLanguageModel + + def __init__(self, toolkit: JsonToolkit, llm: BaseLanguageModel): + super().__init__(toolkit=toolkit, llm=llm) + self.toolkit = toolkit + tools = self.toolkit.get_tools() + tool_names = [tool.name for tool in tools] + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=JSON_PREFIX, + suffix=JSON_SUFFIX, + format_instructions=FORMAT_INSTRUCTIONS, + input_variables=None, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + ) + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) + return AgentExecutor.from_agent_and_tools( + agent=agent, tools=tools, verbose=True + ) diff --git a/src/backend/langflow/interface/listing.py b/src/backend/langflow/interface/listing.py index fcd44c6b0..c67b3ab83 100644 --- a/src/backend/langflow/interface/listing.py +++ b/src/backend/langflow/interface/listing.py @@ -10,7 +10,7 @@ from langflow.settings import settings from langflow.utils import util from langchain.agents.load_tools import get_all_tool_names from langchain.agents import Tool -from langflow.interface.custom_types import PythonFunction +from langflow.interface.custom_types import JsonAgent, PythonFunction from langchain.tools.json.tool import JsonSpec OTHER_TOOLS = {"JsonSpec": JsonSpec} @@ -50,7 +50,7 @@ def list_agents(): agent.__name__ for agent in agents.loading.AGENT_TO_CLASS.values() if agent.__name__ in settings.agents or settings.dev - ] + ] + [JsonAgent.__name__] def list_toolkis(): diff --git a/src/backend/langflow/interface/signature.py b/src/backend/langflow/interface/signature.py index 1426af4e1..80a544d31 100644 --- a/src/backend/langflow/interface/signature.py +++ b/src/backend/langflow/interface/signature.py @@ -40,7 +40,7 @@ def get_toolkit_signature(name: str): """Get the signature of a toolkit.""" try: if name.islower(): - pass + ... # return util.build_template_from_function( # name, toolkit_type_to_loader_dict, add_function=True # ) @@ -77,6 +77,8 @@ def get_chain_signature(name: str): def get_agent_signature(name: str): """Get the signature of an agent.""" try: + if name in customs.get_custom_nodes("agents").keys(): + return customs.get_custom_nodes("agents")[name] return util.build_template_from_class( name, agents.loading.AGENT_TO_CLASS, add_function=True ) diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 642d0c237..f9826d89f 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -112,3 +112,35 @@ class ToolNode(FrontendNode): def to_dict(self): return super().to_dict() + + +class JsonAgentNode(FrontendNode): + name: str = "JsonAgent" + template: Template = Template( + type_name="json_agent", + fields=[ + Field( + field_type="BaseToolkit", + required=True, + placeholder="", + is_list=False, + show=True, + value="", + name="toolkit", + ), + Field( + field_type="BaseLanguageModel", + required=True, + placeholder="", + is_list=False, + show=True, + value="", + name="LLM", + ), + ], + ) + description: str = """Construct a json agent from an LLM and tools.""" + base_classes: list[str] = ["BaseAgent"] + + def to_dict(self): + return super().to_dict() diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 652291992..c7dfc230e 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -406,6 +406,7 @@ def format_dict(d, name: Optional[str] = None): "examples", "temperature", "model_name", + "headers", ] or "api_key" in key ) @@ -427,7 +428,7 @@ def format_dict(d, name: Optional[str] = None): # Replace dict type with str if "dict" in value["type"].lower(): - value["type"] = "str" + value["type"] = "code" value["file"] = key in ["dict_"] @@ -436,6 +437,11 @@ def format_dict(d, name: Optional[str] = None): value["value"] = value["default"] value.pop("default") + if key == "headers": + value[ + "value" + ] = """{'Authorization': + 'Bearer '}""" # Add options to openai if name == "OpenAI" and key == "model_name": value["options"] = constants.OPENAI_MODELS