diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 56f441ff3..ad2bed77a 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -3,7 +3,7 @@ agents: - ZeroShotAgent - JsonAgent - CSVAgent - - initialize_agent + - AgentInitializer - VectorStoreAgent - VectorStoreRouterAgent - SQLAgent diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index 8759fa4a2..c39fb4b2a 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -10,7 +10,7 @@ CUSTOM_NODES = { "agents": { "JsonAgent": frontend_node.agents.JsonAgentNode(), "CSVAgent": frontend_node.agents.CSVAgentNode(), - "initialize_agent": frontend_node.agents.InitializeAgentNode(), + "AgentInitializer": frontend_node.agents.InitializeAgentNode(), "VectorStoreAgent": frontend_node.agents.VectorStoreAgentNode(), "VectorStoreRouterAgent": frontend_node.agents.VectorStoreRouterAgentNode(), "SQLAgent": frontend_node.agents.SQLAgentNode(), diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index 4654ef7cb..8ff61c62b 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -82,7 +82,9 @@ class JsonAgent(CustomAgentExecutor): llm=llm, prompt=prompt, ) - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) # type: ignore + agent = ZeroShotAgent( + llm_chain=llm_chain, allowed_tools=tool_names + ) # type: ignore return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) def run(self, *args, **kwargs): @@ -129,7 +131,9 @@ class CSVAgent(CustomAgentExecutor): prompt=partial_prompt, ) tool_names = {tool.name for tool in tools} - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore + agent = ZeroShotAgent( + llm_chain=llm_chain, allowed_tools=tool_names, **kwargs + ) # type: ignore return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) @@ -166,7 +170,9 @@ class VectorStoreAgent(CustomAgentExecutor): prompt=prompt, ) tool_names = {tool.name for tool in tools} - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore + agent = ZeroShotAgent( + llm_chain=llm_chain, allowed_tools=tool_names, **kwargs + ) # type: ignore return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, verbose=True ) @@ -234,7 +240,9 @@ class SQLAgent(CustomAgentExecutor): prompt=prompt, ) tool_names = {tool.name for tool in tools} # type: ignore - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore + agent = ZeroShotAgent( + llm_chain=llm_chain, allowed_tools=tool_names, **kwargs + ) # type: ignore return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, # type: ignore @@ -277,7 +285,9 @@ class VectorStoreRouterAgent(CustomAgentExecutor): prompt=prompt, ) tool_names = {tool.name for tool in tools} - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore + agent = ZeroShotAgent( + llm_chain=llm_chain, allowed_tools=tool_names, **kwargs + ) # type: ignore return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, verbose=True ) @@ -287,11 +297,11 @@ class VectorStoreRouterAgent(CustomAgentExecutor): class InitializeAgent(CustomAgentExecutor): - """Implementation of initialize_agent function""" + """Implementation of AgentInitializer function""" @staticmethod def function_name(): - return "initialize_agent" + return "AgentInitializer" @classmethod def initialize( @@ -320,7 +330,7 @@ class InitializeAgent(CustomAgentExecutor): CUSTOM_AGENTS = { "JsonAgent": JsonAgent, "CSVAgent": CSVAgent, - "initialize_agent": InitializeAgent, + "AgentInitializer": InitializeAgent, "VectorStoreAgent": VectorStoreAgent, "VectorStoreRouterAgent": VectorStoreRouterAgent, "SQLAgent": SQLAgent, diff --git a/src/backend/langflow/template/frontend_node/agents.py b/src/backend/langflow/template/frontend_node/agents.py index e4fe40187..16a319959 100644 --- a/src/backend/langflow/template/frontend_node/agents.py +++ b/src/backend/langflow/template/frontend_node/agents.py @@ -154,9 +154,9 @@ class CSVAgentNode(FrontendNode): class InitializeAgentNode(FrontendNode): - name: str = "initialize_agent" + name: str = "AgentInitializer" template: Template = Template( - type_name="initailize_agent", + type_name="AgentInitializer", fields=[ TemplateField( field_type="str", diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 7aa8de176..84be3e5f3 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -131,7 +131,7 @@ def test_initialize_agent(client: TestClient): json_response = response.json() agents = json_response["agents"] - initialize_agent = agents["initialize_agent"] + initialize_agent = agents["AgentInitializer"] assert initialize_agent["base_classes"] == ["AgentExecutor", "function"] template = initialize_agent["template"]