From d93413ee5671c2f391d0abd33d590c8100d8b96a Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Fri, 31 Mar 2023 14:07:02 -0300 Subject: [PATCH] feat: JsonAgent implementation --- .../langflow/interface/agents/custom.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index 96c2f5a09..cc998fc12 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -1,3 +1,5 @@ +from typing import Optional + from langchain import LLMChain from langchain.agents import AgentExecutor, ZeroShotAgent from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX @@ -7,16 +9,19 @@ from langchain.schema import BaseLanguageModel from pydantic import BaseModel -class JsonAgent(BaseModel): +class JsonAgent(AgentExecutor): """Json agent""" - toolkit: JsonToolkit - llm: BaseLanguageModel + @classmethod + def initialize(cls, *args, **kwargs): + return cls.from_toolkit_and_llm(*args, **kwargs) - def __init__(self, toolkit: JsonToolkit, llm: BaseLanguageModel): - super().__init__(toolkit=toolkit, llm=llm) - self.toolkit = toolkit - tools = self.toolkit.get_tools() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel): + tools = toolkit.get_tools() tool_names = [tool.name for tool in tools] prompt = ZeroShotAgent.create_prompt( tools, @@ -30,12 +35,12 @@ class JsonAgent(BaseModel): prompt=prompt, ) agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) - self.agent_executor = AgentExecutor.from_agent_and_tools( - agent=agent, tools=tools, verbose=True - ) - - def __call__(self, *args, **kwargs): - return self.agent_executor(*args, **kwargs) + return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) def run(self, *args, **kwargs): - return self.agent_executor.run(*args, **kwargs) + return super().run(*args, **kwargs) + + +CUSTOM_AGENTS = { + "JsonAgent": JsonAgent, +}