diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 9fa1ec3a3..4c034016d 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -1,11 +1,11 @@ import time -from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketException, status +from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketException, status from fastapi.responses import StreamingResponse from langflow.api.utils import build_input_keys_response, format_elapsed_time from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData from langflow.graph.graph.base import Graph -from langflow.services.auth.utils import get_current_active_user, get_current_user_by_jwt +from langflow.services.auth.utils import get_current_active_user, get_current_user_for_websocket from langflow.services.cache.service import BaseCacheService from langflow.services.cache.utils import update_build_status from langflow.services.chat.service import ChatService @@ -20,17 +20,16 @@ router = APIRouter(tags=["Chat"]) async def chat( client_id: str, websocket: WebSocket, - token: str = Query(...), db: Session = Depends(get_session), chat_service: "ChatService" = Depends(get_chat_service), ): """Websocket endpoint for chat.""" try: - user = await get_current_user_by_jwt(token, db) + user = await get_current_user_for_websocket(websocket, db) await websocket.accept() if not user: await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") - if not user.is_active: + elif not user.is_active: await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") if client_id in chat_service.cache_service: diff --git a/src/backend/langflow/services/auth/utils.py b/src/backend/langflow/services/auth/utils.py index 912d1fbe8..e3090a09f 100644 --- a/src/backend/langflow/services/auth/utils.py +++ b/src/backend/langflow/services/auth/utils.py @@ -7,6 +7,7 @@ from fastapi import Depends, HTTPException, Security, status from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer from jose import JWTError, jwt from sqlmodel import Session +from starlette.websockets import WebSocket from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.models.api_key.crud import check_key @@ -130,6 +131,21 @@ async def get_current_user_by_jwt( return user +async def get_current_user_for_websocket( + websocket: WebSocket, + db: Session = Depends(get_session), + query_param: str = Security(api_key_query), +) -> Optional[User]: + token = websocket.query_params.get("token") + api_key = websocket.query_params.get("x-api-key") + if token: + return await get_current_user_by_jwt(token, db) + elif api_key: + return await api_key_security(api_key, query_param, db) + else: + return None + + def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user")