Rename agent - initialize_agent to AgentInitializer

This commit is contained in:
gustavoschaedler 2023-06-13 18:51:33 +01:00
commit 14757ca402
5 changed files with 23 additions and 13 deletions

View file

@ -3,7 +3,7 @@ agents:
- ZeroShotAgent
- JsonAgent
- CSVAgent
- initialize_agent
- AgentInitializer
- VectorStoreAgent
- VectorStoreRouterAgent
- SQLAgent

View file

@ -10,7 +10,7 @@ CUSTOM_NODES = {
"agents": {
"JsonAgent": frontend_node.agents.JsonAgentNode(),
"CSVAgent": frontend_node.agents.CSVAgentNode(),
"initialize_agent": frontend_node.agents.InitializeAgentNode(),
"AgentInitializer": frontend_node.agents.InitializeAgentNode(),
"VectorStoreAgent": frontend_node.agents.VectorStoreAgentNode(),
"VectorStoreRouterAgent": frontend_node.agents.VectorStoreRouterAgentNode(),
"SQLAgent": frontend_node.agents.SQLAgentNode(),

View file

@ -82,7 +82,9 @@ class JsonAgent(CustomAgentExecutor):
llm=llm,
prompt=prompt,
)
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names
) # type: ignore
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def run(self, *args, **kwargs):
@ -129,7 +131,9 @@ class CSVAgent(CustomAgentExecutor):
prompt=partial_prompt,
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
) # type: ignore
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
@ -166,7 +170,9 @@ class VectorStoreAgent(CustomAgentExecutor):
prompt=prompt,
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
) # type: ignore
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
@ -234,7 +240,9 @@ class SQLAgent(CustomAgentExecutor):
prompt=prompt,
)
tool_names = {tool.name for tool in tools} # type: ignore
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
) # type: ignore
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools, # type: ignore
@ -277,7 +285,9 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
prompt=prompt,
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
) # type: ignore
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
@ -287,11 +297,11 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
class InitializeAgent(CustomAgentExecutor):
"""Implementation of initialize_agent function"""
"""Implementation of AgentInitializer function"""
@staticmethod
def function_name():
return "initialize_agent"
return "AgentInitializer"
@classmethod
def initialize(
@ -320,7 +330,7 @@ class InitializeAgent(CustomAgentExecutor):
CUSTOM_AGENTS = {
"JsonAgent": JsonAgent,
"CSVAgent": CSVAgent,
"initialize_agent": InitializeAgent,
"AgentInitializer": InitializeAgent,
"VectorStoreAgent": VectorStoreAgent,
"VectorStoreRouterAgent": VectorStoreRouterAgent,
"SQLAgent": SQLAgent,

View file

@ -154,9 +154,9 @@ class CSVAgentNode(FrontendNode):
class InitializeAgentNode(FrontendNode):
name: str = "initialize_agent"
name: str = "AgentInitializer"
template: Template = Template(
type_name="initailize_agent",
type_name="AgentInitializer",
fields=[
TemplateField(
field_type="str",

View file

@ -131,7 +131,7 @@ def test_initialize_agent(client: TestClient):
json_response = response.json()
agents = json_response["agents"]
initialize_agent = agents["initialize_agent"]
initialize_agent = agents["AgentInitializer"]
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
template = initialize_agent["template"]