refactor(agents/custom.py): create CustomAgentExecutor abstract class and make all agents inherit from it
fix(agents/custom.py): fix import error for NotEnoughElementsException refactor(run.py): remove unused import statement for loading module
This commit is contained in:
parent
fc20775fdb
commit
4046a7b1e5
2 changed files with 26 additions and 8 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
|
|
@ -33,7 +34,25 @@ from langchain.tools.python.tool import PythonAstREPLTool
|
|||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class JsonAgent(AgentExecutor):
|
||||
class CustomAgentExecutor(AgentExecutor, ABC):
|
||||
"""Custom agent executor"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomAgentExecutor"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class JsonAgent(CustomAgentExecutor):
|
||||
"""Json agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -69,7 +88,7 @@ class JsonAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CSVAgent(AgentExecutor):
|
||||
class CSVAgent(CustomAgentExecutor):
|
||||
"""CSV agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -117,7 +136,7 @@ class CSVAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreAgent(AgentExecutor):
|
||||
class VectorStoreAgent(CustomAgentExecutor):
|
||||
"""Vector Store agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -155,7 +174,7 @@ class VectorStoreAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class SQLAgent(AgentExecutor):
|
||||
class SQLAgent(CustomAgentExecutor):
|
||||
"""SQL agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -227,7 +246,7 @@ class SQLAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreRouterAgent(AgentExecutor):
|
||||
class VectorStoreRouterAgent(CustomAgentExecutor):
|
||||
"""Vector Store Router Agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -266,7 +285,7 @@ class VectorStoreRouterAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class InitializeAgent(AgentExecutor):
|
||||
class InitializeAgent(CustomAgentExecutor):
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -2,12 +2,11 @@ import contextlib
|
|||
import io
|
||||
from typing import Any, Dict
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from chromadb.errors import NotEnoughElementsException # type: ignore
|
||||
|
||||
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
|
||||
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue