From 94003cf39aa7e6ff20262c028925a9637cf6ee4c Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 13 Dec 2023 20:01:33 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(utils.py):=20add=20support?= =?UTF-8?q?=20for=20initializing=20class=20objects=20with=20a=20template?= =?UTF-8?q?=20parameter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/interface/initialize/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/backend/langflow/interface/initialize/utils.py b/src/backend/langflow/interface/initialize/utils.py index 69e436fcf..5ecae3d70 100644 --- a/src/backend/langflow/interface/initialize/utils.py +++ b/src/backend/langflow/interface/initialize/utils.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List import orjson from langchain.agents import ZeroShotAgent from langchain.schema import BaseOutputParser, Document + from langflow.services.database.models.base import orjson_dumps @@ -16,6 +17,8 @@ def handle_node_type(node_type, class_object, params: Dict): prompt = instantiate_from_template(class_object, params) elif node_type == "ChatPromptTemplate": prompt = class_object.from_messages(**params) + elif hasattr(class_object, "from_template") and params.get("template"): + prompt = class_object.from_template(template=params.pop("template")) else: prompt = class_object(**params) return params, prompt