Refactor prompt validation and variable handling

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-07 12:27:43 -03:00
commit 224f5b436e
3 changed files with 126 additions and 86 deletions

View file

@ -1,9 +1,7 @@
from typing import Optional
from langchain.prompts import PromptTemplate
from pydantic import BaseModel, field_validator, model_serializer
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.template.frontend_node.base import FrontendNode
@ -80,22 +78,6 @@ INVALID_NAMES = {
}
def validate_prompt(template: str):
input_variables = extract_input_variables_from_prompt(template)
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
try:
PromptTemplate(template=template, input_variables=input_variables)
except Exception as exc:
raise ValueError(f"Invalid prompt: {exc}") from exc
return input_variables
def is_json_like(var):
if var.startswith("{{") and var.endswith("}}"):
# If it is a double brance variable
@ -121,7 +103,9 @@ def fix_variable(var, invalid_chars, wrong_variables):
# Handle variables starting with a number
if var[0].isdigit():
invalid_chars.append(var[0])
new_var, invalid_chars, wrong_variables = fix_variable(var[1:], invalid_chars, wrong_variables)
new_var, invalid_chars, wrong_variables = fix_variable(
var[1:], invalid_chars, wrong_variables
)
# Temporarily replace {{ and }} to avoid treating them as invalid
new_var = new_var.replace("{{", "ᴛᴇᴍᴘᴏᴘᴇɴ").replace("}}", "ᴛᴇᴍᴘʟsᴇ")
@ -148,7 +132,9 @@ def check_variable(var, invalid_chars, wrong_variables, empty_variables):
return wrong_variables, empty_variables
def check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables):
def check_for_errors(
input_variables, fixed_variables, wrong_variables, empty_variables
):
if any(var for var in input_variables if var not in fixed_variables):
error_message = (
f"Error: Input variables contain invalid characters or formats. \n"
@ -173,11 +159,17 @@ def check_input_variables(input_variables):
if is_json_like(var):
continue
new_var, wrong_variables, empty_variables = fix_variable(var, invalid_chars, wrong_variables)
wrong_variables, empty_variables = check_variable(var, INVALID_CHARACTERS, wrong_variables, empty_variables)
new_var, wrong_variables, empty_variables = fix_variable(
var, invalid_chars, wrong_variables
)
wrong_variables, empty_variables = check_variable(
var, INVALID_CHARACTERS, wrong_variables, empty_variables
)
fixed_variables.append(new_var)
variables_to_check.append(var)
check_for_errors(variables_to_check, fixed_variables, wrong_variables, empty_variables)
check_for_errors(
variables_to_check, fixed_variables, wrong_variables, empty_variables
)
return fixed_variables

View file

@ -6,9 +6,14 @@ from langflow.api.v1.base import (
CodeValidationResponse,
PromptValidationResponse,
ValidatePromptRequest,
)
from langflow.base.prompts.utils import (
add_new_variables_to_template,
get_old_custom_fields,
remove_old_variables_from_template,
update_input_variables_field,
validate_prompt,
)
from langflow.template.field.prompt import DefaultPromptField
from langflow.utils.validate import validate_code
# build router
@ -37,13 +42,28 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
input_variables=input_variables,
frontend_node=None,
)
old_custom_fields = get_old_custom_fields(prompt_request)
old_custom_fields = get_old_custom_fields(
prompt_request.custom_fields, prompt_request.name
)
add_new_variables_to_template(input_variables, prompt_request)
add_new_variables_to_template(
input_variables,
prompt_request.custom_fields,
prompt_request.frontend_node.template,
prompt_request.name,
)
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
remove_old_variables_from_template(
old_custom_fields,
input_variables,
prompt_request.custom_fields,
prompt_request.frontend_node.template,
prompt_request.name,
)
update_input_variables_field(input_variables, prompt_request)
update_input_variables_field(
input_variables, prompt_request.frontend_node.template
)
return PromptValidationResponse(
input_variables=input_variables,
@ -52,61 +72,3 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e)) from e
def get_old_custom_fields(prompt_request):
try:
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name]
if old_custom_fields is None:
old_custom_fields = []
old_custom_fields = old_custom_fields.copy()
except KeyError:
old_custom_fields = []
prompt_request.frontend_node.custom_fields[prompt_request.name] = []
return old_custom_fields
def add_new_variables_to_template(input_variables, prompt_request):
for variable in input_variables:
try:
template_field = DefaultPromptField(name=variable, display_name=variable)
if variable in prompt_request.frontend_node.template:
# Set the new field with the old value
template_field.value = prompt_request.frontend_node.template[variable]["value"]
prompt_request.frontend_node.template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
# Remove the variable from the template
prompt_request.frontend_node.template.pop(variable, None)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def update_input_variables_field(input_variables, prompt_request):
if "input_variables" in prompt_request.frontend_node.template:
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables

View file

@ -1,6 +1,12 @@
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from loguru import logger
from langflow.api.v1.base import INVALID_NAMES, check_input_variables
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.schema import Record
from langflow.template.field.prompt import DefaultPromptField
def dict_values_to_string(d: dict) -> dict:
@ -53,3 +59,83 @@ def document_to_string(document: Document) -> str:
str: The document as a string.
"""
return document.page_content
def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]:
input_variables = extract_input_variables_from_prompt(prompt_template)
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
)
try:
PromptTemplate(template=prompt_template, input_variables=input_variables)
except Exception as exc:
logger.error(f"Invalid prompt: {exc}")
if not silent_errors:
raise ValueError(f"Invalid prompt: {exc}") from exc
return input_variables
def get_old_custom_fields(custom_fields, name):
try:
if len(custom_fields) == 1 and name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
name = list(custom_fields.keys())[0]
old_custom_fields = custom_fields[name]
if not old_custom_fields:
old_custom_fields = []
old_custom_fields = old_custom_fields.copy()
except KeyError:
old_custom_fields = []
custom_fields[name] = []
return old_custom_fields
def add_new_variables_to_template(input_variables, custom_fields, template, name):
for variable in input_variables:
try:
template_field = DefaultPromptField(name=variable, display_name=variable)
if variable in template:
# Set the new field with the old value
template_field.value = template[variable]["value"]
template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if variable not in custom_fields[name]:
custom_fields[name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(
old_custom_fields, input_variables, custom_fields, template, name
):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if variable in custom_fields[name]:
custom_fields[name].remove(variable)
# Remove the variable from the template
template.pop(variable, None)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def update_input_variables_field(input_variables, template):
if "input_variables" in template:
template["input_variables"]["value"] = input_variables