Migrate messages from monitor service to database

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-26 16:11:54 -03:00
commit 46dcc2ef56

View file

@ -4,28 +4,76 @@ from typing import TYPE_CHECKING
from alembic.util.exc import CommandError
from loguru import logger
from sqlmodel import Session, text
from sqlmodel import Session, select, text
from langflow.services.deps import get_monitor_service
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
from typing import Dict, List
def migrate_messages_from_monitor_service_to_database(session):
def migrate_messages_from_monitor_service_to_database(session: Session) -> bool:
from langflow.schema.message import Message
from langflow.services.database.models.message import MessageTable
monitor_service = get_monitor_service()
messages_df = monitor_service.get_messages()
if not messages_df.empty:
messages_ids = []
for message in messages_df.to_dict(orient="records"):
messages_ids.append(message["index"])
message = Message(**message)
session.add(MessageTable.from_message(message))
if messages_df.empty:
logger.info("No messages to migrate.")
return True
original_messages: List[Dict] = messages_df.to_dict(orient="records")
db_messages = session.exec(select(MessageTable)).all()
db_messages = [msg[0] for msg in db_messages]
db_msg_dict = {(msg.text, msg.timestamp.isoformat(), str(msg.flow_id, msg.session_id)): msg for msg in db_messages}
# Filter out messages that already exist in the database
original_messages_filtered = []
for message in original_messages:
key = (message["text"], message["timestamp"].isoformat(), str(message["flow_id"]))
if key not in db_msg_dict:
original_messages_filtered.append(message)
if not original_messages_filtered:
logger.info("No messages to migrate.")
return True
try:
# Bulk insert messages
session.bulk_insert_mappings(
MessageTable, [MessageTable.from_message(Message(**msg)).model_dump() for msg in original_messages_filtered]
)
session.commit()
except Exception as e:
logger.error(f"Error during message insertion: {str(e)}")
session.rollback()
return False
# Create a dictionary for faster lookup
all_ok = True
for orig_msg in original_messages_filtered:
key = (orig_msg["text"], orig_msg["timestamp"].isoformat(), str(orig_msg["flow_id"]))
matching_db_msg = db_msg_dict.get(key)
if matching_db_msg is None:
logger.warning(f"Message not found in database: {orig_msg}")
all_ok = False
else:
# Validate other fields
if any(getattr(matching_db_msg, k) != v for k, v in orig_msg.items() if k != "index"):
logger.warning(f"Message mismatch in database: {orig_msg}")
all_ok = False
if all_ok:
messages_ids = [message["index"] for message in original_messages]
monitor_service.delete_messages(messages_ids)
logger.info("Migration completed successfully. Original messages deleted.")
else:
logger.warning("Migration completed with errors. Original messages not deleted.")
return all_ok
def initialize_database(fix_migration: bool = False):