From 04bd0f43fb0e621679b5bb42582455c09eedce56 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 7 Jun 2023 21:14:40 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(custom.py):=20rename=20?= =?UTF-8?q?CustomChain=20to=20CustomAgentExecutor=20=F0=9F=94=A8=20refacto?= =?UTF-8?q?r(base.py):=20add=20CustomAgentExecutor=20class=20and=20move=20?= =?UTF-8?q?CustomChain=20to=20base.py=20=F0=9F=94=A8=20refactor(custom=5Fl?= =?UTF-8?q?ists.py):=20update=20reference=20to=20CustomChain=20to=20Custom?= =?UTF-8?q?AgentExecutor=20The=20CustomChain=20class=20has=20been=20rename?= =?UTF-8?q?d=20to=20CustomAgentExecutor=20to=20better=20reflect=20its=20pu?= =?UTF-8?q?rpose.=20The=20class=20has=20been=20moved=20to=20base.py=20and?= =?UTF-8?q?=20a=20new=20CustomAgentExecutor=20class=20has=20been=20added?= =?UTF-8?q?=20to=20custom.py.=20The=20reference=20to=20CustomChain=20in=20?= =?UTF-8?q?custom=5Flists.py=20has=20been=20updated=20to=20CustomAgentExec?= =?UTF-8?q?utor.=20These=20changes=20improve=20the=20semantics=20of=20the?= =?UTF-8?q?=20code=20and=20make=20it=20easier=20to=20understand=20the=20pu?= =?UTF-8?q?rpose=20of=20the=20classes.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/agents/custom.py | 14 ++++++------- src/backend/langflow/interface/base.py | 20 ++++++++++++++++++- .../langflow/interface/custom_lists.py | 2 +- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index a0a5c243f..f86028985 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -32,10 +32,10 @@ from langchain.memory.chat_memory import BaseChatMemory from langchain.sql_database import SQLDatabase from langchain.tools.python.tool import PythonAstREPLTool from langchain.tools.sql_database.prompt import QUERY_CHECKER -from langflow.interface.base import CustomChain +from langflow.interface.base import CustomAgentExecutor -class JsonAgent(CustomChain): +class JsonAgent(CustomAgentExecutor): """Json agent""" @staticmethod @@ -71,7 +71,7 @@ class JsonAgent(CustomChain): return super().run(*args, **kwargs) -class CSVAgent(CustomChain): +class CSVAgent(CustomAgentExecutor): """CSV agent""" @staticmethod @@ -119,7 +119,7 @@ class CSVAgent(CustomChain): return super().run(*args, **kwargs) -class VectorStoreAgent(CustomChain): +class VectorStoreAgent(CustomAgentExecutor): """Vector Store agent""" @staticmethod @@ -157,7 +157,7 @@ class VectorStoreAgent(CustomChain): return super().run(*args, **kwargs) -class SQLAgent(CustomChain): +class SQLAgent(CustomAgentExecutor): """SQL agent""" @staticmethod @@ -229,7 +229,7 @@ class SQLAgent(CustomChain): return super().run(*args, **kwargs) -class VectorStoreRouterAgent(CustomChain): +class VectorStoreRouterAgent(CustomAgentExecutor): """Vector Store Router Agent""" @staticmethod @@ -268,7 +268,7 @@ class VectorStoreRouterAgent(CustomChain): return super().run(*args, **kwargs) -class InitializeAgent(CustomChain): +class InitializeAgent(CustomAgentExecutor): """Implementation of initialize_agent function""" @staticmethod diff --git a/src/backend/langflow/interface/base.py b/src/backend/langflow/interface/base.py index df03950af..3670bb8ae 100644 --- a/src/backend/langflow/interface/base.py +++ b/src/backend/langflow/interface/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type, Union from langchain.chains.base import Chain - +from langchain.agents import AgentExecutor from pydantic import BaseModel from langflow.template.field.base import TemplateField @@ -103,3 +103,21 @@ class CustomChain(Chain, ABC): def run(self, *args, **kwargs): return super().run(*args, **kwargs) + + +class CustomAgentExecutor(AgentExecutor, ABC): + """Custom chain""" + + @staticmethod + def function_name(): + return "CustomChain" + + @classmethod + def initialize(cls, *args, **kwargs): + pass + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def run(self, *args, **kwargs): + return super().run(*args, **kwargs) diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 8bde1565c..fbdba0a9c 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -71,4 +71,4 @@ textsplitter_type_to_cls_dict: dict[str, Any] = dict( ) # merge CUSTOM_AGENTS and CUSTOM_CHAINS -CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} +CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore