Rename agent - initialize_agent to AgentInitializer
This commit is contained in:
parent
09d4c89136
commit
14757ca402
5 changed files with 23 additions and 13 deletions
|
|
@ -3,7 +3,7 @@ agents:
|
|||
- ZeroShotAgent
|
||||
- JsonAgent
|
||||
- CSVAgent
|
||||
- initialize_agent
|
||||
- AgentInitializer
|
||||
- VectorStoreAgent
|
||||
- VectorStoreRouterAgent
|
||||
- SQLAgent
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ CUSTOM_NODES = {
|
|||
"agents": {
|
||||
"JsonAgent": frontend_node.agents.JsonAgentNode(),
|
||||
"CSVAgent": frontend_node.agents.CSVAgentNode(),
|
||||
"initialize_agent": frontend_node.agents.InitializeAgentNode(),
|
||||
"AgentInitializer": frontend_node.agents.InitializeAgentNode(),
|
||||
"VectorStoreAgent": frontend_node.agents.VectorStoreAgentNode(),
|
||||
"VectorStoreRouterAgent": frontend_node.agents.VectorStoreRouterAgentNode(),
|
||||
"SQLAgent": frontend_node.agents.SQLAgentNode(),
|
||||
|
|
|
|||
|
|
@ -82,7 +82,9 @@ class JsonAgent(CustomAgentExecutor):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) # type: ignore
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names
|
||||
) # type: ignore
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
@ -129,7 +131,9 @@ class CSVAgent(CustomAgentExecutor):
|
|||
prompt=partial_prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
|
||||
) # type: ignore
|
||||
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
|
|
@ -166,7 +170,9 @@ class VectorStoreAgent(CustomAgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
|
||||
) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
)
|
||||
|
|
@ -234,7 +240,9 @@ class SQLAgent(CustomAgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools} # type: ignore
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
|
||||
) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools, # type: ignore
|
||||
|
|
@ -277,7 +285,9 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs
|
||||
) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
)
|
||||
|
|
@ -287,11 +297,11 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
|
|||
|
||||
|
||||
class InitializeAgent(CustomAgentExecutor):
|
||||
"""Implementation of initialize_agent function"""
|
||||
"""Implementation of AgentInitializer function"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "initialize_agent"
|
||||
return "AgentInitializer"
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
|
|
@ -320,7 +330,7 @@ class InitializeAgent(CustomAgentExecutor):
|
|||
CUSTOM_AGENTS = {
|
||||
"JsonAgent": JsonAgent,
|
||||
"CSVAgent": CSVAgent,
|
||||
"initialize_agent": InitializeAgent,
|
||||
"AgentInitializer": InitializeAgent,
|
||||
"VectorStoreAgent": VectorStoreAgent,
|
||||
"VectorStoreRouterAgent": VectorStoreRouterAgent,
|
||||
"SQLAgent": SQLAgent,
|
||||
|
|
|
|||
|
|
@ -154,9 +154,9 @@ class CSVAgentNode(FrontendNode):
|
|||
|
||||
|
||||
class InitializeAgentNode(FrontendNode):
|
||||
name: str = "initialize_agent"
|
||||
name: str = "AgentInitializer"
|
||||
template: Template = Template(
|
||||
type_name="initailize_agent",
|
||||
type_name="AgentInitializer",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ def test_initialize_agent(client: TestClient):
|
|||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
||||
initialize_agent = agents["initialize_agent"]
|
||||
initialize_agent = agents["AgentInitializer"]
|
||||
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
|
||||
template = initialize_agent["template"]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue