From 4e1710bcc7d0ba766e3ce558928e8ab7c0edd0a7 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 25 Aug 2023 17:01:32 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20?= =?UTF-8?q?import=20statement=20for=20Query=20from=20fastapi=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20stateme?= =?UTF-8?q?nt=20for=20Session=20from=20sqlmodel=20=F0=9F=90=9B=20fix(chat.?= =?UTF-8?q?py):=20add=20missing=20import=20statement=20for=20get=5Fsession?= =?UTF-8?q?=20from=20langflow.services.utils=20=F0=9F=90=9B=20fix(chat.py)?= =?UTF-8?q?:=20add=20missing=20import=20statement=20for=20get=5Fcurrent=5F?= =?UTF-8?q?user=20from=20langflow.services.auth.utils=20=F0=9F=90=9B=20fix?= =?UTF-8?q?(chat.py):=20add=20missing=20import=20statement=20for=20HTTPExc?= =?UTF-8?q?eption=20from=20fastapi=20=F0=9F=90=9B=20fix(chat.py):=20add=20?= =?UTF-8?q?missing=20import=20statement=20for=20get=5Fcurrent=5Factive=5Fu?= =?UTF-8?q?ser=20from=20langflow.services.auth.utils=20=F0=9F=90=9B=20fix(?= =?UTF-8?q?chat.py):=20add=20missing=20import=20statement=20for=20WebSocke?= =?UTF-8?q?t=20from=20fastapi=20=F0=9F=90=9B=20fix(chat.py):=20add=20missi?= =?UTF-8?q?ng=20import=20statement=20for=20WebSocketException=20from=20fas?= =?UTF-8?q?tapi=20=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20?= =?UTF-8?q?statement=20for=20status=20from=20fastapi=20=F0=9F=90=9B=20fix(?= =?UTF-8?q?chat.py):=20add=20missing=20import=20statement=20for=20APIRoute?= =?UTF-8?q?r=20from=20fastapi=20=F0=9F=90=9B=20fix(chat.py):=20add=20missi?= =?UTF-8?q?ng=20import=20statement=20for=20Depends=20from=20fastapi=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20stateme?= =?UTF-8?q?nt=20for=20HTTPException=20from=20fastapi=20=F0=9F=90=9B=20fix(?= =?UTF-8?q?chat.py):=20add=20missing=20import=20statement=20for=20get=5Fcu?= =?UTF-8?q?rrent=5Factive=5Fuser=20from=20langflow.services.auth.utils=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20stateme?= =?UTF-8?q?nt=20for=20get=5Fcurrent=5Fuser=20from=20langflow.services.auth?= =?UTF-8?q?.utils=20=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import?= =?UTF-8?q?=20statement=20for=20get=5Fsession=20from=20langflow.services.u?= =?UTF-8?q?tils=20=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20?= =?UTF-8?q?statement=20for=20Session=20from=20sqlmodel=20=F0=9F=90=9B=20fi?= =?UTF-8?q?x(chat.py):=20add=20missing=20import=20statement=20for=20HTTPEx?= =?UTF-8?q?ception=20from=20fastapi=20=F0=9F=90=9B=20fix(chat.py):=20add?= =?UTF-8?q?=20missing=20import=20statement=20for=20get=5Fcurrent=5Fuser=20?= =?UTF-8?q?from=20langflow.services.auth.utils=20=F0=9F=90=9B=20fix(chat.p?= =?UTF-8?q?y):=20add=20missing=20import=20statement=20for=20HTTPException?= =?UTF-8?q?=20from=20fastapi=20=F0=9F=90=9B=20fix(chat.py):=20add=20missin?= =?UTF-8?q?g=20import=20statement=20for=20get=5Fcurrent=5Fuser=20from=20la?= =?UTF-8?q?ngflow.services.auth.utils=20=F0=9F=90=9B=20fix(chat.py):=20add?= =?UTF-8?q?=20missing=20import=20statement=20for=20HTTPException=20from=20?= =?UTF-8?q?fastapi=20=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import?= =?UTF-8?q?=20statement=20for=20get=5Fcurrent=5Fuser=20from=20langflow.ser?= =?UTF-8?q?vices.auth.utils=20=F0=9F=90=9B=20fix(chat.py):=20add=20missing?= =?UTF-8?q?=20import=20statement=20for=20HTTPException=20from=20fastapi=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20stateme?= =?UTF-8?q?nt=20for=20get=5Fcurrent=5Fuser=20from=20langflow.services.auth?= =?UTF-8?q?.utils=20=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import?= =?UTF-8?q?=20statement=20for=20HTTPException=20from=20fastapi=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20missing=20import=20stateme?= =?UTF-8?q?nt=20for=20get=5Fcurrent=5Fuser=20from=20langflow?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/chat.py | 13 +++++++++++-- tests/test_websocket.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 459abcf51..83d85b719 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -2,6 +2,7 @@ from fastapi import ( APIRouter, Depends, HTTPException, + Query, WebSocket, WebSocketException, status, @@ -12,9 +13,11 @@ from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, St from langflow.services import service_manager, ServiceType from langflow.graph.graph.base import Graph -from langflow.services.auth.utils import get_current_active_user +from langflow.services.auth.utils import get_current_active_user, get_current_user +from langflow.services.utils import get_session from langflow.utils.logger import logger from cachetools import LRUCache +from sqlmodel import Session router = APIRouter(tags=["Chat"]) @@ -23,10 +26,16 @@ flow_data_store: LRUCache = LRUCache(maxsize=10) @router.websocket("/chat/{client_id}") async def chat( - client_id: str, websocket: WebSocket, current_user=Depends(get_current_active_user) + client_id: str, + websocket: WebSocket, + token: str = Query(...), + db: Session = Depends(get_session), ): """Websocket endpoint for chat.""" try: + user = await get_current_user(token, db) + if not user.is_active: + raise HTTPException(status_code=401, detail="Invalid token") chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) if client_id in chat_manager.in_memory_cache: await chat_manager.handle_websocket(client_id, websocket) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index dd668c287..16f9eff05 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,13 +1,16 @@ from fastapi import WebSocketDisconnect +from fastapi.testclient import TestClient # from langflow.services.chat.manager import ChatManager import pytest -def test_init_build(client): +def test_init_build(client, active_user, logged_in_headers): response = client.post( - "api/v1/build/init/test", json={"id": "test", "data": {"key": "value"}} + "api/v1/build/init/test", + json={"id": "test", "data": {"key": "value"}}, + headers=logged_in_headers, ) assert response.status_code == 201 assert response.json() == {"flowId": "test"} @@ -24,10 +27,12 @@ def test_init_build(client): # assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -def test_websocket_endpoint(client): +def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers): + # Assuming your websocket_endpoint uses chat_manager which caches data from stream_build + access_token = logged_in_headers["Authorization"].split(" ")[1] with pytest.raises(WebSocketDisconnect): with client.websocket_connect( - "api/v1/chat/non_existing_client_id" + f"api/v1/chat/non_existing_client_id?token={access_token}" ) as websocket: websocket.send_json({"type": "test"}) data = websocket.receive_json()