diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index a5daa61e5..c5b4e48f4 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -107,6 +107,39 @@ async def update_message( raise HTTPException(status_code=500, detail=str(e)) +@router.patch("/messages/session/{old_session_id}", response_model=list[MessageResponse]) +async def update_session_id( + old_session_id: str, + new_session_id: str = Query(..., description="The new session ID to update to"), + session: Session = Depends(get_session), + current_user: User = Depends(get_current_active_user), +): + try: + # Get all messages with the old session ID + stmt = select(MessageTable).where(MessageTable.session_id == old_session_id) + messages = session.exec(stmt).all() + + if not messages: + raise HTTPException(status_code=404, detail="No messages found with the given session ID") + + # Update all messages with the new session ID + for message in messages: + message.session_id = new_session_id + + session.add_all(messages) + + session.commit() + message_responses = [] + for message in messages: + session.refresh(message) + message_responses.append(MessageResponse.model_validate(message, from_attributes=True)) + return message_responses + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.delete("/messages/session/{session_id}", status_code=204) async def delete_messages_session( session_id: str, diff --git a/src/backend/tests/unit/test_messages_endpoints.py b/src/backend/tests/unit/test_messages_endpoints.py index ee4021784..74d9badc1 100644 --- a/src/backend/tests/unit/test_messages_endpoints.py +++ b/src/backend/tests/unit/test_messages_endpoints.py @@ -74,3 +74,40 @@ def test_delete_messages_session(client: TestClient, created_messages, logged_in response = client.get("api/v1/monitor/messages", headers=logged_in_headers) assert response.status_code == 200 assert len(response.json()) == 0 + + +# Successfully update session ID for all messages with the old session ID +def test_successfully_update_session_id(client, session, logged_in_headers, created_messages): + old_session_id = "session_id2" + new_session_id = "new_session_id" + + response = client.patch( + f"api/v1/monitor/messages/session/{old_session_id}", + params={"new_session_id": new_session_id}, + headers=logged_in_headers, + ) + + assert response.status_code == 200, response.text + updated_messages = response.json() + assert len(updated_messages) == len(created_messages) + for message in updated_messages: + assert message["session_id"] == new_session_id + + response = client.get("api/v1/monitor/messages", headers=logged_in_headers, params={"session_id": new_session_id}) + assert response.status_code == 200 + assert len(response.json()) == len(created_messages) + for message in response.json(): + assert message["session_id"] == new_session_id + + +# No messages found with the given session ID +def test_no_messages_found_with_given_session_id(client, session, logged_in_headers): + old_session_id = "non_existent_session_id" + new_session_id = "new_session_id" + + response = client.patch( + f"/messages/session/{old_session_id}", params={"new_session_id": new_session_id}, headers=logged_in_headers + ) + + assert response.status_code == 404, response.text + assert response.json()["detail"] == "Not Found"