Add SQLDatabaseChain and SQLGenerator components

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 13:23:13 -03:00
commit d6cdb3eb96
4 changed files with 140 additions and 25 deletions

View file

@ -1,25 +0,0 @@
from langflow import CustomComponent
from typing import Callable, Union
from langflow.field_typing import BasePromptTemplate, BaseLanguageModel, Chain
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_experimental.sql.base import SQLDatabaseChain
class SQLDatabaseChainComponent(CustomComponent):
display_name = "SQLDatabaseChain"
description = ""
def build_config(self):
return {
"db": {"display_name": "Database"},
"llm": {"display_name": "LLM"},
"prompt": {"display_name": "Prompt"},
}
def build(
self,
db: SQLDatabase,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
) -> Union[Chain, Callable, SQLDatabaseChain]:
return SQLDatabaseChain.from_llm(llm=llm, db=db, prompt=prompt)

View file

@ -0,0 +1,62 @@
from typing import Optional
from langchain.chains import create_sql_query_chain
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, Text
class SQLGeneratorComponent(CustomComponent):
display_name = "Natural Language to SQL"
description = "Generate SQL from natural language."
def build_config(self):
return {
"db": {"display_name": "Database"},
"llm": {"display_name": "LLM"},
"prompt": {
"display_name": "Prompt",
"info": "The prompt must contain `{question}`.",
},
"top_k": {
"display_name": "Top K",
"info": "The number of results per select statement to return. If 0, no limit.",
},
}
def build(
self,
inputs: Text,
db: SQLDatabase,
llm: BaseLanguageModel,
top_k: int = 5,
prompt: Optional[PromptTemplate] = None,
) -> Text:
if top_k > 0:
kwargs = {
"k": top_k,
}
if not prompt:
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
else:
template = prompt.template if hasattr(prompt, "template") else prompt
# Check if {question} is in the prompt
if (
"{question}" not in template
or "question" not in template.input_variables
):
raise ValueError(
"Prompt must contain `{question}` to be used with Natural Language to SQL."
)
sql_query_chain = create_sql_query_chain(
llm=llm, db=db, prompt=prompt, **kwargs
)
query_writer = sql_query_chain | {
"query": lambda x: x.replace("SQLQuery:", "").strip()
}
response = query_writer.invoke({"question": inputs})
query = response.get("query")
self.status = query
return query

View file

@ -0,0 +1,22 @@
from langchain_experimental.sql.base import SQLDatabase
from langflow import CustomComponent
class SQLDatabaseComponent(CustomComponent):
display_name = "SQLDatabase"
description = "SQL Database"
def build_config(self):
return {
"uri": {"display_name": "URI", "info": "URI to the database."},
}
def clean_up_uri(self, uri: str) -> str:
if uri.startswith("postgresql://"):
uri = uri.replace("postgresql://", "postgres://")
return uri.strip()
def build(self, uri: str) -> SQLDatabase:
uri = self.clean_up_uri(uri)
return SQLDatabase.from_uri(uri)

View file

@ -0,0 +1,56 @@
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_experimental.sql.base import SQLDatabase
from langflow import CustomComponent
from langflow.field_typing import Text
class SQLExecutorComponent(CustomComponent):
display_name = "SQL Executor"
description = "Execute SQL query."
def build_config(self):
return {
"database": {"display_name": "Database"},
"include_columns": {
"display_name": "Include Columns",
"info": "Include columns in the result.",
},
"passthrough": {
"display_name": "Passthrough",
"info": "If an error occurs, return the query instead of raising an exception.",
},
"add_error": {
"display_name": "Add Error",
"info": "Add the error to the result.",
},
}
def build(
self,
query: str,
database: SQLDatabase,
include_columns: bool = False,
passthrough: bool = False,
add_error: bool = False,
) -> Text:
error = None
try:
tool = QuerySQLDataBaseTool(db=database)
result = tool.run(query, include_columns=include_columns)
self.status = result
except Exception as e:
result = str(e)
self.status = result
if not passthrough:
raise e
error = repr(e)
if add_error and error is not None:
result = f"{result}\n\nError: {error}\n\nQuery: {query}"
elif error is not None:
# Then we won't add the error to the result
# but since we are in passthrough mode, we will return the query
result = query
return result