diff --git a/pyproject.toml b/pyproject.toml index af5cbc606..5e3744a89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,7 +147,7 @@ ignore-regex = '.*(Stati Uniti|Tense=Pres).*' minversion = "6.0" testpaths = ["tests", "integration"] console_output_style = "progress" -filterwarnings = ["ignore::DeprecationWarning"] +filterwarnings = ["ignore::DeprecationWarning", "ignore::ResourceWarning"] log_cli = true markers = ["async_test", "api_key_required"] diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index d1793740d..f052acdee 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -1,5 +1,4 @@ import warnings -from typing import List, Optional from uuid import UUID from loguru import logger @@ -8,17 +7,18 @@ from sqlmodel import Session, col, select from langflow.schema.message import Message from langflow.services.database.models.message.model import MessageRead, MessageTable +from langflow.services.database.utils import migrate_messages_from_monitor_service_to_database from langflow.services.deps import session_scope def get_messages( - sender: Optional[str] = None, - sender_name: Optional[str] = None, - session_id: Optional[str] = None, - order_by: Optional[str] = "timestamp", - order: Optional[str] = "DESC", - flow_id: Optional[UUID] = None, - limit: Optional[int] = None, + sender: str | None = None, + sender_name: str | None = None, + session_id: str | None = None, + order_by: str | None = "timestamp", + order: str | None = "DESC", + flow_id: UUID | None = None, + limit: int | None = None, ): """ Retrieves messages from the monitor service based on the provided filters. @@ -33,6 +33,8 @@ def get_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ + with session_scope() as session: + migrate_messages_from_monitor_service_to_database(session) messages_read: list[Message] = [] with session_scope() as session: stmt = select(MessageTable) @@ -58,7 +60,7 @@ def get_messages( return messages_read -def add_messages(messages: Message | list[Message], flow_id: Optional[str] = None): +def add_messages(messages: Message | list[Message], flow_id: str | None = None): """ Add a message to the monitor service. """ @@ -111,8 +113,8 @@ def delete_messages(session_id: str): def store_message( message: Message, - flow_id: Optional[str] = None, -) -> List[Message]: + flow_id: str | None = None, +) -> list[Message]: """ Stores a message in the memory. diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index c03c0cb71..5ad7d1922 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -41,6 +41,12 @@ class Message(Data): value = str(value) return value + @field_serializer("flow_id") + def serialize_flow_id(value): + if isinstance(value, str): + return UUID(value) + return value + @field_validator("files", mode="before") @classmethod def validate_files(cls, value): diff --git a/src/backend/base/langflow/services/database/models/message/model.py b/src/backend/base/langflow/services/database/models/message/model.py index 5a775d3d9..7c0b9dc8f 100644 --- a/src/backend/base/langflow/services/database/models/message/model.py +++ b/src/backend/base/langflow/services/database/models/message/model.py @@ -26,7 +26,7 @@ class MessageBase(SQLModel): return value @classmethod - def from_message(cls, message: "Message", flow_id: str | None = None): + def from_message(cls, message: "Message", flow_id: str | UUID | None = None): # first check if the record has all the required fields if message.text is None or not message.sender or not message.sender_name: raise ValueError("The message does not have the required fields (text, sender, sender_name).") @@ -34,6 +34,8 @@ class MessageBase(SQLModel): timestamp = datetime.fromisoformat(message.timestamp) else: timestamp = message.timestamp + if not flow_id and message.flow_id: + flow_id = message.flow_id return cls( sender=message.sender, sender_name=message.sender_name, @@ -52,6 +54,15 @@ class MessageTable(MessageBase, table=True): flow: "Flow" = Relationship(back_populates="messages") files: List[str] = Field(sa_column=Column(JSON)) + @field_validator("flow_id", mode="before") + @classmethod + def validate_flow_id(cls, value): + if value is None: + return value + if isinstance(value, str): + value = UUID(value) + return value + # Needed for Column(JSON) class Config: arbitrary_types_allowed = True diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index ceeaf3e38..32e0b08e2 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -6,22 +6,24 @@ from typing import TYPE_CHECKING import sqlalchemy as sa from alembic import command, util from alembic.config import Config -from langflow.services.base import Service -from langflow.services.database import models # noqa -from langflow.services.database.models.user.crud import get_user_by_username -from langflow.services.database.utils import Result, TableResults -from langflow.services.deps import get_settings_service -from langflow.services.utils import teardown_superuser from loguru import logger from sqlalchemy import event, inspect from sqlalchemy.engine import Engine from sqlalchemy.exc import OperationalError from sqlmodel import Session, SQLModel, create_engine, select, text +from langflow.services.base import Service +from langflow.services.database import models # noqa +from langflow.services.database.models.user.crud import get_user_by_username +from langflow.services.database.utils import Result, TableResults, migrate_messages_from_monitor_service_to_database +from langflow.services.deps import get_settings_service +from langflow.services.utils import teardown_superuser + if TYPE_CHECKING: - from langflow.services.settings.service import SettingsService from sqlalchemy.engine import Engine + from langflow.services.settings.service import SettingsService + class DatabaseService(Service): name = "database_service" @@ -205,6 +207,10 @@ class DatabaseService(Service): logger.error(f"AutogenerateDiffsDetected: {exc}") if not fix: raise RuntimeError(f"There's a mismatch between the models and the database.\n{exc}") + try: + migrate_messages_from_monitor_service_to_database(session) + except Exception as exc: + logger.error(f"Error migrating messages from monitor service to database: {exc}") if fix: self.try_downgrade_upgrade_until_success(alembic_cfg) diff --git a/src/backend/base/langflow/services/database/utils.py b/src/backend/base/langflow/services/database/utils.py index cf2c92cb3..fa40c725f 100644 --- a/src/backend/base/langflow/services/database/utils.py +++ b/src/backend/base/langflow/services/database/utils.py @@ -4,11 +4,78 @@ from typing import TYPE_CHECKING from alembic.util.exc import CommandError from loguru import logger -from sqlmodel import Session, text +from sqlmodel import Session, select, text + +from langflow.services.deps import get_monitor_service if TYPE_CHECKING: from langflow.services.database.service import DatabaseService +from typing import Dict, List + + +def migrate_messages_from_monitor_service_to_database(session: Session) -> bool: + from langflow.schema.message import Message + from langflow.services.database.models.message import MessageTable + + monitor_service = get_monitor_service() + messages_df = monitor_service.get_messages() + + if messages_df.empty: + logger.info("No messages to migrate.") + return True + + original_messages: List[Dict] = messages_df.to_dict(orient="records") + + db_messages = session.exec(select(MessageTable)).all() + db_messages = [msg[0] for msg in db_messages] # type: ignore + db_msg_dict = {(msg.text, msg.timestamp.isoformat(), str(msg.flow_id), msg.session_id): msg for msg in db_messages} + # Filter out messages that already exist in the database + original_messages_filtered = [] + for message in original_messages: + key = (message["text"], message["timestamp"].isoformat(), str(message["flow_id"]), message["session_id"]) + if key not in db_msg_dict: + original_messages_filtered.append(message) + if not original_messages_filtered: + logger.info("No messages to migrate.") + return True + try: + # Bulk insert messages + session.bulk_insert_mappings( + MessageTable, # type: ignore + [MessageTable.from_message(Message(**msg)).model_dump() for msg in original_messages_filtered], + ) + session.commit() + except Exception as e: + logger.error(f"Error during message insertion: {str(e)}") + session.rollback() + return False + + # Create a dictionary for faster lookup + + all_ok = True + for orig_msg in original_messages_filtered: + key = (orig_msg["text"], orig_msg["timestamp"].isoformat(), str(orig_msg["flow_id"]), orig_msg["session_id"]) + matching_db_msg = db_msg_dict.get(key) + + if matching_db_msg is None: + logger.warning(f"Message not found in database: {orig_msg}") + all_ok = False + else: + # Validate other fields + if any(getattr(matching_db_msg, k) != v for k, v in orig_msg.items() if k != "index"): + logger.warning(f"Message mismatch in database: {orig_msg}") + all_ok = False + + if all_ok: + messages_ids = [message["index"] for message in original_messages] + monitor_service.delete_messages(messages_ids) + logger.info("Migration completed successfully. Original messages deleted.") + else: + logger.warning("Migration completed with errors. Original messages not deleted.") + + return all_ok + def initialize_database(fix_migration: bool = False): logger.debug("Initializing database") diff --git a/src/backend/base/langflow/services/monitor/schema.py b/src/backend/base/langflow/services/monitor/schema.py index eeea846a1..de0bb17bb 100644 --- a/src/backend/base/langflow/services/monitor/schema.py +++ b/src/backend/base/langflow/services/monitor/schema.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Any, Optional +from typing import Any from uuid import UUID from pydantic import BaseModel, Field, field_serializer, field_validator @@ -28,15 +28,15 @@ class DefaultModel(BaseModel): class TransactionModel(DefaultModel): - index: Optional[int] = Field(default=None) - timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp") + index: int | None = Field(default=None) + timestamp: datetime | None = Field(default_factory=datetime.now, alias="timestamp") vertex_id: str target_id: str | None = None inputs: dict - outputs: Optional[dict] = None + outputs: dict | None = None status: str - error: Optional[str] = None - flow_id: Optional[str] = Field(default=None, alias="flow_id") + error: str | None = None + flow_id: str | None = Field(default=None, alias="flow_id") # validate target_args in case it is a JSON @field_validator("outputs", "inputs", mode="before") @@ -53,16 +53,16 @@ class TransactionModel(DefaultModel): class TransactionModelResponse(DefaultModel): - index: Optional[int] = Field(default=None) - timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp") + index: int | None = Field(default=None) + timestamp: datetime | None = Field(default_factory=datetime.now, alias="timestamp") vertex_id: str inputs: dict - outputs: Optional[dict] = None + outputs: dict | None = None status: str - error: Optional[str] = None - flow_id: Optional[str] = Field(default=None, alias="flow_id") - source: Optional[str] = None - target: Optional[str] = None + error: str | None = None + flow_id: str | None = Field(default=None, alias="flow_id") + source: str | None = None + target: str | None = None # validate target_args in case it is a JSON @field_validator("outputs", "inputs", mode="before") @@ -81,9 +81,9 @@ class TransactionModelResponse(DefaultModel): return v -class MessageModel(DefaultModel): - id: Optional[str | UUID] = Field(default=None) - flow_id: Optional[UUID] = Field(default=None) +class DuckDbMessageModel(DefaultModel): + index: int | None = Field(default=None, alias="index") + flow_id: str | None = Field(default=None, alias="flow_id") timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) sender: str sender_name: str @@ -112,7 +112,53 @@ class MessageModel(DefaultModel): return v @classmethod - def from_message(cls, message: Message, flow_id: Optional[str] = None): + def from_message(cls, message: Message, flow_id: str | None = None): + # first check if the record has all the required fields + if message.text is None or not message.sender or not message.sender_name: + raise ValueError("The message does not have the required fields (text, sender, sender_name).") + return cls( + sender=message.sender, + sender_name=message.sender_name, + text=message.text, + session_id=message.session_id, + files=message.files or [], + timestamp=message.timestamp, + flow_id=flow_id, + ) + + +class MessageModel(DefaultModel): + id: str | UUID | None = Field(default=None) + flow_id: UUID | None = Field(default=None) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + sender: str + sender_name: str + session_id: str + text: str + files: list[str] = [] + + @field_validator("files", mode="before") + @classmethod + def validate_files(cls, v): + if isinstance(v, str): + v = json.loads(v) + return v + + @field_serializer("timestamp") + @classmethod + def serialize_timestamp(cls, v): + v = v.replace(microsecond=0) + return v.strftime("%Y-%m-%d %H:%M:%S") + + @field_serializer("files") + @classmethod + def serialize_files(cls, v): + if isinstance(v, list): + return json.dumps(v) + return v + + @classmethod + def from_message(cls, message: Message, flow_id: str | None = None): # first check if the record has all the required fields if message.text is None or not message.sender or not message.sender_name: raise ValueError("The message does not have the required fields (text, sender, sender_name).") @@ -139,8 +185,8 @@ class MessageModelRequest(MessageModel): class VertexBuildModel(DefaultModel): - index: Optional[int] = Field(default=None, alias="index", exclude=True) - id: Optional[str] = Field(default=None, alias="id") + index: int | None = Field(default=None, alias="index", exclude=True) + id: str | None = Field(default=None, alias="id") flow_id: str valid: bool params: Any diff --git a/src/backend/base/langflow/services/monitor/service.py b/src/backend/base/langflow/services/monitor/service.py index d15d31329..f644fd871 100644 --- a/src/backend/base/langflow/services/monitor/service.py +++ b/src/backend/base/langflow/services/monitor/service.py @@ -1,6 +1,6 @@ from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Union import duckdb from loguru import logger @@ -10,7 +10,7 @@ from langflow.services.base import Service from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch if TYPE_CHECKING: - from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel + from langflow.services.monitor.schema import DuckDbMessageModel, TransactionModel, VertexBuildModel from langflow.services.settings.service import SettingsService @@ -18,14 +18,14 @@ class MonitorService(Service): name = "monitor_service" def __init__(self, settings_service: "SettingsService"): - from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel + from langflow.services.monitor.schema import DuckDbMessageModel, TransactionModel, VertexBuildModel self.settings_service = settings_service self.base_cache_dir = Path(user_cache_dir("langflow")) self.db_path = self.base_cache_dir / "monitor.duckdb" - self.table_map: dict[str, type[TransactionModel | MessageModel | VertexBuildModel]] = { + self.table_map: dict[str, type[TransactionModel | DuckDbMessageModel | VertexBuildModel]] = { "transactions": TransactionModel, - "messages": MessageModel, + "messages": DuckDbMessageModel, "vertex_builds": VertexBuildModel, } @@ -48,7 +48,7 @@ class MonitorService(Service): def add_row( self, table_name: str, - data: Union[dict, "TransactionModel", "MessageModel", "VertexBuildModel"], + data: Union[dict, "TransactionModel", "DuckDbMessageModel", "VertexBuildModel"], ): # Make sure the model passed matches the table @@ -68,12 +68,48 @@ class MonitorService(Service): def get_timestamp(): return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + def get_messages( + self, + flow_id: str | None = None, + sender: str | None = None, + sender_name: str | None = None, + session_id: str | None = None, + order_by: str | None = "timestamp", + order: str | None = "DESC", + limit: int | None = None, + ): + query = "SELECT index, flow_id, sender_name, sender, session_id, text, files, timestamp FROM messages" + conditions = [] + if sender: + conditions.append(f"sender = '{sender}'") + if sender_name: + conditions.append(f"sender_name = '{sender_name}'") + if session_id: + conditions.append(f"session_id = '{session_id}'") + if flow_id: + conditions.append(f"flow_id = '{flow_id}'") + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + if order_by and order: + # Make sure the order is from newest to oldest + query += f" ORDER BY {order_by} {order.upper()}" + + if limit is not None: + query += f" LIMIT {limit}" + + with duckdb.connect(str(self.db_path), read_only=True) as conn: + df = conn.execute(query).df() + + return df + def get_vertex_builds( self, - flow_id: Optional[str] = None, - vertex_id: Optional[str] = None, - valid: Optional[bool] = None, - order_by: Optional[str] = "timestamp", + flow_id: str | None = None, + vertex_id: str | None = None, + valid: bool | None = None, + order_by: str | None = "timestamp", ): query = "SELECT id, index,flow_id, valid, params, data, artifacts, timestamp FROM vertex_builds" conditions = [] @@ -96,7 +132,7 @@ class MonitorService(Service): return df.to_dict(orient="records") - def delete_vertex_builds(self, flow_id: Optional[str] = None): + def delete_vertex_builds(self, flow_id: str | None = None): query = "DELETE FROM vertex_builds" if flow_id: query += f" WHERE flow_id = '{flow_id}'" @@ -109,7 +145,7 @@ class MonitorService(Service): return self.exec_query(query, read_only=False) - def delete_messages(self, message_ids: Union[List[int], str]): + def delete_messages(self, message_ids: list[int] | str): if isinstance(message_ids, list): # If message_ids is a list, join the string representations of the integers ids_str = ",".join(map(str, message_ids)) @@ -132,11 +168,11 @@ class MonitorService(Service): def get_transactions( self, - source: Optional[str] = None, - target: Optional[str] = None, - status: Optional[str] = None, - order_by: Optional[str] = "timestamp", - flow_id: Optional[str] = None, + source: str | None = None, + target: str | None = None, + status: str | None = None, + order_by: str | None = "timestamp", + flow_id: str | None = None, ): query = ( "SELECT index,flow_id, status, error, timestamp, vertex_id, inputs, outputs, target_id FROM transactions" diff --git a/tests/unit/test_messages.py b/tests/unit/test_messages.py index 198387db8..059d82b61 100644 --- a/tests/unit/test_messages.py +++ b/tests/unit/test_messages.py @@ -35,16 +35,20 @@ def created_messages(session): return messages_read -def test_get_messages(session): - add_messages(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) - add_messages(Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2")) +def test_get_messages(): + add_messages( + [ + Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), + Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), + ] + ) messages = get_messages(sender="User", session_id="session_id2", limit=2) assert len(messages) == 2 assert messages[0].text == "Test message 1" assert messages[1].text == "Test message 2" -def test_add_messages(session): +def test_add_messages(): message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") messages = add_messages(message) assert len(messages) == 1 @@ -65,7 +69,7 @@ def test_delete_messages(session): assert len(messages) == 0 -def test_store_message(session): +def test_store_message(): message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id") stored_messages = store_message(message) assert len(stored_messages) == 1