feat: add SQL agent
This commit is contained in:
parent
885d0ffc22
commit
ec0c6c79c1
8 changed files with 143 additions and 5 deletions
|
|
@ -3,7 +3,7 @@ FROM python:3.10-slim
|
|||
WORKDIR /app
|
||||
|
||||
# Install Poetry
|
||||
RUN apt-get update && apt-get install gcc g++ curl build-essential -y
|
||||
RUN apt-get update && apt-get install gcc g++ curl build-essential postgresql-server-dev-all -y
|
||||
RUN curl -sSL https://install.python-poetry.org | python3 -
|
||||
# # Add Poetry to PATH
|
||||
ENV PATH="${PATH}:/root/.local/bin"
|
||||
|
|
|
|||
25
poetry.lock
generated
25
poetry.lock
generated
|
|
@ -2823,6 +2823,29 @@ files = [
|
|||
[package.extras]
|
||||
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
||||
|
||||
[[package]]
|
||||
name = "psycopg2"
|
||||
version = "2.9.6"
|
||||
description = "psycopg2 - Python-PostgreSQL Database Adapter"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "psycopg2-2.9.6-cp310-cp310-win32.whl", hash = "sha256:f7a7a5ee78ba7dc74265ba69e010ae89dae635eea0e97b055fb641a01a31d2b1"},
|
||||
{file = "psycopg2-2.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:f75001a1cbbe523e00b0ef896a5a1ada2da93ccd752b7636db5a99bc57c44494"},
|
||||
{file = "psycopg2-2.9.6-cp311-cp311-win32.whl", hash = "sha256:53f4ad0a3988f983e9b49a5d9765d663bbe84f508ed655affdb810af9d0972ad"},
|
||||
{file = "psycopg2-2.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:b81fcb9ecfc584f661b71c889edeae70bae30d3ef74fa0ca388ecda50b1222b7"},
|
||||
{file = "psycopg2-2.9.6-cp36-cp36m-win32.whl", hash = "sha256:11aca705ec888e4f4cea97289a0bf0f22a067a32614f6ef64fcf7b8bfbc53744"},
|
||||
{file = "psycopg2-2.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:36c941a767341d11549c0fbdbb2bf5be2eda4caf87f65dfcd7d146828bd27f39"},
|
||||
{file = "psycopg2-2.9.6-cp37-cp37m-win32.whl", hash = "sha256:869776630c04f335d4124f120b7fb377fe44b0a7645ab3c34b4ba42516951889"},
|
||||
{file = "psycopg2-2.9.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a8ad4a47f42aa6aec8d061fdae21eaed8d864d4bb0f0cade5ad32ca16fcd6258"},
|
||||
{file = "psycopg2-2.9.6-cp38-cp38-win32.whl", hash = "sha256:2362ee4d07ac85ff0ad93e22c693d0f37ff63e28f0615a16b6635a645f4b9214"},
|
||||
{file = "psycopg2-2.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:d24ead3716a7d093b90b27b3d73459fe8cd90fd7065cf43b3c40966221d8c394"},
|
||||
{file = "psycopg2-2.9.6-cp39-cp39-win32.whl", hash = "sha256:1861a53a6a0fd248e42ea37c957d36950da00266378746588eab4f4b5649e95f"},
|
||||
{file = "psycopg2-2.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:ded2faa2e6dfb430af7713d87ab4abbfc764d8d7fb73eafe96a24155f906ebf5"},
|
||||
{file = "psycopg2-2.9.6.tar.gz", hash = "sha256:f15158418fd826831b28585e2ab48ed8df2d0d98f502a2b4fe619e7d5ca29011"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ptyprocess"
|
||||
version = "0.7.0"
|
||||
|
|
@ -4712,4 +4735,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "313af2197643e0cbac777d28d0a996e1f61f7c1f809a375b5196cc5942c5cc4b"
|
||||
content-hash = "9cc95816c966ef64f89ff6a825618c0051113e9465d76f74059f757df8d26c07"
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ lxml = "^4.9.2"
|
|||
pysrt = "^1.1.2"
|
||||
fake-useragent = "^1.1.3"
|
||||
docstring-parser = "^0.15"
|
||||
psycopg2 = "^2.9.6"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.1.0"
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ agents:
|
|||
- initialize_agent
|
||||
- VectorStoreAgent
|
||||
- VectorStoreRouterAgent
|
||||
- SQLAgent
|
||||
|
||||
prompts:
|
||||
- PromptTemplate
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ CUSTOM_NODES = {
|
|||
"initialize_agent": nodes.InitializeAgentNode(),
|
||||
"VectorStoreAgent": nodes.VectorStoreAgentNode(),
|
||||
"VectorStoreRouterAgent": nodes.VectorStoreRouterAgentNode(),
|
||||
"SQLAgent": nodes.SQLAgentNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ class AgentNode(Node):
|
|||
|
||||
self._build()
|
||||
|
||||
#! Cannot deepcopy VectorStore
|
||||
if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent"]:
|
||||
#! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents
|
||||
if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent", "SQLAgent"]:
|
||||
return self._built_object
|
||||
return deepcopy(self._built_object)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent, initialize_agent
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
Tool,
|
||||
ZeroShotAgent,
|
||||
initialize_agent,
|
||||
)
|
||||
from langchain.agents.agent_toolkits import (
|
||||
SQLDatabaseToolkit,
|
||||
VectorStoreInfo,
|
||||
VectorStoreRouterToolkit,
|
||||
VectorStoreToolkit,
|
||||
|
|
@ -11,6 +17,7 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
|||
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
|
||||
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
|
||||
from langchain.agents.agent_toolkits.vectorstore.prompt import (
|
||||
PREFIX as VECTORSTORE_PREFIX,
|
||||
)
|
||||
|
|
@ -18,10 +25,13 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import (
|
|||
ROUTER_PREFIX as VECTORSTORE_ROUTER_PREFIX,
|
||||
)
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class JsonAgent(AgentExecutor):
|
||||
|
|
@ -146,6 +156,76 @@ class VectorStoreAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class SQLAgent(AgentExecutor):
|
||||
"""SQL agent"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "SQLAgent"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
return cls.from_toolkit_and_llm(*args, **kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(cls, llm: BaseLLM, database_uri: str, **kwargs: Any):
|
||||
"""Construct a sql agent from an LLM and tools."""
|
||||
db = SQLDatabase.from_uri(database_uri)
|
||||
toolkit = SQLDatabaseToolkit(db=db)
|
||||
|
||||
# The right code should be this, but there is a problem with tools = toolkit.get_tools()
|
||||
# related to `OPENAI_API_KEY`
|
||||
# return create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QueryCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
)
|
||||
|
||||
llmchain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
|
||||
tools = [
|
||||
QuerySQLDataBaseTool(db=db), # type: ignore
|
||||
InfoSQLDatabaseTool(db=db), # type: ignore
|
||||
ListSQLDatabaseTool(db=db), # type: ignore
|
||||
QueryCheckerTool(db=db, llm_chain=llmchain), # type: ignore
|
||||
]
|
||||
|
||||
prefix = SQL_PREFIX.format(dialect=toolkit.dialect, top_k=10)
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools=tools, # type: ignore
|
||||
prefix=prefix,
|
||||
suffix=SQL_SUFFIX,
|
||||
format_instructions=SQL_FORMAT_INSTRUCTIONS,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools] # type: ignore
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools, # type: ignore
|
||||
verbose=True,
|
||||
max_iterations=15,
|
||||
early_stopping_method="force",
|
||||
)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreRouterAgent(AgentExecutor):
|
||||
"""Vector Store Router Agent"""
|
||||
|
||||
|
|
@ -218,4 +298,5 @@ CUSTOM_AGENTS = {
|
|||
"initialize_agent": InitializeAgent,
|
||||
"VectorStoreAgent": VectorStoreAgent,
|
||||
"VectorStoreRouterAgent": VectorStoreRouterAgent,
|
||||
"SQLAgent": SQLAgent,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -304,6 +304,37 @@ class VectorStoreRouterAgentNode(FrontendNode):
|
|||
return super().to_dict()
|
||||
|
||||
|
||||
class SQLAgentNode(FrontendNode):
|
||||
name: str = "SQLAgent"
|
||||
template: Template = Template(
|
||||
type_name="sql_agent",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
value="",
|
||||
name="database_uri",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
show=True,
|
||||
name="llm",
|
||||
display_name="LLM",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct an agent from a Vector Store Router."""
|
||||
base_classes: list[str] = ["AgentExecutor"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
|
||||
class PromptFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue