🔨 refactor(custom.py): rename CustomChain to CustomAgentExecutor
🔨 refactor(base.py): add CustomAgentExecutor class and move CustomChain to base.py 🔨 refactor(custom_lists.py): update reference to CustomChain to CustomAgentExecutor The CustomChain class has been renamed to CustomAgentExecutor to better reflect its purpose. The class has been moved to base.py and a new CustomAgentExecutor class has been added to custom.py. The reference to CustomChain in custom_lists.py has been updated to CustomAgentExecutor. These changes improve the semantics of the code and make it easier to understand the purpose of the classes.
This commit is contained in:
parent
f0975ddf63
commit
04bd0f43fb
3 changed files with 27 additions and 9 deletions
|
|
@ -32,10 +32,10 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langflow.interface.base import CustomChain
|
||||
from langflow.interface.base import CustomAgentExecutor
|
||||
|
||||
|
||||
class JsonAgent(CustomChain):
|
||||
class JsonAgent(CustomAgentExecutor):
|
||||
"""Json agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -71,7 +71,7 @@ class JsonAgent(CustomChain):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CSVAgent(CustomChain):
|
||||
class CSVAgent(CustomAgentExecutor):
|
||||
"""CSV agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -119,7 +119,7 @@ class CSVAgent(CustomChain):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreAgent(CustomChain):
|
||||
class VectorStoreAgent(CustomAgentExecutor):
|
||||
"""Vector Store agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -157,7 +157,7 @@ class VectorStoreAgent(CustomChain):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class SQLAgent(CustomChain):
|
||||
class SQLAgent(CustomAgentExecutor):
|
||||
"""SQL agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -229,7 +229,7 @@ class SQLAgent(CustomChain):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreRouterAgent(CustomChain):
|
||||
class VectorStoreRouterAgent(CustomAgentExecutor):
|
||||
"""Vector Store Router Agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -268,7 +268,7 @@ class VectorStoreRouterAgent(CustomChain):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class InitializeAgent(CustomChain):
|
||||
class InitializeAgent(CustomAgentExecutor):
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
|
|
@ -103,3 +103,21 @@ class CustomChain(Chain, ABC):
|
|||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor, ABC):
|
||||
"""Custom chain"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomChain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -71,4 +71,4 @@ textsplitter_type_to_cls_dict: dict[str, Any] = dict(
|
|||
)
|
||||
|
||||
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
|
||||
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS}
|
||||
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue