From c30b40c4b48393324c7ecf17ee9878c533b786ce Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 23 Jun 2024 22:08:01 -0300 Subject: [PATCH] refactor: Update messages endpoints to use database table --- src/backend/base/langflow/api/v1/monitor.py | 67 ++++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index a99c86bf8..c4e595f63 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -1,15 +1,12 @@ from typing import List, Optional - from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import delete +from sqlmodel import Session, select -from langflow.services.deps import get_monitor_service -from langflow.services.monitor.schema import ( - MessageModelRequest, - MessageModelResponse, - TransactionModelResponse, - VertexBuildMapModel, -) +from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate +from langflow.services.deps import get_monitor_service, get_session +from langflow.services.monitor.schema import MessageModelResponse, TransactionModelResponse, VertexBuildMapModel from langflow.services.monitor.service import MonitorService router = APIRouter(prefix="/monitor", tags=["Monitor"]) @@ -52,18 +49,23 @@ async def get_messages( sender: Optional[str] = Query(None), sender_name: Optional[str] = Query(None), order_by: Optional[str] = Query("timestamp"), - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - df = monitor_service.get_messages( - flow_id=flow_id, - sender=sender, - sender_name=sender_name, - session_id=session_id, - order_by=order_by, - ) - dicts = df.to_dict(orient="records") - return [MessageModelResponse(**d) for d in dicts] + stmt = select(MessageTable) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if order_by: + col = getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + messages = session.exec(stmt) + return [MessageModelResponse(**d) for d in messages] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -71,26 +73,29 @@ async def get_messages( @router.delete("/messages", status_code=204) async def delete_messages( message_ids: List[int], - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - monitor_service.delete_messages(message_ids=message_ids) + session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) + return {"message": "Messages deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/messages/{message_id}", response_model=MessageModelResponse) +@router.post("/messages/{message_id}", response_model=MessageRead) async def update_message( message_id: int, - message: MessageModelRequest, - monitor_service: MonitorService = Depends(get_monitor_service), + message: MessageUpdate, + session: Session = Depends(get_session), ): try: - message_dict = message.model_dump(exclude_none=True) - message_dict.pop("index", None) - monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore - return MessageModelResponse(index=message_id, **message_dict) - + db_message = session.get(MessageTable, message_id) + message_dict = message.model_dump(exclude_unset=True) + db_message.sqlmodel_update(message_dict) + session.add(db_message) + session.commit() + session.refresh(db_message) + return db_message except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -98,10 +103,12 @@ async def update_message( @router.delete("/messages/session/{session_id}", status_code=204) async def delete_messages_session( session_id: str, - monitor_service: MonitorService = Depends(get_monitor_service), + session: Session = Depends(get_session), ): try: - monitor_service.delete_messages_session(session_id=session_id) + session.exec(delete(MessageTable).where(MessageTable.session_id == session_id)) + session.commit() + return {"message": "Messages deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e))