🐛 fix(base.py): change wrong_variables from a set to a list to preserve order and improve error message generation

🔀 refactor(base.py): refactor check_input_variables function to simplify logic and improve readability
The wrong_variables variable is now a list instead of a set to preserve the order of the variables. This change improves the error message generation by ensuring that the variables are displayed in the same order as they appear in the input. The check_input_variables function has been refactored to simplify the logic and improve readability. The code now handles invalid characters and wrong variables separately, making it easier to understand and maintain.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-06 15:23:17 -03:00
commit c3886ed219

View file

@ -77,7 +77,7 @@ def validate_prompt(template: str):
def check_input_variables(input_variables: list):
invalid_chars = []
fixed_variables = []
wrong_variables = set()
wrong_variables = []
empty_variables = []
for variable in input_variables:
new_var = variable
@ -92,17 +92,14 @@ def check_input_variables(input_variables: list):
if variable[0].isdigit():
invalid_chars.append(variable[0])
new_var = new_var.replace(variable[0], "")
wrong_variables.add(variable)
for char in INVALID_CHARACTERS:
if char in variable:
invalid_chars.append(char)
new_var = new_var.replace(char, "")
wrong_variables.add(variable)
wrong_variables.append(variable)
else:
for char in INVALID_CHARACTERS:
if char in variable:
invalid_chars.append(char)
new_var = new_var.replace(char, "")
wrong_variables.append(variable)
fixed_variables.append(new_var)
# if new_var != variable and new_var not in input_variables:
# input_variables.remove(variable)
# input_variables.append(new_var)
# If any of the input_variables is not in the fixed_variables, then it means that
# there are invalid characters in the input_variables
@ -122,17 +119,20 @@ def build_error_message(
input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables
):
input_variables_str = ", ".join([f"'{var}'" for var in input_variables])
error_string = f"Invalid input variables: {input_variables_str}."
error_string = f"Invalid input variables: {input_variables_str}. "
if wrong_variables and invalid_chars:
", ".join([f"'{var}'" for var in wrong_variables])
invalid_chars_str = ", ".join([f"'{char}'" for char in invalid_chars])
error_string += (
f" Please, remove the invalid characters: {invalid_chars_str}"
" from the variables: {wrong_variables_str}."
)
# fix the wrong variables replacing invalid chars and find them in the fixed variables
error_string_vars = "You can fix them by replacing the invalid characters: "
wvars = wrong_variables.copy()
for i, wrong_var in enumerate(wvars):
for char in invalid_chars:
wrong_var = wrong_var.replace(char, "")
if wrong_var in fixed_variables:
error_string_vars += f"'{wrong_variables[i]}' -> '{wrong_var}'"
error_string += error_string_vars
elif empty_variables:
error_string += f" There are {len(empty_variables)} empty variable{'s' if len(empty_variables) > 1 else ''}."
elif len(set(fixed_variables)) != len(fixed_variables):
error_string += " There are duplicate variables."
error_string += "There are duplicate variables."
return error_string