From b690834f6b95de67bd4f5381972e30e04acb5bf5 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 19:37:14 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20(memory.py):=20Refactor=20get=5F?= =?UTF-8?q?messages=20function=20to=20use=20SQLAlchemy=20select=20statemen?= =?UTF-8?q?t=20for=20better=20performance=20and=20readability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (memory.py): Refactor delete_messages function to use SQLAlchemy delete statement for better performance and readability 📝 (monitor/schema.py): Update MessageModel to use UUID type for id and flow_id for consistency and better data handling --- src/backend/base/langflow/memory.py | 65 ++++++++++++------- .../base/langflow/services/monitor/schema.py | 16 ++--- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 2835645ba..7b314fb2c 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -1,12 +1,14 @@ import warnings from typing import List, Optional +from uuid import UUID from loguru import logger -from sqlmodel import Session +from sqlalchemy import delete +from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import get_monitor_service, session_scope +from langflow.services.database.models.message.model import MessageRead, MessageTable +from langflow.services.deps import session_scope def get_messages( @@ -15,6 +17,7 @@ def get_messages( session_id: Optional[str] = None, order_by: Optional[str] = "timestamp", order: Optional[str] = "DESC", + flow_id: Optional[UUID] = None, limit: Optional[int] = None, ): """ @@ -30,29 +33,36 @@ def get_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ - monitor_service = get_monitor_service() - messages_df = monitor_service.get_messages( - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - limit=limit, - order=order, - ) + with session_scope() as session: + stmt = select(MessageTable) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if order_by: + if order == "DESC": + col = getattr(MessageTable, order_by).desc() + else: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + if limit: + stmt = stmt.limit(limit) + messages = session.exec(stmt) + messages_read = [MessageRead.model_validate(d, from_attributes=True) for d in messages] messages: list[Message] = [] - # messages_df has a timestamp - # it gets the last 5 messages, for example - # but now they are ordered from most recent to least recent - # so we need to reverse the order - messages_df = messages_df[::-1] if order == "DESC" else messages_df - for row in messages_df.itertuples(): + + for msg_read in messages_read: msg = Message( - text=row.text, - sender=row.sender, - session_id=row.session_id, - sender_name=row.sender_name, - timestamp=row.timestamp, + text=msg_read.text, + sender=msg_read.sender, + session_id=msg_read.session_id, + sender_name=msg_read.sender_name, + timestamp=msg_read.timestamp, ) messages.append(msg) @@ -102,8 +112,13 @@ def delete_messages(session_id: str): Args: session_id (str): The session ID associated with the messages to delete. """ - monitor_service = get_monitor_service() - monitor_service.delete_messages_session(session_id) + with session_scope() as session: + session.exec( + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) + session.commit() def store_message( diff --git a/src/backend/base/langflow/services/monitor/schema.py b/src/backend/base/langflow/services/monitor/schema.py index 2294678fe..eeea846a1 100644 --- a/src/backend/base/langflow/services/monitor/schema.py +++ b/src/backend/base/langflow/services/monitor/schema.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timezone from typing import Any, Optional +from uuid import UUID from pydantic import BaseModel, Field, field_serializer, field_validator @@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel): class MessageModel(DefaultModel): - index: Optional[int] = Field(default=None) - flow_id: Optional[str] = Field(default=None, alias="flow_id") + id: Optional[str | UUID] = Field(default=None) + flow_id: Optional[UUID] = Field(default=None) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) sender: str sender_name: str @@ -127,16 +128,7 @@ class MessageModel(DefaultModel): class MessageModelResponse(MessageModel): - index: Optional[int] = Field(default=None) - - @field_validator("index", mode="before") - def validate_id(cls, v): - if isinstance(v, float): - try: - return int(v) - except ValueError: - return None - return v + pass class MessageModelRequest(MessageModel):