diff --git a/src/backend/langflow/components/chains/SQLGenerator.py b/src/backend/langflow/components/chains/SQLGenerator.py index 5efb0f738..ea22a6de0 100644 --- a/src/backend/langflow/components/chains/SQLGenerator.py +++ b/src/backend/langflow/components/chains/SQLGenerator.py @@ -32,21 +32,39 @@ class SQLGeneratorComponent(CustomComponent): db: SQLDatabase, llm: BaseLanguageModel, top_k: int = 5, - prompt: Optional[PromptTemplate] = None, + prompt: Optional[Text] = None, ) -> Text: + if prompt: + prompt_template = PromptTemplate.from_template(template=prompt) + else: + prompt_template = None + if top_k > 0: kwargs = { "k": top_k, } - if not prompt: + if not prompt_template: sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs) else: - template = prompt.template if hasattr(prompt, "template") else prompt + template = ( + prompt_template.template + if hasattr(prompt, "template") + else prompt_template + ) # Check if {question} is in the prompt - if "{question}" not in template or "question" not in template.input_variables: - raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.") - sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt, **kwargs) - query_writer = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()} + if ( + "{question}" not in template + or "question" not in template.input_variables + ): + raise ValueError( + "Prompt must contain `{question}` to be used with Natural Language to SQL." + ) + sql_query_chain = create_sql_query_chain( + llm=llm, db=db, prompt=prompt_template, **kwargs + ) + query_writer = sql_query_chain | { + "query": lambda x: x.replace("SQLQuery:", "").strip() + } response = query_writer.invoke({"question": inputs}) query = response.get("query") self.status = query