feat: working implementation of initialize_agent
This commit is contained in:
parent
24bdfaa941
commit
f032795fe4
7 changed files with 72 additions and 34 deletions
|
|
@ -8,7 +8,7 @@ agents:
|
|||
- ZeroShotAgent
|
||||
- JsonAgent
|
||||
- CSVAgent
|
||||
- InitializeAgent
|
||||
- initialize_agent
|
||||
|
||||
prompts:
|
||||
- PromptTemplate
|
||||
|
|
@ -34,6 +34,9 @@ toolkits:
|
|||
- OpenAPIToolkit
|
||||
- JsonToolkit
|
||||
|
||||
memories:
|
||||
- ConversationBufferMemory
|
||||
|
||||
embeddings: []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ CUSTOM_NODES = {
|
|||
"agents": {
|
||||
"JsonAgent": nodes.JsonAgentNode(),
|
||||
"CSVAgent": nodes.CSVAgentNode(),
|
||||
"InitializeAgent": nodes.InitializeAgentNode(),
|
||||
"initialize_agent": nodes.InitializeAgentNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -121,10 +121,10 @@ class Node:
|
|||
f"Required input {key} for module {self.node_type} not found"
|
||||
)
|
||||
elif value["list"]:
|
||||
if key in params:
|
||||
if key not in params:
|
||||
params[key] = []
|
||||
if edge is not None:
|
||||
params[key].append(edge.source)
|
||||
else:
|
||||
params[key] = [edge.source]
|
||||
elif value["required"] or edge is not None:
|
||||
params[key] = edge.source
|
||||
elif value["required"] or value.get("value"):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Dict, List
|
||||
import contextlib
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from langchain.agents import loading
|
||||
|
||||
|
|
@ -31,12 +32,16 @@ class AgentCreator(LangChainTypeCreator):
|
|||
except ValueError as exc:
|
||||
raise ValueError("Agent not found") from exc
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return [
|
||||
agent.__name__
|
||||
for agent in self.type_to_loader_dict.values()
|
||||
if agent.__name__ in settings.agents or settings.dev
|
||||
]
|
||||
# Now this is a generator
|
||||
def to_list(self) -> Iterable:
|
||||
for name, agent in self.type_to_loader_dict.items():
|
||||
agent_name = (
|
||||
agent.function_name()
|
||||
if hasattr(agent, "function_name")
|
||||
else agent.__name__
|
||||
)
|
||||
if agent_name in settings.agents or settings.dev:
|
||||
yield agent_name
|
||||
|
||||
|
||||
agent_creator = AgentCreator()
|
||||
|
|
|
|||
|
|
@ -11,11 +11,16 @@ from langchain.schema import BaseLanguageModel
|
|||
from langchain.llms.base import BaseLLM
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.agents import initialize_agent, Tool
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
class JsonAgent(AgentExecutor):
|
||||
"""Json agent"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "JsonAgent"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
return cls.from_toolkit_and_llm(*args, **kwargs)
|
||||
|
|
@ -48,6 +53,10 @@ class JsonAgent(AgentExecutor):
|
|||
class CSVAgent(AgentExecutor):
|
||||
"""CSV agent"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CSVAgent"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
return cls.from_toolkit_and_llm(*args, **kwargs)
|
||||
|
|
@ -90,15 +99,17 @@ class CSVAgent(AgentExecutor):
|
|||
|
||||
|
||||
class InitializeAgent(AgentExecutor):
|
||||
"""Initialize agent"""
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "initialize_agent"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, llm: BaseLLM, tools: List[Tool], agent: str):
|
||||
return initialize_agent(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
agent=agent,
|
||||
)
|
||||
def initialize(
|
||||
cls, llm: BaseLLM, tools: List[Tool], agent: str, memory: BaseChatMemory
|
||||
):
|
||||
return initialize_agent(tools=tools, llm=llm, agent=agent, memory=memory)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -110,5 +121,5 @@ class InitializeAgent(AgentExecutor):
|
|||
CUSTOM_AGENTS = {
|
||||
"JsonAgent": JsonAgent,
|
||||
"CSVAgent": CSVAgent,
|
||||
"InitializeAgent": InitializeAgent,
|
||||
"initialize_agent": InitializeAgent,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,7 +57,14 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str):
|
|||
loaded_langchain.verbose = True
|
||||
try:
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
result = loaded_langchain(message)
|
||||
chat_input = {}
|
||||
for key in loaded_langchain.input_keys:
|
||||
if key != "chat_history":
|
||||
chat_input[key] = message
|
||||
break
|
||||
if hasattr(loaded_langchain, "run"):
|
||||
loaded_langchain = loaded_langchain.run
|
||||
result = loaded_langchain
|
||||
|
||||
result = (
|
||||
result.get(loaded_langchain.output_keys[0])
|
||||
|
|
|
|||
|
|
@ -153,22 +153,10 @@ class JsonAgentNode(FrontendNode):
|
|||
|
||||
|
||||
class InitializeAgentNode(FrontendNode):
|
||||
name: str = "InitializeAgent"
|
||||
name: str = "initialize_agent"
|
||||
template: Template = Template(
|
||||
type_name="initailize_agent",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="Tool",
|
||||
required=True,
|
||||
show=True,
|
||||
name="tools",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
show=True,
|
||||
name="llm",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
|
|
@ -178,6 +166,25 @@ class InitializeAgentNode(FrontendNode):
|
|||
options=list(loading.AGENT_TO_CLASS.keys()),
|
||||
name="agent",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseChatMemory",
|
||||
required=False,
|
||||
show=True,
|
||||
name="memory",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="Tool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="tools",
|
||||
is_list=True,
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
show=True,
|
||||
name="llm",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a json agent from an LLM and tools."""
|
||||
|
|
@ -186,6 +193,11 @@ class InitializeAgentNode(FrontendNode):
|
|||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: str):
|
||||
# do nothing and don't return anything
|
||||
pass
|
||||
|
||||
|
||||
class CSVAgentNode(FrontendNode):
|
||||
name: str = "CSVAgent"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue