🐛 fix(utils.py): remove unnecessary type casting in get_user_by_username and get_user_by_id functions

🐛 fix(utils.py): fix update_user function to correctly update user attributes and handle username conflicts
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-25 10:16:04 -03:00
commit 92a7ae6be7

View file

@ -4,41 +4,41 @@ from fastapi import Depends, HTTPException
from langflow.services.database.models.user.user import User, UserUpdate
from langflow.services.utils import get_session
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlmodel import Session
from sqlalchemy.orm.attributes import flag_modified
def get_user_by_username(db: Session, username: str) -> User:
db_user = db.query(User).filter(User.username == username).first()
return User.from_orm(db_user) if db_user else None # type: ignore
return db.query(User).filter(User.username == username).first()
def get_user_by_id(db: Session, id: UUID) -> User:
db_user = db.query(User).filter(User.id == id).first()
return User.from_orm(db_user) if db_user else None # type: ignore
return db.query(User).filter(User.id == id).first()
def update_user(
user_id: UUID, user: UserUpdate, db: Session = Depends(get_session)
) -> User:
user_db = get_user_by_username(db, user.username) # type: ignore
if user_db and user_db.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
user_db_by_username = get_user_by_username(db, user.username) # type: ignore
if user_db_by_username and user_db_by_username.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_data = user.dict(exclude_unset=True)
for attr, value in user_data.items():
if hasattr(user_db, attr) and value is not None:
setattr(user_db, attr, value)
user_db.updated_at = datetime.now(timezone.utc)
flag_modified(user_db, "updated_at")
try:
user_data = user.dict(exclude_unset=True)
for key, value in user_data.items():
setattr(user_db, key, value)
user_db.updated_at = datetime.now(timezone.utc)
user_db = db.merge(user_db)
db.commit()
if db.identity_key(instance=user_db) is not None:
db.refresh(user_db)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e