diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index f8bd25913..4ae313fd3 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -28,10 +28,52 @@ class PromptValidationResponse(BaseModel): input_variables: list +INVALID_CHARACTERS = { + " ", + ",", + ".", + ":", + ";", + "!", + "?", + "/", + "\\", + "(", + ")", + "[", + "]", + "{", + "}", +} + + def validate_prompt(template: str): input_variables = extract_input_variables_from_prompt(template) - if invalid := [variable for variable in input_variables if " " in variable]: - raise ValueError( - f"Invalid input variables: {invalid}. Please remove spaces from input variables" - ) + + # Check if there are invalid characters in the input_variables + input_variables = check_input_variables(input_variables) + return PromptValidationResponse(input_variables=input_variables) + + +def check_input_variables(input_variables: list): + invalid_chars = [] + fixed_variables = [] + for variable in input_variables: + new_var = variable + for char in INVALID_CHARACTERS: + if char in variable: + invalid_chars.append(char) + new_var = new_var.replace(char, "") + fixed_variables.append(new_var) + if new_var != variable: + input_variables.remove(variable) + input_variables.append(new_var) + # If any of the input_variables is not in the fixed_variables, then it means that + # there are invalid characters in the input_variables + if any(var not in fixed_variables for var in input_variables): + raise ValueError( + f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead." + ) + + return input_variables