From cd9fc3a2e47df95f6962398ab73c9aa1c12b7655 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 10 Jul 2023 17:49:20 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(loading.py):=20fix=20conditi?= =?UTF-8?q?onal=20statement=20for=20instantiating=20prompt=20objects=20bas?= =?UTF-8?q?ed=20on=20node=5Ftype=20=E2=9C=A8=20feat(loading.py):=20improve?= =?UTF-8?q?=20formatting=20of=20input=20variables=20for=20prompt=20objects?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/initialize/loading.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 5418f6906..ed80b0adb 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -1,3 +1,4 @@ +import contextlib import json from typing import Any, Callable, Dict, List, Sequence, Type @@ -196,7 +197,7 @@ def instantiate_prompt(node_type, class_object, params: Dict): if "tools" not in params: params["tools"] = [] return ZeroShotAgent.create_prompt(**params) - if "MessagePromptTemplate" in node_type: + elif "MessagePromptTemplate" in node_type: # Then we only need the template from_template_params = { "template": params.pop("prompt", params.pop("template", "")) @@ -204,12 +205,12 @@ def instantiate_prompt(node_type, class_object, params: Dict): if not from_template_params.get("template"): raise ValueError("Prompt template is required") - return class_object.from_template(**from_template_params) + prompt = class_object.from_template(**from_template_params) - if node_type == "ChatPromptTemplate": - return class_object.from_messages(**params) - - prompt = class_object(**params) + elif node_type == "ChatPromptTemplate": + prompt = class_object.from_messages(**params) + else: + prompt = class_object(**params) format_kwargs: Dict[str, Any] = {} for input_variable in prompt.input_variables: @@ -221,18 +222,23 @@ def instantiate_prompt(node_type, class_object, params: Dict): variable, "get_format_instructions" ): format_kwargs[input_variable] = variable.get_format_instructions() - # check if is a list of Document elif isinstance(variable, List) and all( isinstance(item, Document) for item in variable ): # Format document to contain page_content and metadata # as one string separated by a newline - format_kwargs[input_variable] = "\n".join( - [ - f"Document:{item.page_content}\nMetadata:{item.metadata}" - for item in variable - ] - ) + if len(variable) > 1: + content = "\n".join( + [item.page_content for item in variable if item.page_content] + ) + else: + content = variable[0].page_content + # content could be a json list of strings + with contextlib.suppress(json.JSONDecodeError): + content = json.loads(content) + if isinstance(content, list): + content = ",".join([str(item) for item in content]) + format_kwargs[input_variable] = content # handle_keys will be a list but it does not exist yet # so we need to create it