feat: add SQL agent

This commit is contained in:
Ibis Prevedello 2023-04-13 22:33:46 -03:00
commit ec0c6c79c1
8 changed files with 143 additions and 5 deletions

View file

@ -3,7 +3,7 @@ FROM python:3.10-slim
WORKDIR /app
# Install Poetry
RUN apt-get update && apt-get install gcc g++ curl build-essential -y
RUN apt-get update && apt-get install gcc g++ curl build-essential postgresql-server-dev-all -y
RUN curl -sSL https://install.python-poetry.org | python3 -
# # Add Poetry to PATH
ENV PATH="${PATH}:/root/.local/bin"

25
poetry.lock generated
View file

@ -2823,6 +2823,29 @@ files = [
[package.extras]
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
name = "psycopg2"
version = "2.9.6"
description = "psycopg2 - Python-PostgreSQL Database Adapter"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "psycopg2-2.9.6-cp310-cp310-win32.whl", hash = "sha256:f7a7a5ee78ba7dc74265ba69e010ae89dae635eea0e97b055fb641a01a31d2b1"},
{file = "psycopg2-2.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:f75001a1cbbe523e00b0ef896a5a1ada2da93ccd752b7636db5a99bc57c44494"},
{file = "psycopg2-2.9.6-cp311-cp311-win32.whl", hash = "sha256:53f4ad0a3988f983e9b49a5d9765d663bbe84f508ed655affdb810af9d0972ad"},
{file = "psycopg2-2.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:b81fcb9ecfc584f661b71c889edeae70bae30d3ef74fa0ca388ecda50b1222b7"},
{file = "psycopg2-2.9.6-cp36-cp36m-win32.whl", hash = "sha256:11aca705ec888e4f4cea97289a0bf0f22a067a32614f6ef64fcf7b8bfbc53744"},
{file = "psycopg2-2.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:36c941a767341d11549c0fbdbb2bf5be2eda4caf87f65dfcd7d146828bd27f39"},
{file = "psycopg2-2.9.6-cp37-cp37m-win32.whl", hash = "sha256:869776630c04f335d4124f120b7fb377fe44b0a7645ab3c34b4ba42516951889"},
{file = "psycopg2-2.9.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a8ad4a47f42aa6aec8d061fdae21eaed8d864d4bb0f0cade5ad32ca16fcd6258"},
{file = "psycopg2-2.9.6-cp38-cp38-win32.whl", hash = "sha256:2362ee4d07ac85ff0ad93e22c693d0f37ff63e28f0615a16b6635a645f4b9214"},
{file = "psycopg2-2.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:d24ead3716a7d093b90b27b3d73459fe8cd90fd7065cf43b3c40966221d8c394"},
{file = "psycopg2-2.9.6-cp39-cp39-win32.whl", hash = "sha256:1861a53a6a0fd248e42ea37c957d36950da00266378746588eab4f4b5649e95f"},
{file = "psycopg2-2.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:ded2faa2e6dfb430af7713d87ab4abbfc764d8d7fb73eafe96a24155f906ebf5"},
{file = "psycopg2-2.9.6.tar.gz", hash = "sha256:f15158418fd826831b28585e2ab48ed8df2d0d98f502a2b4fe619e7d5ca29011"},
]
[[package]]
name = "ptyprocess"
version = "0.7.0"
@ -4712,4 +4735,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "313af2197643e0cbac777d28d0a996e1f61f7c1f809a375b5196cc5942c5cc4b"
content-hash = "9cc95816c966ef64f89ff6a825618c0051113e9465d76f74059f757df8d26c07"

View file

@ -45,6 +45,7 @@ lxml = "^4.9.2"
pysrt = "^1.1.2"
fake-useragent = "^1.1.3"
docstring-parser = "^0.15"
psycopg2 = "^2.9.6"
[tool.poetry.group.dev.dependencies]
black = "^23.1.0"

View file

@ -14,6 +14,7 @@ agents:
- initialize_agent
- VectorStoreAgent
- VectorStoreRouterAgent
- SQLAgent
prompts:
- PromptTemplate

