From f032795fe4ef4688b1cfa520dcd40300a5f0495c Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sun, 2 Apr 2023 09:33:07 -0300 Subject: [PATCH] feat: working implementation of initialize_agent --- src/backend/langflow/config.yaml | 5 ++- src/backend/langflow/custom/customs.py | 2 +- src/backend/langflow/graph/base.py | 6 +-- src/backend/langflow/interface/agents/base.py | 19 ++++++---- .../langflow/interface/agents/custom.py | 27 +++++++++---- src/backend/langflow/interface/run.py | 9 ++++- src/backend/langflow/template/nodes.py | 38 ++++++++++++------- 7 files changed, 72 insertions(+), 34 deletions(-) diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index f45c45fbb..09fd9ca35 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -8,7 +8,7 @@ agents: - ZeroShotAgent - JsonAgent - CSVAgent - - InitializeAgent + - initialize_agent prompts: - PromptTemplate @@ -34,6 +34,9 @@ toolkits: - OpenAPIToolkit - JsonToolkit +memories: + - ConversationBufferMemory + embeddings: [] diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index 0467bac41..112b8db26 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -6,7 +6,7 @@ CUSTOM_NODES = { "agents": { "JsonAgent": nodes.JsonAgentNode(), "CSVAgent": nodes.CSVAgentNode(), - "InitializeAgent": nodes.InitializeAgentNode(), + "initialize_agent": nodes.InitializeAgentNode(), }, } diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 08ad786ac..feadf155c 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -121,10 +121,10 @@ class Node: f"Required input {key} for module {self.node_type} not found" ) elif value["list"]: - if key in params: + if key not in params: + params[key] = [] + if edge is not None: params[key].append(edge.source) - else: - params[key] = [edge.source] elif value["required"] or edge is not None: params[key] = edge.source elif value["required"] or value.get("value"): diff --git a/src/backend/langflow/interface/agents/base.py b/src/backend/langflow/interface/agents/base.py index 35df3bee6..41ea749e9 100644 --- a/src/backend/langflow/interface/agents/base.py +++ b/src/backend/langflow/interface/agents/base.py @@ -1,4 +1,5 @@ -from typing import Dict, List +import contextlib +from typing import Dict, Iterable from langchain.agents import loading @@ -31,12 +32,16 @@ class AgentCreator(LangChainTypeCreator): except ValueError as exc: raise ValueError("Agent not found") from exc - def to_list(self) -> List[str]: - return [ - agent.__name__ - for agent in self.type_to_loader_dict.values() - if agent.__name__ in settings.agents or settings.dev - ] + # Now this is a generator + def to_list(self) -> Iterable: + for name, agent in self.type_to_loader_dict.items(): + agent_name = ( + agent.function_name() + if hasattr(agent, "function_name") + else agent.__name__ + ) + if agent_name in settings.agents or settings.dev: + yield agent_name agent_creator = AgentCreator() diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index c74b9a450..e5ae77743 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -11,11 +11,16 @@ from langchain.schema import BaseLanguageModel from langchain.llms.base import BaseLLM from langchain.tools.python.tool import PythonAstREPLTool from langchain.agents import initialize_agent, Tool +from langchain.memory.chat_memory import BaseChatMemory class JsonAgent(AgentExecutor): """Json agent""" + @staticmethod + def function_name(): + return "JsonAgent" + @classmethod def initialize(cls, *args, **kwargs): return cls.from_toolkit_and_llm(*args, **kwargs) @@ -48,6 +53,10 @@ class JsonAgent(AgentExecutor): class CSVAgent(AgentExecutor): """CSV agent""" + @staticmethod + def function_name(): + return "CSVAgent" + @classmethod def initialize(cls, *args, **kwargs): return cls.from_toolkit_and_llm(*args, **kwargs) @@ -90,15 +99,17 @@ class CSVAgent(AgentExecutor): class InitializeAgent(AgentExecutor): - """Initialize agent""" + """Implementation of initialize_agent function""" + + @staticmethod + def function_name(): + return "initialize_agent" @classmethod - def initialize(cls, llm: BaseLLM, tools: List[Tool], agent: str): - return initialize_agent( - tools=tools, - llm=llm, - agent=agent, - ) + def initialize( + cls, llm: BaseLLM, tools: List[Tool], agent: str, memory: BaseChatMemory + ): + return initialize_agent(tools=tools, llm=llm, agent=agent, memory=memory) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -110,5 +121,5 @@ class InitializeAgent(AgentExecutor): CUSTOM_AGENTS = { "JsonAgent": JsonAgent, "CSVAgent": CSVAgent, - "InitializeAgent": InitializeAgent, + "initialize_agent": InitializeAgent, } diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index cb86789ae..aee7e83ee 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -57,7 +57,14 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str): loaded_langchain.verbose = True try: with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): - result = loaded_langchain(message) + chat_input = {} + for key in loaded_langchain.input_keys: + if key != "chat_history": + chat_input[key] = message + break + if hasattr(loaded_langchain, "run"): + loaded_langchain = loaded_langchain.run + result = loaded_langchain result = ( result.get(loaded_langchain.output_keys[0]) diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 6b855e5c4..ea72ffd89 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -153,22 +153,10 @@ class JsonAgentNode(FrontendNode): class InitializeAgentNode(FrontendNode): - name: str = "InitializeAgent" + name: str = "initialize_agent" template: Template = Template( type_name="initailize_agent", fields=[ - TemplateField( - field_type="Tool", - required=True, - show=True, - name="tools", - ), - TemplateField( - field_type="BaseLanguageModel", - required=True, - show=True, - name="llm", - ), TemplateField( field_type="str", required=True, @@ -178,6 +166,25 @@ class InitializeAgentNode(FrontendNode): options=list(loading.AGENT_TO_CLASS.keys()), name="agent", ), + TemplateField( + field_type="BaseChatMemory", + required=False, + show=True, + name="memory", + ), + TemplateField( + field_type="Tool", + required=False, + show=True, + name="tools", + is_list=True, + ), + TemplateField( + field_type="BaseLanguageModel", + required=True, + show=True, + name="llm", + ), ], ) description: str = """Construct a json agent from an LLM and tools.""" @@ -186,6 +193,11 @@ class InitializeAgentNode(FrontendNode): def to_dict(self): return super().to_dict() + @staticmethod + def format_field(field: TemplateField, name: str): + # do nothing and don't return anything + pass + class CSVAgentNode(FrontendNode): name: str = "CSVAgent"