ref: Remove unused build_lc_memory (#5228)

This commit is contained in:
Christophe Bornet 2024-12-16 16:34:48 +01:00 committed by GitHub
commit 8d66754380
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 24 additions and 311 deletions

View file

@ -1,11 +1,8 @@
from langchain.memory import ConversationBufferMemory
from langflow.custom import Component
from langflow.field_typing import BaseChatMemory
from langflow.helpers.data import data_to_text
from langflow.inputs import HandleInput
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
from langflow.memory import LCBuiltinChatMemory, aget_messages
from langflow.memory import aget_messages
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
@ -114,7 +111,3 @@ class MemoryComponent(Component):
stored_text = data_to_text(self.template, await self.retrieve_messages())
self.status = stored_text
return Message(text=stored_text)
def build_lc_memory(self) -> BaseChatMemory:
chat_memory = self.memory or LCBuiltinChatMemory(flow_id=self.flow_id, session_id=self.session_id)
return ConversationBufferMemory(chat_memory=chat_memory)

View file

@ -573,7 +573,7 @@
"show": true,
"title_case": false,
"type": "code",
"value": "from langchain.memory import ConversationBufferMemory\n\nfrom langflow.custom import Component\nfrom langflow.field_typing import BaseChatMemory\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import LCBuiltinChatMemory, aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n\n def build_lc_memory(self) -> BaseChatMemory:\n chat_memory = self.memory or LCBuiltinChatMemory(flow_id=self.flow_id, session_id=self.session_id)\n return ConversationBufferMemory(chat_memory=chat_memory)\n"
"value": "from langflow.custom import Component\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n"
},
"memory": {
"_input_type": "HandleInput",

View file

@ -1128,7 +1128,7 @@
"show": true,
"title_case": false,
"type": "code",
"value": "from langchain.memory import ConversationBufferMemory\n\nfrom langflow.custom import Component\nfrom langflow.field_typing import BaseChatMemory\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import LCBuiltinChatMemory, aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n\n def build_lc_memory(self) -> BaseChatMemory:\n chat_memory = self.memory or LCBuiltinChatMemory(flow_id=self.flow_id, session_id=self.session_id)\n return ConversationBufferMemory(chat_memory=chat_memory)\n"
"value": "from langflow.custom import Component\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n"
},
"memory": {
"_input_type": "HandleInput",

View file

@ -6,13 +6,13 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.deps import async_session_scope, session_scope
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
from langflow.services.deps import async_session_scope
from langflow.utils.async_helpers import run_until_complete
def _get_variable_query(
@ -50,7 +50,9 @@ def get_messages(
flow_id: UUID | None = None,
limit: int | None = None,
) -> list[Message]:
"""Retrieves messages from the monitor service based on the provided filters.
"""DEPRECATED - Retrieves messages from the monitor service based on the provided filters.
DEPRECATED: Use `aget_messages` instead.
Args:
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User")
@ -64,10 +66,7 @@ def get_messages(
Returns:
List[Data]: A list of Data objects representing the retrieved messages.
"""
with session_scope() as session:
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
messages = session.exec(stmt)
return [Message(**d.model_dump()) for d in messages]
return run_until_complete(aget_messages(sender, sender_name, session_id, order_by, order, flow_id, limit))
async def aget_messages(
@ -100,27 +99,11 @@ async def aget_messages(
def add_messages(messages: Message | list[Message], flow_id: str | UUID | None = None):
"""Add a message to the monitor service."""
if not isinstance(messages, list):
messages = [messages]
"""DEPRECATED - Add a message to the monitor service.
if not all(isinstance(message, Message) for message in messages):
types = ", ".join([str(type(message)) for message in messages])
msg = f"The messages must be instances of Message. Found: {types}"
raise ValueError(msg)
try:
# Convert flow_id to UUID if it's a string
if isinstance(flow_id, str):
flow_id = UUID(flow_id)
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
with session_scope() as session:
messages_models = add_messagetables(messages_models, session)
return [Message(**message.model_dump()) for message in messages_models]
except Exception as e:
logger.exception(e)
raise
DEPRECATED: Use `aadd_messages` instead.
"""
return run_until_complete(aadd_messages(messages, flow_id=flow_id))
async def aadd_messages(messages: Message | list[Message], flow_id: str | UUID | None = None):
@ -143,31 +126,6 @@ async def aadd_messages(messages: Message | list[Message], flow_id: str | UUID |
raise
def update_messages(messages: Message | list[Message]) -> list[Message]:
if not isinstance(messages, list):
messages = [messages]
with session_scope() as session:
updated_messages: list[MessageTable] = []
for message in messages:
message_id = UUID(message.id) if isinstance(message.id, str) else message.id
msg = session.get(MessageTable, message_id)
if msg:
if hasattr(message, "data"):
msg = msg.sqlmodel_update(message.data)
else:
msg = msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True))
if isinstance(msg.flow_id, str):
msg.flow_id = UUID(msg.flow_id)
session.add(msg)
session.commit()
session.refresh(msg)
updated_messages.append(msg)
else:
logger.warning(f"Message with id {message.id} not found")
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
if not isinstance(messages, list):
messages = [messages]
@ -190,26 +148,6 @@ async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
def add_messagetables(messages: list[MessageTable], session: Session):
for message in messages:
try:
session.add(message)
session.commit()
session.refresh(message)
except Exception as e:
logger.exception(e)
raise
new_messages = []
for msg in messages:
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type]
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type]
msg.category = msg.category or ""
new_messages.append(msg)
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
try:
for message in messages:
@ -232,17 +170,14 @@ async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession
def delete_messages(session_id: str) -> None:
"""Delete messages from the monitor service based on the provided session ID.
"""DEPRECATED - Delete messages from the monitor service based on the provided session ID.
DEPRECATED: Use `adelete_messages` instead.
Args:
session_id (str): The session ID associated with the messages to delete.
"""
with session_scope() as session:
session.exec(
delete(MessageTable)
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
return run_until_complete(adelete_messages(session_id))
async def adelete_messages(session_id: str) -> None:
@ -277,7 +212,9 @@ def store_message(
message: Message,
flow_id: str | UUID | None = None,
) -> list[Message]:
"""Stores a message in the memory.
"""DEPRECATED: Stores a message in the memory.
DEPRECATED: Use `astore_message` instead.
Args:
message (Message): The message to store.
@ -290,31 +227,7 @@ def store_message(
Raises:
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided.
"""
if not message:
logger.warning("No message provided.")
return []
# Convert flow_id to UUID if it's a string
if isinstance(flow_id, str):
flow_id = UUID(flow_id)
required_fields = ["session_id", "sender", "sender_name"]
missing_fields = [field for field in required_fields if not getattr(message, field)]
if missing_fields:
missing_descriptions = {
"session_id": "session_id (unique conversation identifier)",
"sender": f"sender (e.g., '{MESSAGE_SENDER_USER}' or '{MESSAGE_SENDER_AI}')",
"sender_name": "sender_name (display name, e.g., 'User' or 'Assistant')",
}
missing = ", ".join(missing_descriptions[field] for field in missing_fields)
msg = (
f"It looks like we're missing some important information: {missing}. "
"Please ensure that your message includes all the required fields."
)
raise ValueError(msg)
if hasattr(message, "id") and message.id:
return update_messages([message])
return add_messages([message], flow_id=flow_id)
return run_until_complete(astore_message(message, flow_id=flow_id))
async def astore_message(
@ -349,6 +262,8 @@ async def astore_message(
class LCBuiltinChatMemory(BaseChatMessageHistory):
"""DEPRECATED: Kept for backward compatibility."""
def __init__(
self,
flow_id: str,

View file

@ -6,7 +6,6 @@ from langflow.memory import (
aadd_messages,
aadd_messagetables,
add_messages,
add_messagetables,
adelete_messages,
aget_messages,
astore_message,
@ -14,7 +13,6 @@ from langflow.memory import (
delete_messages,
get_messages,
store_message,
update_messages,
)
from langflow.schema.content_block import ContentBlock
from langflow.schema.content_types import TextContent, ToolContent
@ -94,14 +92,6 @@ async def test_aadd_messages():
assert messages[0].text == "New Test message"
@pytest.mark.usefixtures("client")
def test_add_messagetables(session):
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
added_messages = add_messagetables(messages, session)
assert len(added_messages) == 1
assert added_messages[0].text == "New Test message"
@pytest.mark.usefixtures("client")
async def test_aadd_messagetables(async_session):
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
@ -179,191 +169,6 @@ def test_convert_to_langchain(method_name):
assert len(list(iterator)) == 2
@pytest.mark.usefixtures("client")
def test_update_single_message(created_message):
# Modify the message
created_message.text = "Updated message"
updated = update_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Updated message"
assert updated[0].id == created_message.id
@pytest.mark.usefixtures("client")
def test_update_multiple_messages(created_messages):
# Modify the messages
for i, message in enumerate(created_messages):
message.text = f"Updated message {i}"
updated = update_messages(created_messages)
assert len(updated) == len(created_messages)
for i, message in enumerate(updated):
assert message.text == f"Updated message {i}"
assert message.id == created_messages[i].id
@pytest.mark.usefixtures("client")
def test_update_nonexistent_message():
# Create a message with a non-existent UUID
message = MessageRead(
id=uuid4(), # Generate a random UUID that won't exist in the database
text="Test message",
sender="User",
sender_name="User",
session_id="session_id",
flow_id=uuid4(),
)
updated = update_messages(message)
assert len(updated) == 0
@pytest.mark.usefixtures("client")
def test_update_mixed_messages(created_messages):
# Create a mix of existing and non-existing messages
nonexistent_message = MessageRead(
id=uuid4(), # Generate a random UUID that won't exist in the database
text="Test message",
sender="User",
sender_name="User",
session_id="session_id",
flow_id=uuid4(),
)
messages_to_update = created_messages[:1] + [nonexistent_message]
created_messages[0].text = "Updated existing message"
updated = update_messages(messages_to_update)
assert len(updated) == 1
assert updated[0].text == "Updated existing message"
assert updated[0].id == created_messages[0].id
assert isinstance(updated[0].id, UUID) # Verify ID is UUID type
@pytest.mark.usefixtures("client")
def test_update_message_with_timestamp(created_message):
# Set a specific timestamp
new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
created_message.timestamp = new_timestamp
created_message.text = "Updated message with timestamp"
updated = update_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Updated message with timestamp"
# Compare timestamps without timezone info since DB doesn't preserve it
assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None)
assert updated[0].id == created_message.id
@pytest.mark.usefixtures("client")
def test_update_multiple_messages_with_timestamps(created_messages):
# Modify messages with different timestamps
for i, message in enumerate(created_messages):
message.text = f"Updated message {i}"
message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
updated = update_messages(created_messages)
assert len(updated) == len(created_messages)
for i, message in enumerate(updated):
assert message.text == f"Updated message {i}"
# Compare timestamps without timezone info
expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None)
assert message.id == created_messages[i].id
@pytest.mark.usefixtures("client")
def test_update_message_with_content_blocks(created_message):
# Create a content block using proper models
text_content = TextContent(
type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"}
)
tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10)
content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True)
created_message.content_blocks = [content_block]
created_message.text = "Message with content blocks"
updated = update_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Message with content blocks"
assert len(updated[0].content_blocks) == 1
# Verify the content block structure
updated_block = updated[0].content_blocks[0]
assert updated_block.title == "Test Block"
assert len(updated_block.contents) == 2
# Verify text content
text_content = updated_block.contents[0]
assert text_content.type == "text"
assert text_content.text == "Test content"
assert text_content.duration == 5
assert text_content.header["title"] == "Test Header"
# Verify tool content
tool_content = updated_block.contents[1]
assert tool_content.type == "tool_use"
assert tool_content.name == "test_tool"
assert tool_content.tool_input == {"param": "value"}
assert tool_content.duration == 10
@pytest.mark.usefixtures("client")
def test_update_message_with_nested_properties(created_message):
# Create a text content with nested properties
text_content = TextContent(
type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15
)
content_block = ContentBlock(
title="Test Properties",
contents=[text_content],
allow_markdown=True,
media_url=["http://example.com/image.jpg"],
)
# Set properties according to the Properties model structure
created_message.properties = Properties(
text_color="blue",
background_color="white",
edited=False,
source=Source(id="test_id", display_name="Test Source", source="test"),
icon="TestIcon",
allow_markdown=True,
state="complete",
targets=[],
)
created_message.text = "Message with nested properties"
created_message.content_blocks = [content_block]
updated = update_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Message with nested properties"
# Verify the properties were properly serialized and stored
assert updated[0].properties.text_color == "blue"
assert updated[0].properties.background_color == "white"
assert updated[0].properties.edited is False
assert updated[0].properties.source.id == "test_id"
assert updated[0].properties.source.display_name == "Test Source"
assert updated[0].properties.source.source == "test"
assert updated[0].properties.icon == "TestIcon"
assert updated[0].properties.allow_markdown is True
assert updated[0].properties.state == "complete"
assert updated[0].properties.targets == []
@pytest.mark.usefixtures("client")
async def test_aupdate_single_message(created_message):
# Modify the message