View file

@ -10,6 +10,7 @@ CUSTOM_NODES = {
"initialize_agent": nodes.InitializeAgentNode(),
"VectorStoreAgent": nodes.VectorStoreAgentNode(),
"VectorStoreRouterAgent": nodes.VectorStoreRouterAgentNode(),
"SQLAgent": nodes.SQLAgentNode(),
},
}

View file

@ -33,8 +33,8 @@ class AgentNode(Node):
self._build()
#! Cannot deepcopy VectorStore
if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent"]:
#! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents
if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent", "SQLAgent"]:
return self._built_object
return deepcopy(self._built_object)

View file

@ -1,8 +1,14 @@
from typing import Any, List, Optional
from langchain import LLMChain
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent, initialize_agent
from langchain.agents import (
AgentExecutor,
Tool,
ZeroShotAgent,
initialize_agent,
)
from langchain.agents.agent_toolkits import (
SQLDatabaseToolkit,
VectorStoreInfo,
VectorStoreRouterToolkit,
VectorStoreToolkit,
@ -11,6 +17,7 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.vectorstore.prompt import (
PREFIX as VECTORSTORE_PREFIX,
)
@ -18,10 +25,13 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import (
ROUTER_PREFIX as VECTORSTORE_ROUTER_PREFIX,
)
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS
from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseLanguageModel
from langchain.sql_database import SQLDatabase
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class JsonAgent(AgentExecutor):
@ -146,6 +156,76 @@ class VectorStoreAgent(AgentExecutor):
return super().run(*args, **kwargs)
class SQLAgent(AgentExecutor):
"""SQL agent"""
@staticmethod
def function_name():
return "SQLAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(cls, llm: BaseLLM, database_uri: str, **kwargs: Any):
"""Construct a sql agent from an LLM and tools."""
db = SQLDatabase.from_uri(database_uri)
toolkit = SQLDatabaseToolkit(db=db)
# The right code should be this, but there is a problem with tools = toolkit.get_tools()
# related to `OPENAI_API_KEY`
# return create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
from langchain.prompts import PromptTemplate
from langchain.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QueryCheckerTool,
QuerySQLDataBaseTool,
)
llmchain = LLMChain(
llm=llm,
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
),
)
tools = [
QuerySQLDataBaseTool(db=db), # type: ignore
InfoSQLDatabaseTool(db=db), # type: ignore
ListSQLDatabaseTool(db=db), # type: ignore
QueryCheckerTool(db=db, llm_chain=llmchain), # type: ignore
]
prefix = SQL_PREFIX.format(dialect=toolkit.dialect, top_k=10)
prompt = ZeroShotAgent.create_prompt(
tools=tools, # type: ignore
prefix=prefix,
suffix=SQL_SUFFIX,
format_instructions=SQL_FORMAT_INSTRUCTIONS,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
tool_names = [tool.name for tool in tools] # type: ignore
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools, # type: ignore
verbose=True,
max_iterations=15,
early_stopping_method="force",
)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class VectorStoreRouterAgent(AgentExecutor):
"""Vector Store Router Agent"""
@ -218,4 +298,5 @@ CUSTOM_AGENTS = {
"initialize_agent": InitializeAgent,
"VectorStoreAgent": VectorStoreAgent,
"VectorStoreRouterAgent": VectorStoreRouterAgent,
"SQLAgent": SQLAgent,
}

View file

@ -304,6 +304,37 @@ class VectorStoreRouterAgentNode(FrontendNode):
return super().to_dict()
class SQLAgentNode(FrontendNode):
name: str = "SQLAgent"
template: Template = Template(
type_name="sql_agent",
fields=[
TemplateField(
field_type="str",
required=True,
placeholder="",
is_list=False,
show=True,
multiline=False,
value="",
name="database_uri",
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
),
],
)
description: str = """Construct an agent from a Vector Store Router."""
base_classes: list[str] = ["AgentExecutor"]
def to_dict(self):
return super().to_dict()
class PromptFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None: