refactor: Rebuild HumanMessages in load_lc_prompt if they are instances of BaseMessage
This commit is contained in:
parent
d6468e4283
commit
eabc363d7d
1 changed files with 16 additions and 2 deletions
|
|
@ -3,7 +3,7 @@ from typing import Annotated, Any, AsyncIterator, Iterator, Optional
|
|||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from langchain_core.load import load
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.prompt_values import ImagePromptValue
|
||||
from langchain_core.prompts import BaseChatPromptTemplate, ChatPromptTemplate, PromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
|
|
@ -123,7 +123,21 @@ class Message(Data):
|
|||
def load_lc_prompt(self):
|
||||
if "prompt" not in self:
|
||||
raise ValueError("Prompt is required.")
|
||||
return load(self.prompt)
|
||||
loaded_prompt = load(self.prompt)
|
||||
# Rebuild HumanMessages if they are instance of BaseMessage
|
||||
if isinstance(loaded_prompt, ChatPromptTemplate):
|
||||
messages = []
|
||||
for message in loaded_prompt.messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
messages.append(message)
|
||||
elif message.type == "human":
|
||||
messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == "system":
|
||||
messages.append(SystemMessage(content=message.content))
|
||||
elif message.type == "ai":
|
||||
messages.append(AIMessage(content=message.content))
|
||||
loaded_prompt.messages = messages
|
||||
return loaded_prompt
|
||||
|
||||
@classmethod
|
||||
def from_lc_prompt(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue