diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 2835645ba..7b314fb2c 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -1,12 +1,14 @@ import warnings from typing import List, Optional +from uuid import UUID from loguru import logger -from sqlmodel import Session +from sqlalchemy import delete +from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import get_monitor_service, session_scope +from langflow.services.database.models.message.model import MessageRead, MessageTable +from langflow.services.deps import session_scope def get_messages( @@ -15,6 +17,7 @@ def get_messages( session_id: Optional[str] = None, order_by: Optional[str] = "timestamp", order: Optional[str] = "DESC", + flow_id: Optional[UUID] = None, limit: Optional[int] = None, ): """ @@ -30,29 +33,36 @@ def get_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ - monitor_service = get_monitor_service() - messages_df = monitor_service.get_messages( - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - limit=limit, - order=order, - ) + with session_scope() as session: + stmt = select(MessageTable) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if order_by: + if order == "DESC": + col = getattr(MessageTable, order_by).desc() + else: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + if limit: + stmt = stmt.limit(limit) + messages = session.exec(stmt) + messages_read = [MessageRead.model_validate(d, from_attributes=True) for d in messages] messages: list[Message] = [] - # messages_df has a timestamp - # it gets the last 5 messages, for example - # but now they are ordered from most recent to least recent - # so we need to reverse the order - messages_df = messages_df[::-1] if order == "DESC" else messages_df - for row in messages_df.itertuples(): + + for msg_read in messages_read: msg = Message( - text=row.text, - sender=row.sender, - session_id=row.session_id, - sender_name=row.sender_name, - timestamp=row.timestamp, + text=msg_read.text, + sender=msg_read.sender, + session_id=msg_read.session_id, + sender_name=msg_read.sender_name, + timestamp=msg_read.timestamp, ) messages.append(msg) @@ -102,8 +112,13 @@ def delete_messages(session_id: str): Args: session_id (str): The session ID associated with the messages to delete. """ - monitor_service = get_monitor_service() - monitor_service.delete_messages_session(session_id) + with session_scope() as session: + session.exec( + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) + session.commit() def store_message( diff --git a/src/backend/base/langflow/services/monitor/schema.py b/src/backend/base/langflow/services/monitor/schema.py index 2294678fe..eeea846a1 100644 --- a/src/backend/base/langflow/services/monitor/schema.py +++ b/src/backend/base/langflow/services/monitor/schema.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timezone from typing import Any, Optional +from uuid import UUID from pydantic import BaseModel, Field, field_serializer, field_validator @@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel): class MessageModel(DefaultModel): - index: Optional[int] = Field(default=None) - flow_id: Optional[str] = Field(default=None, alias="flow_id") + id: Optional[str | UUID] = Field(default=None) + flow_id: Optional[UUID] = Field(default=None) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) sender: str sender_name: str @@ -127,16 +128,7 @@ class MessageModel(DefaultModel): class MessageModelResponse(MessageModel): - index: Optional[int] = Field(default=None) - - @field_validator("index", mode="before") - def validate_id(cls, v): - if isinstance(v, float): - try: - return int(v) - except ValueError: - return None - return v + pass class MessageModelRequest(MessageModel):