diff --git a/src/backend/langflow/auth/auth.py b/src/backend/langflow/auth/auth.py index 067184053..439e1017e 100644 --- a/src/backend/langflow/auth/auth.py +++ b/src/backend/langflow/auth/auth.py @@ -8,14 +8,14 @@ from datetime import datetime, timedelta, timezone from langflow.services.utils import get_session from langflow.database.models.token import TokenData -from langflow.database.models.user import get_user, User +from langflow.database.models.user import User, get_user_by_username -# TODO: Move to env - Test propose!!!!! +# TODO: Move to env - JUST FOR TEST!!!!! SECRET_KEY = "698619adad2d916f1f32d264540976964b3c0d3828e0870a65add5800a8cc6b9" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 -REFRESH_TOKEN_EXPIRE_MINUTES = 180 +REFRESH_TOKEN_EXPIRE_MINUTES = 70 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") @@ -29,29 +29,57 @@ def get_password_hash(password): return pwd_context.hash(password) -def create_access_token(data: dict, expires_delta: timedelta = None): # type: ignore +def create_token(data: dict, expires_delta: timedelta): to_encode = data.copy() - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta( - minutes=ACCESS_TOKEN_EXPIRE_MINUTES - ) + + expire = datetime.now(timezone.utc) + expires_delta to_encode["exp"] = expire + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) -def create_refresh_token(data: dict): - to_encode = data.copy() - expire = datetime.now(timezone.utc) + timedelta( - minutes=REFRESH_TOKEN_EXPIRE_MINUTES +def create_user_tokens(username: str) -> dict: + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_token( + data={"sub": username}, + expires_delta=access_token_expires, ) - to_encode["exp"] = expire - return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES) + refresh_token = create_token( + data={"sub": username, "type": "rf"}, + expires_delta=refresh_token_expires, + ) + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + } + + +def create_refresh_token(refresh_token: str): + try: + payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") # type: ignore + token_type: str = payload.get("type") # type: ignore + + if username is None or token_type is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" + ) + + return create_user_tokens(username) + + except JWTError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) from e def authenticate_user(db: Session, username: str, password: str): - if user := get_user(db, username): + if user := get_user_by_username(db, username): return user if verify_password(password, user.password) else False else: return False @@ -68,13 +96,15 @@ async def get_current_user( try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") # type: ignore - if username is None: + token_type: str = payload.get("type") # type: ignore + + if username is None or token_type: raise credentials_exception token_data = TokenData(username=username) except JWTError as e: raise credentials_exception from e - user = get_user(db, token_data.username) # type: ignore + user = get_user_by_username(db, token_data.username) # type: ignore if user is None: raise credentials_exception return user diff --git a/src/backend/langflow/database/models/user.py b/src/backend/langflow/database/models/user.py index 144e71fae..75c38d98a 100644 --- a/src/backend/langflow/database/models/user.py +++ b/src/backend/langflow/database/models/user.py @@ -32,6 +32,17 @@ class UserListModel(SQLModel): updated_at: datetime = Field() -def get_user(db: Session, username: str) -> User: +class UserPatchModel(SQLModel): + username: str = Field() + is_disabled: bool = Field() + is_superuser: bool = Field() + + +def get_user_by_username(db: Session, username: str) -> User: db_user = db.query(User).filter(User.username == username).first() return User.from_orm(db_user) if db_user else None # type: ignore + + +def get_user_by_id(db: Session, id: UUID) -> User: + db_user = db.query(User).filter(User.id == id).first() + return User.from_orm(db_user) if db_user else None # type: ignore diff --git a/src/backend/langflow/routers/login.py b/src/backend/langflow/routers/login.py index 8108a2d18..b2814b262 100644 --- a/src/backend/langflow/routers/login.py +++ b/src/backend/langflow/routers/login.py @@ -1,48 +1,39 @@ -from datetime import timedelta - +from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm + +from langflow.services.utils import get_session from langflow.database.models.token import Token from langflow.auth.auth import ( - ACCESS_TOKEN_EXPIRE_MINUTES, authenticate_user, - create_access_token, + create_user_tokens, create_refresh_token, ) -from sqlalchemy.orm import Session -from langflow.services.utils import get_session -from langflow.database.models.user import User - - router = APIRouter() -def create_user_token(user: User) -> dict: - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": user.username}, - expires_delta=access_token_expires, - ) - - refresh_token = create_refresh_token(data={"sub": user.username}) - - return { - "access_token": access_token, - "refresh_token": refresh_token, - "token_type": "bearer", - } - - @router.post("/login", response_model=Token) async def login_to_get_access_token( form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_session) ): if user := authenticate_user(db, form_data.username, form_data.password): - return create_user_token(user) + return create_user_tokens(user.username) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) + + +@router.post("/refresh") +async def refresh_token(token: str): + if token: + return create_refresh_token(token) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/src/backend/langflow/routers/users.py b/src/backend/langflow/routers/users.py index bcf1a7075..ab235e5c3 100644 --- a/src/backend/langflow/routers/users.py +++ b/src/backend/langflow/routers/users.py @@ -1,28 +1,29 @@ -from typing import List +from uuid import UUID from sqlmodel import Session, select +from datetime import timezone, datetime from sqlalchemy.exc import IntegrityError from fastapi import APIRouter, Depends, HTTPException from langflow.services.utils import get_session -from langflow.auth.auth import get_current_active_user -from langflow.database.models.user import UserAddModel, UserListModel, User - -from passlib.context import CryptContext +from langflow.auth.auth import get_current_active_user, get_password_hash +from langflow.database.models.user import ( + User, + UserAddModel, + UserListModel, + UserPatchModel, + get_user_by_id, + get_user_by_username, +) router = APIRouter(tags=["Login"]) -def get_password_hash(password): - pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - return pwd_context.hash(password) - - @router.get("/user", response_model=UserListModel) def read_current_user(current_user: User = Depends(get_current_active_user)): return current_user -@router.get("/users", response_model=List[UserListModel]) +@router.get("/users") def read_all_users( skip: int = 0, limit: int = 10, @@ -35,7 +36,7 @@ def read_all_users( return db.execute(query).fetchall() -@router.post("/user", response_model=User) +@router.post("/user", response_model=UserListModel) def add_user( user: UserAddModel, _: Session = Depends(get_current_active_user), @@ -50,6 +51,7 @@ def add_user( db.refresh(new_user) except IntegrityError as e: db.rollback() + raise HTTPException( status_code=400, detail="User exists", @@ -58,7 +60,62 @@ def add_user( return new_user -# TODO: Remove - Just for testing purposes +@router.patch("/user/{user_id}", response_model=UserListModel) +def update_user( + user_id: UUID, + user: UserPatchModel, + _: Session = Depends(get_current_active_user), + db: Session = Depends(get_session), +): + user_db = get_user_by_username(db, user.username) + if user_db and user_db.id != user_id: + raise HTTPException(status_code=409, detail="Username already exists") + + user_db = get_user_by_id(db, user_id) + if not user_db: + raise HTTPException(status_code=404, detail="User not found") + + try: + user_data = user.dict(exclude_unset=True) + + for key, value in user_data.items(): + setattr(user_db, key, value) + + user_db.updated_at = datetime.now(timezone.utc) + user_db = db.merge(user_db) + + db.commit() + + if db.identity_key(instance=user_db) is not None: + db.refresh(user_db) + except IntegrityError as e: + db.rollback() + + raise HTTPException( + status_code=400, + detail=str(e), + ) from e + + return user_db + + +@router.delete("/user/{user_id}") +def delete_user( + user_id: UUID, + _: Session = Depends(get_current_active_user), + db: Session = Depends(get_session), +): + user_db = db.query(User).filter(User.id == user_id).first() + if not user_db: + raise HTTPException(status_code=404, detail="User not found") + + db.delete(user_db) + db.commit() + + return {"detail": "User deleted"} + + +# TODO: REMOVE - Just for testing purposes @router.post("/super_user", response_model=User) def add_super_user_to_testing_purposes(db: Session = Depends(get_session)): new_user = User(username="superuser", password="12345", is_superuser=True)