From 8f0ca52e9ceb56884a28be80ed49b2dbc12f14e3 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 20 Jun 2024 14:39:17 -0300 Subject: [PATCH] refactor: Update auth utils to include token type in create_token function --- .../base/langflow/services/auth/utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index 06836d21a..b0a35e474 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -7,14 +7,15 @@ from cryptography.fernet import Fernet from fastapi import Depends, HTTPException, Security, status from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer from jose import JWTError, jwt +from loguru import logger +from sqlmodel import Session +from starlette.websockets import WebSocket + from langflow.services.database.models.api_key.crud import check_key from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at from langflow.services.database.models.user.model import User from langflow.services.deps import get_session, get_settings_service -from loguru import logger -from sqlmodel import Session -from starlette.websockets import WebSocket oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) @@ -119,7 +120,7 @@ async def get_current_user_by_jwt( headers={"WWW-Authenticate": "Bearer"}, ) - if user_id is None or token_type: + if user_id is None or token_type is None: logger.info(f"Invalid token payload. Token type: {token_type}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -231,7 +232,7 @@ def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created") access_token_expires_longterm = timedelta(days=365) access_token = create_token( - data={"sub": str(super_user.id)}, + data={"sub": str(super_user.id), "type": "access"}, expires_delta=access_token_expires_longterm, ) @@ -247,7 +248,7 @@ def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID def create_user_api_key(user_id: UUID) -> dict: access_token = create_token( - data={"sub": str(user_id), "role": "api_key"}, + data={"sub": str(user_id), "type": "api_key"}, expires_delta=timedelta(days=365 * 2), ) @@ -267,13 +268,13 @@ def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), update access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS) access_token = create_token( - data={"sub": str(user_id)}, + data={"sub": str(user_id), "type": "access"}, expires_delta=access_token_expires, ) refresh_token_expires = timedelta(seconds=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_SECONDS) refresh_token = create_token( - data={"sub": str(user_id), "type": "rf"}, + data={"sub": str(user_id), "type": "refresh"}, expires_delta=refresh_token_expires, ) @@ -302,13 +303,13 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)) ) user_id: UUID = payload.get("sub") # type: ignore token_type: str = payload.get("type") # type: ignore - if user_id is None or token_type is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") return create_user_tokens(user_id, db) except JWTError as e: + logger.error(f"JWT decoding error: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token",