Refactor custom fields handling in validate.py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-20 17:55:49 -03:00
commit 8c0a2b62a3

View file

@ -44,13 +44,13 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
input_variables=input_variables,
frontend_node=None,
)
if not prompt_request.custom_fields:
prompt_request.custom_fields = defaultdict(list)
old_custom_fields = get_old_custom_fields(prompt_request.custom_fields, prompt_request.name)
if not prompt_request.frontend_node.custom_fields:
prompt_request.frontend_node.custom_fields = defaultdict(list)
old_custom_fields = get_old_custom_fields(prompt_request.frontend_node.custom_fields, prompt_request.name)
add_new_variables_to_template(
input_variables,
prompt_request.custom_fields,
prompt_request.frontend_node.custom_fields,
prompt_request.frontend_node.template,
prompt_request.name,
)
@ -58,13 +58,25 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
remove_old_variables_from_template(
old_custom_fields,
input_variables,
prompt_request.custom_fields,
prompt_request.frontend_node.custom_fields,
prompt_request.frontend_node.template,
prompt_request.name,
)
update_input_variables_field(input_variables, prompt_request.frontend_node.template)
# If frontend_node.template contains only one field that is type == 'prompt', then we can remove all fields that are not
# 'code', and not in the input_variables list.
prompt_fields = [
key
for key, field in prompt_request.frontend_node.template.items()
if isinstance(field, dict) and field["type"] == "prompt"
]
if len(prompt_fields) == 1:
for key, field in prompt_request.frontend_node.template.copy().items():
if isinstance(field, dict) and field["type"] != "code" and key not in input_variables + prompt_fields:
del prompt_request.frontend_node.template[key]
return PromptValidationResponse(
input_variables=input_variables,
frontend_node=prompt_request.frontend_node,