refactor: Update auth utils to include token type in create_token function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 14:39:17 -03:00
commit 8f0ca52e9c

View file

@ -7,14 +7,15 @@ from cryptography.fernet import Fernet
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
from jose import JWTError, jwt
from loguru import logger
from sqlmodel import Session
from starlette.websockets import WebSocket
from langflow.services.database.models.api_key.crud import check_key
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_settings_service
from loguru import logger
from sqlmodel import Session
from starlette.websockets import WebSocket
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
@ -119,7 +120,7 @@ async def get_current_user_by_jwt(
headers={"WWW-Authenticate": "Bearer"},
)
if user_id is None or token_type:
if user_id is None or token_type is None:
logger.info(f"Invalid token payload. Token type: {token_type}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -231,7 +232,7 @@ def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created")
access_token_expires_longterm = timedelta(days=365)
access_token = create_token(
data={"sub": str(super_user.id)},
data={"sub": str(super_user.id), "type": "access"},
expires_delta=access_token_expires_longterm,
)
@ -247,7 +248,7 @@ def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID
def create_user_api_key(user_id: UUID) -> dict:
access_token = create_token(
data={"sub": str(user_id), "role": "api_key"},
data={"sub": str(user_id), "type": "api_key"},
expires_delta=timedelta(days=365 * 2),
)
@ -267,13 +268,13 @@ def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), update
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
access_token = create_token(
data={"sub": str(user_id)},
data={"sub": str(user_id), "type": "access"},
expires_delta=access_token_expires,
)
refresh_token_expires = timedelta(seconds=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_SECONDS)
refresh_token = create_token(
data={"sub": str(user_id), "type": "rf"},
data={"sub": str(user_id), "type": "refresh"},
expires_delta=refresh_token_expires,
)
@ -302,13 +303,13 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
)
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if user_id is None or token_type is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
return create_user_tokens(user_id, db)
except JWTError as e:
logger.error(f"JWT decoding error: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",