diff --git a/dev.Dockerfile b/dev.Dockerfile index be9688f75..b38929db2 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -3,7 +3,7 @@ FROM python:3.10-slim WORKDIR /app # Install Poetry -RUN apt-get update && apt-get install gcc g++ curl build-essential -y +RUN apt-get update && apt-get install gcc g++ curl build-essential postgresql-server-dev-all -y RUN curl -sSL https://install.python-poetry.org | python3 - # # Add Poetry to PATH ENV PATH="${PATH}:/root/.local/bin" diff --git a/poetry.lock b/poetry.lock index 27ade8589..5361880ee 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2823,6 +2823,29 @@ files = [ [package.extras] test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +[[package]] +name = "psycopg2" +version = "2.9.6" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "psycopg2-2.9.6-cp310-cp310-win32.whl", hash = "sha256:f7a7a5ee78ba7dc74265ba69e010ae89dae635eea0e97b055fb641a01a31d2b1"}, + {file = "psycopg2-2.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:f75001a1cbbe523e00b0ef896a5a1ada2da93ccd752b7636db5a99bc57c44494"}, + {file = "psycopg2-2.9.6-cp311-cp311-win32.whl", hash = "sha256:53f4ad0a3988f983e9b49a5d9765d663bbe84f508ed655affdb810af9d0972ad"}, + {file = "psycopg2-2.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:b81fcb9ecfc584f661b71c889edeae70bae30d3ef74fa0ca388ecda50b1222b7"}, + {file = "psycopg2-2.9.6-cp36-cp36m-win32.whl", hash = "sha256:11aca705ec888e4f4cea97289a0bf0f22a067a32614f6ef64fcf7b8bfbc53744"}, + {file = "psycopg2-2.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:36c941a767341d11549c0fbdbb2bf5be2eda4caf87f65dfcd7d146828bd27f39"}, + {file = "psycopg2-2.9.6-cp37-cp37m-win32.whl", hash = "sha256:869776630c04f335d4124f120b7fb377fe44b0a7645ab3c34b4ba42516951889"}, + {file = "psycopg2-2.9.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a8ad4a47f42aa6aec8d061fdae21eaed8d864d4bb0f0cade5ad32ca16fcd6258"}, + {file = "psycopg2-2.9.6-cp38-cp38-win32.whl", hash = "sha256:2362ee4d07ac85ff0ad93e22c693d0f37ff63e28f0615a16b6635a645f4b9214"}, + {file = "psycopg2-2.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:d24ead3716a7d093b90b27b3d73459fe8cd90fd7065cf43b3c40966221d8c394"}, + {file = "psycopg2-2.9.6-cp39-cp39-win32.whl", hash = "sha256:1861a53a6a0fd248e42ea37c957d36950da00266378746588eab4f4b5649e95f"}, + {file = "psycopg2-2.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:ded2faa2e6dfb430af7713d87ab4abbfc764d8d7fb73eafe96a24155f906ebf5"}, + {file = "psycopg2-2.9.6.tar.gz", hash = "sha256:f15158418fd826831b28585e2ab48ed8df2d0d98f502a2b4fe619e7d5ca29011"}, +] + [[package]] name = "ptyprocess" version = "0.7.0" @@ -4712,4 +4735,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "313af2197643e0cbac777d28d0a996e1f61f7c1f809a375b5196cc5942c5cc4b" +content-hash = "9cc95816c966ef64f89ff6a825618c0051113e9465d76f74059f757df8d26c07" diff --git a/pyproject.toml b/pyproject.toml index b069afc3a..a8959cf0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ lxml = "^4.9.2" pysrt = "^1.1.2" fake-useragent = "^1.1.3" docstring-parser = "^0.15" +psycopg2 = "^2.9.6" [tool.poetry.group.dev.dependencies] black = "^23.1.0" diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index b3ee373e7..7f00167b4 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -14,6 +14,7 @@ agents: - initialize_agent - VectorStoreAgent - VectorStoreRouterAgent + - SQLAgent prompts: - PromptTemplate diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index 000bf890b..e77b81ec6 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -10,6 +10,7 @@ CUSTOM_NODES = { "initialize_agent": nodes.InitializeAgentNode(), "VectorStoreAgent": nodes.VectorStoreAgentNode(), "VectorStoreRouterAgent": nodes.VectorStoreRouterAgentNode(), + "SQLAgent": nodes.SQLAgentNode(), }, } diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/nodes.py index 5c762b540..7296a0c0d 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/nodes.py @@ -33,8 +33,8 @@ class AgentNode(Node): self._build() - #! Cannot deepcopy VectorStore - if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent"]: + #! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents + if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent", "SQLAgent"]: return self._built_object return deepcopy(self._built_object) diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index 9f6d15257..851bd2af8 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -1,8 +1,14 @@ from typing import Any, List, Optional from langchain import LLMChain -from langchain.agents import AgentExecutor, Tool, ZeroShotAgent, initialize_agent +from langchain.agents import ( + AgentExecutor, + Tool, + ZeroShotAgent, + initialize_agent, +) from langchain.agents.agent_toolkits import ( + SQLDatabaseToolkit, VectorStoreInfo, VectorStoreRouterToolkit, VectorStoreToolkit, @@ -11,6 +17,7 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX +from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX from langchain.agents.agent_toolkits.vectorstore.prompt import ( PREFIX as VECTORSTORE_PREFIX, ) @@ -18,10 +25,13 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import ( ROUTER_PREFIX as VECTORSTORE_ROUTER_PREFIX, ) from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS from langchain.llms.base import BaseLLM from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import BaseLanguageModel +from langchain.sql_database import SQLDatabase from langchain.tools.python.tool import PythonAstREPLTool +from langchain.tools.sql_database.prompt import QUERY_CHECKER class JsonAgent(AgentExecutor): @@ -146,6 +156,76 @@ class VectorStoreAgent(AgentExecutor): return super().run(*args, **kwargs) +class SQLAgent(AgentExecutor): + """SQL agent""" + + @staticmethod + def function_name(): + return "SQLAgent" + + @classmethod + def initialize(cls, *args, **kwargs): + return cls.from_toolkit_and_llm(*args, **kwargs) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def from_toolkit_and_llm(cls, llm: BaseLLM, database_uri: str, **kwargs: Any): + """Construct a sql agent from an LLM and tools.""" + db = SQLDatabase.from_uri(database_uri) + toolkit = SQLDatabaseToolkit(db=db) + + # The right code should be this, but there is a problem with tools = toolkit.get_tools() + # related to `OPENAI_API_KEY` + # return create_sql_agent(llm=llm, toolkit=toolkit, verbose=True) + from langchain.prompts import PromptTemplate + from langchain.tools.sql_database.tool import ( + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QueryCheckerTool, + QuerySQLDataBaseTool, + ) + + llmchain = LLMChain( + llm=llm, + prompt=PromptTemplate( + template=QUERY_CHECKER, input_variables=["query", "dialect"] + ), + ) + + tools = [ + QuerySQLDataBaseTool(db=db), # type: ignore + InfoSQLDatabaseTool(db=db), # type: ignore + ListSQLDatabaseTool(db=db), # type: ignore + QueryCheckerTool(db=db, llm_chain=llmchain), # type: ignore + ] + + prefix = SQL_PREFIX.format(dialect=toolkit.dialect, top_k=10) + prompt = ZeroShotAgent.create_prompt( + tools=tools, # type: ignore + prefix=prefix, + suffix=SQL_SUFFIX, + format_instructions=SQL_FORMAT_INSTRUCTIONS, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + ) + tool_names = [tool.name for tool in tools] # type: ignore + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, # type: ignore + verbose=True, + max_iterations=15, + early_stopping_method="force", + ) + + def run(self, *args, **kwargs): + return super().run(*args, **kwargs) + + class VectorStoreRouterAgent(AgentExecutor): """Vector Store Router Agent""" @@ -218,4 +298,5 @@ CUSTOM_AGENTS = { "initialize_agent": InitializeAgent, "VectorStoreAgent": VectorStoreAgent, "VectorStoreRouterAgent": VectorStoreRouterAgent, + "SQLAgent": SQLAgent, } diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index e45f16531..c31f3c40f 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -304,6 +304,37 @@ class VectorStoreRouterAgentNode(FrontendNode): return super().to_dict() +class SQLAgentNode(FrontendNode): + name: str = "SQLAgent" + template: Template = Template( + type_name="sql_agent", + fields=[ + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + multiline=False, + value="", + name="database_uri", + ), + TemplateField( + field_type="BaseLanguageModel", + required=True, + show=True, + name="llm", + display_name="LLM", + ), + ], + ) + description: str = """Construct an agent from a Vector Store Router.""" + base_classes: list[str] = ["AgentExecutor"] + + def to_dict(self): + return super().to_dict() + + class PromptFrontendNode(FrontendNode): @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: