Refactor prompt validation and variable handling
This commit is contained in:
parent
716b6cf4b7
commit
224f5b436e
3 changed files with 126 additions and 86 deletions
|
|
@ -1,9 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pydantic import BaseModel, field_validator, model_serializer
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
|
|
@ -80,22 +78,6 @@ INVALID_NAMES = {
|
|||
}
|
||||
|
||||
|
||||
def validate_prompt(template: str):
|
||||
input_variables = extract_input_variables_from_prompt(template)
|
||||
|
||||
# Check if there are invalid characters in the input_variables
|
||||
input_variables = check_input_variables(input_variables)
|
||||
if any(var in INVALID_NAMES for var in input_variables):
|
||||
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
|
||||
|
||||
try:
|
||||
PromptTemplate(template=template, input_variables=input_variables)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid prompt: {exc}") from exc
|
||||
|
||||
return input_variables
|
||||
|
||||
|
||||
def is_json_like(var):
|
||||
if var.startswith("{{") and var.endswith("}}"):
|
||||
# If it is a double brance variable
|
||||
|
|
@ -121,7 +103,9 @@ def fix_variable(var, invalid_chars, wrong_variables):
|
|||
# Handle variables starting with a number
|
||||
if var[0].isdigit():
|
||||
invalid_chars.append(var[0])
|
||||
new_var, invalid_chars, wrong_variables = fix_variable(var[1:], invalid_chars, wrong_variables)
|
||||
new_var, invalid_chars, wrong_variables = fix_variable(
|
||||
var[1:], invalid_chars, wrong_variables
|
||||
)
|
||||
|
||||
# Temporarily replace {{ and }} to avoid treating them as invalid
|
||||
new_var = new_var.replace("{{", "ᴛᴇᴍᴘᴏᴘᴇɴ").replace("}}", "ᴛᴇᴍᴘᴄʟᴏsᴇ")
|
||||
|
|
@ -148,7 +132,9 @@ def check_variable(var, invalid_chars, wrong_variables, empty_variables):
|
|||
return wrong_variables, empty_variables
|
||||
|
||||
|
||||
def check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables):
|
||||
def check_for_errors(
|
||||
input_variables, fixed_variables, wrong_variables, empty_variables
|
||||
):
|
||||
if any(var for var in input_variables if var not in fixed_variables):
|
||||
error_message = (
|
||||
f"Error: Input variables contain invalid characters or formats. \n"
|
||||
|
|
@ -173,11 +159,17 @@ def check_input_variables(input_variables):
|
|||
if is_json_like(var):
|
||||
continue
|
||||
|
||||
new_var, wrong_variables, empty_variables = fix_variable(var, invalid_chars, wrong_variables)
|
||||
wrong_variables, empty_variables = check_variable(var, INVALID_CHARACTERS, wrong_variables, empty_variables)
|
||||
new_var, wrong_variables, empty_variables = fix_variable(
|
||||
var, invalid_chars, wrong_variables
|
||||
)
|
||||
wrong_variables, empty_variables = check_variable(
|
||||
var, INVALID_CHARACTERS, wrong_variables, empty_variables
|
||||
)
|
||||
fixed_variables.append(new_var)
|
||||
variables_to_check.append(var)
|
||||
|
||||
check_for_errors(variables_to_check, fixed_variables, wrong_variables, empty_variables)
|
||||
check_for_errors(
|
||||
variables_to_check, fixed_variables, wrong_variables, empty_variables
|
||||
)
|
||||
|
||||
return fixed_variables
|
||||
|
|
|
|||
|
|
@ -6,9 +6,14 @@ from langflow.api.v1.base import (
|
|||
CodeValidationResponse,
|
||||
PromptValidationResponse,
|
||||
ValidatePromptRequest,
|
||||
)
|
||||
from langflow.base.prompts.utils import (
|
||||
add_new_variables_to_template,
|
||||
get_old_custom_fields,
|
||||
remove_old_variables_from_template,
|
||||
update_input_variables_field,
|
||||
validate_prompt,
|
||||
)
|
||||
from langflow.template.field.prompt import DefaultPromptField
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
# build router
|
||||
|
|
@ -37,13 +42,28 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
|
|||
input_variables=input_variables,
|
||||
frontend_node=None,
|
||||
)
|
||||
old_custom_fields = get_old_custom_fields(prompt_request)
|
||||
old_custom_fields = get_old_custom_fields(
|
||||
prompt_request.custom_fields, prompt_request.name
|
||||
)
|
||||
|
||||
add_new_variables_to_template(input_variables, prompt_request)
|
||||
add_new_variables_to_template(
|
||||
input_variables,
|
||||
prompt_request.custom_fields,
|
||||
prompt_request.frontend_node.template,
|
||||
prompt_request.name,
|
||||
)
|
||||
|
||||
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
|
||||
remove_old_variables_from_template(
|
||||
old_custom_fields,
|
||||
input_variables,
|
||||
prompt_request.custom_fields,
|
||||
prompt_request.frontend_node.template,
|
||||
prompt_request.name,
|
||||
)
|
||||
|
||||
update_input_variables_field(input_variables, prompt_request)
|
||||
update_input_variables_field(
|
||||
input_variables, prompt_request.frontend_node.template
|
||||
)
|
||||
|
||||
return PromptValidationResponse(
|
||||
input_variables=input_variables,
|
||||
|
|
@ -52,61 +72,3 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
|
|||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
def get_old_custom_fields(prompt_request):
|
||||
try:
|
||||
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
|
||||
# If there is only one custom field and the name is empty string
|
||||
# then we are dealing with the first prompt request after the node was created
|
||||
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
|
||||
|
||||
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name]
|
||||
if old_custom_fields is None:
|
||||
old_custom_fields = []
|
||||
|
||||
old_custom_fields = old_custom_fields.copy()
|
||||
except KeyError:
|
||||
old_custom_fields = []
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name] = []
|
||||
return old_custom_fields
|
||||
|
||||
|
||||
def add_new_variables_to_template(input_variables, prompt_request):
|
||||
for variable in input_variables:
|
||||
try:
|
||||
template_field = DefaultPromptField(name=variable, display_name=variable)
|
||||
if variable in prompt_request.frontend_node.template:
|
||||
# Set the new field with the old value
|
||||
template_field.value = prompt_request.frontend_node.template[variable]["value"]
|
||||
|
||||
prompt_request.frontend_node.template[variable] = template_field.to_dict()
|
||||
|
||||
# Check if variable is not already in the list before appending
|
||||
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
|
||||
for variable in old_custom_fields:
|
||||
if variable not in input_variables:
|
||||
try:
|
||||
# Remove the variable from custom_fields associated with the given name
|
||||
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
|
||||
|
||||
# Remove the variable from the template
|
||||
prompt_request.frontend_node.template.pop(variable, None)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def update_input_variables_field(input_variables, prompt_request):
|
||||
if "input_variables" in prompt_request.frontend_node.template:
|
||||
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
from fastapi import HTTPException
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.base import INVALID_NAMES, check_input_variables
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langflow.schema import Record
|
||||
from langflow.template.field.prompt import DefaultPromptField
|
||||
|
||||
|
||||
def dict_values_to_string(d: dict) -> dict:
|
||||
|
|
@ -53,3 +59,83 @@ def document_to_string(document: Document) -> str:
|
|||
str: The document as a string.
|
||||
"""
|
||||
return document.page_content
|
||||
|
||||
|
||||
def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]:
|
||||
input_variables = extract_input_variables_from_prompt(prompt_template)
|
||||
|
||||
# Check if there are invalid characters in the input_variables
|
||||
input_variables = check_input_variables(input_variables)
|
||||
if any(var in INVALID_NAMES for var in input_variables):
|
||||
raise ValueError(
|
||||
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
|
||||
)
|
||||
|
||||
try:
|
||||
PromptTemplate(template=prompt_template, input_variables=input_variables)
|
||||
except Exception as exc:
|
||||
logger.error(f"Invalid prompt: {exc}")
|
||||
if not silent_errors:
|
||||
raise ValueError(f"Invalid prompt: {exc}") from exc
|
||||
|
||||
return input_variables
|
||||
|
||||
|
||||
def get_old_custom_fields(custom_fields, name):
|
||||
try:
|
||||
if len(custom_fields) == 1 and name == "":
|
||||
# If there is only one custom field and the name is empty string
|
||||
# then we are dealing with the first prompt request after the node was created
|
||||
name = list(custom_fields.keys())[0]
|
||||
|
||||
old_custom_fields = custom_fields[name]
|
||||
if not old_custom_fields:
|
||||
old_custom_fields = []
|
||||
|
||||
old_custom_fields = old_custom_fields.copy()
|
||||
except KeyError:
|
||||
old_custom_fields = []
|
||||
custom_fields[name] = []
|
||||
return old_custom_fields
|
||||
|
||||
|
||||
def add_new_variables_to_template(input_variables, custom_fields, template, name):
|
||||
for variable in input_variables:
|
||||
try:
|
||||
template_field = DefaultPromptField(name=variable, display_name=variable)
|
||||
if variable in template:
|
||||
# Set the new field with the old value
|
||||
template_field.value = template[variable]["value"]
|
||||
|
||||
template[variable] = template_field.to_dict()
|
||||
|
||||
# Check if variable is not already in the list before appending
|
||||
if variable not in custom_fields[name]:
|
||||
custom_fields[name].append(variable)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def remove_old_variables_from_template(
|
||||
old_custom_fields, input_variables, custom_fields, template, name
|
||||
):
|
||||
for variable in old_custom_fields:
|
||||
if variable not in input_variables:
|
||||
try:
|
||||
# Remove the variable from custom_fields associated with the given name
|
||||
if variable in custom_fields[name]:
|
||||
custom_fields[name].remove(variable)
|
||||
|
||||
# Remove the variable from the template
|
||||
template.pop(variable, None)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def update_input_variables_field(input_variables, template):
|
||||
if "input_variables" in template:
|
||||
template["input_variables"]["value"] = input_variables
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue