feat: add logic to migrate from duckdb table to database (#2385)
This PR adds a function that migrates any data in the duckdb messages table to the message table in the database.
This commit is contained in:
commit
96665b2bfe
9 changed files with 240 additions and 62 deletions
|
|
@ -147,7 +147,7 @@ ignore-regex = '.*(Stati Uniti|Tense=Pres).*'
|
|||
minversion = "6.0"
|
||||
testpaths = ["tests", "integration"]
|
||||
console_output_style = "progress"
|
||||
filterwarnings = ["ignore::DeprecationWarning"]
|
||||
filterwarnings = ["ignore::DeprecationWarning", "ignore::ResourceWarning"]
|
||||
log_cli = true
|
||||
markers = ["async_test", "api_key_required"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import warnings
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -8,17 +7,18 @@ from sqlmodel import Session, col, select
|
|||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.database.utils import migrate_messages_from_monitor_service_to_database
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
def get_messages(
|
||||
sender: Optional[str] = None,
|
||||
sender_name: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
order_by: Optional[str] = "timestamp",
|
||||
order: Optional[str] = "DESC",
|
||||
flow_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = None,
|
||||
sender: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
session_id: str | None = None,
|
||||
order_by: str | None = "timestamp",
|
||||
order: str | None = "DESC",
|
||||
flow_id: UUID | None = None,
|
||||
limit: int | None = None,
|
||||
):
|
||||
"""
|
||||
Retrieves messages from the monitor service based on the provided filters.
|
||||
|
|
@ -33,6 +33,8 @@ def get_messages(
|
|||
Returns:
|
||||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
with session_scope() as session:
|
||||
migrate_messages_from_monitor_service_to_database(session)
|
||||
messages_read: list[Message] = []
|
||||
with session_scope() as session:
|
||||
stmt = select(MessageTable)
|
||||
|
|
@ -58,7 +60,7 @@ def get_messages(
|
|||
return messages_read
|
||||
|
||||
|
||||
def add_messages(messages: Message | list[Message], flow_id: Optional[str] = None):
|
||||
def add_messages(messages: Message | list[Message], flow_id: str | None = None):
|
||||
"""
|
||||
Add a message to the monitor service.
|
||||
"""
|
||||
|
|
@ -111,8 +113,8 @@ def delete_messages(session_id: str):
|
|||
|
||||
def store_message(
|
||||
message: Message,
|
||||
flow_id: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
flow_id: str | None = None,
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Stores a message in the memory.
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,12 @@ class Message(Data):
|
|||
value = str(value)
|
||||
return value
|
||||
|
||||
@field_serializer("flow_id")
|
||||
def serialize_flow_id(value):
|
||||
if isinstance(value, str):
|
||||
return UUID(value)
|
||||
return value
|
||||
|
||||
@field_validator("files", mode="before")
|
||||
@classmethod
|
||||
def validate_files(cls, value):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class MessageBase(SQLModel):
|
|||
return value
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: "Message", flow_id: str | None = None):
|
||||
def from_message(cls, message: "Message", flow_id: str | UUID | None = None):
|
||||
# first check if the record has all the required fields
|
||||
if message.text is None or not message.sender or not message.sender_name:
|
||||
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
|
||||
|
|
@ -34,6 +34,8 @@ class MessageBase(SQLModel):
|
|||
timestamp = datetime.fromisoformat(message.timestamp)
|
||||
else:
|
||||
timestamp = message.timestamp
|
||||
if not flow_id and message.flow_id:
|
||||
flow_id = message.flow_id
|
||||
return cls(
|
||||
sender=message.sender,
|
||||
sender_name=message.sender_name,
|
||||
|
|
@ -52,6 +54,15 @@ class MessageTable(MessageBase, table=True):
|
|||
flow: "Flow" = Relationship(back_populates="messages")
|
||||
files: List[str] = Field(sa_column=Column(JSON))
|
||||
|
||||
@field_validator("flow_id", mode="before")
|
||||
@classmethod
|
||||
def validate_flow_id(cls, value):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
value = UUID(value)
|
||||
return value
|
||||
|
||||
# Needed for Column(JSON)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
|||
|
|
@ -6,22 +6,24 @@ from typing import TYPE_CHECKING
|
|||
import sqlalchemy as sa
|
||||
from alembic import command, util
|
||||
from alembic.config import Config
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database import models # noqa
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.database.utils import Result, TableResults
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.utils import teardown_superuser
|
||||
from loguru import logger
|
||||
from sqlalchemy import event, inspect
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import Session, SQLModel, create_engine, select, text
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database import models # noqa
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.database.utils import Result, TableResults, migrate_messages_from_monitor_service_to_database
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.utils import teardown_superuser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.service import SettingsService
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
name = "database_service"
|
||||
|
|
@ -205,6 +207,10 @@ class DatabaseService(Service):
|
|||
logger.error(f"AutogenerateDiffsDetected: {exc}")
|
||||
if not fix:
|
||||
raise RuntimeError(f"There's a mismatch between the models and the database.\n{exc}")
|
||||
try:
|
||||
migrate_messages_from_monitor_service_to_database(session)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error migrating messages from monitor service to database: {exc}")
|
||||
|
||||
if fix:
|
||||
self.try_downgrade_upgrade_until_success(alembic_cfg)
|
||||
|
|
|
|||
|
|
@ -4,11 +4,78 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from alembic.util.exc import CommandError
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, text
|
||||
from sqlmodel import Session, select, text
|
||||
|
||||
from langflow.services.deps import get_monitor_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def migrate_messages_from_monitor_service_to_database(session: Session) -> bool:
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message import MessageTable
|
||||
|
||||
monitor_service = get_monitor_service()
|
||||
messages_df = monitor_service.get_messages()
|
||||
|
||||
if messages_df.empty:
|
||||
logger.info("No messages to migrate.")
|
||||
return True
|
||||
|
||||
original_messages: List[Dict] = messages_df.to_dict(orient="records")
|
||||
|
||||
db_messages = session.exec(select(MessageTable)).all()
|
||||
db_messages = [msg[0] for msg in db_messages] # type: ignore
|
||||
db_msg_dict = {(msg.text, msg.timestamp.isoformat(), str(msg.flow_id), msg.session_id): msg for msg in db_messages}
|
||||
# Filter out messages that already exist in the database
|
||||
original_messages_filtered = []
|
||||
for message in original_messages:
|
||||
key = (message["text"], message["timestamp"].isoformat(), str(message["flow_id"]), message["session_id"])
|
||||
if key not in db_msg_dict:
|
||||
original_messages_filtered.append(message)
|
||||
if not original_messages_filtered:
|
||||
logger.info("No messages to migrate.")
|
||||
return True
|
||||
try:
|
||||
# Bulk insert messages
|
||||
session.bulk_insert_mappings(
|
||||
MessageTable, # type: ignore
|
||||
[MessageTable.from_message(Message(**msg)).model_dump() for msg in original_messages_filtered],
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during message insertion: {str(e)}")
|
||||
session.rollback()
|
||||
return False
|
||||
|
||||
# Create a dictionary for faster lookup
|
||||
|
||||
all_ok = True
|
||||
for orig_msg in original_messages_filtered:
|
||||
key = (orig_msg["text"], orig_msg["timestamp"].isoformat(), str(orig_msg["flow_id"]), orig_msg["session_id"])
|
||||
matching_db_msg = db_msg_dict.get(key)
|
||||
|
||||
if matching_db_msg is None:
|
||||
logger.warning(f"Message not found in database: {orig_msg}")
|
||||
all_ok = False
|
||||
else:
|
||||
# Validate other fields
|
||||
if any(getattr(matching_db_msg, k) != v for k, v in orig_msg.items() if k != "index"):
|
||||
logger.warning(f"Message mismatch in database: {orig_msg}")
|
||||
all_ok = False
|
||||
|
||||
if all_ok:
|
||||
messages_ids = [message["index"] for message in original_messages]
|
||||
monitor_service.delete_messages(messages_ids)
|
||||
logger.info("Migration completed successfully. Original messages deleted.")
|
||||
else:
|
||||
logger.warning("Migration completed with errors. Original messages not deleted.")
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
def initialize_database(fix_migration: bool = False):
|
||||
logger.debug("Initializing database")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
|
@ -28,15 +28,15 @@ class DefaultModel(BaseModel):
|
|||
|
||||
|
||||
class TransactionModel(DefaultModel):
|
||||
index: Optional[int] = Field(default=None)
|
||||
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
|
||||
index: int | None = Field(default=None)
|
||||
timestamp: datetime | None = Field(default_factory=datetime.now, alias="timestamp")
|
||||
vertex_id: str
|
||||
target_id: str | None = None
|
||||
inputs: dict
|
||||
outputs: Optional[dict] = None
|
||||
outputs: dict | None = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
flow_id: Optional[str] = Field(default=None, alias="flow_id")
|
||||
error: str | None = None
|
||||
flow_id: str | None = Field(default=None, alias="flow_id")
|
||||
|
||||
# validate target_args in case it is a JSON
|
||||
@field_validator("outputs", "inputs", mode="before")
|
||||
|
|
@ -53,16 +53,16 @@ class TransactionModel(DefaultModel):
|
|||
|
||||
|
||||
class TransactionModelResponse(DefaultModel):
|
||||
index: Optional[int] = Field(default=None)
|
||||
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
|
||||
index: int | None = Field(default=None)
|
||||
timestamp: datetime | None = Field(default_factory=datetime.now, alias="timestamp")
|
||||
vertex_id: str
|
||||
inputs: dict
|
||||
outputs: Optional[dict] = None
|
||||
outputs: dict | None = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
flow_id: Optional[str] = Field(default=None, alias="flow_id")
|
||||
source: Optional[str] = None
|
||||
target: Optional[str] = None
|
||||
error: str | None = None
|
||||
flow_id: str | None = Field(default=None, alias="flow_id")
|
||||
source: str | None = None
|
||||
target: str | None = None
|
||||
|
||||
# validate target_args in case it is a JSON
|
||||
@field_validator("outputs", "inputs", mode="before")
|
||||
|
|
@ -81,9 +81,9 @@ class TransactionModelResponse(DefaultModel):
|
|||
return v
|
||||
|
||||
|
||||
class MessageModel(DefaultModel):
|
||||
id: Optional[str | UUID] = Field(default=None)
|
||||
flow_id: Optional[UUID] = Field(default=None)
|
||||
class DuckDbMessageModel(DefaultModel):
|
||||
index: int | None = Field(default=None, alias="index")
|
||||
flow_id: str | None = Field(default=None, alias="flow_id")
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
sender: str
|
||||
sender_name: str
|
||||
|
|
@ -112,7 +112,53 @@ class MessageModel(DefaultModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message, flow_id: Optional[str] = None):
|
||||
def from_message(cls, message: Message, flow_id: str | None = None):
|
||||
# first check if the record has all the required fields
|
||||
if message.text is None or not message.sender or not message.sender_name:
|
||||
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
|
||||
return cls(
|
||||
sender=message.sender,
|
||||
sender_name=message.sender_name,
|
||||
text=message.text,
|
||||
session_id=message.session_id,
|
||||
files=message.files or [],
|
||||
timestamp=message.timestamp,
|
||||
flow_id=flow_id,
|
||||
)
|
||||
|
||||
|
||||
class MessageModel(DefaultModel):
|
||||
id: str | UUID | None = Field(default=None)
|
||||
flow_id: UUID | None = Field(default=None)
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
sender: str
|
||||
sender_name: str
|
||||
session_id: str
|
||||
text: str
|
||||
files: list[str] = []
|
||||
|
||||
@field_validator("files", mode="before")
|
||||
@classmethod
|
||||
def validate_files(cls, v):
|
||||
if isinstance(v, str):
|
||||
v = json.loads(v)
|
||||
return v
|
||||
|
||||
@field_serializer("timestamp")
|
||||
@classmethod
|
||||
def serialize_timestamp(cls, v):
|
||||
v = v.replace(microsecond=0)
|
||||
return v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@field_serializer("files")
|
||||
@classmethod
|
||||
def serialize_files(cls, v):
|
||||
if isinstance(v, list):
|
||||
return json.dumps(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message, flow_id: str | None = None):
|
||||
# first check if the record has all the required fields
|
||||
if message.text is None or not message.sender or not message.sender_name:
|
||||
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
|
||||
|
|
@ -139,8 +185,8 @@ class MessageModelRequest(MessageModel):
|
|||
|
||||
|
||||
class VertexBuildModel(DefaultModel):
|
||||
index: Optional[int] = Field(default=None, alias="index", exclude=True)
|
||||
id: Optional[str] = Field(default=None, alias="id")
|
||||
index: int | None = Field(default=None, alias="index", exclude=True)
|
||||
id: str | None = Field(default=None, alias="id")
|
||||
flow_id: str
|
||||
valid: bool
|
||||
params: Any
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import duckdb
|
||||
from loguru import logger
|
||||
|
|
@ -10,7 +10,7 @@ from langflow.services.base import Service
|
|||
from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
|
||||
from langflow.services.monitor.schema import DuckDbMessageModel, TransactionModel, VertexBuildModel
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
|
||||
|
|
@ -18,14 +18,14 @@ class MonitorService(Service):
|
|||
name = "monitor_service"
|
||||
|
||||
def __init__(self, settings_service: "SettingsService"):
|
||||
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
|
||||
from langflow.services.monitor.schema import DuckDbMessageModel, TransactionModel, VertexBuildModel
|
||||
|
||||
self.settings_service = settings_service
|
||||
self.base_cache_dir = Path(user_cache_dir("langflow"))
|
||||
self.db_path = self.base_cache_dir / "monitor.duckdb"
|
||||
self.table_map: dict[str, type[TransactionModel | MessageModel | VertexBuildModel]] = {
|
||||
self.table_map: dict[str, type[TransactionModel | DuckDbMessageModel | VertexBuildModel]] = {
|
||||
"transactions": TransactionModel,
|
||||
"messages": MessageModel,
|
||||
"messages": DuckDbMessageModel,
|
||||
"vertex_builds": VertexBuildModel,
|
||||
}
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ class MonitorService(Service):
|
|||
def add_row(
|
||||
self,
|
||||
table_name: str,
|
||||
data: Union[dict, "TransactionModel", "MessageModel", "VertexBuildModel"],
|
||||
data: Union[dict, "TransactionModel", "DuckDbMessageModel", "VertexBuildModel"],
|
||||
):
|
||||
# Make sure the model passed matches the table
|
||||
|
||||
|
|
@ -68,12 +68,48 @@ class MonitorService(Service):
|
|||
def get_timestamp():
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
flow_id: str | None = None,
|
||||
sender: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
session_id: str | None = None,
|
||||
order_by: str | None = "timestamp",
|
||||
order: str | None = "DESC",
|
||||
limit: int | None = None,
|
||||
):
|
||||
query = "SELECT index, flow_id, sender_name, sender, session_id, text, files, timestamp FROM messages"
|
||||
conditions = []
|
||||
if sender:
|
||||
conditions.append(f"sender = '{sender}'")
|
||||
if sender_name:
|
||||
conditions.append(f"sender_name = '{sender_name}'")
|
||||
if session_id:
|
||||
conditions.append(f"session_id = '{session_id}'")
|
||||
if flow_id:
|
||||
conditions.append(f"flow_id = '{flow_id}'")
|
||||
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
if order_by and order:
|
||||
# Make sure the order is from newest to oldest
|
||||
query += f" ORDER BY {order_by} {order.upper()}"
|
||||
|
||||
if limit is not None:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
with duckdb.connect(str(self.db_path), read_only=True) as conn:
|
||||
df = conn.execute(query).df()
|
||||
|
||||
return df
|
||||
|
||||
def get_vertex_builds(
|
||||
self,
|
||||
flow_id: Optional[str] = None,
|
||||
vertex_id: Optional[str] = None,
|
||||
valid: Optional[bool] = None,
|
||||
order_by: Optional[str] = "timestamp",
|
||||
flow_id: str | None = None,
|
||||
vertex_id: str | None = None,
|
||||
valid: bool | None = None,
|
||||
order_by: str | None = "timestamp",
|
||||
):
|
||||
query = "SELECT id, index,flow_id, valid, params, data, artifacts, timestamp FROM vertex_builds"
|
||||
conditions = []
|
||||
|
|
@ -96,7 +132,7 @@ class MonitorService(Service):
|
|||
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def delete_vertex_builds(self, flow_id: Optional[str] = None):
|
||||
def delete_vertex_builds(self, flow_id: str | None = None):
|
||||
query = "DELETE FROM vertex_builds"
|
||||
if flow_id:
|
||||
query += f" WHERE flow_id = '{flow_id}'"
|
||||
|
|
@ -109,7 +145,7 @@ class MonitorService(Service):
|
|||
|
||||
return self.exec_query(query, read_only=False)
|
||||
|
||||
def delete_messages(self, message_ids: Union[List[int], str]):
|
||||
def delete_messages(self, message_ids: list[int] | str):
|
||||
if isinstance(message_ids, list):
|
||||
# If message_ids is a list, join the string representations of the integers
|
||||
ids_str = ",".join(map(str, message_ids))
|
||||
|
|
@ -132,11 +168,11 @@ class MonitorService(Service):
|
|||
|
||||
def get_transactions(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
target: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
order_by: Optional[str] = "timestamp",
|
||||
flow_id: Optional[str] = None,
|
||||
source: str | None = None,
|
||||
target: str | None = None,
|
||||
status: str | None = None,
|
||||
order_by: str | None = "timestamp",
|
||||
flow_id: str | None = None,
|
||||
):
|
||||
query = (
|
||||
"SELECT index,flow_id, status, error, timestamp, vertex_id, inputs, outputs, target_id FROM transactions"
|
||||
|
|
|
|||
|
|
@ -35,16 +35,20 @@ def created_messages(session):
|
|||
return messages_read
|
||||
|
||||
|
||||
def test_get_messages(session):
|
||||
add_messages(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
|
||||
add_messages(Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"))
|
||||
def test_get_messages():
|
||||
add_messages(
|
||||
[
|
||||
Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
)
|
||||
messages = get_messages(sender="User", session_id="session_id2", limit=2)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].text == "Test message 1"
|
||||
assert messages[1].text == "Test message 2"
|
||||
|
||||
|
||||
def test_add_messages(session):
|
||||
def test_add_messages():
|
||||
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
|
||||
messages = add_messages(message)
|
||||
assert len(messages) == 1
|
||||
|
|
@ -65,7 +69,7 @@ def test_delete_messages(session):
|
|||
assert len(messages) == 0
|
||||
|
||||
|
||||
def test_store_message(session):
|
||||
def test_store_message():
|
||||
message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id")
|
||||
stored_messages = store_message(message)
|
||||
assert len(stored_messages) == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue