Refactor SQLGeneratorComponent to handle prompt template
This commit is contained in:
parent
55c595eaea
commit
67aca6dd36
1 changed files with 25 additions and 7 deletions
|
|
@ -32,21 +32,39 @@ class SQLGeneratorComponent(CustomComponent):
|
|||
db: SQLDatabase,
|
||||
llm: BaseLanguageModel,
|
||||
top_k: int = 5,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
prompt: Optional[Text] = None,
|
||||
) -> Text:
|
||||
if prompt:
|
||||
prompt_template = PromptTemplate.from_template(template=prompt)
|
||||
else:
|
||||
prompt_template = None
|
||||
|
||||
if top_k > 0:
|
||||
kwargs = {
|
||||
"k": top_k,
|
||||
}
|
||||
if not prompt:
|
||||
if not prompt_template:
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
|
||||
else:
|
||||
template = prompt.template if hasattr(prompt, "template") else prompt
|
||||
template = (
|
||||
prompt_template.template
|
||||
if hasattr(prompt, "template")
|
||||
else prompt_template
|
||||
)
|
||||
# Check if {question} is in the prompt
|
||||
if "{question}" not in template or "question" not in template.input_variables:
|
||||
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt, **kwargs)
|
||||
query_writer = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
|
||||
if (
|
||||
"{question}" not in template
|
||||
or "question" not in template.input_variables
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt must contain `{question}` to be used with Natural Language to SQL."
|
||||
)
|
||||
sql_query_chain = create_sql_query_chain(
|
||||
llm=llm, db=db, prompt=prompt_template, **kwargs
|
||||
)
|
||||
query_writer = sql_query_chain | {
|
||||
"query": lambda x: x.replace("SQLQuery:", "").strip()
|
||||
}
|
||||
response = query_writer.invoke({"question": inputs})
|
||||
query = response.get("query")
|
||||
self.status = query
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue