refactor: Rebuild HumanMessages in load_lc_prompt if they are instances of BaseMessage

This commit is contained in:
ogabrielluiz 2024-06-17 16:33:38 -03:00
commit eabc363d7d

View file

@ -3,7 +3,7 @@ from typing import Annotated, Any, AsyncIterator, Iterator, Optional
from fastapi.encoders import jsonable_encoder
from langchain_core.load import load
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ImagePromptValue
from langchain_core.prompts import BaseChatPromptTemplate, ChatPromptTemplate, PromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
@ -123,7 +123,21 @@ class Message(Data):
def load_lc_prompt(self):
if "prompt" not in self:
raise ValueError("Prompt is required.")
return load(self.prompt)
loaded_prompt = load(self.prompt)
# Rebuild HumanMessages if they are instance of BaseMessage
if isinstance(loaded_prompt, ChatPromptTemplate):
messages = []
for message in loaded_prompt.messages:
if isinstance(message, HumanMessage):
messages.append(message)
elif message.type == "human":
messages.append(HumanMessage(content=message.content))
elif message.type == "system":
messages.append(SystemMessage(content=message.content))
elif message.type == "ai":
messages.append(AIMessage(content=message.content))
loaded_prompt.messages = messages
return loaded_prompt
@classmethod
def from_lc_prompt(