feat: add logic to migrate from duckdb table to database (#2385)

This PR adds a function that migrates any data in the duckdb messages
table to the message table in the database.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-26 23:03:17 +00:00 committed by GitHub
commit 96665b2bfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 240 additions and 62 deletions

View file

@ -147,7 +147,7 @@ ignore-regex = '.*(Stati Uniti|Tense=Pres).*'
minversion = "6.0"
testpaths = ["tests", "integration"]
console_output_style = "progress"
filterwarnings = ["ignore::DeprecationWarning"]
filterwarnings = ["ignore::DeprecationWarning", "ignore::ResourceWarning"]
log_cli = true
markers = ["async_test", "api_key_required"]

View file

@ -1,5 +1,4 @@
import warnings
from typing import List, Optional
from uuid import UUID
from loguru import logger
@ -8,17 +7,18 @@ from sqlmodel import Session, col, select
from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.database.utils import migrate_messages_from_monitor_service_to_database
from langflow.services.deps import session_scope
def get_messages(
sender: Optional[str] = None,
sender_name: Optional[str] = None,
session_id: Optional[str] = None,
order_by: Optional[str] = "timestamp",
order: Optional[str] = "DESC",
flow_id: Optional[UUID] = None,
limit: Optional[int] = None,
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
):
"""
Retrieves messages from the monitor service based on the provided filters.
@ -33,6 +33,8 @@ def get_messages(
Returns:
List[Data]: A list of Data objects representing the retrieved messages.
"""
with session_scope() as session:
migrate_messages_from_monitor_service_to_database(session)
messages_read: list[Message] = []
with session_scope() as session:
stmt = select(MessageTable)
@ -58,7 +60,7 @@ def get_messages(
return messages_read
def add_messages(messages: Message | list[Message], flow_id: Optional[str] = None):
def add_messages(messages: Message | list[Message], flow_id: str | None = None):
"""
Add a message to the monitor service.
"""
@ -111,8 +113,8 @@ def delete_messages(session_id: str):
def store_message(
message: Message,
flow_id: Optional[str] = None,
) -> List[Message]:
flow_id: str | None = None,
) -> list[Message]:
"""
Stores a message in the memory.

View file

@ -41,6 +41,12 @@ class Message(Data):
value = str(value)
return value
@field_serializer("flow_id")
def serialize_flow_id(value):
if isinstance(value, str):
return UUID(value)
return value
@field_validator("files", mode="before")
@classmethod
def validate_files(cls, value):

View file

@ -26,7 +26,7 @@ class MessageBase(SQLModel):
return value
@classmethod
def from_message(cls, message: "Message", flow_id: str | None = None):
def from_message(cls, message: "Message", flow_id: str | UUID | None = None):
# first check if the record has all the required fields
if message.text is None or not message.sender or not message.sender_name:
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
@ -34,6 +34,8 @@ class MessageBase(SQLModel):
timestamp = datetime.fromisoformat(message.timestamp)
else:
timestamp = message.timestamp
if not flow_id and message.flow_id:
flow_id = message.flow_id
return cls(
sender=message.sender,
sender_name=message.sender_name,
@ -52,6 +54,15 @@ class MessageTable(MessageBase, table=True):
flow: "Flow" = Relationship(back_populates="messages")
files: List[str] = Field(sa_column=Column(JSON))
@field_validator("flow_id", mode="before")
@classmethod
def validate_flow_id(cls, value):
if value is None:
return value
if isinstance(value, str):
value = UUID(value)
return value
# Needed for Column(JSON)
class Config:
arbitrary_types_allowed = True

View file

@ -6,22 +6,24 @@ from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import command, util
from alembic.config import Config
from langflow.services.base import Service
from langflow.services.database import models # noqa
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.database.utils import Result, TableResults
from langflow.services.deps import get_settings_service
from langflow.services.utils import teardown_superuser
from loguru import logger
from sqlalchemy import event, inspect
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel, create_engine, select, text
from langflow.services.base import Service
from langflow.services.database import models # noqa
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.database.utils import Result, TableResults, migrate_messages_from_monitor_service_to_database
from langflow.services.deps import get_settings_service
from langflow.services.utils import teardown_superuser
if TYPE_CHECKING:
from langflow.services.settings.service import SettingsService
from sqlalchemy.engine import Engine
from langflow.services.settings.service import SettingsService
class DatabaseService(Service):
name = "database_service"
@ -205,6 +207,10 @@ class DatabaseService(Service):
logger.error(f"AutogenerateDiffsDetected: {exc}")
if not fix:
raise RuntimeError(f"There's a mismatch between the models and the database.\n{exc}")
try:
migrate_messages_from_monitor_service_to_database(session)
except Exception as exc:
logger.error(f"Error migrating messages from monitor service to database: {exc}")
if fix:
self.try_downgrade_upgrade_until_success(alembic_cfg)

View file

@ -4,11 +4,78 @@ from typing import TYPE_CHECKING
from alembic.util.exc import CommandError
from loguru import logger
from sqlmodel import Session, text
from sqlmodel import Session, select, text
from langflow.services.deps import get_monitor_service
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
from typing import Dict, List
def migrate_messages_from_monitor_service_to_database(session: Session) -> bool:
from langflow.schema.message import Message
from langflow.services.database.models.message import MessageTable
monitor_service = get_monitor_service()
messages_df = monitor_service.get_messages()
if messages_df.empty:
logger.info("No messages to migrate.")
return True
original_messages: List[Dict] = messages_df.to_dict(orient="records")
db_messages = session.exec(select(MessageTable)).all()
db_messages = [msg[0] for msg in db_messages] # type: ignore
db_msg_dict = {(msg.text, msg.timestamp.isoformat(), str(msg.flow_id), msg.session_id): msg for msg in db_messages}
# Filter out messages that already exist in the database
original_messages_filtered = []
for message in original_messages:
key = (message["text"], message["timestamp"].isoformat(), str(message["flow_id"]), message["session_id"])
if key not in db_msg_dict:
original_messages_filtered.append(message)
if not original_messages_filtered:
logger.info("No messages to migrate.")
return True
try:
# Bulk insert messages
session.bulk_insert_mappings(
MessageTable, # type: ignore
[MessageTable.from_message(Message(**msg)).model_dump() for msg in original_messages_filtered],
)
session.commit()
except Exception as e:
logger.error(f"Error during message insertion: {str(e)}")
session.rollback()
return False
# Create a dictionary for faster lookup
all_ok = True
for orig_msg in original_messages_filtered:
key = (orig_msg["text"], orig_msg["timestamp"].isoformat(), str(orig_msg["flow_id"]), orig_msg["session_id"])
matching_db_msg = db_msg_dict.get(key)
if matching_db_msg is None:
logger.warning(f"Message not found in database: {orig_msg}")
all_ok = False
else:
# Validate other fields
if any(getattr(matching_db_msg, k) != v for k, v in orig_msg.items() if k != "index"):
logger.warning(f"Message mismatch in database: {orig_msg}")
all_ok = False
if all_ok:
messages_ids = [message["index"] for message in original_messages]
monitor_service.delete_messages(messages_ids)
logger.info("Migration completed successfully. Original messages deleted.")
else:
logger.warning("Migration completed with errors. Original messages not deleted.")
return all_ok
def initialize_database(fix_migration: bool = False):
logger.debug("Initializing database")

View file

@ -1,6 +1,6 @@
import json
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field, field_serializer, field_validator
@ -28,15 +28,15 @@ class DefaultModel(BaseModel):
class TransactionModel(DefaultModel):
index: Optional[int] = Field(default=None)
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
index: int | None = Field(default=None)
timestamp: datetime | None = Field(default_factory=datetime.now, alias="timestamp")
vertex_id: str
target_id: str | None = None
inputs: dict
outputs: Optional[dict] = None
outputs: dict | None = None
status: str
error: Optional[str] = None
flow_id: Optional[str] = Field(default=None, alias="flow_id")
error: str | None = None
flow_id: str | None = Field(default=None, alias="flow_id")
# validate target_args in case it is a JSON
@field_validator("outputs", "inputs", mode="before")
@ -53,16 +53,16 @@ class TransactionModel(DefaultModel):
class TransactionModelResponse(DefaultModel):
index: Optional[int] = Field(default=None)
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
index: int | None = Field(default=None)
timestamp: datetime | None = Field(default_factory=datetime.now, alias="timestamp")
vertex_id: str
inputs: dict
outputs: Optional[dict] = None
outputs: dict | None = None
status: str
error: Optional[str] = None
flow_id: Optional[str] = Field(default=None, alias="flow_id")
source: Optional[str] = None
target: Optional[str] = None
error: str | None = None
flow_id: str | None = Field(default=None, alias="flow_id")
source: str | None = None
target: str | None = None
# validate target_args in case it is a JSON
@field_validator("outputs", "inputs", mode="before")
@ -81,9 +81,9 @@ class TransactionModelResponse(DefaultModel):
return v
class MessageModel(DefaultModel):
id: Optional[str | UUID] = Field(default=None)
flow_id: Optional[UUID] = Field(default=None)
class DuckDbMessageModel(DefaultModel):
index: int | None = Field(default=None, alias="index")
flow_id: str | None = Field(default=None, alias="flow_id")
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
sender: str
sender_name: str
@ -112,7 +112,53 @@ class MessageModel(DefaultModel):
return v
@classmethod
def from_message(cls, message: Message, flow_id: Optional[str] = None):
def from_message(cls, message: Message, flow_id: str | None = None):
# first check if the record has all the required fields
if message.text is None or not message.sender or not message.sender_name:
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
return cls(
sender=message.sender,
sender_name=message.sender_name,
text=message.text,
session_id=message.session_id,
files=message.files or [],
timestamp=message.timestamp,
flow_id=flow_id,
)
class MessageModel(DefaultModel):
id: str | UUID | None = Field(default=None)
flow_id: UUID | None = Field(default=None)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
sender: str
sender_name: str
session_id: str
text: str
files: list[str] = []
@field_validator("files", mode="before")
@classmethod
def validate_files(cls, v):
if isinstance(v, str):
v = json.loads(v)
return v
@field_serializer("timestamp")
@classmethod
def serialize_timestamp(cls, v):
v = v.replace(microsecond=0)
return v.strftime("%Y-%m-%d %H:%M:%S")
@field_serializer("files")
@classmethod
def serialize_files(cls, v):
if isinstance(v, list):
return json.dumps(v)
return v
@classmethod
def from_message(cls, message: Message, flow_id: str | None = None):
# first check if the record has all the required fields
if message.text is None or not message.sender or not message.sender_name:
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
@ -139,8 +185,8 @@ class MessageModelRequest(MessageModel):
class VertexBuildModel(DefaultModel):
index: Optional[int] = Field(default=None, alias="index", exclude=True)
id: Optional[str] = Field(default=None, alias="id")
index: int | None = Field(default=None, alias="index", exclude=True)
id: str | None = Field(default=None, alias="id")
flow_id: str
valid: bool
params: Any

View file

@ -1,6 +1,6 @@
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Union
import duckdb
from loguru import logger
@ -10,7 +10,7 @@ from langflow.services.base import Service
from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch
if TYPE_CHECKING:
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
from langflow.services.monitor.schema import DuckDbMessageModel, TransactionModel, VertexBuildModel
from langflow.services.settings.service import SettingsService
@ -18,14 +18,14 @@ class MonitorService(Service):
name = "monitor_service"
def __init__(self, settings_service: "SettingsService"):
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
from langflow.services.monitor.schema import DuckDbMessageModel, TransactionModel, VertexBuildModel
self.settings_service = settings_service
self.base_cache_dir = Path(user_cache_dir("langflow"))
self.db_path = self.base_cache_dir / "monitor.duckdb"
self.table_map: dict[str, type[TransactionModel | MessageModel | VertexBuildModel]] = {
self.table_map: dict[str, type[TransactionModel | DuckDbMessageModel | VertexBuildModel]] = {
"transactions": TransactionModel,
"messages": MessageModel,
"messages": DuckDbMessageModel,
"vertex_builds": VertexBuildModel,
}
@ -48,7 +48,7 @@ class MonitorService(Service):
def add_row(
self,
table_name: str,
data: Union[dict, "TransactionModel", "MessageModel", "VertexBuildModel"],
data: Union[dict, "TransactionModel", "DuckDbMessageModel", "VertexBuildModel"],
):
# Make sure the model passed matches the table
@ -68,12 +68,48 @@ class MonitorService(Service):
def get_timestamp():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_messages(
self,
flow_id: str | None = None,
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
limit: int | None = None,
):
query = "SELECT index, flow_id, sender_name, sender, session_id, text, files, timestamp FROM messages"
conditions = []
if sender:
conditions.append(f"sender = '{sender}'")
if sender_name:
conditions.append(f"sender_name = '{sender_name}'")
if session_id:
conditions.append(f"session_id = '{session_id}'")
if flow_id:
conditions.append(f"flow_id = '{flow_id}'")
if conditions:
query += " WHERE " + " AND ".join(conditions)
if order_by and order:
# Make sure the order is from newest to oldest
query += f" ORDER BY {order_by} {order.upper()}"
if limit is not None:
query += f" LIMIT {limit}"
with duckdb.connect(str(self.db_path), read_only=True) as conn:
df = conn.execute(query).df()
return df
def get_vertex_builds(
self,
flow_id: Optional[str] = None,
vertex_id: Optional[str] = None,
valid: Optional[bool] = None,
order_by: Optional[str] = "timestamp",
flow_id: str | None = None,
vertex_id: str | None = None,
valid: bool | None = None,
order_by: str | None = "timestamp",
):
query = "SELECT id, index,flow_id, valid, params, data, artifacts, timestamp FROM vertex_builds"
conditions = []
@ -96,7 +132,7 @@ class MonitorService(Service):
return df.to_dict(orient="records")
def delete_vertex_builds(self, flow_id: Optional[str] = None):
def delete_vertex_builds(self, flow_id: str | None = None):
query = "DELETE FROM vertex_builds"
if flow_id:
query += f" WHERE flow_id = '{flow_id}'"
@ -109,7 +145,7 @@ class MonitorService(Service):
return self.exec_query(query, read_only=False)
def delete_messages(self, message_ids: Union[List[int], str]):
def delete_messages(self, message_ids: list[int] | str):
if isinstance(message_ids, list):
# If message_ids is a list, join the string representations of the integers
ids_str = ",".join(map(str, message_ids))
@ -132,11 +168,11 @@ class MonitorService(Service):
def get_transactions(
self,
source: Optional[str] = None,
target: Optional[str] = None,
status: Optional[str] = None,
order_by: Optional[str] = "timestamp",
flow_id: Optional[str] = None,
source: str | None = None,
target: str | None = None,
status: str | None = None,
order_by: str | None = "timestamp",
flow_id: str | None = None,
):
query = (
"SELECT index,flow_id, status, error, timestamp, vertex_id, inputs, outputs, target_id FROM transactions"

View file

@ -35,16 +35,20 @@ def created_messages(session):
return messages_read
def test_get_messages(session):
add_messages(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
add_messages(Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"))
def test_get_messages():
add_messages(
[
Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
]
)
messages = get_messages(sender="User", session_id="session_id2", limit=2)
assert len(messages) == 2
assert messages[0].text == "Test message 1"
assert messages[1].text == "Test message 2"
def test_add_messages(session):
def test_add_messages():
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
messages = add_messages(message)
assert len(messages) == 1
@ -65,7 +69,7 @@ def test_delete_messages(session):
assert len(messages) == 0
def test_store_message(session):
def test_store_message():
message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id")
stored_messages = store_message(message)
assert len(stored_messages) == 1