diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index a0a5c243f..f86028985 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -32,10 +32,10 @@ from langchain.memory.chat_memory import BaseChatMemory from langchain.sql_database import SQLDatabase from langchain.tools.python.tool import PythonAstREPLTool from langchain.tools.sql_database.prompt import QUERY_CHECKER -from langflow.interface.base import CustomChain +from langflow.interface.base import CustomAgentExecutor -class JsonAgent(CustomChain): +class JsonAgent(CustomAgentExecutor): """Json agent""" @staticmethod @@ -71,7 +71,7 @@ class JsonAgent(CustomChain): return super().run(*args, **kwargs) -class CSVAgent(CustomChain): +class CSVAgent(CustomAgentExecutor): """CSV agent""" @staticmethod @@ -119,7 +119,7 @@ class CSVAgent(CustomChain): return super().run(*args, **kwargs) -class VectorStoreAgent(CustomChain): +class VectorStoreAgent(CustomAgentExecutor): """Vector Store agent""" @staticmethod @@ -157,7 +157,7 @@ class VectorStoreAgent(CustomChain): return super().run(*args, **kwargs) -class SQLAgent(CustomChain): +class SQLAgent(CustomAgentExecutor): """SQL agent""" @staticmethod @@ -229,7 +229,7 @@ class SQLAgent(CustomChain): return super().run(*args, **kwargs) -class VectorStoreRouterAgent(CustomChain): +class VectorStoreRouterAgent(CustomAgentExecutor): """Vector Store Router Agent""" @staticmethod @@ -268,7 +268,7 @@ class VectorStoreRouterAgent(CustomChain): return super().run(*args, **kwargs) -class InitializeAgent(CustomChain): +class InitializeAgent(CustomAgentExecutor): """Implementation of initialize_agent function""" @staticmethod diff --git a/src/backend/langflow/interface/base.py b/src/backend/langflow/interface/base.py index df03950af..3670bb8ae 100644 --- a/src/backend/langflow/interface/base.py +++ b/src/backend/langflow/interface/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type, Union from langchain.chains.base import Chain - +from langchain.agents import AgentExecutor from pydantic import BaseModel from langflow.template.field.base import TemplateField @@ -103,3 +103,21 @@ class CustomChain(Chain, ABC): def run(self, *args, **kwargs): return super().run(*args, **kwargs) + + +class CustomAgentExecutor(AgentExecutor, ABC): + """Custom chain""" + + @staticmethod + def function_name(): + return "CustomChain" + + @classmethod + def initialize(cls, *args, **kwargs): + pass + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def run(self, *args, **kwargs): + return super().run(*args, **kwargs) diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 8bde1565c..fbdba0a9c 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -71,4 +71,4 @@ textsplitter_type_to_cls_dict: dict[str, Any] = dict( ) # merge CUSTOM_AGENTS and CUSTOM_CHAINS -CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} +CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore