feat: jsonagent
This commit is contained in:
parent
0d10c7ba05
commit
22c3c83d6a
7 changed files with 80 additions and 4 deletions
|
|
@ -4,6 +4,7 @@ from langflow.template import nodes
|
|||
CUSTOM_NODES = {
|
||||
"prompts": {**nodes.ZeroShotPromptNode().to_dict()},
|
||||
"tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()},
|
||||
"agents": {**nodes.JsonAgentNode().to_dict()},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
|
|||
if not toolkit_name.islower()
|
||||
}
|
||||
|
||||
|
||||
wrapper_type_to_cls_dict: dict[str, Any] = {
|
||||
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
from typing import Callable, Optional
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.agents import AgentExecutor, ZeroShotAgent
|
||||
from langflow.utils import validate
|
||||
from pydantic import BaseModel, validator
|
||||
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
|
|
@ -33,3 +39,31 @@ class PythonFunction(Function):
|
|||
"""Python function"""
|
||||
|
||||
code: str
|
||||
|
||||
|
||||
class JsonAgent(BaseModel):
|
||||
"""Json agent"""
|
||||
|
||||
toolkit: JsonToolkit
|
||||
llm: BaseLanguageModel
|
||||
|
||||
def __init__(self, toolkit: JsonToolkit, llm: BaseLanguageModel):
|
||||
super().__init__(toolkit=toolkit, llm=llm)
|
||||
self.toolkit = toolkit
|
||||
tools = self.toolkit.get_tools()
|
||||
tool_names = [tool.name for tool in tools]
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=JSON_PREFIX,
|
||||
suffix=JSON_SUFFIX,
|
||||
format_instructions=FORMAT_INSTRUCTIONS,
|
||||
input_variables=None,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from langflow.settings import settings
|
|||
from langflow.utils import util
|
||||
from langchain.agents.load_tools import get_all_tool_names
|
||||
from langchain.agents import Tool
|
||||
from langflow.interface.custom_types import PythonFunction
|
||||
from langflow.interface.custom_types import JsonAgent, PythonFunction
|
||||
from langchain.tools.json.tool import JsonSpec
|
||||
|
||||
OTHER_TOOLS = {"JsonSpec": JsonSpec}
|
||||
|
|
@ -50,7 +50,7 @@ def list_agents():
|
|||
agent.__name__
|
||||
for agent in agents.loading.AGENT_TO_CLASS.values()
|
||||
if agent.__name__ in settings.agents or settings.dev
|
||||
]
|
||||
] + [JsonAgent.__name__]
|
||||
|
||||
|
||||
def list_toolkis():
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def get_toolkit_signature(name: str):
|
|||
"""Get the signature of a toolkit."""
|
||||
try:
|
||||
if name.islower():
|
||||
pass
|
||||
...
|
||||
# return util.build_template_from_function(
|
||||
# name, toolkit_type_to_loader_dict, add_function=True
|
||||
# )
|
||||
|
|
@ -77,6 +77,8 @@ def get_chain_signature(name: str):
|
|||
def get_agent_signature(name: str):
|
||||
"""Get the signature of an agent."""
|
||||
try:
|
||||
if name in customs.get_custom_nodes("agents").keys():
|
||||
return customs.get_custom_nodes("agents")[name]
|
||||
return util.build_template_from_class(
|
||||
name, agents.loading.AGENT_TO_CLASS, add_function=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -112,3 +112,35 @@ class ToolNode(FrontendNode):
|
|||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
|
||||
class JsonAgentNode(FrontendNode):
|
||||
name: str = "JsonAgent"
|
||||
template: Template = Template(
|
||||
type_name="json_agent",
|
||||
fields=[
|
||||
Field(
|
||||
field_type="BaseToolkit",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
value="",
|
||||
name="toolkit",
|
||||
),
|
||||
Field(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
value="",
|
||||
name="LLM",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a json agent from an LLM and tools."""
|
||||
base_classes: list[str] = ["BaseAgent"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
|
|||
|
|
@ -406,6 +406,7 @@ def format_dict(d, name: Optional[str] = None):
|
|||
"examples",
|
||||
"temperature",
|
||||
"model_name",
|
||||
"headers",
|
||||
]
|
||||
or "api_key" in key
|
||||
)
|
||||
|
|
@ -427,7 +428,7 @@ def format_dict(d, name: Optional[str] = None):
|
|||
|
||||
# Replace dict type with str
|
||||
if "dict" in value["type"].lower():
|
||||
value["type"] = "str"
|
||||
value["type"] = "code"
|
||||
|
||||
value["file"] = key in ["dict_"]
|
||||
|
||||
|
|
@ -436,6 +437,11 @@ def format_dict(d, name: Optional[str] = None):
|
|||
value["value"] = value["default"]
|
||||
value.pop("default")
|
||||
|
||||
if key == "headers":
|
||||
value[
|
||||
"value"
|
||||
] = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
# Add options to openai
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
value["options"] = constants.OPENAI_MODELS
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue