🐛 fix(loading.py): fix conditional statement for instantiating prompt objects based on node_type

 feat(loading.py): improve formatting of input variables for prompt objects
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-10 17:49:20 -03:00
commit cd9fc3a2e4

View file

@ -1,3 +1,4 @@
import contextlib
import json
from typing import Any, Callable, Dict, List, Sequence, Type
@ -196,7 +197,7 @@ def instantiate_prompt(node_type, class_object, params: Dict):
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
if "MessagePromptTemplate" in node_type:
elif "MessagePromptTemplate" in node_type:
# Then we only need the template
from_template_params = {
"template": params.pop("prompt", params.pop("template", ""))
@ -204,12 +205,12 @@ def instantiate_prompt(node_type, class_object, params: Dict):
if not from_template_params.get("template"):
raise ValueError("Prompt template is required")
return class_object.from_template(**from_template_params)
prompt = class_object.from_template(**from_template_params)
if node_type == "ChatPromptTemplate":
return class_object.from_messages(**params)
prompt = class_object(**params)
elif node_type == "ChatPromptTemplate":
prompt = class_object.from_messages(**params)
else:
prompt = class_object(**params)
format_kwargs: Dict[str, Any] = {}
for input_variable in prompt.input_variables:
@ -221,18 +222,23 @@ def instantiate_prompt(node_type, class_object, params: Dict):
variable, "get_format_instructions"
):
format_kwargs[input_variable] = variable.get_format_instructions()
# check if is a list of Document
elif isinstance(variable, List) and all(
isinstance(item, Document) for item in variable
):
# Format document to contain page_content and metadata
# as one string separated by a newline
format_kwargs[input_variable] = "\n".join(
[
f"Document:{item.page_content}\nMetadata:{item.metadata}"
for item in variable
]
)
if len(variable) > 1:
content = "\n".join(
[item.page_content for item in variable if item.page_content]
)
else:
content = variable[0].page_content
# content could be a json list of strings
with contextlib.suppress(json.JSONDecodeError):
content = json.loads(content)
if isinstance(content, list):
content = ",".join([str(item) for item in content])
format_kwargs[input_variable] = content
# handle_keys will be a list but it does not exist yet
# so we need to create it