Merge branch 'authentication' into login

This commit is contained in:
Cristhian Zanforlin Lousa 2023-08-11 08:22:37 -03:00
commit c30cb3e002
4 changed files with 176 additions and 80 deletions

View file

@ -8,14 +8,14 @@ from datetime import datetime, timedelta, timezone
from langflow.services.utils import get_session
from langflow.database.models.token import TokenData
from langflow.database.models.user import get_user, User
from langflow.database.models.user import User, get_user_by_username
# TODO: Move to env - Test propose!!!!!
# TODO: Move to env - JUST FOR TEST!!!!!
SECRET_KEY = "698619adad2d916f1f32d264540976964b3c0d3828e0870a65add5800a8cc6b9"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
REFRESH_TOKEN_EXPIRE_MINUTES = 180
REFRESH_TOKEN_EXPIRE_MINUTES = 70
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
@ -29,29 +29,57 @@ def get_password_hash(password):
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: timedelta = None): # type: ignore
def create_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
)
expire = datetime.now(timezone.utc) + expires_delta
to_encode["exp"] = expire
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(data: dict):
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(
minutes=REFRESH_TOKEN_EXPIRE_MINUTES
def create_user_tokens(username: str) -> dict:
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_token(
data={"sub": username},
expires_delta=access_token_expires,
)
to_encode["exp"] = expire
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
refresh_token = create_token(
data={"sub": username, "type": "rf"},
expires_delta=refresh_token_expires,
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
def create_refresh_token(refresh_token: str):
try:
payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if username is None or token_type is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
return create_user_tokens(username)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
) from e
def authenticate_user(db: Session, username: str, password: str):
if user := get_user(db, username):
if user := get_user_by_username(db, username):
return user if verify_password(password, user.password) else False
else:
return False
@ -68,13 +96,15 @@ async def get_current_user(
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") # type: ignore
if username is None:
token_type: str = payload.get("type") # type: ignore
if username is None or token_type:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError as e:
raise credentials_exception from e
user = get_user(db, token_data.username) # type: ignore
user = get_user_by_username(db, token_data.username) # type: ignore
if user is None:
raise credentials_exception
return user

View file

@ -32,6 +32,17 @@ class UserListModel(SQLModel):
updated_at: datetime = Field()
def get_user(db: Session, username: str) -> User:
class UserPatchModel(SQLModel):
username: str = Field()
is_disabled: bool = Field()
is_superuser: bool = Field()
def get_user_by_username(db: Session, username: str) -> User:
db_user = db.query(User).filter(User.username == username).first()
return User.from_orm(db_user) if db_user else None # type: ignore
def get_user_by_id(db: Session, id: UUID) -> User:
db_user = db.query(User).filter(User.id == id).first()
return User.from_orm(db_user) if db_user else None # type: ignore

View file

@ -1,48 +1,39 @@
from datetime import timedelta
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from langflow.services.utils import get_session
from langflow.database.models.token import Token
from langflow.auth.auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
authenticate_user,
create_access_token,
create_user_tokens,
create_refresh_token,
)
from sqlalchemy.orm import Session
from langflow.services.utils import get_session
from langflow.database.models.user import User
router = APIRouter()
def create_user_token(user: User) -> dict:
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=access_token_expires,
)
refresh_token = create_refresh_token(data={"sub": user.username})
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
@router.post("/login", response_model=Token)
async def login_to_get_access_token(
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_session)
):
if user := authenticate_user(db, form_data.username, form_data.password):
return create_user_token(user)
return create_user_tokens(user.username)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
@router.post("/refresh")
async def refresh_token(token: str):
if token:
return create_refresh_token(token)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
headers={"WWW-Authenticate": "Bearer"},
)

View file

@ -1,24 +1,52 @@
from typing import List
from sqlmodel import Session, select
from uuid import UUID
from datetime import timezone, datetime
from sqlalchemy.exc import IntegrityError
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from langflow.services.utils import get_session
from langflow.auth.auth import get_current_active_user
from langflow.database.models.user import UserAddModel, UserListModel, User
from passlib.context import CryptContext
from langflow.auth.auth import get_current_active_user, get_password_hash
from langflow.database.models.user import (
User,
UserAddModel,
UserListModel,
UserPatchModel,
get_user_by_id,
get_user_by_username,
)
router = APIRouter(tags=["Login"])
def get_password_hash(password):
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.hash(password)
@router.post("/user", response_model=UserListModel)
def add_user(
user: UserAddModel,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> User:
"""
Add a new user to the database.
"""
new_user = User(**user.dict())
try:
new_user.password = get_password_hash(user.password)
db.add(new_user)
db.commit()
db.refresh(new_user)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail="User exists") from e
return new_user
@router.get("/user", response_model=UserListModel)
def read_current_user(current_user: User = Depends(get_current_active_user)):
def read_current_user(current_user: User = Depends(get_current_active_user)) -> User:
"""
Retrieve the current user's data.
"""
return current_user
@ -28,52 +56,88 @@ def read_all_users(
limit: int = 10,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
):
query = select(User)
query = query.offset(skip).limit(limit)
) -> List[UserListModel]:
"""
Retrieve a list of users from the database with pagination.
"""
query = select(User).offset(skip).limit(limit)
users = db.execute(query).fetchall()
return db.execute(query).fetchall()
return [UserListModel(**dict(user.User)) for user in users]
@router.post("/user", response_model=User)
def add_user(
user: UserAddModel,
@router.patch("/user/{user_id}", response_model=UserListModel)
def update_user(
user_id: UUID,
user: UserPatchModel,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
):
new_user = User(**user.dict())
try:
new_user.password = get_password_hash(user.password)
) -> User:
"""
Update an existing user's data.
"""
user_db = get_user_by_username(db, user.username)
if user_db and user_db.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
db.add(new_user)
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
try:
user_data = user.dict(exclude_unset=True)
for key, value in user_data.items():
setattr(user_db, key, value)
user_db.updated_at = datetime.now(timezone.utc)
user_db = db.merge(user_db)
db.commit()
db.refresh(new_user)
if db.identity_key(instance=user_db) is not None:
db.refresh(user_db)
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=400,
detail="User exists",
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
return new_user
return user_db
# TODO: Remove - Just for testing purposes
@router.delete("/user/{user_id}")
def delete_user(
user_id: UUID,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> dict:
"""
Delete a user from the database.
"""
user_db = db.query(User).filter(User.id == user_id).first()
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
db.delete(user_db)
db.commit()
return {"detail": "User deleted"}
# TODO: REMOVE - Just for testing purposes
@router.post("/super_user", response_model=User)
def add_super_user_to_testing_purposes(db: Session = Depends(get_session)):
def add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(
db: Session = Depends(get_session),
) -> User:
"""
Add a superuser for testing purposes.
(This should be removed in production)
"""
new_user = User(username="superuser", password="12345", is_superuser=True)
try:
new_user.password = get_password_hash(new_user.password)
db.add(new_user)
db.commit()
db.refresh(new_user)
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=400,
detail="User exists",
) from e
raise HTTPException(status_code=400, detail="User exists") from e
return new_user