From 224f5b436e0c6a573ac7ffe90cfb88c922899b2e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 7 Mar 2024 12:27:43 -0300 Subject: [PATCH] Refactor prompt validation and variable handling --- src/backend/langflow/api/v1/base.py | 38 ++++------ src/backend/langflow/api/v1/validate.py | 88 ++++++---------------- src/backend/langflow/base/prompts/utils.py | 86 +++++++++++++++++++++ 3 files changed, 126 insertions(+), 86 deletions(-) diff --git a/src/backend/langflow/api/v1/base.py b/src/backend/langflow/api/v1/base.py index cc16c6d1b..bad43c437 100644 --- a/src/backend/langflow/api/v1/base.py +++ b/src/backend/langflow/api/v1/base.py @@ -1,9 +1,7 @@ from typing import Optional -from langchain.prompts import PromptTemplate from pydantic import BaseModel, field_validator, model_serializer -from langflow.interface.utils import extract_input_variables_from_prompt from langflow.template.frontend_node.base import FrontendNode @@ -80,22 +78,6 @@ INVALID_NAMES = { } -def validate_prompt(template: str): - input_variables = extract_input_variables_from_prompt(template) - - # Check if there are invalid characters in the input_variables - input_variables = check_input_variables(input_variables) - if any(var in INVALID_NAMES for var in input_variables): - raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ") - - try: - PromptTemplate(template=template, input_variables=input_variables) - except Exception as exc: - raise ValueError(f"Invalid prompt: {exc}") from exc - - return input_variables - - def is_json_like(var): if var.startswith("{{") and var.endswith("}}"): # If it is a double brance variable @@ -121,7 +103,9 @@ def fix_variable(var, invalid_chars, wrong_variables): # Handle variables starting with a number if var[0].isdigit(): invalid_chars.append(var[0]) - new_var, invalid_chars, wrong_variables = fix_variable(var[1:], invalid_chars, wrong_variables) + new_var, invalid_chars, wrong_variables = fix_variable( + var[1:], invalid_chars, wrong_variables + ) # Temporarily replace {{ and }} to avoid treating them as invalid new_var = new_var.replace("{{", "ᴛᴇᴍᴘᴏᴘᴇɴ").replace("}}", "ᴛᴇᴍᴘᴄʟᴏsᴇ") @@ -148,7 +132,9 @@ def check_variable(var, invalid_chars, wrong_variables, empty_variables): return wrong_variables, empty_variables -def check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables): +def check_for_errors( + input_variables, fixed_variables, wrong_variables, empty_variables +): if any(var for var in input_variables if var not in fixed_variables): error_message = ( f"Error: Input variables contain invalid characters or formats. \n" @@ -173,11 +159,17 @@ def check_input_variables(input_variables): if is_json_like(var): continue - new_var, wrong_variables, empty_variables = fix_variable(var, invalid_chars, wrong_variables) - wrong_variables, empty_variables = check_variable(var, INVALID_CHARACTERS, wrong_variables, empty_variables) + new_var, wrong_variables, empty_variables = fix_variable( + var, invalid_chars, wrong_variables + ) + wrong_variables, empty_variables = check_variable( + var, INVALID_CHARACTERS, wrong_variables, empty_variables + ) fixed_variables.append(new_var) variables_to_check.append(var) - check_for_errors(variables_to_check, fixed_variables, wrong_variables, empty_variables) + check_for_errors( + variables_to_check, fixed_variables, wrong_variables, empty_variables + ) return fixed_variables diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py index 02c17686b..b7b43c376 100644 --- a/src/backend/langflow/api/v1/validate.py +++ b/src/backend/langflow/api/v1/validate.py @@ -6,9 +6,14 @@ from langflow.api.v1.base import ( CodeValidationResponse, PromptValidationResponse, ValidatePromptRequest, +) +from langflow.base.prompts.utils import ( + add_new_variables_to_template, + get_old_custom_fields, + remove_old_variables_from_template, + update_input_variables_field, validate_prompt, ) -from langflow.template.field.prompt import DefaultPromptField from langflow.utils.validate import validate_code # build router @@ -37,13 +42,28 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest): input_variables=input_variables, frontend_node=None, ) - old_custom_fields = get_old_custom_fields(prompt_request) + old_custom_fields = get_old_custom_fields( + prompt_request.custom_fields, prompt_request.name + ) - add_new_variables_to_template(input_variables, prompt_request) + add_new_variables_to_template( + input_variables, + prompt_request.custom_fields, + prompt_request.frontend_node.template, + prompt_request.name, + ) - remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request) + remove_old_variables_from_template( + old_custom_fields, + input_variables, + prompt_request.custom_fields, + prompt_request.frontend_node.template, + prompt_request.name, + ) - update_input_variables_field(input_variables, prompt_request) + update_input_variables_field( + input_variables, prompt_request.frontend_node.template + ) return PromptValidationResponse( input_variables=input_variables, @@ -52,61 +72,3 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest): except Exception as e: logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e - - -def get_old_custom_fields(prompt_request): - try: - if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "": - # If there is only one custom field and the name is empty string - # then we are dealing with the first prompt request after the node was created - prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0] - - old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name] - if old_custom_fields is None: - old_custom_fields = [] - - old_custom_fields = old_custom_fields.copy() - except KeyError: - old_custom_fields = [] - prompt_request.frontend_node.custom_fields[prompt_request.name] = [] - return old_custom_fields - - -def add_new_variables_to_template(input_variables, prompt_request): - for variable in input_variables: - try: - template_field = DefaultPromptField(name=variable, display_name=variable) - if variable in prompt_request.frontend_node.template: - # Set the new field with the old value - template_field.value = prompt_request.frontend_node.template[variable]["value"] - - prompt_request.frontend_node.template[variable] = template_field.to_dict() - - # Check if variable is not already in the list before appending - if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]: - prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable) - - except Exception as exc: - logger.exception(exc) - raise HTTPException(status_code=500, detail=str(exc)) from exc - - -def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request): - for variable in old_custom_fields: - if variable not in input_variables: - try: - # Remove the variable from custom_fields associated with the given name - if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]: - prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable) - - # Remove the variable from the template - prompt_request.frontend_node.template.pop(variable, None) - - except Exception as exc: - logger.exception(exc) - raise HTTPException(status_code=500, detail=str(exc)) from exc - - -def update_input_variables_field(input_variables, prompt_request): - if "input_variables" in prompt_request.frontend_node.template: - prompt_request.frontend_node.template["input_variables"]["value"] = input_variables diff --git a/src/backend/langflow/base/prompts/utils.py b/src/backend/langflow/base/prompts/utils.py index 1f41ebda1..c30d2d6a2 100644 --- a/src/backend/langflow/base/prompts/utils.py +++ b/src/backend/langflow/base/prompts/utils.py @@ -1,6 +1,12 @@ +from fastapi import HTTPException +from langchain.prompts import PromptTemplate from langchain_core.documents import Document +from loguru import logger +from langflow.api.v1.base import INVALID_NAMES, check_input_variables +from langflow.interface.utils import extract_input_variables_from_prompt from langflow.schema import Record +from langflow.template.field.prompt import DefaultPromptField def dict_values_to_string(d: dict) -> dict: @@ -53,3 +59,83 @@ def document_to_string(document: Document) -> str: str: The document as a string. """ return document.page_content + + +def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]: + input_variables = extract_input_variables_from_prompt(prompt_template) + + # Check if there are invalid characters in the input_variables + input_variables = check_input_variables(input_variables) + if any(var in INVALID_NAMES for var in input_variables): + raise ValueError( + f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. " + ) + + try: + PromptTemplate(template=prompt_template, input_variables=input_variables) + except Exception as exc: + logger.error(f"Invalid prompt: {exc}") + if not silent_errors: + raise ValueError(f"Invalid prompt: {exc}") from exc + + return input_variables + + +def get_old_custom_fields(custom_fields, name): + try: + if len(custom_fields) == 1 and name == "": + # If there is only one custom field and the name is empty string + # then we are dealing with the first prompt request after the node was created + name = list(custom_fields.keys())[0] + + old_custom_fields = custom_fields[name] + if not old_custom_fields: + old_custom_fields = [] + + old_custom_fields = old_custom_fields.copy() + except KeyError: + old_custom_fields = [] + custom_fields[name] = [] + return old_custom_fields + + +def add_new_variables_to_template(input_variables, custom_fields, template, name): + for variable in input_variables: + try: + template_field = DefaultPromptField(name=variable, display_name=variable) + if variable in template: + # Set the new field with the old value + template_field.value = template[variable]["value"] + + template[variable] = template_field.to_dict() + + # Check if variable is not already in the list before appending + if variable not in custom_fields[name]: + custom_fields[name].append(variable) + + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +def remove_old_variables_from_template( + old_custom_fields, input_variables, custom_fields, template, name +): + for variable in old_custom_fields: + if variable not in input_variables: + try: + # Remove the variable from custom_fields associated with the given name + if variable in custom_fields[name]: + custom_fields[name].remove(variable) + + # Remove the variable from the template + template.pop(variable, None) + + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +def update_input_variables_field(input_variables, template): + if "input_variables" in template: + template["input_variables"]["value"] = input_variables