🔨 refactor(custom.py): rename CustomChain to CustomAgentExecutor

🔨 refactor(base.py): add CustomAgentExecutor class and move CustomChain to base.py
🔨 refactor(custom_lists.py): update reference to CustomChain to CustomAgentExecutor
The CustomChain class has been renamed to CustomAgentExecutor to better reflect its purpose. The class has been moved to base.py and a new CustomAgentExecutor class has been added to custom.py. The reference to CustomChain in custom_lists.py has been updated to CustomAgentExecutor. These changes improve the semantics of the code and make it easier to understand the purpose of the classes.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-07 21:14:40 -03:00
commit 04bd0f43fb
3 changed files with 27 additions and 9 deletions

View file

@ -32,10 +32,10 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.sql_database import SQLDatabase
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
from langflow.interface.base import CustomChain
from langflow.interface.base import CustomAgentExecutor
class JsonAgent(CustomChain):
class JsonAgent(CustomAgentExecutor):
"""Json agent"""
@staticmethod
@ -71,7 +71,7 @@ class JsonAgent(CustomChain):
return super().run(*args, **kwargs)
class CSVAgent(CustomChain):
class CSVAgent(CustomAgentExecutor):
"""CSV agent"""
@staticmethod
@ -119,7 +119,7 @@ class CSVAgent(CustomChain):
return super().run(*args, **kwargs)
class VectorStoreAgent(CustomChain):
class VectorStoreAgent(CustomAgentExecutor):
"""Vector Store agent"""
@staticmethod
@ -157,7 +157,7 @@ class VectorStoreAgent(CustomChain):
return super().run(*args, **kwargs)
class SQLAgent(CustomChain):
class SQLAgent(CustomAgentExecutor):
"""SQL agent"""
@staticmethod
@ -229,7 +229,7 @@ class SQLAgent(CustomChain):
return super().run(*args, **kwargs)
class VectorStoreRouterAgent(CustomChain):
class VectorStoreRouterAgent(CustomAgentExecutor):
"""Vector Store Router Agent"""
@staticmethod
@ -268,7 +268,7 @@ class VectorStoreRouterAgent(CustomChain):
return super().run(*args, **kwargs)
class InitializeAgent(CustomChain):
class InitializeAgent(CustomAgentExecutor):
"""Implementation of initialize_agent function"""
@staticmethod

View file

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Union
from langchain.chains.base import Chain
from langchain.agents import AgentExecutor
from pydantic import BaseModel
from langflow.template.field.base import TemplateField
@ -103,3 +103,21 @@ class CustomChain(Chain, ABC):
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class CustomAgentExecutor(AgentExecutor, ABC):
"""Custom chain"""
@staticmethod
def function_name():
return "CustomChain"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)

View file

@ -71,4 +71,4 @@ textsplitter_type_to_cls_dict: dict[str, Any] = dict(
)
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS}
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore