🐛 fix(loading.py): fix conditional statement for instantiating prompt objects based on node_type
✨ feat(loading.py): improve formatting of input variables for prompt objects
This commit is contained in:
parent
00c1734fb8
commit
cd9fc3a2e4
1 changed files with 19 additions and 13 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Sequence, Type
|
||||
|
||||
|
|
@ -196,7 +197,7 @@ def instantiate_prompt(node_type, class_object, params: Dict):
|
|||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
if "MessagePromptTemplate" in node_type:
|
||||
elif "MessagePromptTemplate" in node_type:
|
||||
# Then we only need the template
|
||||
from_template_params = {
|
||||
"template": params.pop("prompt", params.pop("template", ""))
|
||||
|
|
@ -204,12 +205,12 @@ def instantiate_prompt(node_type, class_object, params: Dict):
|
|||
|
||||
if not from_template_params.get("template"):
|
||||
raise ValueError("Prompt template is required")
|
||||
return class_object.from_template(**from_template_params)
|
||||
prompt = class_object.from_template(**from_template_params)
|
||||
|
||||
if node_type == "ChatPromptTemplate":
|
||||
return class_object.from_messages(**params)
|
||||
|
||||
prompt = class_object(**params)
|
||||
elif node_type == "ChatPromptTemplate":
|
||||
prompt = class_object.from_messages(**params)
|
||||
else:
|
||||
prompt = class_object(**params)
|
||||
|
||||
format_kwargs: Dict[str, Any] = {}
|
||||
for input_variable in prompt.input_variables:
|
||||
|
|
@ -221,18 +222,23 @@ def instantiate_prompt(node_type, class_object, params: Dict):
|
|||
variable, "get_format_instructions"
|
||||
):
|
||||
format_kwargs[input_variable] = variable.get_format_instructions()
|
||||
# check if is a list of Document
|
||||
elif isinstance(variable, List) and all(
|
||||
isinstance(item, Document) for item in variable
|
||||
):
|
||||
# Format document to contain page_content and metadata
|
||||
# as one string separated by a newline
|
||||
format_kwargs[input_variable] = "\n".join(
|
||||
[
|
||||
f"Document:{item.page_content}\nMetadata:{item.metadata}"
|
||||
for item in variable
|
||||
]
|
||||
)
|
||||
if len(variable) > 1:
|
||||
content = "\n".join(
|
||||
[item.page_content for item in variable if item.page_content]
|
||||
)
|
||||
else:
|
||||
content = variable[0].page_content
|
||||
# content could be a json list of strings
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
content = json.loads(content)
|
||||
if isinstance(content, list):
|
||||
content = ",".join([str(item) for item in content])
|
||||
format_kwargs[input_variable] = content
|
||||
# handle_keys will be a list but it does not exist yet
|
||||
# so we need to create it
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue