Refactor SQLGeneratorComponent to handle prompt template

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 20:29:12 -03:00
commit 67aca6dd36

View file

@ -32,21 +32,39 @@ class SQLGeneratorComponent(CustomComponent):
db: SQLDatabase,
llm: BaseLanguageModel,
top_k: int = 5,
prompt: Optional[PromptTemplate] = None,
prompt: Optional[Text] = None,
) -> Text:
if prompt:
prompt_template = PromptTemplate.from_template(template=prompt)
else:
prompt_template = None
if top_k > 0:
kwargs = {
"k": top_k,
}
if not prompt:
if not prompt_template:
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
else:
template = prompt.template if hasattr(prompt, "template") else prompt
template = (
prompt_template.template
if hasattr(prompt, "template")
else prompt_template
)
# Check if {question} is in the prompt
if "{question}" not in template or "question" not in template.input_variables:
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt, **kwargs)
query_writer = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
if (
"{question}" not in template
or "question" not in template.input_variables
):
raise ValueError(
"Prompt must contain `{question}` to be used with Natural Language to SQL."
)
sql_query_chain = create_sql_query_chain(
llm=llm, db=db, prompt=prompt_template, **kwargs
)
query_writer = sql_query_chain | {
"query": lambda x: x.replace("SQLQuery:", "").strip()
}
response = query_writer.invoke({"question": inputs})
query = response.get("query")
self.status = query