🔒 chore(auth.py): add support for retrieving current user and current active user

🔒 chore(auth.py): add support for creating user tokens with user id

🔒 chore(auth.py): add support for updating last_login_at when creating user tokens

🔒 chore(auth.py): add support for updating user information

🔒 chore(login.py): add support for retrieving current active user in login route

🐛 fix(users.py): fix import order and remove unused imports to improve code readability
 feat(users.py): add support for pagination in read_all_users endpoint to retrieve a list of users with pagination
🔧 refactor(users.py): rename update_user function to patch_user for better semantics and consistency
🔧 refactor(users.py): refactor update_user function to use update_user function from models module for better code organization and reusability
 feat(users.py): add support for creating a superuser for testing purposes in add_super_user_for_testing_purposes_delete_me_before_merge_into_dev endpoint
This commit is contained in:
gustavoschaedler 2023-08-14 21:18:47 +01:00
commit 501d7399a8
5 changed files with 140 additions and 96 deletions

View file

@ -1,3 +1,4 @@
from uuid import UUID
from typing import Annotated
from jose import JWTError, jwt
from sqlalchemy.orm import Session
@ -7,8 +8,12 @@ from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta, timezone
from langflow.services.utils import get_session
from langflow.database.models.token import TokenData
from langflow.database.models.user import User, get_user_by_username
from langflow.database.models.user import (
User,
get_user_by_id,
get_user_by_username,
update_user_last_login_at,
)
# TODO: Move to env - JUST FOR TEST!!!!!
@ -21,6 +26,38 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if user_id is None or token_type:
raise credentials_exception
except JWTError as e:
raise credentials_exception from e
user = get_user_by_id(db, user_id) # type: ignore
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
@ -38,19 +75,22 @@ def create_token(data: dict, expires_delta: timedelta):
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_user_tokens(username: str) -> dict:
def create_user_tokens(user_id: UUID, db: Session = Depends(get_session)) -> dict:
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_token(
data={"sub": username},
data={"sub": str(user_id)},
expires_delta=access_token_expires,
)
refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
refresh_token = create_token(
data={"sub": username, "type": "rf"},
data={"sub": str(user_id), "type": "rf"},
expires_delta=refresh_token_expires,
)
# Update: last_login_at
update_user_last_login_at(user_id, db)
return {
"access_token": access_token,
"refresh_token": refresh_token,
@ -61,15 +101,15 @@ def create_user_tokens(username: str) -> dict:
def create_refresh_token(refresh_token: str):
try:
payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") # type: ignore
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if username is None or token_type is None:
if user_id is None or token_type is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
return create_user_tokens(username)
return create_user_tokens(user_id)
except JWTError as e:
raise HTTPException(
@ -78,41 +118,10 @@ def create_refresh_token(refresh_token: str):
) from e
def authenticate_user(db: Session, username: str, password: str):
def authenticate_user(
username: str, password: str, db: Session = Depends(get_session)
) -> User | None:
if user := get_user_by_username(db, username):
return user if verify_password(password, user.password) else False
return user if verify_password(password, user.password) else None
else:
return False
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if username is None or token_type:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError as e:
raise credentials_exception from e
user = get_user_by_username(db, token_data.username) # type: ignore
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if current_user.is_disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
return None

View file

@ -5,7 +5,3 @@ class Token(BaseModel):
access_token: str
refresh_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None

View file

@ -1,41 +1,52 @@
from datetime import datetime
from sqlalchemy.orm import Session
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
from sqlmodel import Field
from uuid import UUID, uuid4
from pydantic import BaseModel
from typing import Optional, List
from sqlalchemy.orm import Session
from datetime import timezone, datetime
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, Depends
from langflow.services.utils import get_session
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
class User(SQLModelSerializable, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
username: str = Field(index=True, unique=True)
password: str = Field()
is_disabled: bool = Field(default=False)
is_active: bool = Field(default=False)
is_superuser: bool = Field(default=False)
create_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
last_login_at: Optional[datetime] = Field()
class UserAddModel(SQLModel):
username: str = Field()
password: str = Field()
is_disabled: bool = Field(default=False)
is_superuser: bool = Field(default=False)
class UserListModel(SQLModel):
id: UUID = Field(default_factory=uuid4)
username: str = Field()
is_disabled: bool = Field()
is_active: bool = Field()
is_superuser: bool = Field()
create_at: datetime = Field()
updated_at: datetime = Field()
last_login_at: Optional[datetime] = Field()
class UserPatchModel(SQLModel):
username: str = Field()
is_disabled: bool = Field()
is_superuser: bool = Field()
username: Optional[str] = Field()
is_active: Optional[bool] = Field()
is_superuser: Optional[bool] = Field()
last_login_at: Optional[datetime] = Field()
class UsersResponse(BaseModel):
total_count: int
users: List[UserListModel]
def get_user_by_username(db: Session, username: str) -> User:
@ -46,3 +57,38 @@ def get_user_by_username(db: Session, username: str) -> User:
def get_user_by_id(db: Session, id: UUID) -> User:
db_user = db.query(User).filter(User.id == id).first()
return User.from_orm(db_user) if db_user else None # type: ignore
def update_user(
user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session)
) -> User:
user_db = get_user_by_username(db, user.username) # type: ignore
if user_db and user_db.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
try:
user_data = user.dict(exclude_unset=True)
for key, value in user_data.items():
setattr(user_db, key, value)
user_db.updated_at = datetime.now(timezone.utc)
user_db = db.merge(user_db)
db.commit()
if db.identity_key(instance=user_db) is not None:
db.refresh(user_db)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e
return user_db
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore
return update_user(user_id, user_data, db)

View file

@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from langflow.services.utils import get_session
from langflow.database.models.token import Token
from langflow.auth.auth import (
@ -15,10 +16,12 @@ router = APIRouter()
@router.post("/login", response_model=Token)
async def login_to_get_access_token(
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_session)
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_session),
# _: Session = Depends(get_current_active_user)
):
if user := authenticate_user(db, form_data.username, form_data.password):
return create_user_tokens(user.username)
if user := authenticate_user(form_data.username, form_data.password, db):
return create_user_tokens(user_id=user.id, db=db)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -1,10 +1,10 @@
from typing import List
from uuid import UUID
from datetime import timezone, datetime
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from fastapi import APIRouter, Depends, HTTPException
from langflow.services.utils import get_session
from langflow.auth.auth import get_current_active_user, get_password_hash
@ -13,8 +13,8 @@ from langflow.database.models.user import (
UserAddModel,
UserListModel,
UserPatchModel,
get_user_by_id,
get_user_by_username,
UsersResponse,
update_user,
)
router = APIRouter(tags=["Login"])
@ -23,7 +23,6 @@ router = APIRouter(tags=["Login"])
@router.post("/user", response_model=UserListModel)
def add_user(
user: UserAddModel,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> User:
"""
@ -32,6 +31,7 @@ def add_user(
new_user = User(**user.dict())
try:
new_user.password = get_password_hash(user.password)
db.add(new_user)
db.commit()
db.refresh(new_user)
@ -50,24 +50,30 @@ def read_current_user(current_user: User = Depends(get_current_active_user)) ->
return current_user
@router.get("/users", response_model=List[UserListModel])
@router.get("/users", response_model=UsersResponse)
def read_all_users(
skip: int = 0,
limit: int = 10,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> List[UserListModel]:
) -> UsersResponse:
"""
Retrieve a list of users from the database with pagination.
"""
query = select(User).offset(skip).limit(limit)
users = db.execute(query).fetchall()
return [UserListModel(**dict(user.User)) for user in users]
count_query = select(func.count()).select_from(User) # type: ignore
total_count = db.execute(count_query).scalar()
return UsersResponse(
total_count=total_count, # type: ignore
users=[UserListModel(**dict(user.User)) for user in users],
)
@router.patch("/user/{user_id}", response_model=UserListModel)
def update_user(
def patch_user(
user_id: UUID,
user: UserPatchModel,
_: Session = Depends(get_current_active_user),
@ -76,29 +82,7 @@ def update_user(
"""
Update an existing user's data.
"""
user_db = get_user_by_username(db, user.username)
if user_db and user_db.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
try:
user_data = user.dict(exclude_unset=True)
for key, value in user_data.items():
setattr(user_db, key, value)
user_db.updated_at = datetime.now(timezone.utc)
user_db = db.merge(user_db)
db.commit()
if db.identity_key(instance=user_db) is not None:
db.refresh(user_db)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e
return user_db
return update_user(user_id, user, db)
@router.delete("/user/{user_id}")
@ -129,7 +113,13 @@ def add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(
Add a superuser for testing purposes.
(This should be removed in production)
"""
new_user = User(username="superuser", password="12345", is_superuser=True)
new_user = User(
username="superuser",
password="12345",
is_active=True,
is_superuser=True,
last_login_at=None,
)
try:
new_user.password = get_password_hash(new_user.password)