fix: condition for invalid prompt fixed

This commit is contained in:
Gabriel Almeida 2023-04-06 08:58:36 -03:00
commit fe1358a0df

View file

@ -28,10 +28,52 @@ class PromptValidationResponse(BaseModel):
input_variables: list
INVALID_CHARACTERS = {
" ",
",",
".",
":",
";",
"!",
"?",
"/",
"\\",
"(",
")",
"[",
"]",
"{",
"}",
}
def validate_prompt(template: str):
input_variables = extract_input_variables_from_prompt(template)
if invalid := [variable for variable in input_variables if " " in variable]:
raise ValueError(
f"Invalid input variables: {invalid}. Please remove spaces from input variables"
)
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
return PromptValidationResponse(input_variables=input_variables)
def check_input_variables(input_variables: list):
invalid_chars = []
fixed_variables = []
for variable in input_variables:
new_var = variable
for char in INVALID_CHARACTERS:
if char in variable:
invalid_chars.append(char)
new_var = new_var.replace(char, "")
fixed_variables.append(new_var)
if new_var != variable:
input_variables.remove(variable)
input_variables.append(new_var)
# If any of the input_variables is not in the fixed_variables, then it means that
# there are invalid characters in the input_variables
if any(var not in fixed_variables for var in input_variables):
raise ValueError(
f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead."
)
return input_variables