diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 459abcf51..83d85b719 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -2,6 +2,7 @@ from fastapi import ( APIRouter, Depends, HTTPException, + Query, WebSocket, WebSocketException, status, @@ -12,9 +13,11 @@ from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, St from langflow.services import service_manager, ServiceType from langflow.graph.graph.base import Graph -from langflow.services.auth.utils import get_current_active_user +from langflow.services.auth.utils import get_current_active_user, get_current_user +from langflow.services.utils import get_session from langflow.utils.logger import logger from cachetools import LRUCache +from sqlmodel import Session router = APIRouter(tags=["Chat"]) @@ -23,10 +26,16 @@ flow_data_store: LRUCache = LRUCache(maxsize=10) @router.websocket("/chat/{client_id}") async def chat( - client_id: str, websocket: WebSocket, current_user=Depends(get_current_active_user) + client_id: str, + websocket: WebSocket, + token: str = Query(...), + db: Session = Depends(get_session), ): """Websocket endpoint for chat.""" try: + user = await get_current_user(token, db) + if not user.is_active: + raise HTTPException(status_code=401, detail="Invalid token") chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) if client_id in chat_manager.in_memory_cache: await chat_manager.handle_websocket(client_id, websocket) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index dd668c287..16f9eff05 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,13 +1,16 @@ from fastapi import WebSocketDisconnect +from fastapi.testclient import TestClient # from langflow.services.chat.manager import ChatManager import pytest -def test_init_build(client): +def test_init_build(client, active_user, logged_in_headers): response = client.post( - "api/v1/build/init/test", json={"id": "test", "data": {"key": "value"}} + "api/v1/build/init/test", + json={"id": "test", "data": {"key": "value"}}, + headers=logged_in_headers, ) assert response.status_code == 201 assert response.json() == {"flowId": "test"} @@ -24,10 +27,12 @@ def test_init_build(client): # assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -def test_websocket_endpoint(client): +def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers): + # Assuming your websocket_endpoint uses chat_manager which caches data from stream_build + access_token = logged_in_headers["Authorization"].split(" ")[1] with pytest.raises(WebSocketDisconnect): with client.websocket_connect( - "api/v1/chat/non_existing_client_id" + f"api/v1/chat/non_existing_client_id?token={access_token}" ) as websocket: websocket.send_json({"type": "test"}) data = websocket.receive_json()