From 2b0765d41b35e35fb4f92f41d84da29e0156d485 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 23 Jun 2024 21:51:38 -0300 Subject: [PATCH 01/24] feat: Add message table to the database This commit adds a new table called "message" to the database. The table includes columns for timestamp, sender, sender_name, session_id, text, id, flow_id, and files. The "message" table is created using Alembic migration. This addition allows for storing and retrieving messages in the application. --- .../d066bfd22890_add_message_table.py | 52 ++++++++++++++++ .../base/langflow/base/models/model.py | 6 +- .../services/database/models/__init__.py | 3 +- .../services/database/models/flow/model.py | 4 +- .../database/models/message/__init__.py | 3 + .../services/database/models/message/model.py | 62 +++++++++++++++++++ 6 files changed, 127 insertions(+), 3 deletions(-) create mode 100644 src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py create mode 100644 src/backend/base/langflow/services/database/models/message/__init__.py create mode 100644 src/backend/base/langflow/services/database/models/message/model.py diff --git a/src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py b/src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py new file mode 100644 index 000000000..dd63398b7 --- /dev/null +++ b/src/backend/base/langflow/alembic/versions/d066bfd22890_add_message_table.py @@ -0,0 +1,52 @@ +"""Add message table + +Revision ID: 325180f0c4e1 +Revises: 631faacf5da2 +Create Date: 2024-06-23 21:29:28.220100 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +from langflow.utils import migration + +# revision identifiers, used by Alembic. +revision: str = "325180f0c4e1" +down_revision: Union[str, None] = "631faacf5da2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + if not migration.table_exists("message", conn): + op.create_table( + "message", + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("sender", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("sender_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("session_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("flow_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("files", sa.JSON(), nullable=True), + sa.ForeignKeyConstraint( + ["flow_id"], + ["flow.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + conn = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + if migration.table_exists("message", conn): + op.drop_table("message") + # ### end Alembic commands ### diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index 81f124fd7..bea9c88d8 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -119,7 +119,11 @@ class LCModelComponent(Component): return status_message def get_chat_result( - self, runnable: LanguageModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None + self, + runnable: LanguageModel, + stream: bool, + input_value: str | Message, + system_message: Optional[str] = None, ): messages: list[Union[BaseMessage]] = [] if not input_value and not system_message: diff --git a/src/backend/base/langflow/services/database/models/__init__.py b/src/backend/base/langflow/services/database/models/__init__.py index ce12a6fce..6e1f09fe3 100644 --- a/src/backend/base/langflow/services/database/models/__init__.py +++ b/src/backend/base/langflow/services/database/models/__init__.py @@ -1,7 +1,8 @@ from .api_key import ApiKey from .flow import Flow from .folder import Folder +from .message import MessageTable from .user import User from .variable import Variable -__all__ = ["Flow", "User", "ApiKey", "Variable", "Folder"] +__all__ = ["Flow", "User", "ApiKey", "Variable", "Folder", "MessageTable"] diff --git a/src/backend/base/langflow/services/database/models/flow/model.py b/src/backend/base/langflow/services/database/models/flow/model.py index 624ea0543..6d3e4aea8 100644 --- a/src/backend/base/langflow/services/database/models/flow/model.py +++ b/src/backend/base/langflow/services/database/models/flow/model.py @@ -3,7 +3,7 @@ import re import warnings from datetime import datetime, timezone -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from uuid import UUID, uuid4 import emoji @@ -17,6 +17,7 @@ from langflow.schema import Data if TYPE_CHECKING: from langflow.services.database.models.folder import Folder + from langflow.services.database.models.message import MessageTable from langflow.services.database.models.user import User @@ -141,6 +142,7 @@ class Flow(FlowBase, table=True): user: "User" = Relationship(back_populates="flows") folder_id: Optional[UUID] = Field(default=None, foreign_key="folder.id", nullable=True, index=True) folder: Optional["Folder"] = Relationship(back_populates="flows") + messages: List["MessageTable"] = Relationship(back_populates="flow") def to_data(self): serialized = self.model_dump() diff --git a/src/backend/base/langflow/services/database/models/message/__init__.py b/src/backend/base/langflow/services/database/models/message/__init__.py new file mode 100644 index 000000000..8cfb2ff4f --- /dev/null +++ b/src/backend/base/langflow/services/database/models/message/__init__.py @@ -0,0 +1,3 @@ +from .model import MessageTable, MessageCreate, MessageRead, MessageUpdate + +__all__ = ["MessageTable", "MessageCreate", "MessageRead", "MessageUpdate"] diff --git a/src/backend/base/langflow/services/database/models/message/model.py b/src/backend/base/langflow/services/database/models/message/model.py new file mode 100644 index 000000000..87da2ef7d --- /dev/null +++ b/src/backend/base/langflow/services/database/models/message/model.py @@ -0,0 +1,62 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING, List, Optional +from uuid import UUID, uuid4 + +from sqlmodel import JSON, Column, Field, Relationship, SQLModel + +if TYPE_CHECKING: + from langflow.schema.message import Message + from langflow.services.database.models.flow.model import Flow + + +class MessageBase(SQLModel): + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + sender: str + sender_name: str + session_id: str + text: str + files: list[str] = Field(default_factory=list) + + @classmethod + def from_message(cls, message: "Message", flow_id: str | None = None): + # first check if the record has all the required fields + if message.text is None or not message.sender or not message.sender_name: + raise ValueError("The message does not have the required fields (text, sender, sender_name).") + return cls( + sender=message.sender, + sender_name=message.sender_name, + text=message.text, + session_id=message.session_id, + files=message.files or [], + timestamp=message.timestamp, + flow_id=flow_id, + ) + + +class MessageTable(MessageBase, table=True): + __tablename__ = "message" + id: UUID = Field(default_factory=uuid4, primary_key=True) + flow_id: Optional[UUID] = Field(default=None, foreign_key="flow.id") + flow: "Flow" = Relationship(back_populates="messages") + files: List[str] = Field(sa_column=Column(JSON)) + + # Needed for Column(JSON) + class Config: + arbitrary_types_allowed = True + + +class MessageRead(MessageBase): + id: UUID + flow_id: Optional[UUID] = Field() + + +class MessageCreate(MessageBase): + pass + + +class MessageUpdate(SQLModel): + text: Optional[str] = None + sender: Optional[str] = None + sender_name: Optional[str] = None + session_id: Optional[str] = None + files: Optional[list[str]] = None From c1c478e8c4ab2d1bce0a8de3f696ef39d31fb37f Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 23 Jun 2024 22:07:41 -0300 Subject: [PATCH 02/24] refactor: Update add_messages function to use database session This commit refactors the add_messages function in memory.py to use a database session for adding messages to the monitor service. Instead of directly calling the monitor_service.add_message method, the messages are now added using a session object. This change ensures that the messages are properly persisted in the database and improves the reliability of the application. --- src/backend/base/langflow/memory.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index e89682969..c824310d9 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -4,8 +4,8 @@ from typing import List, Optional from loguru import logger from langflow.schema.message import Message -from langflow.services.deps import get_monitor_service -from langflow.services.monitor.schema import MessageModel +from langflow.services.database.models.message.model import MessageTable +from langflow.services.deps import get_monitor_service, session_scope def get_messages( @@ -64,7 +64,6 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non Add a message to the monitor service. """ try: - monitor_service = get_monitor_service() if not isinstance(messages, list): messages = [messages] @@ -72,19 +71,20 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non types = ", ".join([str(type(message)) for message in messages]) raise ValueError(f"The messages must be instances of Message. Found: {types}") - messages_models: list[MessageModel] = [] + messages_models: list[MessageTable] = [] for msg in messages: - if not msg.timestamp: - msg.timestamp = monitor_service.get_timestamp() - messages_models.append(MessageModel.from_message(msg, flow_id=flow_id)) + messages_models.append(MessageTable.from_message(msg, flow_id=flow_id)) + with session_scope() as session: + for message_model in messages_models: + try: + session.add(message_model) + session.commit() + session.refresh(message_model) + except Exception as e: + logger.error(f"Error adding message to monitor service: {e}") + logger.exception(e) + raise e - for message_model in messages_models: - try: - monitor_service.add_message(message_model) - except Exception as e: - logger.error(f"Error adding message to monitor service: {e}") - logger.exception(e) - raise e return messages_models except Exception as e: logger.exception(e) From 3bb5f9d5e74c6ced2e99de930ba3abcb8aebb5ae Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 23 Jun 2024 22:08:01 -0300 Subject: [PATCH 03/24] refactor: Update messages endpoints to use database table --- src/backend/base/langflow/api/v1/monitor.py | 67 ++++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index a99c86bf8..c4e595f63 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -1,15 +1,12 @@ from typing import List, Optional - from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import delete +from sqlmodel import Session, select -from langflow.services.deps import get_monitor_service -from langflow.services.monitor.schema import ( - MessageModelRequest, - MessageModelResponse, - TransactionModelResponse, - VertexBuildMapModel, -) +from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate +from langflow.services.deps import get_monitor_service, get_session +from langflow.services.monitor.schema import MessageModelResponse, TransactionModelResponse, VertexBuildMapModel from langflow.services.monitor.service import MonitorService router = APIRouter(prefix="/monitor", tags=["Monitor"]) @@ -52,18 +49,23 @@ async def get_messages( sender: Optional[str] = Query(None), sender_name: Optional[str] = Query(None), order_by: Optional[str] = Query("timestamp"), - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - df = monitor_service.get_messages( - flow_id=flow_id, - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - ) - dicts = df.to_dict(orient="records") - return [MessageModelResponse(**d) for d in dicts] + stmt = select(MessageTable) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if order_by: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + messages = session.exec(stmt) + return [MessageModelResponse(**d) for d in messages] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -71,26 +73,29 @@ async def get_messages( @router.delete("/messages", status_code=204) async def delete_messages( message_ids: List[int], - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - monitor_service.delete_messages(message_ids=message_ids) + session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) + return {"message": "Messages deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/messages/{message_id}", response_model=MessageModelResponse) +@router.post("/messages/{message_id}", response_model=MessageRead) async def update_message( message_id: int, - message: MessageModelRequest, - monitor_service: MonitorService = Depends(get_monitor_service), + message: MessageUpdate, + session: Session = Depends(get_session), ): try: - message_dict = message.model_dump(exclude_none=True) - message_dict.pop("index", None) - monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore - return MessageModelResponse(index=message_id, **message_dict) - + db_message = session.get(MessageTable, message_id) + message_dict = message.model_dump(exclude_unset=True) + db_message.sqlmodel_update(message_dict) + session.add(db_message) + session.commit() + session.refresh(db_message) + return db_message except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -98,10 +103,12 @@ async def update_message( @router.delete("/messages/session/{session_id}", status_code=204) async def delete_messages_session( session_id: str, - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - monitor_service.delete_messages_session(session_id=session_id) + session.exec(delete(MessageTable).where(MessageTable.session_id == session_id)) + session.commit() + return {"message": "Messages deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From 35d80fe97854bb8e3334f7099f9270f5193b8d17 Mon Sep 17 00:00:00 2001 From: anovazzi1 Date: Mon, 24 Jun 2024 17:33:20 -0300 Subject: [PATCH 04/24] fix fetch data to work even with autologin true --- src/frontend/src/contexts/authContext.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/frontend/src/contexts/authContext.tsx b/src/frontend/src/contexts/authContext.tsx index cfdbdfb4f..ab845963a 100644 --- a/src/frontend/src/contexts/authContext.tsx +++ b/src/frontend/src/contexts/authContext.tsx @@ -12,6 +12,8 @@ import { useGlobalVariablesStore } from "../stores/globalVariablesStore/globalVa import { useStoreStore } from "../stores/storeStore"; import { Users } from "../types/api"; import { AuthContextType } from "../types/contexts/auth"; +import { useGlobalVariablesStore } from "../stores/globalVariablesStore/globalVariables"; +import { useStoreStore } from "../stores/storeStore"; const initialValue: AuthContextType = { isAdmin: false, From 6d9e2e4350f246b5b110e6d1bdaa494793e25469 Mon Sep 17 00:00:00 2001 From: anovazzi1 Date: Mon, 24 Jun 2024 17:34:03 -0300 Subject: [PATCH 05/24] format code --- src/frontend/src/contexts/authContext.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/frontend/src/contexts/authContext.tsx b/src/frontend/src/contexts/authContext.tsx index ab845963a..50a0cf4e7 100644 --- a/src/frontend/src/contexts/authContext.tsx +++ b/src/frontend/src/contexts/authContext.tsx @@ -14,6 +14,8 @@ import { Users } from "../types/api"; import { AuthContextType } from "../types/contexts/auth"; import { useGlobalVariablesStore } from "../stores/globalVariablesStore/globalVariables"; import { useStoreStore } from "../stores/storeStore"; +import { Users } from "../types/api"; +import { AuthContextType } from "../types/contexts/auth"; const initialValue: AuthContextType = { isAdmin: false, From f5fec47f76c43fc6d915569a933563b7e94cd936 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 09:13:13 -0300 Subject: [PATCH 06/24] chore: Add error handling for message not found in update_message endpoint --- src/backend/base/langflow/api/v1/monitor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index c4e595f63..244c9a2fa 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -1,4 +1,5 @@ from typing import List, Optional +from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import delete @@ -84,12 +85,14 @@ async def delete_messages( @router.post("/messages/{message_id}", response_model=MessageRead) async def update_message( - message_id: int, + message_id: UUID, message: MessageUpdate, session: Session = Depends(get_session), ): try: db_message = session.get(MessageTable, message_id) + if not db_message: + raise HTTPException(status_code=404, detail="Message not found") message_dict = message.model_dump(exclude_unset=True) db_message.sqlmodel_update(message_dict) session.add(db_message) From 8fbf026476527094fe1f88688e199e8f51bf649e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 09:27:37 -0300 Subject: [PATCH 07/24] Fix issue with message timestamp conversion in MessageBase model --- .../base/langflow/services/database/models/message/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/backend/base/langflow/services/database/models/message/model.py b/src/backend/base/langflow/services/database/models/message/model.py index 87da2ef7d..26161ac8b 100644 --- a/src/backend/base/langflow/services/database/models/message/model.py +++ b/src/backend/base/langflow/services/database/models/message/model.py @@ -22,6 +22,8 @@ class MessageBase(SQLModel): # first check if the record has all the required fields if message.text is None or not message.sender or not message.sender_name: raise ValueError("The message does not have the required fields (text, sender, sender_name).") + if isinstance(message.timestamp, str): + message.timestamp = datetime.fromisoformat(message.timestamp) return cls( sender=message.sender, sender_name=message.sender_name, From 1105e61200d6f65c7e36d2875d15bb0dbb6947f5 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 12:13:56 -0300 Subject: [PATCH 08/24] Refactor add_messages function to separate message addition and commit logic --- src/backend/base/langflow/memory.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index c824310d9..2835645ba 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -2,6 +2,7 @@ import warnings from typing import List, Optional from loguru import logger +from sqlmodel import Session from langflow.schema.message import Message from langflow.services.database.models.message.model import MessageTable @@ -75,22 +76,25 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non for msg in messages: messages_models.append(MessageTable.from_message(msg, flow_id=flow_id)) with session_scope() as session: - for message_model in messages_models: - try: - session.add(message_model) - session.commit() - session.refresh(message_model) - except Exception as e: - logger.error(f"Error adding message to monitor service: {e}") - logger.exception(e) - raise e - + messages_models = add_messagetables(messages_models, session) return messages_models except Exception as e: logger.exception(e) raise e +def add_messagetables(messages: list[MessageTable], session: Session): + for message in messages: + try: + session.add(message) + session.commit() + session.refresh(message) + except Exception as e: + logger.exception(e) + raise e + return messages + + def delete_messages(session_id: str): """ Delete messages from the monitor service based on the provided session ID. From 988d2cf10b0205294acee7a88d5d16d7b3e3b4b6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 12:14:17 -0300 Subject: [PATCH 09/24] fix: Refactor monitor.py messages endpoints --- src/backend/base/langflow/api/v1/monitor.py | 26 ++++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index 244c9a2fa..e6ea4ed18 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -3,9 +3,11 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import delete -from sqlmodel import Session, select +from sqlmodel import Session, col, select +from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate +from langflow.services.database.models.user.model import User from langflow.services.deps import get_monitor_service, get_session from langflow.services.monitor.schema import MessageModelResponse, TransactionModelResponse, VertexBuildMapModel from langflow.services.monitor.service import MonitorService @@ -66,39 +68,42 @@ async def get_messages( col = getattr(MessageTable, order_by).asc() stmt = stmt.order_by(col) messages = session.exec(stmt) - return [MessageModelResponse(**d) for d in messages] + return [MessageModelResponse.model_validate(d, from_attributes=True) for d in messages] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.delete("/messages", status_code=204) async def delete_messages( - message_ids: List[int], + message_ids: List[UUID], session: Session = Depends(get_session), + current_user: User = Depends(get_current_active_user), ): try: - session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) - return {"message": "Messages deleted successfully"} + session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/messages/{message_id}", response_model=MessageRead) +@router.put("/messages/{message_id}", response_model=MessageRead) async def update_message( message_id: UUID, message: MessageUpdate, session: Session = Depends(get_session), + user: User = Depends(get_current_active_user), ): try: db_message = session.get(MessageTable, message_id) if not db_message: raise HTTPException(status_code=404, detail="Message not found") - message_dict = message.model_dump(exclude_unset=True) + message_dict = message.model_dump(exclude_unset=True, exclude_none=True) db_message.sqlmodel_update(message_dict) session.add(db_message) session.commit() session.refresh(db_message) return db_message + except HTTPException as e: + raise e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -109,7 +114,11 @@ async def delete_messages_session( session: Session = Depends(get_session), ): try: - session.exec(delete(MessageTable).where(MessageTable.session_id == session_id)) + session.exec( + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) session.commit() return {"message": "Messages deleted successfully"} except Exception as e: @@ -147,4 +156,3 @@ async def get_transactions( return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - raise HTTPException(status_code=500, detail=str(e)) From 115f6fbb1181d2ae4c0c010a0e827da260bca001 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 12:14:35 -0300 Subject: [PATCH 10/24] test: add messages tests --- tests/test_messages_endpoints.py | 78 ++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/test_messages_endpoints.py diff --git a/tests/test_messages_endpoints.py b/tests/test_messages_endpoints.py new file mode 100644 index 000000000..62fd0c941 --- /dev/null +++ b/tests/test_messages_endpoints.py @@ -0,0 +1,78 @@ +from uuid import UUID + +import pytest +from fastapi.testclient import TestClient + +from langflow.memory import add_messagetables + +# Assuming you have these imports available +from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate +from langflow.services.database.models.message.model import MessageTable +from langflow.services.deps import session_scope + + +@pytest.fixture() +def created_message(): + with session_scope() as session: + message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") + messagetable = MessageTable.model_validate(message, from_attributes=True) + messagetables = add_messagetables([messagetable], session) + message_read = MessageRead.model_validate(messagetables[0], from_attributes=True) + return message_read + + +@pytest.fixture() +def created_messages(session): + with session_scope() as session: + messages = [ + MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), + ] + messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] + messagetables = add_messagetables(messagetables, session) + messages_read = [ + MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables + ] + return messages_read + + +def test_delete_messages(client: TestClient, created_messages, logged_in_headers): + response = client.request( + "DELETE", "api/v1/monitor/messages", json=[str(msg.id) for msg in created_messages], headers=logged_in_headers + ) + assert response.status_code == 204, response.text + assert response.reason_phrase == "No Content" + + +def test_update_message(client: TestClient, logged_in_headers, created_message): + message_id = created_message.id + message_update = MessageUpdate(text="Updated content") + response = client.put( + f"api/v1/monitor/messages/{message_id}", json=message_update.model_dump(), headers=logged_in_headers + ) + assert response.status_code == 200, response.text + updated_message = MessageRead(**response.json()) + assert updated_message.text == "Updated content" + + +def test_update_message_not_found(client: TestClient, logged_in_headers): + non_existent_id = UUID("00000000-0000-0000-0000-000000000000") + message_update = MessageUpdate(text="Updated content") + response = client.put( + f"api/v1/monitor/messages/{non_existent_id}", json=message_update.model_dump(), headers=logged_in_headers + ) + assert response.status_code == 404, response.text + assert response.json()["detail"] == "Message not found" + + +def test_delete_messages_session(client: TestClient, created_messages, logged_in_headers): + session_id = "session_id2" + response = client.delete(f"api/v1/monitor/messages/session/{session_id}", headers=logged_in_headers) + assert response.status_code == 204 + assert response.reason_phrase == "No Content" + + assert len(created_messages) == 3 + response = client.get("api/v1/monitor/messages", headers=logged_in_headers) + assert response.status_code == 200 + assert len(response.json()) == 0 From cd73904095a40a6cffdb93b8eac6c357d74fc566 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 25 Jun 2024 19:18:41 +0000 Subject: [PATCH 11/24] Apply Prettier formatting --- src/frontend/src/contexts/authContext.tsx | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/frontend/src/contexts/authContext.tsx b/src/frontend/src/contexts/authContext.tsx index 50a0cf4e7..0c5906a1b 100644 --- a/src/frontend/src/contexts/authContext.tsx +++ b/src/frontend/src/contexts/authContext.tsx @@ -8,14 +8,13 @@ import { } from "../controllers/API"; import useAlertStore from "../stores/alertStore"; import { useFolderStore } from "../stores/foldersStore"; -import { useGlobalVariablesStore } from "../stores/globalVariablesStore/globalVariables"; -import { useStoreStore } from "../stores/storeStore"; -import { Users } from "../types/api"; -import { AuthContextType } from "../types/contexts/auth"; -import { useGlobalVariablesStore } from "../stores/globalVariablesStore/globalVariables"; -import { useStoreStore } from "../stores/storeStore"; -import { Users } from "../types/api"; -import { AuthContextType } from "../types/contexts/auth"; +import { + useGlobalVariablesStore, + useGlobalVariablesStore, +} from "../stores/globalVariablesStore/globalVariables"; +import { useStoreStore, useStoreStore } from "../stores/storeStore"; +import { Users, Users } from "../types/api"; +import { AuthContextType, AuthContextType } from "../types/contexts/auth"; const initialValue: AuthContextType = { isAdmin: false, From c1df05f295e6b57c83c94497c3589acc08ff8ef4 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 16:54:05 -0300 Subject: [PATCH 12/24] refactor: Add test for adding messages --- tests/test_messages_endpoints.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_messages_endpoints.py b/tests/test_messages_endpoints.py index 62fd0c941..163ce610b 100644 --- a/tests/test_messages_endpoints.py +++ b/tests/test_messages_endpoints.py @@ -1,9 +1,12 @@ +import uuid from uuid import UUID import pytest from fastapi.testclient import TestClient +from sqlmodel import select -from langflow.memory import add_messagetables +from langflow.memory import add_messages, add_messagetables +from langflow.schema.message import Message # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate @@ -37,6 +40,21 @@ def created_messages(session): return messages_read +def test_add_message(session, flow): + session_id = str(uuid.uuid4()) + message = Message(text="Test message", sender="User", sender_name="User", session_id=session_id) + messages = add_messages([message], flow.id) + + with session_scope() as session: + message = session.exec(select(MessageTable).where(MessageTable.session_id == session_id)).first() + assert message is not None + assert len(messages) == 1 + assert message.text == "Test message" + assert message.sender == "User" + assert message.sender_name == "User" + assert message.session_id == session_id + + def test_delete_messages(client: TestClient, created_messages, logged_in_headers): response = client.request( "DELETE", "api/v1/monitor/messages", json=[str(msg.id) for msg in created_messages], headers=logged_in_headers From 22609eac7087404e59a4238f767760d162ab0faf Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 17:03:41 -0300 Subject: [PATCH 13/24] chore: Fix issue with message timestamp conversion in MessageBase model --- .../base/langflow/services/database/models/message/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/backend/base/langflow/services/database/models/message/model.py b/src/backend/base/langflow/services/database/models/message/model.py index 26161ac8b..ed36af3c9 100644 --- a/src/backend/base/langflow/services/database/models/message/model.py +++ b/src/backend/base/langflow/services/database/models/message/model.py @@ -23,14 +23,16 @@ class MessageBase(SQLModel): if message.text is None or not message.sender or not message.sender_name: raise ValueError("The message does not have the required fields (text, sender, sender_name).") if isinstance(message.timestamp, str): - message.timestamp = datetime.fromisoformat(message.timestamp) + timestamp = datetime.fromisoformat(message.timestamp) + else: + timestamp = message.timestamp return cls( sender=message.sender, sender_name=message.sender_name, text=message.text, session_id=message.session_id, files=message.files or [], - timestamp=message.timestamp, + timestamp=timestamp, flow_id=flow_id, ) From 7be4f88678418a40898b1269b7217a3122125dc4 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 17:05:02 -0300 Subject: [PATCH 14/24] refactor: ignore type error --- src/backend/base/langflow/api/v1/monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index e6ea4ed18..f6c1fc4ac 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -114,7 +114,7 @@ async def delete_messages_session( session: Session = Depends(get_session), ): try: - session.exec( + session.exec( # type: ignore delete(MessageTable) .where(col(MessageTable.session_id) == session_id) .execution_options(synchronize_session="fetch") From a85737501600b13304eea2027601b51a00bfeee9 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 17:10:04 -0300 Subject: [PATCH 15/24] refactor: Remove unused imports in authContext.tsx --- src/frontend/src/contexts/authContext.tsx | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/frontend/src/contexts/authContext.tsx b/src/frontend/src/contexts/authContext.tsx index 0c5906a1b..cfdbdfb4f 100644 --- a/src/frontend/src/contexts/authContext.tsx +++ b/src/frontend/src/contexts/authContext.tsx @@ -8,13 +8,10 @@ import { } from "../controllers/API"; import useAlertStore from "../stores/alertStore"; import { useFolderStore } from "../stores/foldersStore"; -import { - useGlobalVariablesStore, - useGlobalVariablesStore, -} from "../stores/globalVariablesStore/globalVariables"; -import { useStoreStore, useStoreStore } from "../stores/storeStore"; -import { Users, Users } from "../types/api"; -import { AuthContextType, AuthContextType } from "../types/contexts/auth"; +import { useGlobalVariablesStore } from "../stores/globalVariablesStore/globalVariables"; +import { useStoreStore } from "../stores/storeStore"; +import { Users } from "../types/api"; +import { AuthContextType } from "../types/contexts/auth"; const initialValue: AuthContextType = { isAdmin: false, From a1afbf86cf131730a677d365f90b680a65ccccc7 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 18:26:13 -0300 Subject: [PATCH 16/24] Refactor curl command in python_test.yml to use the correct API endpoint for auto_login --- .github/workflows/python_test.yml | 2 +- .github/workflows/release.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python_test.yml b/.github/workflows/python_test.yml index 9e065c58f..8e37f40f7 100644 --- a/.github/workflows/python_test.yml +++ b/.github/workflows/python_test.yml @@ -45,7 +45,7 @@ jobs: poetry run python -m langflow run --host 127.0.0.1 --port 7860 --backend-only & SERVER_PID=$! # Wait for the server to start - timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) + timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 5; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) # Terminate the server kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1) sleep 10 # give the server some time to terminate diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9f84be063..f16dc35d3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -62,7 +62,7 @@ jobs: python -m langflow run --host 127.0.0.1 --port 7860 & SERVER_PID=$! # Wait for the server to start - timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) + timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) # Terminate the server kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1) sleep 10 # give the server some time to terminate @@ -124,7 +124,7 @@ jobs: python -m langflow run --host 127.0.0.1 --port 7860 & SERVER_PID=$! # Wait for the server to start - timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) + timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1) # Terminate the server kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1) sleep 10 # give the server some time to terminate From f6d7bcfd88b99e33c39cf8b1658a6f251e0c67d5 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 18:30:08 -0300 Subject: [PATCH 17/24] chore: Update lint-js.yml to trigger checks on requested actions --- .github/workflows/lint-js.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint-js.yml b/.github/workflows/lint-js.yml index aafd15f4a..b528ee8ed 100644 --- a/.github/workflows/lint-js.yml +++ b/.github/workflows/lint-js.yml @@ -5,7 +5,7 @@ on: paths: - "src/frontend/**" merge_group: - branches: [dev] + types: [checks_requested] env: NODE_VERSION: "21" From 3204af12638c979c149bce13b39ad83683aba21a Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 19:37:14 -0300 Subject: [PATCH 18/24] =?UTF-8?q?=F0=9F=93=9D=20(memory.py):=20Refactor=20?= =?UTF-8?q?get=5Fmessages=20function=20to=20use=20SQLAlchemy=20select=20st?= =?UTF-8?q?atement=20for=20better=20performance=20and=20readability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (memory.py): Refactor delete_messages function to use SQLAlchemy delete statement for better performance and readability 📝 (monitor/schema.py): Update MessageModel to use UUID type for id and flow_id for consistency and better data handling --- src/backend/base/langflow/memory.py | 65 ++++++++++++------- .../base/langflow/services/monitor/schema.py | 16 ++--- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 2835645ba..7b314fb2c 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -1,12 +1,14 @@ import warnings from typing import List, Optional +from uuid import UUID from loguru import logger -from sqlmodel import Session +from sqlalchemy import delete +from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import get_monitor_service, session_scope +from langflow.services.database.models.message.model import MessageRead, MessageTable +from langflow.services.deps import session_scope def get_messages( @@ -15,6 +17,7 @@ def get_messages( session_id: Optional[str] = None, order_by: Optional[str] = "timestamp", order: Optional[str] = "DESC", + flow_id: Optional[UUID] = None, limit: Optional[int] = None, ): """ @@ -30,29 +33,36 @@ def get_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ - monitor_service = get_monitor_service() - messages_df = monitor_service.get_messages( - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - limit=limit, - order=order, - ) + with session_scope() as session: + stmt = select(MessageTable) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if order_by: + if order == "DESC": + col = getattr(MessageTable, order_by).desc() + else: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + if limit: + stmt = stmt.limit(limit) + messages = session.exec(stmt) + messages_read = [MessageRead.model_validate(d, from_attributes=True) for d in messages] messages: list[Message] = [] - # messages_df has a timestamp - # it gets the last 5 messages, for example - # but now they are ordered from most recent to least recent - # so we need to reverse the order - messages_df = messages_df[::-1] if order == "DESC" else messages_df - for row in messages_df.itertuples(): + + for msg_read in messages_read: msg = Message( - text=row.text, - sender=row.sender, - session_id=row.session_id, - sender_name=row.sender_name, - timestamp=row.timestamp, + text=msg_read.text, + sender=msg_read.sender, + session_id=msg_read.session_id, + sender_name=msg_read.sender_name, + timestamp=msg_read.timestamp, ) messages.append(msg) @@ -102,8 +112,13 @@ def delete_messages(session_id: str): Args: session_id (str): The session ID associated with the messages to delete. """ - monitor_service = get_monitor_service() - monitor_service.delete_messages_session(session_id) + with session_scope() as session: + session.exec( + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) + session.commit() def store_message( diff --git a/src/backend/base/langflow/services/monitor/schema.py b/src/backend/base/langflow/services/monitor/schema.py index 2294678fe..eeea846a1 100644 --- a/src/backend/base/langflow/services/monitor/schema.py +++ b/src/backend/base/langflow/services/monitor/schema.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timezone from typing import Any, Optional +from uuid import UUID from pydantic import BaseModel, Field, field_serializer, field_validator @@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel): class MessageModel(DefaultModel): - index: Optional[int] = Field(default=None) - flow_id: Optional[str] = Field(default=None, alias="flow_id") + id: Optional[str | UUID] = Field(default=None) + flow_id: Optional[UUID] = Field(default=None) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) sender: str sender_name: str @@ -127,16 +128,7 @@ class MessageModel(DefaultModel): class MessageModelResponse(MessageModel): - index: Optional[int] = Field(default=None) - - @field_validator("index", mode="before") - def validate_id(cls, v): - if isinstance(v, float): - try: - return int(v) - except ValueError: - return None - return v + pass class MessageModelRequest(MessageModel): From 76fffe5990a80fbc12ad3357ba4874b714b8876d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 19:37:36 -0300 Subject: [PATCH 19/24] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20(service.py):=20remo?= =?UTF-8?q?ve=20unused=20code=20related=20to=20adding=20and=20retrieving?= =?UTF-8?q?=20messages=20in=20MonitorService?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../base/langflow/services/monitor/service.py | 46 ++----------------- 1 file changed, 4 insertions(+), 42 deletions(-) diff --git a/src/backend/base/langflow/services/monitor/service.py b/src/backend/base/langflow/services/monitor/service.py index 6b99b9760..d15d31329 100644 --- a/src/backend/base/langflow/services/monitor/service.py +++ b/src/backend/base/langflow/services/monitor/service.py @@ -3,14 +3,15 @@ from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union import duckdb -from langflow.services.base import Service -from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch from loguru import logger from platformdirs import user_cache_dir +from langflow.services.base import Service +from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch + if TYPE_CHECKING: - from langflow.services.settings.service import SettingsService from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel + from langflow.services.settings.service import SettingsService class MonitorService(Service): @@ -129,45 +130,6 @@ class MonitorService(Service): return self.exec_query(query, read_only=False) - def add_message(self, message: "MessageModel"): - self.add_row("messages", message) - - def get_messages( - self, - flow_id: Optional[str] = None, - sender: Optional[str] = None, - sender_name: Optional[str] = None, - session_id: Optional[str] = None, - order_by: Optional[str] = "timestamp", - order: Optional[str] = "DESC", - limit: Optional[int] = None, - ): - query = "SELECT index, flow_id, sender_name, sender, session_id, text, files, timestamp FROM messages" - conditions = [] - if sender: - conditions.append(f"sender = '{sender}'") - if sender_name: - conditions.append(f"sender_name = '{sender_name}'") - if session_id: - conditions.append(f"session_id = '{session_id}'") - if flow_id: - conditions.append(f"flow_id = '{flow_id}'") - - if conditions: - query += " WHERE " + " AND ".join(conditions) - - if order_by and order: - # Make sure the order is from newest to oldest - query += f" ORDER BY {order_by} {order.upper()}" - - if limit is not None: - query += f" LIMIT {limit}" - - with duckdb.connect(str(self.db_path), read_only=True) as conn: - df = conn.execute(query).df() - - return df - def get_transactions( self, source: Optional[str] = None, From 69f7a9a159ca3ac1d66d427f5bab6891b9f1b9de Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 19:50:44 -0300 Subject: [PATCH 20/24] refactor: Remove unused imports and test adding messages --- tests/test_messages_endpoints.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/tests/test_messages_endpoints.py b/tests/test_messages_endpoints.py index 163ce610b..62fd0c941 100644 --- a/tests/test_messages_endpoints.py +++ b/tests/test_messages_endpoints.py @@ -1,12 +1,9 @@ -import uuid from uuid import UUID import pytest from fastapi.testclient import TestClient -from sqlmodel import select -from langflow.memory import add_messages, add_messagetables -from langflow.schema.message import Message +from langflow.memory import add_messagetables # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate @@ -40,21 +37,6 @@ def created_messages(session): return messages_read -def test_add_message(session, flow): - session_id = str(uuid.uuid4()) - message = Message(text="Test message", sender="User", sender_name="User", session_id=session_id) - messages = add_messages([message], flow.id) - - with session_scope() as session: - message = session.exec(select(MessageTable).where(MessageTable.session_id == session_id)).first() - assert message is not None - assert len(messages) == 1 - assert message.text == "Test message" - assert message.sender == "User" - assert message.sender_name == "User" - assert message.session_id == session_id - - def test_delete_messages(client: TestClient, created_messages, logged_in_headers): response = client.request( "DELETE", "api/v1/monitor/messages", json=[str(msg.id) for msg in created_messages], headers=logged_in_headers From 0bfb702736dfc9be271e7abcdbc6c2c0b6546e73 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 19:50:56 -0300 Subject: [PATCH 21/24] refactor: Remove unused imports and optimize get_messages function --- src/backend/base/langflow/memory.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 7b314fb2c..b0115d9b2 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -7,7 +7,7 @@ from sqlalchemy import delete from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.database.models.message.model import MessageRead, MessageTable +from langflow.services.database.models.message.model import MessageTable from langflow.services.deps import session_scope @@ -33,6 +33,7 @@ def get_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ + messages_read: list[Message] = [] with session_scope() as session: stmt = select(MessageTable) if sender: @@ -52,22 +53,9 @@ def get_messages( if limit: stmt = stmt.limit(limit) messages = session.exec(stmt) - messages_read = [MessageRead.model_validate(d, from_attributes=True) for d in messages] + messages_read = [Message(**d.model_dump()) for d in messages] - messages: list[Message] = [] - - for msg_read in messages_read: - msg = Message( - text=msg_read.text, - sender=msg_read.sender, - session_id=msg_read.session_id, - sender_name=msg_read.sender_name, - timestamp=msg_read.timestamp, - ) - - messages.append(msg) - - return messages + return messages_read def add_messages(messages: Message | list[Message], flow_id: Optional[str] = None): @@ -102,7 +90,7 @@ def add_messagetables(messages: list[MessageTable], session: Session): except Exception as e: logger.exception(e) raise e - return messages + return [Message(**message.model_dump()) for message in messages] def delete_messages(session_id: str): From 2065dba21549c35bd1832fc07acb5b1f5311184f Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 19:51:12 -0300 Subject: [PATCH 22/24] =?UTF-8?q?=E2=9C=A8=20(test=5Fmessages.py):=20Add?= =?UTF-8?q?=20unit=20tests=20for=20message=20handling=20functions=20in=20l?= =?UTF-8?q?angflow=20module.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_messages.py | 72 +++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/unit/test_messages.py diff --git a/tests/unit/test_messages.py b/tests/unit/test_messages.py new file mode 100644 index 000000000..198387db8 --- /dev/null +++ b/tests/unit/test_messages.py @@ -0,0 +1,72 @@ +import pytest + +from langflow.memory import add_messages, add_messagetables, delete_messages, get_messages, store_message +from langflow.schema.message import Message + +# Assuming you have these imports available +from langflow.services.database.models.message import MessageCreate, MessageRead +from langflow.services.database.models.message.model import MessageTable +from langflow.services.deps import session_scope + + +@pytest.fixture() +def created_message(): + with session_scope() as session: + message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") + messagetable = MessageTable.model_validate(message, from_attributes=True) + messagetables = add_messagetables([messagetable], session) + message_read = MessageRead.model_validate(messagetables[0], from_attributes=True) + return message_read + + +@pytest.fixture() +def created_messages(session): + with session_scope() as session: + messages = [ + MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), + MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), + ] + messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] + messagetables = add_messagetables(messagetables, session) + messages_read = [ + MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables + ] + return messages_read + + +def test_get_messages(session): + add_messages(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) + add_messages(Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2")) + messages = get_messages(sender="User", session_id="session_id2", limit=2) + assert len(messages) == 2 + assert messages[0].text == "Test message 1" + assert messages[1].text == "Test message 2" + + +def test_add_messages(session): + message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") + messages = add_messages(message) + assert len(messages) == 1 + assert messages[0].text == "New Test message" + + +def test_add_messagetables(session): + messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] + added_messages = add_messagetables(messages, session) + assert len(added_messages) == 1 + assert added_messages[0].text == "New Test message" + + +def test_delete_messages(session): + session_id = "session_id2" + delete_messages(session_id) + messages = session.query(MessageTable).filter(MessageTable.session_id == session_id).all() + assert len(messages) == 0 + + +def test_store_message(session): + message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id") + stored_messages = store_message(message) + assert len(stored_messages) == 1 + assert stored_messages[0].text == "Stored message" From 12f35f298b5bda2c12b393926f22ee323bb16226 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 20:20:12 -0300 Subject: [PATCH 23/24] refactor: Update add_messages and add_messagetables functions to return Message objects --- src/backend/base/langflow/memory.py | 6 +++--- src/backend/base/langflow/schema/message.py | 10 +++++++++- tests/test_messages_endpoints.py | 8 +++----- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index b0115d9b2..d1793740d 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -7,7 +7,7 @@ from sqlalchemy import delete from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.database.models.message.model import MessageTable +from langflow.services.database.models.message.model import MessageRead, MessageTable from langflow.services.deps import session_scope @@ -75,7 +75,7 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non messages_models.append(MessageTable.from_message(msg, flow_id=flow_id)) with session_scope() as session: messages_models = add_messagetables(messages_models, session) - return messages_models + return [Message(**message.model_dump()) for message in messages_models] except Exception as e: logger.exception(e) raise e @@ -90,7 +90,7 @@ def add_messagetables(messages: list[MessageTable], session: Session): except Exception as e: logger.exception(e) raise e - return [Message(**message.model_dump()) for message in messages] + return [MessageRead.model_validate(message, from_attributes=True) for message in messages] def delete_messages(session_id: str): diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index c50dab880..c03c0cb71 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -1,5 +1,6 @@ from datetime import datetime, timezone from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional +from uuid import UUID from fastapi.encoders import jsonable_encoder from langchain_core.load import load @@ -31,7 +32,14 @@ class Message(Data): timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field( default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") ) - flow_id: Optional[str] = None + flow_id: Optional[str | UUID] = None + + @field_validator("flow_id", mode="before") + @classmethod + def validate_flow_id(cls, value): + if isinstance(value, UUID): + value = str(value) + return value @field_validator("files", mode="before") @classmethod diff --git a/tests/test_messages_endpoints.py b/tests/test_messages_endpoints.py index 62fd0c941..ee4021784 100644 --- a/tests/test_messages_endpoints.py +++ b/tests/test_messages_endpoints.py @@ -30,11 +30,9 @@ def created_messages(session): MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - messagetables = add_messagetables(messagetables, session) - messages_read = [ - MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables - ] - return messages_read + message_list = add_messagetables(messagetables, session) + + return message_list def test_delete_messages(client: TestClient, created_messages, logged_in_headers): From 5bdd035dc3dda910a81284f80d4267ac31deb350 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 21:48:50 -0300 Subject: [PATCH 24/24] refactor: Add field_validator for files in MessageBase model --- .../langflow/services/database/models/message/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/backend/base/langflow/services/database/models/message/model.py b/src/backend/base/langflow/services/database/models/message/model.py index ed36af3c9..5a775d3d9 100644 --- a/src/backend/base/langflow/services/database/models/message/model.py +++ b/src/backend/base/langflow/services/database/models/message/model.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, List, Optional from uuid import UUID, uuid4 +from pydantic import field_validator from sqlmodel import JSON, Column, Field, Relationship, SQLModel if TYPE_CHECKING: @@ -17,6 +18,13 @@ class MessageBase(SQLModel): text: str files: list[str] = Field(default_factory=list) + @field_validator("files", mode="before") + @classmethod + def validate_files(cls, value): + if not value: + value = [] + return value + @classmethod def from_message(cls, message: "Message", flow_id: str | None = None): # first check if the record has all the required fields