From d6cdb3eb969df202b2aed87bfd4511e3c6ae5707 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 26 Feb 2024 13:23:13 -0300 Subject: [PATCH] Add SQLDatabaseChain and SQLGenerator components --- .../components/chains/SQLDatabaseChain.py | 25 -------- .../components/chains/SQLGenerator.py | 62 +++++++++++++++++++ .../components/utilities/SQLDatabase.py | 22 +++++++ .../components/utilities/SQLExecutor.py | 56 +++++++++++++++++ 4 files changed, 140 insertions(+), 25 deletions(-) delete mode 100644 src/backend/langflow/components/chains/SQLDatabaseChain.py create mode 100644 src/backend/langflow/components/chains/SQLGenerator.py create mode 100644 src/backend/langflow/components/utilities/SQLDatabase.py create mode 100644 src/backend/langflow/components/utilities/SQLExecutor.py diff --git a/src/backend/langflow/components/chains/SQLDatabaseChain.py b/src/backend/langflow/components/chains/SQLDatabaseChain.py deleted file mode 100644 index 56bd433ba..000000000 --- a/src/backend/langflow/components/chains/SQLDatabaseChain.py +++ /dev/null @@ -1,25 +0,0 @@ -from langflow import CustomComponent -from typing import Callable, Union -from langflow.field_typing import BasePromptTemplate, BaseLanguageModel, Chain -from langchain_community.utilities.sql_database import SQLDatabase -from langchain_experimental.sql.base import SQLDatabaseChain - - -class SQLDatabaseChainComponent(CustomComponent): - display_name = "SQLDatabaseChain" - description = "" - - def build_config(self): - return { - "db": {"display_name": "Database"}, - "llm": {"display_name": "LLM"}, - "prompt": {"display_name": "Prompt"}, - } - - def build( - self, - db: SQLDatabase, - llm: BaseLanguageModel, - prompt: BasePromptTemplate, - ) -> Union[Chain, Callable, SQLDatabaseChain]: - return SQLDatabaseChain.from_llm(llm=llm, db=db, prompt=prompt) diff --git a/src/backend/langflow/components/chains/SQLGenerator.py b/src/backend/langflow/components/chains/SQLGenerator.py new file mode 100644 index 000000000..3b6347ca4 --- /dev/null +++ b/src/backend/langflow/components/chains/SQLGenerator.py @@ -0,0 +1,62 @@ +from typing import Optional + +from langchain.chains import create_sql_query_chain +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_core.prompts import PromptTemplate + +from langflow import CustomComponent +from langflow.field_typing import BaseLanguageModel, Text + + +class SQLGeneratorComponent(CustomComponent): + display_name = "Natural Language to SQL" + description = "Generate SQL from natural language." + + def build_config(self): + return { + "db": {"display_name": "Database"}, + "llm": {"display_name": "LLM"}, + "prompt": { + "display_name": "Prompt", + "info": "The prompt must contain `{question}`.", + }, + "top_k": { + "display_name": "Top K", + "info": "The number of results per select statement to return. If 0, no limit.", + }, + } + + def build( + self, + inputs: Text, + db: SQLDatabase, + llm: BaseLanguageModel, + top_k: int = 5, + prompt: Optional[PromptTemplate] = None, + ) -> Text: + if top_k > 0: + kwargs = { + "k": top_k, + } + if not prompt: + sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs) + else: + template = prompt.template if hasattr(prompt, "template") else prompt + # Check if {question} is in the prompt + if ( + "{question}" not in template + or "question" not in template.input_variables + ): + raise ValueError( + "Prompt must contain `{question}` to be used with Natural Language to SQL." + ) + sql_query_chain = create_sql_query_chain( + llm=llm, db=db, prompt=prompt, **kwargs + ) + query_writer = sql_query_chain | { + "query": lambda x: x.replace("SQLQuery:", "").strip() + } + response = query_writer.invoke({"question": inputs}) + query = response.get("query") + self.status = query + return query diff --git a/src/backend/langflow/components/utilities/SQLDatabase.py b/src/backend/langflow/components/utilities/SQLDatabase.py new file mode 100644 index 000000000..ddd1be318 --- /dev/null +++ b/src/backend/langflow/components/utilities/SQLDatabase.py @@ -0,0 +1,22 @@ +from langchain_experimental.sql.base import SQLDatabase + +from langflow import CustomComponent + + +class SQLDatabaseComponent(CustomComponent): + display_name = "SQLDatabase" + description = "SQL Database" + + def build_config(self): + return { + "uri": {"display_name": "URI", "info": "URI to the database."}, + } + + def clean_up_uri(self, uri: str) -> str: + if uri.startswith("postgresql://"): + uri = uri.replace("postgresql://", "postgres://") + return uri.strip() + + def build(self, uri: str) -> SQLDatabase: + uri = self.clean_up_uri(uri) + return SQLDatabase.from_uri(uri) diff --git a/src/backend/langflow/components/utilities/SQLExecutor.py b/src/backend/langflow/components/utilities/SQLExecutor.py new file mode 100644 index 000000000..da6fe98fc --- /dev/null +++ b/src/backend/langflow/components/utilities/SQLExecutor.py @@ -0,0 +1,56 @@ +from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from langchain_experimental.sql.base import SQLDatabase + +from langflow import CustomComponent +from langflow.field_typing import Text + + +class SQLExecutorComponent(CustomComponent): + display_name = "SQL Executor" + description = "Execute SQL query." + + def build_config(self): + return { + "database": {"display_name": "Database"}, + "include_columns": { + "display_name": "Include Columns", + "info": "Include columns in the result.", + }, + "passthrough": { + "display_name": "Passthrough", + "info": "If an error occurs, return the query instead of raising an exception.", + }, + "add_error": { + "display_name": "Add Error", + "info": "Add the error to the result.", + }, + } + + def build( + self, + query: str, + database: SQLDatabase, + include_columns: bool = False, + passthrough: bool = False, + add_error: bool = False, + ) -> Text: + error = None + try: + tool = QuerySQLDataBaseTool(db=database) + result = tool.run(query, include_columns=include_columns) + self.status = result + except Exception as e: + result = str(e) + self.status = result + if not passthrough: + raise e + error = repr(e) + + if add_error and error is not None: + result = f"{result}\n\nError: {error}\n\nQuery: {query}" + elif error is not None: + # Then we won't add the error to the result + # but since we are in passthrough mode, we will return the query + result = query + + return result