196 lines
7.2 KiB
Python
196 lines
7.2 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
from langchain_core.load import load
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.prompt_values import ImagePromptValue
|
|
from langchain_core.prompts import BaseChatPromptTemplate, ChatPromptTemplate, PromptTemplate
|
|
from langchain_core.prompts.image import ImagePromptTemplate
|
|
from loguru import logger
|
|
from pydantic import BeforeValidator, ConfigDict, Field, field_serializer, field_validator
|
|
|
|
from langflow.base.prompts.utils import dict_values_to_string
|
|
from langflow.schema.data import Data
|
|
from langflow.schema.image import Image, get_file_paths, is_image_file
|
|
|
|
|
|
def _timestamp_to_str(timestamp: datetime) -> str:
|
|
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
class Message(Data):
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
# Helper class to deal with image data
|
|
text_key: str = "text"
|
|
text: Optional[str | AsyncIterator | Iterator] = Field(default="")
|
|
sender: Optional[str] = None
|
|
sender_name: Optional[str] = None
|
|
files: Optional[list[str | Image]] = Field(default=[])
|
|
session_id: Optional[str] = Field(default="")
|
|
timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field(
|
|
default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
|
)
|
|
flow_id: Optional[str | UUID] = None
|
|
|
|
@field_validator("flow_id", mode="before")
|
|
@classmethod
|
|
def validate_flow_id(cls, value):
|
|
if isinstance(value, UUID):
|
|
value = str(value)
|
|
return value
|
|
|
|
@field_validator("files", mode="before")
|
|
@classmethod
|
|
def validate_files(cls, value):
|
|
if not value:
|
|
value = []
|
|
elif not isinstance(value, list):
|
|
value = [value]
|
|
return value
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
new_files: List[Any] = []
|
|
for file in self.files or []:
|
|
if is_image_file(file):
|
|
new_files.append(Image(path=file))
|
|
else:
|
|
new_files.append(file)
|
|
self.files = new_files
|
|
if "timestamp" not in self.data:
|
|
self.data["timestamp"] = self.timestamp
|
|
|
|
def set_flow_id(self, flow_id: str):
|
|
self.flow_id = flow_id
|
|
|
|
def to_lc_message(
|
|
self,
|
|
) -> BaseMessage:
|
|
"""
|
|
Converts the Data to a BaseMessage.
|
|
|
|
Returns:
|
|
BaseMessage: The converted BaseMessage.
|
|
"""
|
|
# The idea of this function is to be a helper to convert a Data to a BaseMessage
|
|
# It will use the "sender" key to determine if the message is Human or AI
|
|
# If the key is not present, it will default to AI
|
|
# But first we check if all required keys are present in the data dictionary
|
|
# they are: "text", "sender"
|
|
if self.text is None or not self.sender:
|
|
logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.")
|
|
|
|
if self.sender == "User" or not self.sender:
|
|
if self.files:
|
|
contents = [{"type": "text", "text": self.text}]
|
|
contents.extend(self.get_file_content_dicts())
|
|
human_message = HumanMessage(content=contents) # type: ignore
|
|
else:
|
|
if not isinstance(self.text, str):
|
|
text = ""
|
|
else:
|
|
text = self.text
|
|
human_message = HumanMessage(
|
|
content=text,
|
|
)
|
|
|
|
return human_message
|
|
|
|
return AIMessage(content=self.text) # type: ignore
|
|
|
|
@classmethod
|
|
def from_data(cls, data: "Data") -> "Message":
|
|
"""
|
|
Converts a BaseMessage to a Data.
|
|
|
|
Args:
|
|
record (BaseMessage): The BaseMessage to convert.
|
|
|
|
Returns:
|
|
Data: The converted Data.
|
|
"""
|
|
|
|
return cls(
|
|
text=data.text,
|
|
sender=data.sender,
|
|
sender_name=data.sender_name,
|
|
files=data.files,
|
|
session_id=data.session_id,
|
|
timestamp=data.timestamp,
|
|
flow_id=data.flow_id,
|
|
)
|
|
|
|
@field_serializer("text", mode="plain")
|
|
def serialize_text(self, value):
|
|
if isinstance(value, AsyncIterator):
|
|
return ""
|
|
elif isinstance(value, Iterator):
|
|
return ""
|
|
return value
|
|
|
|
async def get_file_content_dicts(self):
|
|
content_dicts = []
|
|
files = await get_file_paths(self.files)
|
|
|
|
for file in files:
|
|
if isinstance(file, Image):
|
|
content_dicts.append(file.to_content_dict())
|
|
else:
|
|
image_template = ImagePromptTemplate()
|
|
image_prompt_value: ImagePromptValue = image_template.invoke(input={"path": file})
|
|
content_dicts.append({"type": "image_url", "image_url": image_prompt_value.image_url})
|
|
return content_dicts
|
|
|
|
def load_lc_prompt(self):
|
|
if "prompt" not in self:
|
|
raise ValueError("Prompt is required.")
|
|
loaded_prompt = load(self.prompt)
|
|
# Rebuild HumanMessages if they are instance of BaseMessage
|
|
if isinstance(loaded_prompt, ChatPromptTemplate):
|
|
messages = []
|
|
for message in loaded_prompt.messages:
|
|
if isinstance(message, HumanMessage):
|
|
messages.append(message)
|
|
elif message.type == "human":
|
|
messages.append(HumanMessage(content=message.content))
|
|
elif message.type == "system":
|
|
messages.append(SystemMessage(content=message.content))
|
|
elif message.type == "ai":
|
|
messages.append(AIMessage(content=message.content))
|
|
loaded_prompt.messages = messages
|
|
return loaded_prompt
|
|
|
|
@classmethod
|
|
def from_lc_prompt(
|
|
cls,
|
|
prompt: BaseChatPromptTemplate,
|
|
):
|
|
prompt_json = prompt.to_json()
|
|
return cls(prompt=prompt_json)
|
|
|
|
def format_text(self):
|
|
prompt_template = PromptTemplate.from_template(self.template)
|
|
variables_with_str_values = dict_values_to_string(self.variables)
|
|
formatted_prompt = prompt_template.format(**variables_with_str_values)
|
|
self.text = formatted_prompt
|
|
return formatted_prompt
|
|
|
|
@classmethod
|
|
async def from_template_and_variables(cls, template: str, **variables):
|
|
instance = cls(template=template, variables=variables)
|
|
text = instance.format_text()
|
|
# Get all Message instances from the kwargs
|
|
message = HumanMessage(content=text)
|
|
contents = []
|
|
for value in variables.values():
|
|
if isinstance(value, cls) and value.files:
|
|
content_dicts = await value.get_file_content_dicts()
|
|
contents.extend(content_dicts)
|
|
if contents:
|
|
message = HumanMessage(content=[{"type": "text", "text": text}] + contents)
|
|
|
|
prompt_template = ChatPromptTemplate.from_messages([message]) # type: ignore
|
|
instance.prompt = jsonable_encoder(prompt_template.to_json())
|
|
instance.messages = instance.prompt.get("kwargs", {}).get("messages", [])
|
|
return instance
|