Merge branch 'authentication' into login
This commit is contained in:
commit
c30cb3e002
4 changed files with 176 additions and 80 deletions
|
|
@ -8,14 +8,14 @@ from datetime import datetime, timedelta, timezone
|
|||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.database.models.token import TokenData
|
||||
from langflow.database.models.user import get_user, User
|
||||
from langflow.database.models.user import User, get_user_by_username
|
||||
|
||||
|
||||
# TODO: Move to env - Test propose!!!!!
|
||||
# TODO: Move to env - JUST FOR TEST!!!!!
|
||||
SECRET_KEY = "698619adad2d916f1f32d264540976964b3c0d3828e0870a65add5800a8cc6b9"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES = 180
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES = 70
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
||||
|
|
@ -29,29 +29,57 @@ def get_password_hash(password):
|
|||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta = None): # type: ignore
|
||||
def create_token(data: dict, expires_delta: timedelta):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode["exp"] = expire
|
||||
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
def create_user_tokens(username: str) -> dict:
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_token(
|
||||
data={"sub": username},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
to_encode["exp"] = expire
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
|
||||
refresh_token = create_token(
|
||||
data={"sub": username, "type": "rf"},
|
||||
expires_delta=refresh_token_expires,
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
def create_refresh_token(refresh_token: str):
|
||||
try:
|
||||
payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
||||
if username is None or token_type is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
return create_user_tokens(username)
|
||||
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
) from e
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str):
|
||||
if user := get_user(db, username):
|
||||
if user := get_user_by_username(db, username):
|
||||
return user if verify_password(password, user.password) else False
|
||||
else:
|
||||
return False
|
||||
|
|
@ -68,13 +96,15 @@ async def get_current_user(
|
|||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub") # type: ignore
|
||||
if username is None:
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
||||
if username is None or token_type:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except JWTError as e:
|
||||
raise credentials_exception from e
|
||||
|
||||
user = get_user(db, token_data.username) # type: ignore
|
||||
user = get_user_by_username(db, token_data.username) # type: ignore
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -32,6 +32,17 @@ class UserListModel(SQLModel):
|
|||
updated_at: datetime = Field()
|
||||
|
||||
|
||||
def get_user(db: Session, username: str) -> User:
|
||||
class UserPatchModel(SQLModel):
|
||||
username: str = Field()
|
||||
is_disabled: bool = Field()
|
||||
is_superuser: bool = Field()
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> User:
|
||||
db_user = db.query(User).filter(User.username == username).first()
|
||||
return User.from_orm(db_user) if db_user else None # type: ignore
|
||||
|
||||
|
||||
def get_user_by_id(db: Session, id: UUID) -> User:
|
||||
db_user = db.query(User).filter(User.id == id).first()
|
||||
return User.from_orm(db_user) if db_user else None # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,48 +1,39 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.database.models.token import Token
|
||||
from langflow.auth.auth import (
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_user_tokens,
|
||||
create_refresh_token,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.database.models.user import User
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def create_user_token(user: User) -> dict:
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login_to_get_access_token(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_session)
|
||||
):
|
||||
if user := authenticate_user(db, form_data.username, form_data.password):
|
||||
return create_user_token(user)
|
||||
return create_user_tokens(user.username)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(token: str):
|
||||
if token:
|
||||
return create_refresh_token(token)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,52 @@
|
|||
from typing import List
|
||||
from sqlmodel import Session, select
|
||||
from uuid import UUID
|
||||
from datetime import timezone, datetime
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.auth.auth import get_current_active_user
|
||||
from langflow.database.models.user import UserAddModel, UserListModel, User
|
||||
|
||||
from passlib.context import CryptContext
|
||||
from langflow.auth.auth import get_current_active_user, get_password_hash
|
||||
from langflow.database.models.user import (
|
||||
User,
|
||||
UserAddModel,
|
||||
UserListModel,
|
||||
UserPatchModel,
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Login"])
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
return pwd_context.hash(password)
|
||||
@router.post("/user", response_model=UserListModel)
|
||||
def add_user(
|
||||
user: UserAddModel,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
"""
|
||||
Add a new user to the database.
|
||||
"""
|
||||
new_user = User(**user.dict())
|
||||
try:
|
||||
new_user.password = get_password_hash(user.password)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail="User exists") from e
|
||||
|
||||
return new_user
|
||||
|
||||
|
||||
@router.get("/user", response_model=UserListModel)
|
||||
def read_current_user(current_user: User = Depends(get_current_active_user)):
|
||||
def read_current_user(current_user: User = Depends(get_current_active_user)) -> User:
|
||||
"""
|
||||
Retrieve the current user's data.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
|
|
@ -28,52 +56,88 @@ def read_all_users(
|
|||
limit: int = 10,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
query = select(User)
|
||||
query = query.offset(skip).limit(limit)
|
||||
) -> List[UserListModel]:
|
||||
"""
|
||||
Retrieve a list of users from the database with pagination.
|
||||
"""
|
||||
query = select(User).offset(skip).limit(limit)
|
||||
users = db.execute(query).fetchall()
|
||||
|
||||
return db.execute(query).fetchall()
|
||||
return [UserListModel(**dict(user.User)) for user in users]
|
||||
|
||||
|
||||
@router.post("/user", response_model=User)
|
||||
def add_user(
|
||||
user: UserAddModel,
|
||||
@router.patch("/user/{user_id}", response_model=UserListModel)
|
||||
def update_user(
|
||||
user_id: UUID,
|
||||
user: UserPatchModel,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
new_user = User(**user.dict())
|
||||
try:
|
||||
new_user.password = get_password_hash(user.password)
|
||||
) -> User:
|
||||
"""
|
||||
Update an existing user's data.
|
||||
"""
|
||||
user_db = get_user_by_username(db, user.username)
|
||||
if user_db and user_db.id != user_id:
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
|
||||
db.add(new_user)
|
||||
user_db = get_user_by_id(db, user_id)
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
try:
|
||||
user_data = user.dict(exclude_unset=True)
|
||||
for key, value in user_data.items():
|
||||
setattr(user_db, key, value)
|
||||
|
||||
user_db.updated_at = datetime.now(timezone.utc)
|
||||
user_db = db.merge(user_db)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
if db.identity_key(instance=user_db) is not None:
|
||||
db.refresh(user_db)
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="User exists",
|
||||
) from e
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
return new_user
|
||||
return user_db
|
||||
|
||||
|
||||
# TODO: Remove - Just for testing purposes
|
||||
@router.delete("/user/{user_id}")
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a user from the database.
|
||||
"""
|
||||
user_db = db.query(User).filter(User.id == user_id).first()
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
db.delete(user_db)
|
||||
db.commit()
|
||||
|
||||
return {"detail": "User deleted"}
|
||||
|
||||
|
||||
# TODO: REMOVE - Just for testing purposes
|
||||
@router.post("/super_user", response_model=User)
|
||||
def add_super_user_to_testing_purposes(db: Session = Depends(get_session)):
|
||||
def add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
"""
|
||||
Add a superuser for testing purposes.
|
||||
(This should be removed in production)
|
||||
"""
|
||||
new_user = User(username="superuser", password="12345", is_superuser=True)
|
||||
|
||||
try:
|
||||
new_user.password = get_password_hash(new_user.password)
|
||||
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="User exists",
|
||||
) from e
|
||||
raise HTTPException(status_code=400, detail="User exists") from e
|
||||
|
||||
return new_user
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue