Add SQLDatabaseChain and SQLGenerator components
This commit is contained in:
parent
e3c5112899
commit
d6cdb3eb96
4 changed files with 140 additions and 25 deletions
|
|
@ -1,25 +0,0 @@
|
|||
from langflow import CustomComponent
|
||||
from typing import Callable, Union
|
||||
from langflow.field_typing import BasePromptTemplate, BaseLanguageModel, Chain
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_experimental.sql.base import SQLDatabaseChain
|
||||
|
||||
|
||||
class SQLDatabaseChainComponent(CustomComponent):
|
||||
display_name = "SQLDatabaseChain"
|
||||
description = ""
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"db": {"display_name": "Database"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"prompt": {"display_name": "Prompt"},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
db: SQLDatabase,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
) -> Union[Chain, Callable, SQLDatabaseChain]:
|
||||
return SQLDatabaseChain.from_llm(llm=llm, db=db, prompt=prompt)
|
||||
62
src/backend/langflow/components/chains/SQLGenerator.py
Normal file
62
src/backend/langflow/components/chains/SQLGenerator.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains import create_sql_query_chain
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, Text
|
||||
|
||||
|
||||
class SQLGeneratorComponent(CustomComponent):
|
||||
display_name = "Natural Language to SQL"
|
||||
description = "Generate SQL from natural language."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"db": {"display_name": "Database"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"prompt": {
|
||||
"display_name": "Prompt",
|
||||
"info": "The prompt must contain `{question}`.",
|
||||
},
|
||||
"top_k": {
|
||||
"display_name": "Top K",
|
||||
"info": "The number of results per select statement to return. If 0, no limit.",
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
inputs: Text,
|
||||
db: SQLDatabase,
|
||||
llm: BaseLanguageModel,
|
||||
top_k: int = 5,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
) -> Text:
|
||||
if top_k > 0:
|
||||
kwargs = {
|
||||
"k": top_k,
|
||||
}
|
||||
if not prompt:
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
|
||||
else:
|
||||
template = prompt.template if hasattr(prompt, "template") else prompt
|
||||
# Check if {question} is in the prompt
|
||||
if (
|
||||
"{question}" not in template
|
||||
or "question" not in template.input_variables
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt must contain `{question}` to be used with Natural Language to SQL."
|
||||
)
|
||||
sql_query_chain = create_sql_query_chain(
|
||||
llm=llm, db=db, prompt=prompt, **kwargs
|
||||
)
|
||||
query_writer = sql_query_chain | {
|
||||
"query": lambda x: x.replace("SQLQuery:", "").strip()
|
||||
}
|
||||
response = query_writer.invoke({"question": inputs})
|
||||
query = response.get("query")
|
||||
self.status = query
|
||||
return query
|
||||
22
src/backend/langflow/components/utilities/SQLDatabase.py
Normal file
22
src/backend/langflow/components/utilities/SQLDatabase.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from langchain_experimental.sql.base import SQLDatabase
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class SQLDatabaseComponent(CustomComponent):
|
||||
display_name = "SQLDatabase"
|
||||
description = "SQL Database"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"uri": {"display_name": "URI", "info": "URI to the database."},
|
||||
}
|
||||
|
||||
def clean_up_uri(self, uri: str) -> str:
|
||||
if uri.startswith("postgresql://"):
|
||||
uri = uri.replace("postgresql://", "postgres://")
|
||||
return uri.strip()
|
||||
|
||||
def build(self, uri: str) -> SQLDatabase:
|
||||
uri = self.clean_up_uri(uri)
|
||||
return SQLDatabase.from_uri(uri)
|
||||
56
src/backend/langflow/components/utilities/SQLExecutor.py
Normal file
56
src/backend/langflow/components/utilities/SQLExecutor.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
||||
from langchain_experimental.sql.base import SQLDatabase
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
class SQLExecutorComponent(CustomComponent):
|
||||
display_name = "SQL Executor"
|
||||
description = "Execute SQL query."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"database": {"display_name": "Database"},
|
||||
"include_columns": {
|
||||
"display_name": "Include Columns",
|
||||
"info": "Include columns in the result.",
|
||||
},
|
||||
"passthrough": {
|
||||
"display_name": "Passthrough",
|
||||
"info": "If an error occurs, return the query instead of raising an exception.",
|
||||
},
|
||||
"add_error": {
|
||||
"display_name": "Add Error",
|
||||
"info": "Add the error to the result.",
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
query: str,
|
||||
database: SQLDatabase,
|
||||
include_columns: bool = False,
|
||||
passthrough: bool = False,
|
||||
add_error: bool = False,
|
||||
) -> Text:
|
||||
error = None
|
||||
try:
|
||||
tool = QuerySQLDataBaseTool(db=database)
|
||||
result = tool.run(query, include_columns=include_columns)
|
||||
self.status = result
|
||||
except Exception as e:
|
||||
result = str(e)
|
||||
self.status = result
|
||||
if not passthrough:
|
||||
raise e
|
||||
error = repr(e)
|
||||
|
||||
if add_error and error is not None:
|
||||
result = f"{result}\n\nError: {error}\n\nQuery: {query}"
|
||||
elif error is not None:
|
||||
# Then we won't add the error to the result
|
||||
# but since we are in passthrough mode, we will return the query
|
||||
result = query
|
||||
|
||||
return result
|
||||
Loading…
Add table
Add a link
Reference in a new issue