diff --git a/src/backend/langflow/auth/auth.py b/src/backend/langflow/auth/auth.py index 439e1017e..36be8e14a 100644 --- a/src/backend/langflow/auth/auth.py +++ b/src/backend/langflow/auth/auth.py @@ -1,3 +1,4 @@ +from uuid import UUID from typing import Annotated from jose import JWTError, jwt from sqlalchemy.orm import Session @@ -7,8 +8,12 @@ from fastapi import Depends, HTTPException, status from datetime import datetime, timedelta, timezone from langflow.services.utils import get_session -from langflow.database.models.token import TokenData -from langflow.database.models.user import User, get_user_by_username +from langflow.database.models.user import ( + User, + get_user_by_id, + get_user_by_username, + update_user_last_login_at, +) # TODO: Move to env - JUST FOR TEST!!!!! @@ -21,6 +26,38 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session) +) -> User: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + user_id: UUID = payload.get("sub") # type: ignore + token_type: str = payload.get("type") # type: ignore + + if user_id is None or token_type: + raise credentials_exception + except JWTError as e: + raise credentials_exception from e + + user = get_user_by_id(db, user_id) # type: ignore + if user is None: + raise credentials_exception + return user + + +async def get_current_active_user( + current_user: Annotated[User, Depends(get_current_user)] +): + if not current_user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) @@ -38,19 +75,22 @@ def create_token(data: dict, expires_delta: timedelta): return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) -def create_user_tokens(username: str) -> dict: +def create_user_tokens(user_id: UUID, db: Session = Depends(get_session)) -> dict: access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( - data={"sub": username}, + data={"sub": str(user_id)}, expires_delta=access_token_expires, ) refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES) refresh_token = create_token( - data={"sub": username, "type": "rf"}, + data={"sub": str(user_id), "type": "rf"}, expires_delta=refresh_token_expires, ) + # Update: last_login_at + update_user_last_login_at(user_id, db) + return { "access_token": access_token, "refresh_token": refresh_token, @@ -61,15 +101,15 @@ def create_user_tokens(username: str) -> dict: def create_refresh_token(refresh_token: str): try: payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") # type: ignore + user_id: UUID = payload.get("sub") # type: ignore token_type: str = payload.get("type") # type: ignore - if username is None or token_type is None: + if user_id is None or token_type is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) - return create_user_tokens(username) + return create_user_tokens(user_id) except JWTError as e: raise HTTPException( @@ -78,41 +118,10 @@ def create_refresh_token(refresh_token: str): ) from e -def authenticate_user(db: Session, username: str, password: str): +def authenticate_user( + username: str, password: str, db: Session = Depends(get_session) +) -> User | None: if user := get_user_by_username(db, username): - return user if verify_password(password, user.password) else False + return user if verify_password(password, user.password) else None else: - return False - - -async def get_current_user( - token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session) -) -> User: - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") # type: ignore - token_type: str = payload.get("type") # type: ignore - - if username is None or token_type: - raise credentials_exception - token_data = TokenData(username=username) - except JWTError as e: - raise credentials_exception from e - - user = get_user_by_username(db, token_data.username) # type: ignore - if user is None: - raise credentials_exception - return user - - -async def get_current_active_user( - current_user: Annotated[User, Depends(get_current_user)] -): - if current_user.is_disabled: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user + return None diff --git a/src/backend/langflow/database/models/token.py b/src/backend/langflow/database/models/token.py index e78743877..68c70f07f 100644 --- a/src/backend/langflow/database/models/token.py +++ b/src/backend/langflow/database/models/token.py @@ -5,7 +5,3 @@ class Token(BaseModel): access_token: str refresh_token: str token_type: str - - -class TokenData(BaseModel): - username: str | None = None diff --git a/src/backend/langflow/database/models/user.py b/src/backend/langflow/database/models/user.py index 75c38d98a..94ceb4e15 100644 --- a/src/backend/langflow/database/models/user.py +++ b/src/backend/langflow/database/models/user.py @@ -1,41 +1,52 @@ -from datetime import datetime -from sqlalchemy.orm import Session - -from langflow.services.database.models.base import SQLModelSerializable, SQLModel from sqlmodel import Field from uuid import UUID, uuid4 +from pydantic import BaseModel +from typing import Optional, List +from sqlalchemy.orm import Session +from datetime import timezone, datetime +from sqlalchemy.exc import IntegrityError +from fastapi import HTTPException, Depends + +from langflow.services.utils import get_session +from langflow.services.database.models.base import SQLModelSerializable, SQLModel class User(SQLModelSerializable, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) username: str = Field(index=True, unique=True) password: str = Field() - is_disabled: bool = Field(default=False) + is_active: bool = Field(default=False) is_superuser: bool = Field(default=False) create_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) + last_login_at: Optional[datetime] = Field() class UserAddModel(SQLModel): username: str = Field() password: str = Field() - is_disabled: bool = Field(default=False) - is_superuser: bool = Field(default=False) class UserListModel(SQLModel): id: UUID = Field(default_factory=uuid4) username: str = Field() - is_disabled: bool = Field() + is_active: bool = Field() is_superuser: bool = Field() create_at: datetime = Field() updated_at: datetime = Field() + last_login_at: Optional[datetime] = Field() class UserPatchModel(SQLModel): - username: str = Field() - is_disabled: bool = Field() - is_superuser: bool = Field() + username: Optional[str] = Field() + is_active: Optional[bool] = Field() + is_superuser: Optional[bool] = Field() + last_login_at: Optional[datetime] = Field() + + +class UsersResponse(BaseModel): + total_count: int + users: List[UserListModel] def get_user_by_username(db: Session, username: str) -> User: @@ -46,3 +57,38 @@ def get_user_by_username(db: Session, username: str) -> User: def get_user_by_id(db: Session, id: UUID) -> User: db_user = db.query(User).filter(User.id == id).first() return User.from_orm(db_user) if db_user else None # type: ignore + + +def update_user( + user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session) +) -> User: + user_db = get_user_by_username(db, user.username) # type: ignore + if user_db and user_db.id != user_id: + raise HTTPException(status_code=409, detail="Username already exists") + + user_db = get_user_by_id(db, user_id) + if not user_db: + raise HTTPException(status_code=404, detail="User not found") + + try: + user_data = user.dict(exclude_unset=True) + for key, value in user_data.items(): + setattr(user_db, key, value) + + user_db.updated_at = datetime.now(timezone.utc) + user_db = db.merge(user_db) + db.commit() + if db.identity_key(instance=user_db) is not None: + db.refresh(user_db) + + except IntegrityError as e: + db.rollback() + raise HTTPException(status_code=400, detail=str(e)) from e + + return user_db + + +def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)): + user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore + + return update_user(user_id, user_data, db) diff --git a/src/backend/langflow/routers/login.py b/src/backend/langflow/routers/login.py index b2814b262..35fdb9cdb 100644 --- a/src/backend/langflow/routers/login.py +++ b/src/backend/langflow/routers/login.py @@ -2,6 +2,7 @@ from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm + from langflow.services.utils import get_session from langflow.database.models.token import Token from langflow.auth.auth import ( @@ -15,10 +16,12 @@ router = APIRouter() @router.post("/login", response_model=Token) async def login_to_get_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_session) + form_data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(get_session), + # _: Session = Depends(get_current_active_user) ): - if user := authenticate_user(db, form_data.username, form_data.password): - return create_user_tokens(user.username) + if user := authenticate_user(form_data.username, form_data.password, db): + return create_user_tokens(user_id=user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/src/backend/langflow/routers/users.py b/src/backend/langflow/routers/users.py index 5cf20581b..da738a5cd 100644 --- a/src/backend/langflow/routers/users.py +++ b/src/backend/langflow/routers/users.py @@ -1,10 +1,10 @@ -from typing import List from uuid import UUID -from datetime import timezone, datetime +from sqlalchemy import func from sqlalchemy.exc import IntegrityError -from fastapi import APIRouter, Depends, HTTPException + from sqlmodel import Session, select +from fastapi import APIRouter, Depends, HTTPException from langflow.services.utils import get_session from langflow.auth.auth import get_current_active_user, get_password_hash @@ -13,8 +13,8 @@ from langflow.database.models.user import ( UserAddModel, UserListModel, UserPatchModel, - get_user_by_id, - get_user_by_username, + UsersResponse, + update_user, ) router = APIRouter(tags=["Login"]) @@ -23,7 +23,6 @@ router = APIRouter(tags=["Login"]) @router.post("/user", response_model=UserListModel) def add_user( user: UserAddModel, - _: Session = Depends(get_current_active_user), db: Session = Depends(get_session), ) -> User: """ @@ -32,6 +31,7 @@ def add_user( new_user = User(**user.dict()) try: new_user.password = get_password_hash(user.password) + db.add(new_user) db.commit() db.refresh(new_user) @@ -50,24 +50,30 @@ def read_current_user(current_user: User = Depends(get_current_active_user)) -> return current_user -@router.get("/users", response_model=List[UserListModel]) +@router.get("/users", response_model=UsersResponse) def read_all_users( skip: int = 0, limit: int = 10, _: Session = Depends(get_current_active_user), db: Session = Depends(get_session), -) -> List[UserListModel]: +) -> UsersResponse: """ Retrieve a list of users from the database with pagination. """ query = select(User).offset(skip).limit(limit) users = db.execute(query).fetchall() - return [UserListModel(**dict(user.User)) for user in users] + count_query = select(func.count()).select_from(User) # type: ignore + total_count = db.execute(count_query).scalar() + + return UsersResponse( + total_count=total_count, # type: ignore + users=[UserListModel(**dict(user.User)) for user in users], + ) @router.patch("/user/{user_id}", response_model=UserListModel) -def update_user( +def patch_user( user_id: UUID, user: UserPatchModel, _: Session = Depends(get_current_active_user), @@ -76,29 +82,7 @@ def update_user( """ Update an existing user's data. """ - user_db = get_user_by_username(db, user.username) - if user_db and user_db.id != user_id: - raise HTTPException(status_code=409, detail="Username already exists") - - user_db = get_user_by_id(db, user_id) - if not user_db: - raise HTTPException(status_code=404, detail="User not found") - - try: - user_data = user.dict(exclude_unset=True) - for key, value in user_data.items(): - setattr(user_db, key, value) - - user_db.updated_at = datetime.now(timezone.utc) - user_db = db.merge(user_db) - db.commit() - if db.identity_key(instance=user_db) is not None: - db.refresh(user_db) - except IntegrityError as e: - db.rollback() - raise HTTPException(status_code=400, detail=str(e)) from e - - return user_db + return update_user(user_id, user, db) @router.delete("/user/{user_id}") @@ -129,7 +113,13 @@ def add_super_user_for_testing_purposes_delete_me_before_merge_into_dev( Add a superuser for testing purposes. (This should be removed in production) """ - new_user = User(username="superuser", password="12345", is_superuser=True) + new_user = User( + username="superuser", + password="12345", + is_active=True, + is_superuser=True, + last_login_at=None, + ) try: new_user.password = get_password_hash(new_user.password)