📝 (memory.py): Refactor get_messages function to use SQLAlchemy select statement for better performance and readability
📝 (memory.py): Refactor delete_messages function to use SQLAlchemy delete statement for better performance and readability 📝 (monitor/schema.py): Update MessageModel to use UUID type for id and flow_id for consistency and better data handling
This commit is contained in:
parent
d44cf6fc4a
commit
b690834f6b
2 changed files with 44 additions and 37 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import warnings
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import get_monitor_service, session_scope
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
def get_messages(
|
||||
|
|
@ -15,6 +17,7 @@ def get_messages(
|
|||
session_id: Optional[str] = None,
|
||||
order_by: Optional[str] = "timestamp",
|
||||
order: Optional[str] = "DESC",
|
||||
flow_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -30,29 +33,36 @@ def get_messages(
|
|||
Returns:
|
||||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
monitor_service = get_monitor_service()
|
||||
messages_df = monitor_service.get_messages(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
order_by=order_by,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
with session_scope() as session:
|
||||
stmt = select(MessageTable)
|
||||
if sender:
|
||||
stmt = stmt.where(MessageTable.sender == sender)
|
||||
if sender_name:
|
||||
stmt = stmt.where(MessageTable.sender_name == sender_name)
|
||||
if session_id:
|
||||
stmt = stmt.where(MessageTable.session_id == session_id)
|
||||
if flow_id:
|
||||
stmt = stmt.where(MessageTable.flow_id == flow_id)
|
||||
if order_by:
|
||||
if order == "DESC":
|
||||
col = getattr(MessageTable, order_by).desc()
|
||||
else:
|
||||
col = getattr(MessageTable, order_by).asc()
|
||||
stmt = stmt.order_by(col)
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
messages = session.exec(stmt)
|
||||
messages_read = [MessageRead.model_validate(d, from_attributes=True) for d in messages]
|
||||
|
||||
messages: list[Message] = []
|
||||
# messages_df has a timestamp
|
||||
# it gets the last 5 messages, for example
|
||||
# but now they are ordered from most recent to least recent
|
||||
# so we need to reverse the order
|
||||
messages_df = messages_df[::-1] if order == "DESC" else messages_df
|
||||
for row in messages_df.itertuples():
|
||||
|
||||
for msg_read in messages_read:
|
||||
msg = Message(
|
||||
text=row.text,
|
||||
sender=row.sender,
|
||||
session_id=row.session_id,
|
||||
sender_name=row.sender_name,
|
||||
timestamp=row.timestamp,
|
||||
text=msg_read.text,
|
||||
sender=msg_read.sender,
|
||||
session_id=msg_read.session_id,
|
||||
sender_name=msg_read.sender_name,
|
||||
timestamp=msg_read.timestamp,
|
||||
)
|
||||
|
||||
messages.append(msg)
|
||||
|
|
@ -102,8 +112,13 @@ def delete_messages(session_id: str):
|
|||
Args:
|
||||
session_id (str): The session ID associated with the messages to delete.
|
||||
"""
|
||||
monitor_service = get_monitor_service()
|
||||
monitor_service.delete_messages_session(session_id)
|
||||
with session_scope() as session:
|
||||
session.exec(
|
||||
delete(MessageTable)
|
||||
.where(col(MessageTable.session_id) == session_id)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def store_message(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
|
@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel):
|
|||
|
||||
|
||||
class MessageModel(DefaultModel):
|
||||
index: Optional[int] = Field(default=None)
|
||||
flow_id: Optional[str] = Field(default=None, alias="flow_id")
|
||||
id: Optional[str | UUID] = Field(default=None)
|
||||
flow_id: Optional[UUID] = Field(default=None)
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
sender: str
|
||||
sender_name: str
|
||||
|
|
@ -127,16 +128,7 @@ class MessageModel(DefaultModel):
|
|||
|
||||
|
||||
class MessageModelResponse(MessageModel):
|
||||
index: Optional[int] = Field(default=None)
|
||||
|
||||
@field_validator("index", mode="before")
|
||||
def validate_id(cls, v):
|
||||
if isinstance(v, float):
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
return None
|
||||
return v
|
||||
pass
|
||||
|
||||
|
||||
class MessageModelRequest(MessageModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue