🔒 chore(auth.py): add support for retrieving current user and current active user
🔒 chore(auth.py): add support for creating user tokens with user id 🔒 chore(auth.py): add support for updating last_login_at when creating user tokens 🔒 chore(auth.py): add support for updating user information 🔒 chore(login.py): add support for retrieving current active user in login route 🐛 fix(users.py): fix import order and remove unused imports to improve code readability ✨ feat(users.py): add support for pagination in read_all_users endpoint to retrieve a list of users with pagination 🔧 refactor(users.py): rename update_user function to patch_user for better semantics and consistency 🔧 refactor(users.py): refactor update_user function to use update_user function from models module for better code organization and reusability ✨ feat(users.py): add support for creating a superuser for testing purposes in add_super_user_for_testing_purposes_delete_me_before_merge_into_dev endpoint
This commit is contained in:
parent
f1b2fea20f
commit
501d7399a8
5 changed files with 140 additions and 96 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from uuid import UUID
|
||||
from typing import Annotated
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -7,8 +8,12 @@ from fastapi import Depends, HTTPException, status
|
|||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.database.models.token import TokenData
|
||||
from langflow.database.models.user import User, get_user_by_username
|
||||
from langflow.database.models.user import (
|
||||
User,
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
update_user_last_login_at,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Move to env - JUST FOR TEST!!!!!
|
||||
|
|
@ -21,6 +26,38 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: UUID = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
||||
if user_id is None or token_type:
|
||||
raise credentials_exception
|
||||
except JWTError as e:
|
||||
raise credentials_exception from e
|
||||
|
||||
user = get_user_by_id(db, user_id) # type: ignore
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
|
@ -38,19 +75,22 @@ def create_token(data: dict, expires_delta: timedelta):
|
|||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_user_tokens(username: str) -> dict:
|
||||
def create_user_tokens(user_id: UUID, db: Session = Depends(get_session)) -> dict:
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_token(
|
||||
data={"sub": username},
|
||||
data={"sub": str(user_id)},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
|
||||
refresh_token = create_token(
|
||||
data={"sub": username, "type": "rf"},
|
||||
data={"sub": str(user_id), "type": "rf"},
|
||||
expires_delta=refresh_token_expires,
|
||||
)
|
||||
|
||||
# Update: last_login_at
|
||||
update_user_last_login_at(user_id, db)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
|
|
@ -61,15 +101,15 @@ def create_user_tokens(username: str) -> dict:
|
|||
def create_refresh_token(refresh_token: str):
|
||||
try:
|
||||
payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub") # type: ignore
|
||||
user_id: UUID = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
||||
if username is None or token_type is None:
|
||||
if user_id is None or token_type is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
return create_user_tokens(username)
|
||||
return create_user_tokens(user_id)
|
||||
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -78,41 +118,10 @@ def create_refresh_token(refresh_token: str):
|
|||
) from e
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str):
|
||||
def authenticate_user(
|
||||
username: str, password: str, db: Session = Depends(get_session)
|
||||
) -> User | None:
|
||||
if user := get_user_by_username(db, username):
|
||||
return user if verify_password(password, user.password) else False
|
||||
return user if verify_password(password, user.password) else None
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
||||
if username is None or token_type:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except JWTError as e:
|
||||
raise credentials_exception from e
|
||||
|
||||
user = get_user_by_username(db, token_data.username) # type: ignore
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
if current_user.is_disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,3 @@ class Token(BaseModel):
|
|||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,41 +1,52 @@
|
|||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
|
||||
from sqlmodel import Field
|
||||
from uuid import UUID, uuid4
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import timezone, datetime
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import HTTPException, Depends
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
|
||||
|
||||
|
||||
class User(SQLModelSerializable, table=True):
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
|
||||
username: str = Field(index=True, unique=True)
|
||||
password: str = Field()
|
||||
is_disabled: bool = Field(default=False)
|
||||
is_active: bool = Field(default=False)
|
||||
is_superuser: bool = Field(default=False)
|
||||
create_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UserAddModel(SQLModel):
|
||||
username: str = Field()
|
||||
password: str = Field()
|
||||
is_disabled: bool = Field(default=False)
|
||||
is_superuser: bool = Field(default=False)
|
||||
|
||||
|
||||
class UserListModel(SQLModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
username: str = Field()
|
||||
is_disabled: bool = Field()
|
||||
is_active: bool = Field()
|
||||
is_superuser: bool = Field()
|
||||
create_at: datetime = Field()
|
||||
updated_at: datetime = Field()
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UserPatchModel(SQLModel):
|
||||
username: str = Field()
|
||||
is_disabled: bool = Field()
|
||||
is_superuser: bool = Field()
|
||||
username: Optional[str] = Field()
|
||||
is_active: Optional[bool] = Field()
|
||||
is_superuser: Optional[bool] = Field()
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UsersResponse(BaseModel):
|
||||
total_count: int
|
||||
users: List[UserListModel]
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> User:
|
||||
|
|
@ -46,3 +57,38 @@ def get_user_by_username(db: Session, username: str) -> User:
|
|||
def get_user_by_id(db: Session, id: UUID) -> User:
|
||||
db_user = db.query(User).filter(User.id == id).first()
|
||||
return User.from_orm(db_user) if db_user else None # type: ignore
|
||||
|
||||
|
||||
def update_user(
|
||||
user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session)
|
||||
) -> User:
|
||||
user_db = get_user_by_username(db, user.username) # type: ignore
|
||||
if user_db and user_db.id != user_id:
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
|
||||
user_db = get_user_by_id(db, user_id)
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
try:
|
||||
user_data = user.dict(exclude_unset=True)
|
||||
for key, value in user_data.items():
|
||||
setattr(user_db, key, value)
|
||||
|
||||
user_db.updated_at = datetime.now(timezone.utc)
|
||||
user_db = db.merge(user_db)
|
||||
db.commit()
|
||||
if db.identity_key(instance=user_db) is not None:
|
||||
db.refresh(user_db)
|
||||
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
|
||||
user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore
|
||||
|
||||
return update_user(user_id, user_data, db)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from sqlalchemy.orm import Session
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.database.models.token import Token
|
||||
from langflow.auth.auth import (
|
||||
|
|
@ -15,10 +16,12 @@ router = APIRouter()
|
|||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login_to_get_access_token(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_session)
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_session),
|
||||
# _: Session = Depends(get_current_active_user)
|
||||
):
|
||||
if user := authenticate_user(db, form_data.username, form_data.password):
|
||||
return create_user_tokens(user.username)
|
||||
if user := authenticate_user(form_data.username, form_data.password, db):
|
||||
return create_user_tokens(user_id=user.id, db=db)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
from datetime import timezone, datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.auth.auth import get_current_active_user, get_password_hash
|
||||
|
|
@ -13,8 +13,8 @@ from langflow.database.models.user import (
|
|||
UserAddModel,
|
||||
UserListModel,
|
||||
UserPatchModel,
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
UsersResponse,
|
||||
update_user,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Login"])
|
||||
|
|
@ -23,7 +23,6 @@ router = APIRouter(tags=["Login"])
|
|||
@router.post("/user", response_model=UserListModel)
|
||||
def add_user(
|
||||
user: UserAddModel,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
"""
|
||||
|
|
@ -32,6 +31,7 @@ def add_user(
|
|||
new_user = User(**user.dict())
|
||||
try:
|
||||
new_user.password = get_password_hash(user.password)
|
||||
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
|
|
@ -50,24 +50,30 @@ def read_current_user(current_user: User = Depends(get_current_active_user)) ->
|
|||
return current_user
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserListModel])
|
||||
@router.get("/users", response_model=UsersResponse)
|
||||
def read_all_users(
|
||||
skip: int = 0,
|
||||
limit: int = 10,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
) -> List[UserListModel]:
|
||||
) -> UsersResponse:
|
||||
"""
|
||||
Retrieve a list of users from the database with pagination.
|
||||
"""
|
||||
query = select(User).offset(skip).limit(limit)
|
||||
users = db.execute(query).fetchall()
|
||||
|
||||
return [UserListModel(**dict(user.User)) for user in users]
|
||||
count_query = select(func.count()).select_from(User) # type: ignore
|
||||
total_count = db.execute(count_query).scalar()
|
||||
|
||||
return UsersResponse(
|
||||
total_count=total_count, # type: ignore
|
||||
users=[UserListModel(**dict(user.User)) for user in users],
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/user/{user_id}", response_model=UserListModel)
|
||||
def update_user(
|
||||
def patch_user(
|
||||
user_id: UUID,
|
||||
user: UserPatchModel,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
|
|
@ -76,29 +82,7 @@ def update_user(
|
|||
"""
|
||||
Update an existing user's data.
|
||||
"""
|
||||
user_db = get_user_by_username(db, user.username)
|
||||
if user_db and user_db.id != user_id:
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
|
||||
user_db = get_user_by_id(db, user_id)
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
try:
|
||||
user_data = user.dict(exclude_unset=True)
|
||||
for key, value in user_data.items():
|
||||
setattr(user_db, key, value)
|
||||
|
||||
user_db.updated_at = datetime.now(timezone.utc)
|
||||
user_db = db.merge(user_db)
|
||||
db.commit()
|
||||
if db.identity_key(instance=user_db) is not None:
|
||||
db.refresh(user_db)
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
return user_db
|
||||
return update_user(user_id, user, db)
|
||||
|
||||
|
||||
@router.delete("/user/{user_id}")
|
||||
|
|
@ -129,7 +113,13 @@ def add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(
|
|||
Add a superuser for testing purposes.
|
||||
(This should be removed in production)
|
||||
"""
|
||||
new_user = User(username="superuser", password="12345", is_superuser=True)
|
||||
new_user = User(
|
||||
username="superuser",
|
||||
password="12345",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
last_login_at=None,
|
||||
)
|
||||
|
||||
try:
|
||||
new_user.password = get_password_hash(new_user.password)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue