feat(loading.py): add support for new prompt templates in instantiate_prompt function

The config.yaml file has been updated to include new prompt templates. In types.py, the input_variables list was not being cleared when prompt or messages were present, which has been fixed. The instantiate_prompt function in loading.py has been updated to support the new prompt templates.
🔀 chore(config): add new prompt templates to config.yaml
🐛 fix(types.py): fix input_variables not being cleared when prompt or messages are present
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-27 12:06:23 -03:00
commit 8107f9e332
3 changed files with 29 additions and 5 deletions

View file

@ -73,6 +73,11 @@ prompts:
- PromptTemplate
- FewShotPromptTemplate
- ZeroShotPrompt
- ChatPromptTemplate
- SystemMessagePromptTemplate
- AIMessagePromptTemplate
- HumanMessagePromptTemplate
- ChatMessagePromptTemplate
textsplitters:
- CharacterTextSplitter
- RecursiveCharacterTextSplitter

View file

@ -180,11 +180,17 @@ class PromptVertex(Vertex):
]
else:
prompt_params = ["template"]
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
if "prompt" not in self.params and "messages" not in self.params:
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
else:
self.params.pop("input_variables", None)
self._build()
return self._built_object

View file

@ -97,6 +97,19 @@ def instantiate_prompt(node_type, class_object, params):
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
if "MessagePromptTemplate" in node_type:
# Then we only need the template
from_template_params = {
"template": params.pop("prompt", params.pop("template", ""))
}
if not from_template_params.get("template"):
raise ValueError("Prompt template is required")
return class_object.from_template(**from_template_params)
if node_type == "ChatPromptTemplate":
return class_object.from_messages(**params)
return class_object(**params)