feat: working implementation of initialize_agent

This commit is contained in:
Gabriel Almeida 2023-04-02 09:33:07 -03:00
commit f032795fe4
7 changed files with 72 additions and 34 deletions

View file

@ -8,7 +8,7 @@ agents:
- ZeroShotAgent
- JsonAgent
- CSVAgent
- InitializeAgent
- initialize_agent
prompts:
- PromptTemplate
@ -34,6 +34,9 @@ toolkits:
- OpenAPIToolkit
- JsonToolkit
memories:
- ConversationBufferMemory
embeddings: []

View file

@ -6,7 +6,7 @@ CUSTOM_NODES = {
"agents": {
"JsonAgent": nodes.JsonAgentNode(),
"CSVAgent": nodes.CSVAgentNode(),
"InitializeAgent": nodes.InitializeAgentNode(),
"initialize_agent": nodes.InitializeAgentNode(),
},
}

View file

@ -121,10 +121,10 @@ class Node:
f"Required input {key} for module {self.node_type} not found"
)
elif value["list"]:
if key in params:
if key not in params:
params[key] = []
if edge is not None:
params[key].append(edge.source)
else:
params[key] = [edge.source]
elif value["required"] or edge is not None:
params[key] = edge.source
elif value["required"] or value.get("value"):

View file

@ -1,4 +1,5 @@
from typing import Dict, List
import contextlib
from typing import Dict, Iterable
from langchain.agents import loading
@ -31,12 +32,16 @@ class AgentCreator(LangChainTypeCreator):
except ValueError as exc:
raise ValueError("Agent not found") from exc
def to_list(self) -> List[str]:
return [
agent.__name__
for agent in self.type_to_loader_dict.values()
if agent.__name__ in settings.agents or settings.dev
]
# Now this is a generator
def to_list(self) -> Iterable:
for name, agent in self.type_to_loader_dict.items():
agent_name = (
agent.function_name()
if hasattr(agent, "function_name")
else agent.__name__
)
if agent_name in settings.agents or settings.dev:
yield agent_name
agent_creator = AgentCreator()

View file

@ -11,11 +11,16 @@ from langchain.schema import BaseLanguageModel
from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.agents import initialize_agent, Tool
from langchain.memory.chat_memory import BaseChatMemory
class JsonAgent(AgentExecutor):
"""Json agent"""
@staticmethod
def function_name():
return "JsonAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
@ -48,6 +53,10 @@ class JsonAgent(AgentExecutor):
class CSVAgent(AgentExecutor):
"""CSV agent"""
@staticmethod
def function_name():
return "CSVAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
@ -90,15 +99,17 @@ class CSVAgent(AgentExecutor):
class InitializeAgent(AgentExecutor):
"""Initialize agent"""
"""Implementation of initialize_agent function"""
@staticmethod
def function_name():
return "initialize_agent"
@classmethod
def initialize(cls, llm: BaseLLM, tools: List[Tool], agent: str):
return initialize_agent(
tools=tools,
llm=llm,
agent=agent,
)
def initialize(
cls, llm: BaseLLM, tools: List[Tool], agent: str, memory: BaseChatMemory
):
return initialize_agent(tools=tools, llm=llm, agent=agent, memory=memory)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -110,5 +121,5 @@ class InitializeAgent(AgentExecutor):
CUSTOM_AGENTS = {
"JsonAgent": JsonAgent,
"CSVAgent": CSVAgent,
"InitializeAgent": InitializeAgent,
"initialize_agent": InitializeAgent,
}

View file

@ -57,7 +57,14 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str):
loaded_langchain.verbose = True
try:
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
result = loaded_langchain(message)
chat_input = {}
for key in loaded_langchain.input_keys:
if key != "chat_history":
chat_input[key] = message
break
if hasattr(loaded_langchain, "run"):
loaded_langchain = loaded_langchain.run
result = loaded_langchain
result = (
result.get(loaded_langchain.output_keys[0])

View file

@ -153,22 +153,10 @@ class JsonAgentNode(FrontendNode):
class InitializeAgentNode(FrontendNode):
name: str = "InitializeAgent"
name: str = "initialize_agent"
template: Template = Template(
type_name="initailize_agent",
fields=[
TemplateField(
field_type="Tool",
required=True,
show=True,
name="tools",
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
),
TemplateField(
field_type="str",
required=True,
@ -178,6 +166,25 @@ class InitializeAgentNode(FrontendNode):
options=list(loading.AGENT_TO_CLASS.keys()),
name="agent",
),
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
),
TemplateField(
field_type="Tool",
required=False,
show=True,
name="tools",
is_list=True,
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
),
],
)
description: str = """Construct a json agent from an LLM and tools."""
@ -186,6 +193,11 @@ class InitializeAgentNode(FrontendNode):
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: str):
# do nothing and don't return anything
pass
class CSVAgentNode(FrontendNode):
name: str = "CSVAgent"