From eabc363d7d2a6c070d4f45bcd171fe089cd10842 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Mon, 17 Jun 2024 16:33:38 -0300 Subject: [PATCH] refactor: Rebuild HumanMessages in load_lc_prompt if they are instances of BaseMessage --- src/backend/base/langflow/schema/message.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index 84428a6d7..f6689dd27 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -3,7 +3,7 @@ from typing import Annotated, Any, AsyncIterator, Iterator, Optional from fastapi.encoders import jsonable_encoder from langchain_core.load import load -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.prompt_values import ImagePromptValue from langchain_core.prompts import BaseChatPromptTemplate, ChatPromptTemplate, PromptTemplate from langchain_core.prompts.image import ImagePromptTemplate @@ -123,7 +123,21 @@ class Message(Data): def load_lc_prompt(self): if "prompt" not in self: raise ValueError("Prompt is required.") - return load(self.prompt) + loaded_prompt = load(self.prompt) + # Rebuild HumanMessages if they are instance of BaseMessage + if isinstance(loaded_prompt, ChatPromptTemplate): + messages = [] + for message in loaded_prompt.messages: + if isinstance(message, HumanMessage): + messages.append(message) + elif message.type == "human": + messages.append(HumanMessage(content=message.content)) + elif message.type == "system": + messages.append(SystemMessage(content=message.content)) + elif message.type == "ai": + messages.append(AIMessage(content=message.content)) + loaded_prompt.messages = messages + return loaded_prompt @classmethod def from_lc_prompt(