From ec9dd59b0b49a3b09d9071250a0a4528ccf1835d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 21 Feb 2024 16:00:23 -0300 Subject: [PATCH] Refactor validate.py: Add PROMPT_INPUT_TYPES and rearrange imports --- src/backend/langflow/api/v1/validate.py | 54 ++++++++++++++++++------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py index 1cebeedb6..f1c2010db 100644 --- a/src/backend/langflow/api/v1/validate.py +++ b/src/backend/langflow/api/v1/validate.py @@ -1,6 +1,4 @@ from fastapi import APIRouter, HTTPException -from loguru import logger - from langflow.api.v1.base import ( Code, CodeValidationResponse, @@ -9,7 +7,8 @@ from langflow.api.v1.base import ( validate_prompt, ) from langflow.template.field.base import TemplateField -from langflow.utils.validate import validate_code +from langflow.utils.validate import PROMPT_INPUT_TYPES, validate_code +from loguru import logger # build router router = APIRouter(prefix="/validate", tags=["Validate"]) @@ -41,7 +40,9 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest): add_new_variables_to_template(input_variables, prompt_request) - remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request) + remove_old_variables_from_template( + old_custom_fields, input_variables, prompt_request + ) update_input_variables_field(input_variables, prompt_request) @@ -56,12 +57,19 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest): def get_old_custom_fields(prompt_request): try: - if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "": + if ( + len(prompt_request.frontend_node.custom_fields) == 1 + and prompt_request.name == "" + ): # If there is only one custom field and the name is empty string # then we are dealing with the first prompt request after the node was created - prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0] + prompt_request.name = list( + prompt_request.frontend_node.custom_fields.keys() + )[0] - old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name] + old_custom_fields = prompt_request.frontend_node.custom_fields[ + prompt_request.name + ] if old_custom_fields is None: old_custom_fields = [] @@ -82,31 +90,45 @@ def add_new_variables_to_template(input_variables, prompt_request): show=True, advanced=False, multiline=True, - input_types=["Document", "BaseOutputParser", "Text"], + input_types=PROMPT_INPUT_TYPES, value="", # Set the value to empty string ) if variable in prompt_request.frontend_node.template: # Set the new field with the old value - template_field.value = prompt_request.frontend_node.template[variable]["value"] + template_field.value = prompt_request.frontend_node.template[variable][ + "value" + ] prompt_request.frontend_node.template[variable] = template_field.to_dict() # Check if variable is not already in the list before appending - if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]: - prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable) + if ( + variable + not in prompt_request.frontend_node.custom_fields[prompt_request.name] + ): + prompt_request.frontend_node.custom_fields[prompt_request.name].append( + variable + ) except Exception as exc: logger.exception(exc) raise HTTPException(status_code=500, detail=str(exc)) from exc -def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request): +def remove_old_variables_from_template( + old_custom_fields, input_variables, prompt_request +): for variable in old_custom_fields: if variable not in input_variables: try: # Remove the variable from custom_fields associated with the given name - if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]: - prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable) + if ( + variable + in prompt_request.frontend_node.custom_fields[prompt_request.name] + ): + prompt_request.frontend_node.custom_fields[ + prompt_request.name + ].remove(variable) # Remove the variable from the template prompt_request.frontend_node.template.pop(variable, None) @@ -118,4 +140,6 @@ def remove_old_variables_from_template(old_custom_fields, input_variables, promp def update_input_variables_field(input_variables, prompt_request): if "input_variables" in prompt_request.frontend_node.template: - prompt_request.frontend_node.template["input_variables"]["value"] = input_variables + prompt_request.frontend_node.template["input_variables"][ + "value" + ] = input_variables