fix: condition for invalid prompt fixed
This commit is contained in:
parent
87b1b39475
commit
fe1358a0df
1 changed files with 46 additions and 4 deletions
|
|
@ -28,10 +28,52 @@ class PromptValidationResponse(BaseModel):
|
|||
input_variables: list
|
||||
|
||||
|
||||
INVALID_CHARACTERS = {
|
||||
" ",
|
||||
",",
|
||||
".",
|
||||
":",
|
||||
";",
|
||||
"!",
|
||||
"?",
|
||||
"/",
|
||||
"\\",
|
||||
"(",
|
||||
")",
|
||||
"[",
|
||||
"]",
|
||||
"{",
|
||||
"}",
|
||||
}
|
||||
|
||||
|
||||
def validate_prompt(template: str):
|
||||
input_variables = extract_input_variables_from_prompt(template)
|
||||
if invalid := [variable for variable in input_variables if " " in variable]:
|
||||
raise ValueError(
|
||||
f"Invalid input variables: {invalid}. Please remove spaces from input variables"
|
||||
)
|
||||
|
||||
# Check if there are invalid characters in the input_variables
|
||||
input_variables = check_input_variables(input_variables)
|
||||
|
||||
return PromptValidationResponse(input_variables=input_variables)
|
||||
|
||||
|
||||
def check_input_variables(input_variables: list):
|
||||
invalid_chars = []
|
||||
fixed_variables = []
|
||||
for variable in input_variables:
|
||||
new_var = variable
|
||||
for char in INVALID_CHARACTERS:
|
||||
if char in variable:
|
||||
invalid_chars.append(char)
|
||||
new_var = new_var.replace(char, "")
|
||||
fixed_variables.append(new_var)
|
||||
if new_var != variable:
|
||||
input_variables.remove(variable)
|
||||
input_variables.append(new_var)
|
||||
# If any of the input_variables is not in the fixed_variables, then it means that
|
||||
# there are invalid characters in the input_variables
|
||||
if any(var not in fixed_variables for var in input_variables):
|
||||
raise ValueError(
|
||||
f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead."
|
||||
)
|
||||
|
||||
return input_variables
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue