📝 (memory.py): Refactor get_messages function to use SQLAlchemy select statement for better performance and readability

📝 (memory.py): Refactor delete_messages function to use SQLAlchemy delete statement for better performance and readability
📝 (monitor/schema.py): Update MessageModel to use UUID type for id and flow_id for consistency and better data handling
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-25 19:37:14 -03:00
commit b690834f6b
2 changed files with 44 additions and 37 deletions

View file

@ -1,12 +1,14 @@
import warnings
from typing import List, Optional
from uuid import UUID
from loguru import logger
from sqlmodel import Session
from sqlalchemy import delete
from sqlmodel import Session, col, select
from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageTable
from langflow.services.deps import get_monitor_service, session_scope
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.deps import session_scope
def get_messages(
@ -15,6 +17,7 @@ def get_messages(
session_id: Optional[str] = None,
order_by: Optional[str] = "timestamp",
order: Optional[str] = "DESC",
flow_id: Optional[UUID] = None,
limit: Optional[int] = None,
):
"""
@ -30,29 +33,36 @@ def get_messages(
Returns:
List[Data]: A list of Data objects representing the retrieved messages.
"""
monitor_service = get_monitor_service()
messages_df = monitor_service.get_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
order_by=order_by,
limit=limit,
order=order,
)
with session_scope() as session:
stmt = select(MessageTable)
if sender:
stmt = stmt.where(MessageTable.sender == sender)
if sender_name:
stmt = stmt.where(MessageTable.sender_name == sender_name)
if session_id:
stmt = stmt.where(MessageTable.session_id == session_id)
if flow_id:
stmt = stmt.where(MessageTable.flow_id == flow_id)
if order_by:
if order == "DESC":
col = getattr(MessageTable, order_by).desc()
else:
col = getattr(MessageTable, order_by).asc()
stmt = stmt.order_by(col)
if limit:
stmt = stmt.limit(limit)
messages = session.exec(stmt)
messages_read = [MessageRead.model_validate(d, from_attributes=True) for d in messages]
messages: list[Message] = []
# messages_df has a timestamp
# it gets the last 5 messages, for example
# but now they are ordered from most recent to least recent
# so we need to reverse the order
messages_df = messages_df[::-1] if order == "DESC" else messages_df
for row in messages_df.itertuples():
for msg_read in messages_read:
msg = Message(
text=row.text,
sender=row.sender,
session_id=row.session_id,
sender_name=row.sender_name,
timestamp=row.timestamp,
text=msg_read.text,
sender=msg_read.sender,
session_id=msg_read.session_id,
sender_name=msg_read.sender_name,
timestamp=msg_read.timestamp,
)
messages.append(msg)
@ -102,8 +112,13 @@ def delete_messages(session_id: str):
Args:
session_id (str): The session ID associated with the messages to delete.
"""
monitor_service = get_monitor_service()
monitor_service.delete_messages_session(session_id)
with session_scope() as session:
session.exec(
delete(MessageTable)
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
session.commit()
def store_message(

View file

@ -1,6 +1,7 @@
import json
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import UUID
from pydantic import BaseModel, Field, field_serializer, field_validator
@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel):
class MessageModel(DefaultModel):
index: Optional[int] = Field(default=None)
flow_id: Optional[str] = Field(default=None, alias="flow_id")
id: Optional[str | UUID] = Field(default=None)
flow_id: Optional[UUID] = Field(default=None)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
sender: str
sender_name: str
@ -127,16 +128,7 @@ class MessageModel(DefaultModel):
class MessageModelResponse(MessageModel):
index: Optional[int] = Field(default=None)
@field_validator("index", mode="before")
def validate_id(cls, v):
if isinstance(v, float):
try:
return int(v)
except ValueError:
return None
return v
pass
class MessageModelRequest(MessageModel):