🐛 fix(validate.py): rename parameter 'prompt' to 'prompt_request' in post_validate_prompt function for clarity

 feat(validate.py): refactor post_validate_prompt function to improve code readability and maintainability
The parameter 'prompt' in the 'post_validate_prompt' function has been renamed to 'prompt_request' to improve clarity and avoid confusion with the 'prompt' variable used within the function. The function has also been refactored to improve code readability and maintainability by extracting logic into separate helper functions. The helper functions 'get_old_custom_fields', 'add_new_variables_to_template', 'remove_old_variables_from_template', and 'update_input_variables_field' have been added to handle specific tasks within the 'post_validate_prompt' function. This refactoring improves the overall structure and organization of the code.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-03 23:10:41 -03:00
commit d5c7fb9dc5

View file

@ -28,49 +28,76 @@ def post_validate_code(code: Code):
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
def post_validate_prompt(prompt: ValidatePromptRequest):
def post_validate_prompt(prompt_request: ValidatePromptRequest):
try:
input_variables = validate_prompt(prompt.template)
# Reinitialize custom_fields
old_custom_fields = prompt.frontend_node.custom_fields.copy()
prompt.frontend_node.custom_fields = []
# Add new variables to the template
for variable in input_variables:
try:
template_field = TemplateField(
name=variable,
display_name=variable,
field_type="str",
show=True,
advanced=False,
input_types=["Document", "BaseOutputParser"],
)
input_variables = validate_prompt(prompt_request.template)
prompt.frontend_node.template[variable] = template_field.to_dict()
prompt.frontend_node.custom_fields.append(variable)
old_custom_fields = get_old_custom_fields(prompt_request)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
add_new_variables_to_template(input_variables, prompt_request)
# Remove variables that are not in the template anymore
for variable in old_custom_fields:
if variable not in input_variables:
try:
prompt.frontend_node.template.pop(variable, None)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
)
# Now we will set the field "input_variables" to the new list of variables
# if it exists
if "input_variables" in prompt.frontend_node.template:
prompt.frontend_node.template["input_variables"]["value"] = input_variables
update_input_variables_field(input_variables, prompt_request)
return PromptValidationResponse(
input_variables=input_variables,
frontend_node=prompt.frontend_node,
frontend_node=prompt_request.frontend_node,
)
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e)) from e
def get_old_custom_fields(prompt_request):
try:
old_custom_fields = prompt_request.frontend_node.custom_fields[
prompt_request.name
].copy()
except KeyError:
old_custom_fields = []
prompt_request.frontend_node.custom_fields[prompt_request.name] = []
return old_custom_fields
def add_new_variables_to_template(input_variables, prompt_request):
for variable in input_variables:
try:
template_field = TemplateField(
name=variable,
display_name=variable,
field_type="str",
show=True,
advanced=False,
input_types=["BaseLoader", "BaseOutputParser"],
)
prompt_request.frontend_node.template[variable] = template_field.to_dict()
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
variable
)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(
old_custom_fields, input_variables, prompt_request
):
for variable in old_custom_fields:
if variable not in input_variables:
try:
prompt_request.frontend_node.template.pop(variable, None)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def update_input_variables_field(input_variables, prompt_request):
if "input_variables" in prompt_request.frontend_node.template:
prompt_request.frontend_node.template["input_variables"][
"value"
] = input_variables