diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py index 96900edb7..7c9279801 100644 --- a/src/backend/langflow/api/v1/validate.py +++ b/src/backend/langflow/api/v1/validate.py @@ -28,49 +28,76 @@ def post_validate_code(code: Code): @router.post("/prompt", status_code=200, response_model=PromptValidationResponse) -def post_validate_prompt(prompt: ValidatePromptRequest): +def post_validate_prompt(prompt_request: ValidatePromptRequest): try: - input_variables = validate_prompt(prompt.template) - # Reinitialize custom_fields - old_custom_fields = prompt.frontend_node.custom_fields.copy() - prompt.frontend_node.custom_fields = [] - # Add new variables to the template - for variable in input_variables: - try: - template_field = TemplateField( - name=variable, - display_name=variable, - field_type="str", - show=True, - advanced=False, - input_types=["Document", "BaseOutputParser"], - ) + input_variables = validate_prompt(prompt_request.template) - prompt.frontend_node.template[variable] = template_field.to_dict() - prompt.frontend_node.custom_fields.append(variable) + old_custom_fields = get_old_custom_fields(prompt_request) - except Exception as exc: - logger.exception(exc) - raise HTTPException(status_code=500, detail=str(exc)) from exc + add_new_variables_to_template(input_variables, prompt_request) - # Remove variables that are not in the template anymore - for variable in old_custom_fields: - if variable not in input_variables: - try: - prompt.frontend_node.template.pop(variable, None) - except Exception as exc: - logger.exception(exc) - raise HTTPException(status_code=500, detail=str(exc)) from exc + remove_old_variables_from_template( + old_custom_fields, input_variables, prompt_request + ) - # Now we will set the field "input_variables" to the new list of variables - # if it exists - if "input_variables" in prompt.frontend_node.template: - prompt.frontend_node.template["input_variables"]["value"] = input_variables + update_input_variables_field(input_variables, prompt_request) return PromptValidationResponse( input_variables=input_variables, - frontend_node=prompt.frontend_node, + frontend_node=prompt_request.frontend_node, ) except Exception as e: logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e + + +def get_old_custom_fields(prompt_request): + try: + old_custom_fields = prompt_request.frontend_node.custom_fields[ + prompt_request.name + ].copy() + except KeyError: + old_custom_fields = [] + prompt_request.frontend_node.custom_fields[prompt_request.name] = [] + return old_custom_fields + + +def add_new_variables_to_template(input_variables, prompt_request): + for variable in input_variables: + try: + template_field = TemplateField( + name=variable, + display_name=variable, + field_type="str", + show=True, + advanced=False, + input_types=["BaseLoader", "BaseOutputParser"], + ) + + prompt_request.frontend_node.template[variable] = template_field.to_dict() + prompt_request.frontend_node.custom_fields[prompt_request.name].append( + variable + ) + + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +def remove_old_variables_from_template( + old_custom_fields, input_variables, prompt_request +): + for variable in old_custom_fields: + if variable not in input_variables: + try: + prompt_request.frontend_node.template.pop(variable, None) + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +def update_input_variables_field(input_variables, prompt_request): + if "input_variables" in prompt_request.frontend_node.template: + prompt_request.frontend_node.template["input_variables"][ + "value" + ] = input_variables