ref: Remove unused build_lc_memory (#5228)
This commit is contained in:
parent
e8667009b7
commit
8d66754380
5 changed files with 24 additions and 311 deletions
|
|
@ -1,11 +1,8 @@
|
|||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import BaseChatMemory
|
||||
from langflow.helpers.data import data_to_text
|
||||
from langflow.inputs import HandleInput
|
||||
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
|
||||
from langflow.memory import LCBuiltinChatMemory, aget_messages
|
||||
from langflow.memory import aget_messages
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
|
||||
|
|
@ -114,7 +111,3 @@ class MemoryComponent(Component):
|
|||
stored_text = data_to_text(self.template, await self.retrieve_messages())
|
||||
self.status = stored_text
|
||||
return Message(text=stored_text)
|
||||
|
||||
def build_lc_memory(self) -> BaseChatMemory:
|
||||
chat_memory = self.memory or LCBuiltinChatMemory(flow_id=self.flow_id, session_id=self.session_id)
|
||||
return ConversationBufferMemory(chat_memory=chat_memory)
|
||||
|
|
|
|||
|
|
@ -573,7 +573,7 @@
|
|||
"show": true,
|
||||
"title_case": false,
|
||||
"type": "code",
|
||||
"value": "from langchain.memory import ConversationBufferMemory\n\nfrom langflow.custom import Component\nfrom langflow.field_typing import BaseChatMemory\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import LCBuiltinChatMemory, aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n\n def build_lc_memory(self) -> BaseChatMemory:\n chat_memory = self.memory or LCBuiltinChatMemory(flow_id=self.flow_id, session_id=self.session_id)\n return ConversationBufferMemory(chat_memory=chat_memory)\n"
|
||||
"value": "from langflow.custom import Component\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n"
|
||||
},
|
||||
"memory": {
|
||||
"_input_type": "HandleInput",
|
||||
|
|
|
|||
|
|
@ -1128,7 +1128,7 @@
|
|||
"show": true,
|
||||
"title_case": false,
|
||||
"type": "code",
|
||||
"value": "from langchain.memory import ConversationBufferMemory\n\nfrom langflow.custom import Component\nfrom langflow.field_typing import BaseChatMemory\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import LCBuiltinChatMemory, aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n\n def build_lc_memory(self) -> BaseChatMemory:\n chat_memory = self.memory or LCBuiltinChatMemory(flow_id=self.flow_id, session_id=self.session_id)\n return ConversationBufferMemory(chat_memory=chat_memory)\n"
|
||||
"value": "from langflow.custom import Component\nfrom langflow.helpers.data import data_to_text\nfrom langflow.inputs import HandleInput\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output\nfrom langflow.memory import aget_messages\nfrom langflow.schema import Data\nfrom langflow.schema.message import Message\nfrom langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER\n\n\nclass MemoryComponent(Component):\n display_name = \"Message History\"\n description = \"Retrieves stored chat messages from Langflow tables or an external memory.\"\n icon = \"message-square-more\"\n name = \"Memory\"\n\n inputs = [\n HandleInput(\n name=\"memory\",\n display_name=\"External Memory\",\n input_types=[\"BaseChatMessageHistory\"],\n info=\"Retrieve messages from an external memory. If empty, it will use the Langflow tables.\",\n ),\n DropdownInput(\n name=\"sender\",\n display_name=\"Sender Type\",\n options=[MESSAGE_SENDER_AI, MESSAGE_SENDER_USER, \"Machine and User\"],\n value=\"Machine and User\",\n info=\"Filter by sender type.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"sender_name\",\n display_name=\"Sender Name\",\n info=\"Filter by sender name.\",\n advanced=True,\n ),\n IntInput(\n name=\"n_messages\",\n display_name=\"Number of Messages\",\n value=100,\n info=\"Number of messages to retrieve.\",\n advanced=True,\n ),\n MessageTextInput(\n name=\"session_id\",\n display_name=\"Session ID\",\n info=\"The session ID of the chat. If empty, the current session ID parameter will be used.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"order\",\n display_name=\"Order\",\n options=[\"Ascending\", \"Descending\"],\n value=\"Ascending\",\n info=\"Order of the messages.\",\n advanced=True,\n tool_mode=True,\n ),\n MultilineInput(\n name=\"template\",\n display_name=\"Template\",\n info=\"The template to use for formatting the data. \"\n \"It can contain the keys {text}, {sender} or any other key in the message data.\",\n value=\"{sender_name}: {text}\",\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"messages\", method=\"retrieve_messages\"),\n Output(display_name=\"Text\", name=\"messages_text\", method=\"retrieve_messages_as_text\"),\n ]\n\n async def retrieve_messages(self) -> Data:\n sender = self.sender\n sender_name = self.sender_name\n session_id = self.session_id\n n_messages = self.n_messages\n order = \"DESC\" if self.order == \"Descending\" else \"ASC\"\n\n if sender == \"Machine and User\":\n sender = None\n\n if self.memory:\n # override session_id\n self.memory.session_id = session_id\n\n stored = await self.memory.aget_messages()\n # langchain memories are supposed to return messages in ascending order\n if order == \"DESC\":\n stored = stored[::-1]\n if n_messages:\n stored = stored[:n_messages]\n stored = [Message.from_lc_message(m) for m in stored]\n if sender:\n expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER\n stored = [m for m in stored if m.type == expected_type]\n else:\n stored = await aget_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n self.status = stored\n return stored\n\n async def retrieve_messages_as_text(self) -> Message:\n stored_text = data_to_text(self.template, await self.retrieve_messages())\n self.status = stored_text\n return Message(text=stored_text)\n"
|
||||
},
|
||||
"memory": {
|
||||
"_input_type": "HandleInput",
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ from langchain_core.chat_history import BaseChatMessageHistory
|
|||
from langchain_core.messages import BaseMessage
|
||||
from loguru import logger
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.deps import async_session_scope, session_scope
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.utils.async_helpers import run_until_complete
|
||||
|
||||
|
||||
def _get_variable_query(
|
||||
|
|
@ -50,7 +50,9 @@ def get_messages(
|
|||
flow_id: UUID | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Message]:
|
||||
"""Retrieves messages from the monitor service based on the provided filters.
|
||||
"""DEPRECATED - Retrieves messages from the monitor service based on the provided filters.
|
||||
|
||||
DEPRECATED: Use `aget_messages` instead.
|
||||
|
||||
Args:
|
||||
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User")
|
||||
|
|
@ -64,10 +66,7 @@ def get_messages(
|
|||
Returns:
|
||||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
with session_scope() as session:
|
||||
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
|
||||
messages = session.exec(stmt)
|
||||
return [Message(**d.model_dump()) for d in messages]
|
||||
return run_until_complete(aget_messages(sender, sender_name, session_id, order_by, order, flow_id, limit))
|
||||
|
||||
|
||||
async def aget_messages(
|
||||
|
|
@ -100,27 +99,11 @@ async def aget_messages(
|
|||
|
||||
|
||||
def add_messages(messages: Message | list[Message], flow_id: str | UUID | None = None):
|
||||
"""Add a message to the monitor service."""
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
"""DEPRECATED - Add a message to the monitor service.
|
||||
|
||||
if not all(isinstance(message, Message) for message in messages):
|
||||
types = ", ".join([str(type(message)) for message in messages])
|
||||
msg = f"The messages must be instances of Message. Found: {types}"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Convert flow_id to UUID if it's a string
|
||||
if isinstance(flow_id, str):
|
||||
flow_id = UUID(flow_id)
|
||||
|
||||
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
|
||||
with session_scope() as session:
|
||||
messages_models = add_messagetables(messages_models, session)
|
||||
return [Message(**message.model_dump()) for message in messages_models]
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise
|
||||
DEPRECATED: Use `aadd_messages` instead.
|
||||
"""
|
||||
return run_until_complete(aadd_messages(messages, flow_id=flow_id))
|
||||
|
||||
|
||||
async def aadd_messages(messages: Message | list[Message], flow_id: str | UUID | None = None):
|
||||
|
|
@ -143,31 +126,6 @@ async def aadd_messages(messages: Message | list[Message], flow_id: str | UUID |
|
|||
raise
|
||||
|
||||
|
||||
def update_messages(messages: Message | list[Message]) -> list[Message]:
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
with session_scope() as session:
|
||||
updated_messages: list[MessageTable] = []
|
||||
for message in messages:
|
||||
message_id = UUID(message.id) if isinstance(message.id, str) else message.id
|
||||
msg = session.get(MessageTable, message_id)
|
||||
if msg:
|
||||
if hasattr(message, "data"):
|
||||
msg = msg.sqlmodel_update(message.data)
|
||||
else:
|
||||
msg = msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True))
|
||||
if isinstance(msg.flow_id, str):
|
||||
msg.flow_id = UUID(msg.flow_id)
|
||||
session.add(msg)
|
||||
session.commit()
|
||||
session.refresh(msg)
|
||||
updated_messages.append(msg)
|
||||
else:
|
||||
logger.warning(f"Message with id {message.id} not found")
|
||||
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
|
||||
|
||||
|
||||
async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
|
@ -190,26 +148,6 @@ async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
|
|||
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
|
||||
|
||||
|
||||
def add_messagetables(messages: list[MessageTable], session: Session):
|
||||
for message in messages:
|
||||
try:
|
||||
session.add(message)
|
||||
session.commit()
|
||||
session.refresh(message)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
new_messages = []
|
||||
for msg in messages:
|
||||
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type]
|
||||
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type]
|
||||
msg.category = msg.category or ""
|
||||
new_messages.append(msg)
|
||||
|
||||
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
|
||||
|
||||
|
||||
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
|
||||
try:
|
||||
for message in messages:
|
||||
|
|
@ -232,17 +170,14 @@ async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession
|
|||
|
||||
|
||||
def delete_messages(session_id: str) -> None:
|
||||
"""Delete messages from the monitor service based on the provided session ID.
|
||||
"""DEPRECATED - Delete messages from the monitor service based on the provided session ID.
|
||||
|
||||
DEPRECATED: Use `adelete_messages` instead.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the messages to delete.
|
||||
"""
|
||||
with session_scope() as session:
|
||||
session.exec(
|
||||
delete(MessageTable)
|
||||
.where(col(MessageTable.session_id) == session_id)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
return run_until_complete(adelete_messages(session_id))
|
||||
|
||||
|
||||
async def adelete_messages(session_id: str) -> None:
|
||||
|
|
@ -277,7 +212,9 @@ def store_message(
|
|||
message: Message,
|
||||
flow_id: str | UUID | None = None,
|
||||
) -> list[Message]:
|
||||
"""Stores a message in the memory.
|
||||
"""DEPRECATED: Stores a message in the memory.
|
||||
|
||||
DEPRECATED: Use `astore_message` instead.
|
||||
|
||||
Args:
|
||||
message (Message): The message to store.
|
||||
|
|
@ -290,31 +227,7 @@ def store_message(
|
|||
Raises:
|
||||
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided.
|
||||
"""
|
||||
if not message:
|
||||
logger.warning("No message provided.")
|
||||
return []
|
||||
|
||||
# Convert flow_id to UUID if it's a string
|
||||
if isinstance(flow_id, str):
|
||||
flow_id = UUID(flow_id)
|
||||
|
||||
required_fields = ["session_id", "sender", "sender_name"]
|
||||
missing_fields = [field for field in required_fields if not getattr(message, field)]
|
||||
if missing_fields:
|
||||
missing_descriptions = {
|
||||
"session_id": "session_id (unique conversation identifier)",
|
||||
"sender": f"sender (e.g., '{MESSAGE_SENDER_USER}' or '{MESSAGE_SENDER_AI}')",
|
||||
"sender_name": "sender_name (display name, e.g., 'User' or 'Assistant')",
|
||||
}
|
||||
missing = ", ".join(missing_descriptions[field] for field in missing_fields)
|
||||
msg = (
|
||||
f"It looks like we're missing some important information: {missing}. "
|
||||
"Please ensure that your message includes all the required fields."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if hasattr(message, "id") and message.id:
|
||||
return update_messages([message])
|
||||
return add_messages([message], flow_id=flow_id)
|
||||
return run_until_complete(astore_message(message, flow_id=flow_id))
|
||||
|
||||
|
||||
async def astore_message(
|
||||
|
|
@ -349,6 +262,8 @@ async def astore_message(
|
|||
|
||||
|
||||
class LCBuiltinChatMemory(BaseChatMessageHistory):
|
||||
"""DEPRECATED: Kept for backward compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flow_id: str,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from langflow.memory import (
|
|||
aadd_messages,
|
||||
aadd_messagetables,
|
||||
add_messages,
|
||||
add_messagetables,
|
||||
adelete_messages,
|
||||
aget_messages,
|
||||
astore_message,
|
||||
|
|
@ -14,7 +13,6 @@ from langflow.memory import (
|
|||
delete_messages,
|
||||
get_messages,
|
||||
store_message,
|
||||
update_messages,
|
||||
)
|
||||
from langflow.schema.content_block import ContentBlock
|
||||
from langflow.schema.content_types import TextContent, ToolContent
|
||||
|
|
@ -94,14 +92,6 @@ async def test_aadd_messages():
|
|||
assert messages[0].text == "New Test message"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_add_messagetables(session):
|
||||
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
|
||||
added_messages = add_messagetables(messages, session)
|
||||
assert len(added_messages) == 1
|
||||
assert added_messages[0].text == "New Test message"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aadd_messagetables(async_session):
|
||||
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
|
||||
|
|
@ -179,191 +169,6 @@ def test_convert_to_langchain(method_name):
|
|||
assert len(list(iterator)) == 2
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_single_message(created_message):
|
||||
# Modify the message
|
||||
created_message.text = "Updated message"
|
||||
updated = update_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Updated message"
|
||||
assert updated[0].id == created_message.id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_multiple_messages(created_messages):
|
||||
# Modify the messages
|
||||
for i, message in enumerate(created_messages):
|
||||
message.text = f"Updated message {i}"
|
||||
|
||||
updated = update_messages(created_messages)
|
||||
|
||||
assert len(updated) == len(created_messages)
|
||||
for i, message in enumerate(updated):
|
||||
assert message.text == f"Updated message {i}"
|
||||
assert message.id == created_messages[i].id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_nonexistent_message():
|
||||
# Create a message with a non-existent UUID
|
||||
message = MessageRead(
|
||||
id=uuid4(), # Generate a random UUID that won't exist in the database
|
||||
text="Test message",
|
||||
sender="User",
|
||||
sender_name="User",
|
||||
session_id="session_id",
|
||||
flow_id=uuid4(),
|
||||
)
|
||||
|
||||
updated = update_messages(message)
|
||||
assert len(updated) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_mixed_messages(created_messages):
|
||||
# Create a mix of existing and non-existing messages
|
||||
nonexistent_message = MessageRead(
|
||||
id=uuid4(), # Generate a random UUID that won't exist in the database
|
||||
text="Test message",
|
||||
sender="User",
|
||||
sender_name="User",
|
||||
session_id="session_id",
|
||||
flow_id=uuid4(),
|
||||
)
|
||||
|
||||
messages_to_update = created_messages[:1] + [nonexistent_message]
|
||||
created_messages[0].text = "Updated existing message"
|
||||
|
||||
updated = update_messages(messages_to_update)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Updated existing message"
|
||||
assert updated[0].id == created_messages[0].id
|
||||
assert isinstance(updated[0].id, UUID) # Verify ID is UUID type
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_message_with_timestamp(created_message):
|
||||
# Set a specific timestamp
|
||||
new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
created_message.timestamp = new_timestamp
|
||||
created_message.text = "Updated message with timestamp"
|
||||
|
||||
updated = update_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Updated message with timestamp"
|
||||
|
||||
# Compare timestamps without timezone info since DB doesn't preserve it
|
||||
assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None)
|
||||
assert updated[0].id == created_message.id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_multiple_messages_with_timestamps(created_messages):
|
||||
# Modify messages with different timestamps
|
||||
for i, message in enumerate(created_messages):
|
||||
message.text = f"Updated message {i}"
|
||||
message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
updated = update_messages(created_messages)
|
||||
|
||||
assert len(updated) == len(created_messages)
|
||||
for i, message in enumerate(updated):
|
||||
assert message.text == f"Updated message {i}"
|
||||
# Compare timestamps without timezone info
|
||||
expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
|
||||
assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None)
|
||||
assert message.id == created_messages[i].id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_message_with_content_blocks(created_message):
|
||||
# Create a content block using proper models
|
||||
text_content = TextContent(
|
||||
type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"}
|
||||
)
|
||||
|
||||
tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10)
|
||||
|
||||
content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True)
|
||||
|
||||
created_message.content_blocks = [content_block]
|
||||
created_message.text = "Message with content blocks"
|
||||
|
||||
updated = update_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Message with content blocks"
|
||||
assert len(updated[0].content_blocks) == 1
|
||||
|
||||
# Verify the content block structure
|
||||
updated_block = updated[0].content_blocks[0]
|
||||
assert updated_block.title == "Test Block"
|
||||
assert len(updated_block.contents) == 2
|
||||
|
||||
# Verify text content
|
||||
text_content = updated_block.contents[0]
|
||||
assert text_content.type == "text"
|
||||
assert text_content.text == "Test content"
|
||||
assert text_content.duration == 5
|
||||
assert text_content.header["title"] == "Test Header"
|
||||
|
||||
# Verify tool content
|
||||
tool_content = updated_block.contents[1]
|
||||
assert tool_content.type == "tool_use"
|
||||
assert tool_content.name == "test_tool"
|
||||
assert tool_content.tool_input == {"param": "value"}
|
||||
assert tool_content.duration == 10
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_update_message_with_nested_properties(created_message):
|
||||
# Create a text content with nested properties
|
||||
text_content = TextContent(
|
||||
type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15
|
||||
)
|
||||
|
||||
content_block = ContentBlock(
|
||||
title="Test Properties",
|
||||
contents=[text_content],
|
||||
allow_markdown=True,
|
||||
media_url=["http://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
# Set properties according to the Properties model structure
|
||||
created_message.properties = Properties(
|
||||
text_color="blue",
|
||||
background_color="white",
|
||||
edited=False,
|
||||
source=Source(id="test_id", display_name="Test Source", source="test"),
|
||||
icon="TestIcon",
|
||||
allow_markdown=True,
|
||||
state="complete",
|
||||
targets=[],
|
||||
)
|
||||
created_message.text = "Message with nested properties"
|
||||
created_message.content_blocks = [content_block]
|
||||
|
||||
updated = update_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Message with nested properties"
|
||||
|
||||
# Verify the properties were properly serialized and stored
|
||||
assert updated[0].properties.text_color == "blue"
|
||||
assert updated[0].properties.background_color == "white"
|
||||
assert updated[0].properties.edited is False
|
||||
assert updated[0].properties.source.id == "test_id"
|
||||
assert updated[0].properties.source.display_name == "Test Source"
|
||||
assert updated[0].properties.source.source == "test"
|
||||
assert updated[0].properties.icon == "TestIcon"
|
||||
assert updated[0].properties.allow_markdown is True
|
||||
assert updated[0].properties.state == "complete"
|
||||
assert updated[0].properties.targets == []
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_single_message(created_message):
|
||||
# Modify the message
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue