langflow/src/backend/base/langflow/schema/message.py

196 lines
7.2 KiB
Python

from datetime import datetime, timezone
from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional
from uuid import UUID
from fastapi.encoders import jsonable_encoder
from langchain_core.load import load
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ImagePromptValue
from langchain_core.prompts import BaseChatPromptTemplate, ChatPromptTemplate, PromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from loguru import logger
from pydantic import BeforeValidator, ConfigDict, Field, field_serializer, field_validator
from langflow.base.prompts.utils import dict_values_to_string
from langflow.schema.data import Data
from langflow.schema.image import Image, get_file_paths, is_image_file
def _timestamp_to_str(timestamp: datetime) -> str:
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
class Message(Data):
model_config = ConfigDict(arbitrary_types_allowed=True)
# Helper class to deal with image data
text_key: str = "text"
text: Optional[str | AsyncIterator | Iterator] = Field(default="")
sender: Optional[str] = None
sender_name: Optional[str] = None
files: Optional[list[str | Image]] = Field(default=[])
session_id: Optional[str] = Field(default="")
timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field(
default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
)
flow_id: Optional[str | UUID] = None
@field_validator("flow_id", mode="before")
@classmethod
def validate_flow_id(cls, value):
if isinstance(value, UUID):
value = str(value)
return value
@field_validator("files", mode="before")
@classmethod
def validate_files(cls, value):
if not value:
value = []
elif not isinstance(value, list):
value = [value]
return value
def model_post_init(self, __context: Any) -> None:
new_files: List[Any] = []
for file in self.files or []:
if is_image_file(file):
new_files.append(Image(path=file))
else:
new_files.append(file)
self.files = new_files
if "timestamp" not in self.data:
self.data["timestamp"] = self.timestamp
def set_flow_id(self, flow_id: str):
self.flow_id = flow_id
def to_lc_message(
self,
) -> BaseMessage:
"""
Converts the Data to a BaseMessage.
Returns:
BaseMessage: The converted BaseMessage.
"""
# The idea of this function is to be a helper to convert a Data to a BaseMessage
# It will use the "sender" key to determine if the message is Human or AI
# If the key is not present, it will default to AI
# But first we check if all required keys are present in the data dictionary
# they are: "text", "sender"
if self.text is None or not self.sender:
logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.")
if self.sender == "User" or not self.sender:
if self.files:
contents = [{"type": "text", "text": self.text}]
contents.extend(self.get_file_content_dicts())
human_message = HumanMessage(content=contents) # type: ignore
else:
if not isinstance(self.text, str):
text = ""
else:
text = self.text
human_message = HumanMessage(
content=text,
)
return human_message
return AIMessage(content=self.text) # type: ignore
@classmethod
def from_data(cls, data: "Data") -> "Message":
"""
Converts a BaseMessage to a Data.
Args:
record (BaseMessage): The BaseMessage to convert.
Returns:
Data: The converted Data.
"""
return cls(
text=data.text,
sender=data.sender,
sender_name=data.sender_name,
files=data.files,
session_id=data.session_id,
timestamp=data.timestamp,
flow_id=data.flow_id,
)
@field_serializer("text", mode="plain")
def serialize_text(self, value):
if isinstance(value, AsyncIterator):
return ""
elif isinstance(value, Iterator):
return ""
return value
async def get_file_content_dicts(self):
content_dicts = []
files = await get_file_paths(self.files)
for file in files:
if isinstance(file, Image):
content_dicts.append(file.to_content_dict())
else:
image_template = ImagePromptTemplate()
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file})
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
return content_dicts
def load_lc_prompt(self):
if "prompt" not in self:
raise ValueError("Prompt is required.")
loaded_prompt = load(self.prompt)
# Rebuild HumanMessages if they are instance of BaseMessage
if isinstance(loaded_prompt, ChatPromptTemplate):
messages = []
for message in loaded_prompt.messages:
if isinstance(message, HumanMessage):
messages.append(message)
elif message.type == "human":
messages.append(HumanMessage(content=message.content))
elif message.type == "system":
messages.append(SystemMessage(content=message.content))
elif message.type == "ai":
messages.append(AIMessage(content=message.content))
loaded_prompt.messages = messages
return loaded_prompt
@classmethod
def from_lc_prompt(
cls,
prompt: BaseChatPromptTemplate,
):
prompt_json = prompt.to_json()
return cls(prompt=prompt_json)
def format_text(self):
prompt_template = PromptTemplate.from_template(self.template)
variables_with_str_values = dict_values_to_string(self.variables)
formatted_prompt = prompt_template.format(**variables_with_str_values)
self.text = formatted_prompt
return formatted_prompt
@classmethod
async def from_template_and_variables(cls, template: str, **variables):
instance = cls(template=template, variables=variables)
text = instance.format_text()
# Get all Message instances from the kwargs
message = HumanMessage(content=text)
contents = []
for value in variables.values():
if isinstance(value, cls) and value.files:
content_dicts = await value.get_file_content_dicts()
contents.extend(content_dicts)
if contents:
message = HumanMessage(content=[{"type": "text", "text": text}] + contents)
prompt_template = ChatPromptTemplate.from_messages([message]) # type: ignore
instance.prompt = jsonable_encoder(prompt_template.to_json())
instance.messages = instance.prompt.get("kwargs", {}).get("messages", [])
return instance