diff --git a/.github/workflows/lint-js.yml b/.github/workflows/lint-js.yml index aafd15f4a..b528ee8ed 100644 --- a/.github/workflows/lint-js.yml +++ b/.github/workflows/lint-js.yml @@ -5,7 +5,7 @@ on: paths: - "src/frontend/**" merge_group: - branches: [dev] + types: [checks_requested] env: NODE_VERSION: "21" diff --git a/.github/workflows/python_test.yml b/.github/workflows/python_test.yml index 9e065c58f..8e37f40f7 100644 --- a/.github/workflows/python_test.yml +++ b/.github/workflows/python_test.yml @@ -45,7 +45,7 @@ jobs: poetry run python -m langflow run --host 127.0.0.1 --port 7860 --backend-only & SERVER_PID=$! # Wait for the server to start - timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) + timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 5; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) # Terminate the server kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1) sleep 10 # give the server some time to terminate diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9f84be063..f16dc35d3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -62,7 +62,7 @@ jobs: python -m langflow run --host 127.0.0.1 --port 7860 & SERVER_PID=$! # Wait for the server to start - timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) + timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) # Terminate the server kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1) sleep 10 # give the server some time to terminate @@ -124,7 +124,7 @@ jobs: python -m langflow run --host 127.0.0.1 --port 7860 & SERVER_PID=$! # Wait for the server to start - timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) + timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) # Terminate the server kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1) sleep 10 # give the server some time to terminate diff --git a/src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py b/src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py new file mode 100644 index 000000000..dd63398b7 --- /dev/null +++ b/src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py @@ -0,0 +1,52 @@ +"""Add message table + +Revision ID: 325180f0c4e1 +Revises: 631faacf5da2 +Create Date: 2024-06-23 21:29:28.220100 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +from langflow.utils import migration + +# revision identifiers, used by Alembic. +revision: str = "325180f0c4e1" +down_revision: Union[str, None] = "631faacf5da2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + if not migration.table_exists("message", conn): + op.create_table( + "message", + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("sender", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("sender_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("session_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("flow_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("files", sa.JSON(), nullable=True), + sa.ForeignKeyConstraint( + ["flow_id"], + ["flow.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + conn = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + if migration.table_exists("message", conn): + op.drop_table("message") + # ### end Alembic commands ### diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index a99c86bf8..f6c1fc4ac 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -1,15 +1,15 @@ from typing import List, Optional - +from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import delete +from sqlmodel import Session, col, select -from langflow.services.deps import get_monitor_service -from langflow.services.monitor.schema import ( - MessageModelRequest, - MessageModelResponse, - TransactionModelResponse, - VertexBuildMapModel, -) +from langflow.services.auth.utils import get_current_active_user +from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate +from langflow.services.database.models.user.model import User +from langflow.services.deps import get_monitor_service, get_session +from langflow.services.monitor.schema import MessageModelResponse, TransactionModelResponse, VertexBuildMapModel from langflow.services.monitor.service import MonitorService router = APIRouter(prefix="/monitor", tags=["Monitor"]) @@ -52,45 +52,58 @@ async def get_messages( sender: Optional[str] = Query(None), sender_name: Optional[str] = Query(None), order_by: Optional[str] = Query("timestamp"), - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - df = monitor_service.get_messages( - flow_id=flow_id, - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - ) - dicts = df.to_dict(orient="records") - return [MessageModelResponse(**d) for d in dicts] + stmt = select(MessageTable) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if order_by: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + messages = session.exec(stmt) + return [MessageModelResponse.model_validate(d, from_attributes=True) for d in messages] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.delete("/messages", status_code=204) async def delete_messages( - message_ids: List[int], - monitor_service: MonitorService = Depends(get_monitor_service), + message_ids: List[UUID], + session: Session = Depends(get_session), + current_user: User = Depends(get_current_active_user), ): try: - monitor_service.delete_messages(message_ids=message_ids) + session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/messages/{message_id}", response_model=MessageModelResponse) +@router.put("/messages/{message_id}", response_model=MessageRead) async def update_message( - message_id: int, - message: MessageModelRequest, - monitor_service: MonitorService = Depends(get_monitor_service), + message_id: UUID, + message: MessageUpdate, + session: Session = Depends(get_session), + user: User = Depends(get_current_active_user), ): try: - message_dict = message.model_dump(exclude_none=True) - message_dict.pop("index", None) - monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore - return MessageModelResponse(index=message_id, **message_dict) - + db_message = session.get(MessageTable, message_id) + if not db_message: + raise HTTPException(status_code=404, detail="Message not found") + message_dict = message.model_dump(exclude_unset=True, exclude_none=True) + db_message.sqlmodel_update(message_dict) + session.add(db_message) + session.commit() + session.refresh(db_message) + return db_message + except HTTPException as e: + raise e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -98,10 +111,16 @@ async def update_message( @router.delete("/messages/session/{session_id}", status_code=204) async def delete_messages_session( session_id: str, - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - monitor_service.delete_messages_session(session_id=session_id) + session.exec( # type: ignore + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) + session.commit() + return {"message": "Messages deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -137,4 +156,3 @@ async def get_transactions( return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index 81f124fd7..bea9c88d8 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -119,7 +119,11 @@ class LCModelComponent(Component): return status_message def get_chat_result( - self, runnable: LanguageModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None + self, + runnable: LanguageModel, + stream: bool, + input_value: str | Message, + system_message: Optional[str] = None, ): messages: list[Union[BaseMessage]] = [] if not input_value and not system_message: diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index e89682969..d1793740d 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -1,11 +1,14 @@ import warnings from typing import List, Optional +from uuid import UUID from loguru import logger +from sqlalchemy import delete +from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.deps import get_monitor_service -from langflow.services.monitor.schema import MessageModel +from langflow.services.database.models.message.model import MessageRead, MessageTable +from langflow.services.deps import session_scope def get_messages( @@ -14,6 +17,7 @@ def get_messages( session_id: Optional[str] = None, order_by: Optional[str] = "timestamp", order: Optional[str] = "DESC", + flow_id: Optional[UUID] = None, limit: Optional[int] = None, ): """ @@ -29,34 +33,29 @@ def get_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ - monitor_service = get_monitor_service() - messages_df = monitor_service.get_messages( - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - limit=limit, - order=order, - ) + messages_read: list[Message] = [] + with session_scope() as session: + stmt = select(MessageTable) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if order_by: + if order == "DESC": + col = getattr(MessageTable, order_by).desc() + else: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + if limit: + stmt = stmt.limit(limit) + messages = session.exec(stmt) + messages_read = [Message(**d.model_dump()) for d in messages] - messages: list[Message] = [] - # messages_df has a timestamp - # it gets the last 5 messages, for example - # but now they are ordered from most recent to least recent - # so we need to reverse the order - messages_df = messages_df[::-1] if order == "DESC" else messages_df - for row in messages_df.itertuples(): - msg = Message( - text=row.text, - sender=row.sender, - session_id=row.session_id, - sender_name=row.sender_name, - timestamp=row.timestamp, - ) - - messages.append(msg) - - return messages + return messages_read def add_messages(messages: Message | list[Message], flow_id: Optional[str] = None): @@ -64,7 +63,6 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non Add a message to the monitor service. """ try: - monitor_service = get_monitor_service() if not isinstance(messages, list): messages = [messages] @@ -72,25 +70,29 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non types = ", ".join([str(type(message)) for message in messages]) raise ValueError(f"The messages must be instances of Message. Found: {types}") - messages_models: list[MessageModel] = [] + messages_models: list[MessageTable] = [] for msg in messages: - if not msg.timestamp: - msg.timestamp = monitor_service.get_timestamp() - messages_models.append(MessageModel.from_message(msg, flow_id=flow_id)) - - for message_model in messages_models: - try: - monitor_service.add_message(message_model) - except Exception as e: - logger.error(f"Error adding message to monitor service: {e}") - logger.exception(e) - raise e - return messages_models + messages_models.append(MessageTable.from_message(msg, flow_id=flow_id)) + with session_scope() as session: + messages_models = add_messagetables(messages_models, session) + return [Message(**message.model_dump()) for message in messages_models] except Exception as e: logger.exception(e) raise e +def add_messagetables(messages: list[MessageTable], session: Session): + for message in messages: + try: + session.add(message) + session.commit() + session.refresh(message) + except Exception as e: + logger.exception(e) + raise e + return [MessageRead.model_validate(message, from_attributes=True) for message in messages] + + def delete_messages(session_id: str): """ Delete messages from the monitor service based on the provided session ID. @@ -98,8 +100,13 @@ def delete_messages(session_id: str): Args: session_id (str): The session ID associated with the messages to delete. """ - monitor_service = get_monitor_service() - monitor_service.delete_messages_session(session_id) + with session_scope() as session: + session.exec( + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) + session.commit() def store_message( diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index c50dab880..c03c0cb71 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -1,5 +1,6 @@ from datetime import datetime, timezone from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional +from uuid import UUID from fastapi.encoders import jsonable_encoder from langchain_core.load import load @@ -31,7 +32,14 @@ class Message(Data): timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field( default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") ) - flow_id: Optional[str] = None + flow_id: Optional[str | UUID] = None + + @field_validator("flow_id", mode="before") + @classmethod + def validate_flow_id(cls, value): + if isinstance(value, UUID): + value = str(value) + return value @field_validator("files", mode="before") @classmethod diff --git a/src/backend/base/langflow/services/database/models/__init__.py b/src/backend/base/langflow/services/database/models/__init__.py index ce12a6fce..6e1f09fe3 100644 --- a/src/backend/base/langflow/services/database/models/__init__.py +++ b/src/backend/base/langflow/services/database/models/__init__.py @@ -1,7 +1,8 @@ from .api_key import ApiKey from .flow import Flow from .folder import Folder +from .message import MessageTable from .user import User from .variable import Variable -__all__ = ["Flow", "User", "ApiKey", "Variable", "Folder"] +__all__ = ["Flow", "User", "ApiKey", "Variable", "Folder", "MessageTable"] diff --git a/src/backend/base/langflow/services/database/models/flow/model.py b/src/backend/base/langflow/services/database/models/flow/model.py index 624ea0543..6d3e4aea8 100644 --- a/src/backend/base/langflow/services/database/models/flow/model.py +++ b/src/backend/base/langflow/services/database/models/flow/model.py @@ -3,7 +3,7 @@ import re import warnings from datetime import datetime, timezone -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from uuid import UUID, uuid4 import emoji @@ -17,6 +17,7 @@ from langflow.schema import Data if TYPE_CHECKING: from langflow.services.database.models.folder import Folder + from langflow.services.database.models.message import MessageTable from langflow.services.database.models.user import User @@ -141,6 +142,7 @@ class Flow(FlowBase, table=True): user: "User" = Relationship(back_populates="flows") folder_id: Optional[UUID] = Field(default=None, foreign_key="folder.id", nullable=True, index=True) folder: Optional["Folder"] = Relationship(back_populates="flows") + messages: List["MessageTable"] = Relationship(back_populates="flow") def to_data(self): serialized = self.model_dump() diff --git a/src/backend/base/langflow/services/database/models/message/__init__.py b/src/backend/base/langflow/services/database/models/message/__init__.py new file mode 100644 index 000000000..8cfb2ff4f --- /dev/null +++ b/src/backend/base/langflow/services/database/models/message/__init__.py @@ -0,0 +1,3 @@ +from .model import MessageTable, MessageCreate, MessageRead, MessageUpdate + +__all__ = ["MessageTable", "MessageCreate", "MessageRead", "MessageUpdate"] diff --git a/src/backend/base/langflow/services/database/models/message/model.py b/src/backend/base/langflow/services/database/models/message/model.py new file mode 100644 index 000000000..5a775d3d9 --- /dev/null +++ b/src/backend/base/langflow/services/database/models/message/model.py @@ -0,0 +1,74 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING, List, Optional +from uuid import UUID, uuid4 + +from pydantic import field_validator +from sqlmodel import JSON, Column, Field, Relationship, SQLModel + +if TYPE_CHECKING: + from langflow.schema.message import Message + from langflow.services.database.models.flow.model import Flow + + +class MessageBase(SQLModel): + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + sender: str + sender_name: str + session_id: str + text: str + files: list[str] = Field(default_factory=list) + + @field_validator("files", mode="before") + @classmethod + def validate_files(cls, value): + if not value: + value = [] + return value + + @classmethod + def from_message(cls, message: "Message", flow_id: str | None = None): + # first check if the record has all the required fields + if message.text is None or not message.sender or not message.sender_name: + raise ValueError("The message does not have the required fields (text, sender, sender_name).") + if isinstance(message.timestamp, str): + timestamp = datetime.fromisoformat(message.timestamp) + else: + timestamp = message.timestamp + return cls( + sender=message.sender, + sender_name=message.sender_name, + text=message.text, + session_id=message.session_id, + files=message.files or [], + timestamp=timestamp, + flow_id=flow_id, + ) + + +class MessageTable(MessageBase, table=True): + __tablename__ = "message" + id: UUID = Field(default_factory=uuid4, primary_key=True) + flow_id: Optional[UUID] = Field(default=None, foreign_key="flow.id") + flow: "Flow" = Relationship(back_populates="messages") + files: List[str] = Field(sa_column=Column(JSON)) + + # Needed for Column(JSON) + class Config: + arbitrary_types_allowed = True + + +class MessageRead(MessageBase): + id: UUID + flow_id: Optional[UUID] = Field() + + +class MessageCreate(MessageBase): + pass + + +class MessageUpdate(SQLModel): + text: Optional[str] = None + sender: Optional[str] = None + sender_name: Optional[str] = None + session_id: Optional[str] = None + files: Optional[list[str]] = None diff --git a/src/backend/base/langflow/services/monitor/schema.py b/src/backend/base/langflow/services/monitor/schema.py index 2294678fe..eeea846a1 100644 --- a/src/backend/base/langflow/services/monitor/schema.py +++ b/src/backend/base/langflow/services/monitor/schema.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timezone from typing import Any, Optional +from uuid import UUID from pydantic import BaseModel, Field, field_serializer, field_validator @@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel): class MessageModel(DefaultModel): - index: Optional[int] = Field(default=None) - flow_id: Optional[str] = Field(default=None, alias="flow_id") + id: Optional[str | UUID] = Field(default=None) + flow_id: Optional[UUID] = Field(default=None) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) sender: str sender_name: str @@ -127,16 +128,7 @@ class MessageModel(DefaultModel): class MessageModelResponse(MessageModel): - index: Optional[int] = Field(default=None) - - @field_validator("index", mode="before") - def validate_id(cls, v): - if isinstance(v, float): - try: - return int(v) - except ValueError: - return None - return v + pass class MessageModelRequest(MessageModel): diff --git a/src/backend/base/langflow/services/monitor/service.py b/src/backend/base/langflow/services/monitor/service.py index 6b99b9760..d15d31329 100644 --- a/src/backend/base/langflow/services/monitor/service.py +++ b/src/backend/base/langflow/services/monitor/service.py @@ -3,14 +3,15 @@ from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union import duckdb -from langflow.services.base import Service -from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch from loguru import logger from platformdirs import user_cache_dir +from langflow.services.base import Service +from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch + if TYPE_CHECKING: - from langflow.services.settings.service import SettingsService from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel + from langflow.services.settings.service import SettingsService class MonitorService(Service): @@ -129,45 +130,6 @@ class MonitorService(Service): return self.exec_query(query, read_only=False) - def add_message(self, message: "MessageModel"): - self.add_row("messages", message) - - def get_messages( - self, - flow_id: Optional[str] = None, - sender: Optional[str] = None, - sender_name: Optional[str] = None, - session_id: Optional[str] = None, - order_by: Optional[str] = "timestamp", - order: Optional[str] = "DESC", - limit: Optional[int] = None, - ): - query = "SELECT index, flow_id, sender_name, sender, session_id, text, files, timestamp FROM messages" - conditions = [] - if sender: - conditions.append(f"sender = '{sender}'") - if sender_name: - conditions.append(f"sender_name = '{sender_name}'") - if session_id: - conditions.append(f"session_id = '{session_id}'") - if flow_id: - conditions.append(f"flow_id = '{flow_id}'") - - if conditions: - query += " WHERE " + " AND ".join(conditions) - - if order_by and order: - # Make sure the order is from newest to oldest - query += f" ORDER BY {order_by} {order.upper()}" - - if limit is not None: - query += f" LIMIT {limit}" - - with duckdb.connect(str(self.db_path), read_only=True) as conn: - df = conn.execute(query).df() - - return df - def get_transactions( self, source: Optional[str] = None, diff --git a/tests/test_messages_endpoints.py b/tests/test_messages_endpoints.py new file mode 100644 index 000000000..ee4021784 --- /dev/null +++ b/tests/test_messages_endpoints.py @@ -0,0 +1,76 @@ +from uuid import UUID + +import pytest +from fastapi.testclient import TestClient + +from langflow.memory import add_messagetables + +# Assuming you have these imports available +from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate +from langflow.services.database.models.message.model import MessageTable +from langflow.services.deps import session_scope + + +@pytest.fixture() +def created_message(): + with session_scope() as session: + message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") + messagetable = MessageTable.model_validate(message, from_attributes=True) + messagetables = add_messagetables([messagetable], session) + message_read = MessageRead.model_validate(messagetables[0], from_attributes=True) + return message_read + + +@pytest.fixture() +def created_messages(session): + with session_scope() as session: + messages = [ + MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), + ] + messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] + message_list = add_messagetables(messagetables, session) + + return message_list + + +def test_delete_messages(client: TestClient, created_messages, logged_in_headers): + response = client.request( + "DELETE", "api/v1/monitor/messages", json=[str(msg.id) for msg in created_messages], headers=logged_in_headers + ) + assert response.status_code == 204, response.text + assert response.reason_phrase == "No Content" + + +def test_update_message(client: TestClient, logged_in_headers, created_message): + message_id = created_message.id + message_update = MessageUpdate(text="Updated content") + response = client.put( + f"api/v1/monitor/messages/{message_id}", json=message_update.model_dump(), headers=logged_in_headers + ) + assert response.status_code == 200, response.text + updated_message = MessageRead(**response.json()) + assert updated_message.text == "Updated content" + + +def test_update_message_not_found(client: TestClient, logged_in_headers): + non_existent_id = UUID("00000000-0000-0000-0000-000000000000") + message_update = MessageUpdate(text="Updated content") + response = client.put( + f"api/v1/monitor/messages/{non_existent_id}", json=message_update.model_dump(), headers=logged_in_headers + ) + assert response.status_code == 404, response.text + assert response.json()["detail"] == "Message not found" + + +def test_delete_messages_session(client: TestClient, created_messages, logged_in_headers): + session_id = "session_id2" + response = client.delete(f"api/v1/monitor/messages/session/{session_id}", headers=logged_in_headers) + assert response.status_code == 204 + assert response.reason_phrase == "No Content" + + assert len(created_messages) == 3 + response = client.get("api/v1/monitor/messages", headers=logged_in_headers) + assert response.status_code == 200 + assert len(response.json()) == 0 diff --git a/tests/unit/test_messages.py b/tests/unit/test_messages.py new file mode 100644 index 000000000..198387db8 --- /dev/null +++ b/tests/unit/test_messages.py @@ -0,0 +1,72 @@ +import pytest + +from langflow.memory import add_messages, add_messagetables, delete_messages, get_messages, store_message +from langflow.schema.message import Message + +# Assuming you have these imports available +from langflow.services.database.models.message import MessageCreate, MessageRead +from langflow.services.database.models.message.model import MessageTable +from langflow.services.deps import session_scope + + +@pytest.fixture() +def created_message(): + with session_scope() as session: + message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") + messagetable = MessageTable.model_validate(message, from_attributes=True) + messagetables = add_messagetables([messagetable], session) + message_read = MessageRead.model_validate(messagetables[0], from_attributes=True) + return message_read + + +@pytest.fixture() +def created_messages(session): + with session_scope() as session: + messages = [ + MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), + ] + messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] + messagetables = add_messagetables(messagetables, session) + messages_read = [ + MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables + ] + return messages_read + + +def test_get_messages(session): + add_messages(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) + add_messages(Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2")) + messages = get_messages(sender="User", session_id="session_id2", limit=2) + assert len(messages) == 2 + assert messages[0].text == "Test message 1" + assert messages[1].text == "Test message 2" + + +def test_add_messages(session): + message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") + messages = add_messages(message) + assert len(messages) == 1 + assert messages[0].text == "New Test message" + + +def test_add_messagetables(session): + messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] + added_messages = add_messagetables(messages, session) + assert len(added_messages) == 1 + assert added_messages[0].text == "New Test message" + + +def test_delete_messages(session): + session_id = "session_id2" + delete_messages(session_id) + messages = session.query(MessageTable).filter(MessageTable.session_id == session_id).all() + assert len(messages) == 0 + + +def test_store_message(session): + message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id") + stored_messages = store_message(message) + assert len(stored_messages) == 1 + assert stored_messages[0].text == "Stored message"