fix: Refactor monitor.py messages endpoints

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-25 12:14:17 -03:00
commit 609c0d34c0

View file

@ -3,9 +3,11 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import delete
from sqlmodel import Session, select
from sqlmodel import Session, col, select
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_monitor_service, get_session
from langflow.services.monitor.schema import MessageModelResponse, TransactionModelResponse, VertexBuildMapModel
from langflow.services.monitor.service import MonitorService
@ -66,39 +68,42 @@ async def get_messages(
col = getattr(MessageTable, order_by).asc()
stmt = stmt.order_by(col)
messages = session.exec(stmt)
return [MessageModelResponse(**d) for d in messages]
return [MessageModelResponse.model_validate(d, from_attributes=True) for d in messages]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/messages", status_code=204)
async def delete_messages(
message_ids: List[int],
message_ids: List[UUID],
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
):
try:
session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids)))
return {"message": "Messages deleted successfully"}
session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/messages/{message_id}", response_model=MessageRead)
@router.put("/messages/{message_id}", response_model=MessageRead)
async def update_message(
message_id: UUID,
message: MessageUpdate,
session: Session = Depends(get_session),
user: User = Depends(get_current_active_user),
):
try:
db_message = session.get(MessageTable, message_id)
if not db_message:
raise HTTPException(status_code=404, detail="Message not found")
message_dict = message.model_dump(exclude_unset=True)
message_dict = message.model_dump(exclude_unset=True, exclude_none=True)
db_message.sqlmodel_update(message_dict)
session.add(db_message)
session.commit()
session.refresh(db_message)
return db_message
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@ -109,7 +114,11 @@ async def delete_messages_session(
session: Session = Depends(get_session),
):
try:
session.exec(delete(MessageTable).where(MessageTable.session_id == session_id))
session.exec(
delete(MessageTable)
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
session.commit()
return {"message": "Messages deleted successfully"}
except Exception as e:
@ -147,4 +156,3 @@ async def get_transactions(
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))