From ef949838e26e97166605f976ac079c85433e86c8 Mon Sep 17 00:00:00 2001 From: Maryam Abdoli Date: Fri, 3 Nov 2023 13:01:28 -0400 Subject: [PATCH] fix lints and also the bug in the get_current_user --- src/backend/langflow/api/v1/chat.py | 7 +++---- src/backend/langflow/services/auth/utils.py | 17 ++++++++--------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 0f277fbb7..0647ef71f 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -2,7 +2,6 @@ from fastapi import ( APIRouter, Depends, HTTPException, - Query, WebSocket, WebSocketException, status, @@ -11,8 +10,9 @@ from fastapi.responses import StreamingResponse from langflow.api.utils import build_input_keys_response from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData +from langflow.services.database.models.user.user import User from langflow.graph.graph.base import Graph -from langflow.services.auth.utils import get_current_active_user, get_current_user +from langflow.services.auth.utils import get_current_active_user from langflow.services.cache.utils import update_build_status from loguru import logger from langflow.services.getters import get_chat_service, get_session, get_cache_service @@ -28,14 +28,13 @@ router = APIRouter(tags=["Chat"]) async def chat( client_id: str, websocket: WebSocket, - token: str = Query(...), db: Session = Depends(get_session), chat_service: "ChatService" = Depends(get_chat_service), + user: User = Depends(get_current_active_user), ): """Websocket endpoint for chat.""" try: await websocket.accept() - user = await get_current_user(token, db) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" diff --git a/src/backend/langflow/services/auth/utils.py b/src/backend/langflow/services/auth/utils.py index 9881b09f2..7cc91c117 100644 --- a/src/backend/langflow/services/auth/utils.py +++ b/src/backend/langflow/services/auth/utils.py @@ -15,7 +15,7 @@ from langflow.services.database.models.user.crud import ( from langflow.services.getters import get_session, get_settings_service from sqlmodel import Session -oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login") +oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) API_KEY_NAME = "x-api-key" @@ -74,23 +74,22 @@ async def get_current_user( header_param: str = Security(api_key_header), db: Session = Depends(get_session), ) -> User: - try: + if token: return await get_current_user_by_jwt(token, db) - except HTTPException as exc: + else: if not query_param and not header_param: - raise exc + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="An API key must be passed as query or header", + ) user = await api_key_security(query_param, header_param, db) if user: return user + raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid or missing API key", ) - except Exception as exc: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Internal server error: {exc}", - ) async def get_current_user_by_jwt(