Adds Tests for Login, Users and API keys (#821)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-25 15:27:02 +00:00 committed by GitHub
commit eab34e2fdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
44 changed files with 998 additions and 269 deletions

View file

@ -0,0 +1,100 @@
"""Add ApiKey table
Revision ID: 5512e39b4012
Revises: 0a534bdfd84b
Create Date: 2023-08-23 21:05:51.042203
"""
import contextlib
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "5512e39b4012"
down_revision: Union[str, None] = "0a534bdfd84b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with contextlib.suppress(sa.exc.OperationalError):
op.create_table(
"apikey",
sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("create_at", sa.DateTime(), nullable=False),
sa.Column("last_used_at", sa.DateTime(), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("id"),
)
op.create_index(op.f("ix_apikey_api_key"), "apikey", ["api_key"], unique=True)
with contextlib.suppress(sa.exc.OperationalError):
op.create_table(
"user",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("create_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.Column("last_login_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("id"),
)
op.create_index(op.f("ix_user_username"), "user", ["username"], unique=True)
with contextlib.suppress(sa.exc.OperationalError):
op.drop_table("flowstyle")
with contextlib.suppress(sa.exc.OperationalError):
op.drop_index("ix_component_frontend_node_id", table_name="component")
op.drop_index("ix_component_name", table_name="component")
op.drop_table("component")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"component",
sa.Column("id", sa.CHAR(length=32), nullable=False),
sa.Column("frontend_node_id", sa.CHAR(length=32), nullable=False),
sa.Column("name", sa.VARCHAR(), nullable=False),
sa.Column("description", sa.VARCHAR(), nullable=True),
sa.Column("python_code", sa.VARCHAR(), nullable=True),
sa.Column("return_type", sa.VARCHAR(), nullable=True),
sa.Column("is_disabled", sa.BOOLEAN(), nullable=False),
sa.Column("is_read_only", sa.BOOLEAN(), nullable=False),
sa.Column("create_at", sa.DATETIME(), nullable=False),
sa.Column("update_at", sa.DATETIME(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_component_name", "component", ["name"], unique=False)
op.create_index(
"ix_component_frontend_node_id", "component", ["frontend_node_id"], unique=False
)
op.create_table(
"flowstyle",
sa.Column("color", sa.VARCHAR(), nullable=False),
sa.Column("emoji", sa.VARCHAR(), nullable=False),
sa.Column("flow_id", sa.CHAR(length=32), nullable=True),
sa.Column("id", sa.CHAR(length=32), nullable=False),
sa.ForeignKeyConstraint(
["flow_id"],
["flow.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("id"),
)
op.drop_index(op.f("ix_user_username"), table_name="user")
op.drop_table("user")
op.drop_index(op.f("ix_apikey_api_key"), table_name="apikey")
op.drop_table("apikey")
# ### end Alembic commands ###

View file

@ -6,6 +6,9 @@ from langflow.api.v1 import (
validate_router,
flows_router,
component_router,
users_router,
api_key_router,
login_router,
)
router = APIRouter(
@ -16,3 +19,6 @@ router.include_router(endpoints_router)
router.include_router(validate_router)
router.include_router(component_router)
router.include_router(flows_router)
router.include_router(users_router)
router.include_router(api_key_router)
router.include_router(login_router)

View file

@ -3,6 +3,9 @@ from langflow.api.v1.validate import router as validate_router
from langflow.api.v1.chat import router as chat_router
from langflow.api.v1.flows import router as flows_router
from langflow.api.v1.components import router as component_router
from langflow.api.v1.users import router as users_router
from langflow.api.v1.api_key import router as api_key_router
from langflow.api.v1.login import router as login_router
__all__ = [
"chat_router",
@ -10,4 +13,7 @@ __all__ = [
"component_router",
"validate_router",
"flows_router",
"users_router",
"api_key_router",
"login_router",
]

View file

@ -0,0 +1,61 @@
from uuid import UUID
from fastapi import APIRouter, HTTPException, Depends
from langflow.api.v1.schemas import ApiKeysResponse
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.api_key.api_key import (
ApiKeyCreate,
UnmaskedApiKeyRead,
)
# Assuming you have these methods in your service layer
from langflow.services.database.models.api_key.crud import (
get_api_keys,
create_api_key,
delete_api_key,
)
from langflow.services.database.models.user.user import User
from langflow.services.utils import get_session
from sqlmodel import Session
router = APIRouter(tags=["APIKey"])
@router.get("/api_key", response_model=ApiKeysResponse)
def get_api_keys_route(
db: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
):
try:
user_id = current_user.id
keys = get_api_keys(db, user_id)
return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api_key", response_model=UnmaskedApiKeyRead)
def create_api_key_route(
req: ApiKeyCreate,
current_user: User = Depends(get_current_active_user),
db: Session = Depends(get_session),
):
try:
user_id = current_user.id
return create_api_key(db, req, user_id=user_id)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@router.delete("/api_key/{api_key_id}")
def delete_api_key_route(
api_key_id: UUID,
current_user=Depends(get_current_active_user),
db: Session = Depends(get_session),
):
try:
delete_api_key(db, api_key_id)
return {"detail": "API Key deleted"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View file

@ -58,8 +58,12 @@ def get_all():
logger.info(f"Loading {len(custom_component_dicts)} category(ies)")
for custom_component_dict in custom_component_dicts:
logger.debug(
{key: len(value) for key, value in custom_component_dict.items()}
# custom_component_dict is a dict of dicts
if not custom_component_dict:
continue
category = list(custom_component_dict.keys())[0]
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict

View file

@ -1,10 +1,10 @@
from sqlalchemy.orm import Session
from sqlmodel import Session
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from langflow.services.utils import get_session
from langflow.database.models.token import Token
from langflow.auth.auth import (
from langflow.services.database.models import Token
from langflow.services.auth.utils import (
authenticate_user,
create_user_tokens,
create_refresh_token,
@ -13,7 +13,7 @@ from langflow.auth.auth import (
from langflow.services.utils import get_settings_manager
router = APIRouter()
router = APIRouter(tags=["Login"])
@router.post("/login", response_model=Token)
@ -36,7 +36,7 @@ async def login_to_get_access_token(
async def auto_login(db: Session = Depends(get_session)):
settings_manager = get_settings_manager()
if settings_manager.settings.AUTO_LOGIN:
if settings_manager.auth_settings.AUTO_LOGIN:
return create_user_longterm_token(db)
raise HTTPException(

View file

@ -1,7 +1,10 @@
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langflow.services.database.models.api_key.api_key import ApiKeyRead
from langflow.services.database.models.flow import FlowCreate, FlowRead
from langflow.services.database.models.user import UserRead
from langflow.services.database.models.base import orjson_dumps
from pydantic import BaseModel, Field, validator
@ -137,3 +140,26 @@ class ComponentListCreate(BaseModel):
class ComponentListRead(BaseModel):
flows: List[FlowRead]
class UsersResponse(BaseModel):
total_count: int
users: List[UserRead]
class ApiKeyResponse(BaseModel):
id: str
api_key: str
name: str
created_at: str
last_used_at: str
class ApiKeysResponse(BaseModel):
total_count: int
user_id: UUID
api_keys: List[ApiKeyRead]
class CreateApiKeyRequest(BaseModel):
name: str

View file

@ -1,4 +1,11 @@
from uuid import UUID
from langflow.api.v1.schemas import UsersResponse
from langflow.services.database.models.user import (
User,
UserCreate,
UserRead,
UserUpdate,
)
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
@ -7,28 +14,27 @@ from sqlmodel import Session, select
from fastapi import APIRouter, Depends, HTTPException
from langflow.services.utils import get_session
from langflow.auth.auth import get_current_active_user, get_password_hash
from langflow.database.models.user import (
User,
UserAddModel,
UserListModel,
UserPatchModel,
UsersResponse,
from langflow.services.auth.utils import (
get_current_active_superuser,
get_current_active_user,
get_password_hash,
)
from langflow.services.database.models.user.crud import (
update_user,
)
router = APIRouter(tags=["Login"])
router = APIRouter(tags=["Users"])
@router.post("/user", response_model=UserListModel)
@router.post("/user", response_model=UserRead, status_code=201)
def add_user(
user: UserAddModel,
user: UserCreate,
db: Session = Depends(get_session),
) -> User:
"""
Add a new user to the database.
"""
new_user = User(**user.dict())
new_user = User.from_orm(user)
try:
new_user.password = get_password_hash(user.password)
@ -42,8 +48,10 @@ def add_user(
return new_user
@router.get("/user", response_model=UserListModel)
def read_current_user(current_user: User = Depends(get_current_active_user)) -> User:
@router.get("/user", response_model=UserRead)
def read_current_user(
current_user: User = Depends(get_current_active_user),
) -> User:
"""
Retrieve the current user's data.
"""
@ -54,7 +62,7 @@ def read_current_user(current_user: User = Depends(get_current_active_user)) ->
def read_all_users(
skip: int = 0,
limit: int = 10,
_: Session = Depends(get_current_active_user),
current_user: Session = Depends(get_current_active_superuser),
db: Session = Depends(get_session),
) -> UsersResponse:
"""
@ -68,14 +76,14 @@ def read_all_users(
return UsersResponse(
total_count=total_count, # type: ignore
users=[UserListModel(**dict(user.User)) for user in users],
users=[UserRead(**dict(user.User)) for user in users],
)
@router.patch("/user/{user_id}", response_model=UserListModel)
@router.patch("/user/{user_id}", response_model=UserRead)
def patch_user(
user_id: UUID,
user: UserPatchModel,
user: UserUpdate,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> User:
@ -88,12 +96,21 @@ def patch_user(
@router.delete("/user/{user_id}")
def delete_user(
user_id: UUID,
_: Session = Depends(get_current_active_user),
current_user: User = Depends(get_current_active_superuser),
db: Session = Depends(get_session),
) -> dict:
"""
Delete a user from the database.
"""
if current_user.id == user_id:
raise HTTPException(
status_code=400, detail="You can't delete your own user account"
)
elif not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="You don't have the permission to delete this user"
)
user_db = db.query(User).filter(User.id == user_id).first()
if not user_db:
raise HTTPException(status_code=404, detail="User not found")

View file

@ -1,94 +0,0 @@
from sqlmodel import Field
from uuid import UUID, uuid4
from pydantic import BaseModel
from typing import Optional, List
from sqlalchemy.orm import Session
from datetime import timezone, datetime
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, Depends
from langflow.services.utils import get_session
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
class User(SQLModelSerializable, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
username: str = Field(index=True, unique=True)
password: str = Field()
is_active: bool = Field(default=False)
is_superuser: bool = Field(default=False)
create_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
last_login_at: Optional[datetime] = Field()
class UserAddModel(SQLModel):
username: str = Field()
password: str = Field()
class UserListModel(SQLModel):
id: UUID = Field(default_factory=uuid4)
username: str = Field()
is_active: bool = Field()
is_superuser: bool = Field()
create_at: datetime = Field()
updated_at: datetime = Field()
last_login_at: Optional[datetime] = Field()
class UserPatchModel(SQLModel):
username: Optional[str] = Field()
is_active: Optional[bool] = Field()
is_superuser: Optional[bool] = Field()
last_login_at: Optional[datetime] = Field()
class UsersResponse(BaseModel):
total_count: int
users: List[UserListModel]
def get_user_by_username(db: Session, username: str) -> User:
db_user = db.query(User).filter(User.username == username).first()
return User.from_orm(db_user) if db_user else None # type: ignore
def get_user_by_id(db: Session, id: UUID) -> User:
db_user = db.query(User).filter(User.id == id).first()
return User.from_orm(db_user) if db_user else None # type: ignore
def update_user(
user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session)
) -> User:
user_db = get_user_by_username(db, user.username) # type: ignore
if user_db and user_db.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
try:
user_data = user.dict(exclude_unset=True)
for key, value in user_data.items():
setattr(user_db, key, value)
user_db.updated_at = datetime.now(timezone.utc)
user_db = db.merge(user_db)
db.commit()
if db.identity_key(instance=user_db) is not None:
db.refresh(user_db)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e
return user_db
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore
return update_user(user_id, user_data, db)

View file

@ -6,7 +6,7 @@ from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from langflow.api import router
from langflow.routers import api_key, login, users, health
from langflow.interface.utils import setup_llm_caching
from langflow.services.database.utils import initialize_database
@ -31,11 +31,6 @@ def create_app():
allow_headers=["*"],
)
app.include_router(login.router)
app.include_router(api_key.router)
app.include_router(users.router)
app.include_router(health.router)
app.include_router(router)
app.on_event("startup")(initialize_services)

View file

@ -1,49 +0,0 @@
from fastapi import APIRouter
router = APIRouter(tags=["APIKey"])
@router.get("/api_key/{user_id}")
def get_api_key(user_id: str):
return {
"total_count": 3,
"user_id": user_id,
"api_keys": [
{
"id": "4425707e-cce4-4d1b-a54e-bd2632064657",
"api_key": "lf-...abcd",
"name": "my api_key name - 01",
"created_at": "2023-08-15T19:28:40.019613",
"last_used_at": "2023-08-16T18:38:20.875210",
},
{
"id": "6fb7282b-9f2e-4efe-9bda-0c3d8f899473",
"api_key": "lf-...abcd",
"name": "my api_key name - 02",
"created_at": "2023-08-15T19:41:30.077942",
"last_used_at": "2023-08-15T19:45:32.067899",
},
{
"id": "c55f3b32-4920-42b6-a5cd-698b4251806e",
"api_key": "lf-...abcd",
"name": "my api_key name - 03",
"created_at": "2023-08-15T20:29:40.577808",
"last_used_at": "2023-08-15T20:29:40.577816",
},
],
}
@router.post("/api_key/{user_id}")
def create_api_key(user_id: str):
return {
"user_id": user_id,
"name": "my api-key 01",
"api_key": "lf-eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI1YTBmODM1ZS0yMTQxLTQ2YWItYmQ4NS0yMWEzMjQ1MTE2ZDAiLCJleHAiOjE2OTIyMTUwMTN9.c_s0ZPRtjSI9yUrhi8ACIwyXf0feRLYfaeIZEbRVKQg",
}
@router.delete("/api_key/{api_key_id}")
def delete_api_key(api_key_id: str):
return {"detail": "API Key deleted"}

View file

@ -1,8 +0,0 @@
from fastapi import APIRouter
router = APIRouter()
@router.get("/health")
def get_health():
return {"status": "OK"}

View file

@ -0,0 +1,12 @@
from langflow.services.factory import ServiceFactory
from langflow.services.auth.service import AuthManager
class AuthManagerFactory(ServiceFactory):
name = "auth_manager"
def __init__(self):
super().__init__(AuthManager)
def create(self, settings_manager):
return AuthManager(settings_manager)

View file

@ -0,0 +1,18 @@
from fastapi import Request
from langflow.services.base import Service
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langflow.services.settings.manager import SettingsManager
class AuthManager(Service):
name = "auth_manager"
def __init__(self, settings_manager: "SettingsManager"):
self.settings_manager = settings_manager
# We need to define a function that can be passed to the Depends() function.
# This function will be called by FastAPI to run oauth2_scheme
def run_oauth2_scheme(self, request: Request):
return self.settings_manager.auth_settings.oauth2_scheme(request=request)

View file

@ -1,28 +1,30 @@
from uuid import UUID
from typing import Annotated
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from fastapi.security import OAuth2PasswordBearer
from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta, timezone
from langflow.services.utils import get_settings_manager, get_session
from langflow.database.models.user import (
User,
from fastapi import Depends, HTTPException, Request, status
from jose import JWTError, jwt
from typing import Annotated, Coroutine
from uuid import UUID
from langflow.services.auth.service import AuthManager
from langflow.services.database.models.user.user import User
from langflow.services.database.models.user.crud import (
get_user_by_id,
get_user_by_username,
update_user_last_login_at,
)
from langflow.services.utils import get_session, get_settings_manager
from sqlmodel import Session
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
def auth_scheme_dependency(request: Request):
settings_manager = (
get_settings_manager()
) # Assuming get_settings_manager is defined
return AuthManager(settings_manager).run_oauth2_scheme(request)
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
token: Annotated[str, Depends(auth_scheme_dependency)],
db: Session = Depends(get_session),
) -> User:
settings_manager = get_settings_manager()
@ -32,11 +34,14 @@ async def get_current_user(
headers={"WWW-Authenticate": "Bearer"},
)
if isinstance(token, Coroutine):
token = await token
try:
payload = jwt.decode(
token,
settings_manager.settings.SECRET_KEY,
algorithms=[settings_manager.settings.ALGORITHM],
settings_manager.auth_settings.SECRET_KEY,
algorithms=[settings_manager.auth_settings.ALGORITHM],
)
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
@ -47,25 +52,39 @@ async def get_current_user(
raise credentials_exception from e
user = get_user_by_id(db, user_id) # type: ignore
if user is None:
if user is None or not user.is_active:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def get_current_active_superuser(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=401, detail="Inactive user")
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges"
)
return current_user
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
settings_manager = get_settings_manager()
return settings_manager.auth_settings.pwd_context.verify(
plain_password, hashed_password
)
def get_password_hash(password):
return pwd_context.hash(password)
settings_manager = get_settings_manager()
return settings_manager.auth_settings.pwd_context.hash(password)
def create_token(data: dict, expires_delta: timedelta):
@ -77,21 +96,23 @@ def create_token(data: dict, expires_delta: timedelta):
return jwt.encode(
to_encode,
settings_manager.settings.SECRET_KEY,
algorithm=settings_manager.settings.ALGORITHM,
settings_manager.auth_settings.SECRET_KEY,
algorithm=settings_manager.auth_settings.ALGORITHM,
)
def create_super_user(db: Session = Depends(get_session)) -> User:
settings_manager = get_settings_manager()
super_user = get_user_by_username(db, settings_manager.settings.FIRST_SUPERUSER)
super_user = get_user_by_username(
db, settings_manager.auth_settings.FIRST_SUPERUSER
)
if not super_user:
super_user = User(
username=settings_manager.settings.FIRST_SUPERUSER,
username=settings_manager.auth_settings.FIRST_SUPERUSER,
password=get_password_hash(
settings_manager.settings.FIRST_SUPERUSER_PASSWORD
settings_manager.auth_settings.FIRST_SUPERUSER_PASSWORD
),
is_superuser=True,
is_active=True,
@ -147,7 +168,7 @@ def create_user_tokens(
settings_manager = get_settings_manager()
access_token_expires = timedelta(
minutes=settings_manager.settings.ACCESS_TOKEN_EXPIRE_MINUTES
minutes=settings_manager.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
access_token = create_token(
data={"sub": str(user_id)},
@ -155,7 +176,7 @@ def create_user_tokens(
)
refresh_token_expires = timedelta(
minutes=settings_manager.settings.REFRESH_TOKEN_EXPIRE_MINUTES
minutes=settings_manager.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
refresh_token = create_token(
data={"sub": str(user_id), "type": "rf"},
@ -179,8 +200,8 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
try:
payload = jwt.decode(
refresh_token,
settings_manager.settings.SECRET_KEY,
algorithms=[settings_manager.settings.ALGORITHM],
settings_manager.auth_settings.SECRET_KEY,
algorithms=[settings_manager.auth_settings.ALGORITHM],
)
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore

View file

@ -6,6 +6,6 @@ class CacheManagerFactory(ServiceFactory):
def __init__(self):
super().__init__(CacheManager)
def create(self, settings_service):
def create(self):
# Here you would have logic to create and configure a CacheManager
return CacheManager()

View file

@ -6,6 +6,6 @@ class ChatManagerFactory(ServiceFactory):
def __init__(self):
super().__init__(ChatManager)
def create(self, settings_service):
def create(self):
# Here you would have logic to create and configure a ChatManager
return ChatManager()

View file

@ -10,8 +10,8 @@ class DatabaseManagerFactory(ServiceFactory):
def __init__(self):
super().__init__(DatabaseManager)
def create(self, settings_service: "SettingsManager"):
def create(self, settings_manager: "SettingsManager"):
# Here you would have logic to create and configure a DatabaseManager
if not settings_service.settings.DATABASE_URL:
if not settings_manager.settings.DATABASE_URL:
raise ValueError("No database URL provided")
return DatabaseManager(settings_service.settings.DATABASE_URL)
return DatabaseManager(settings_manager.settings.DATABASE_URL)

View file

@ -1,4 +1,6 @@
from .flow import Flow
from .user import User
from .token import Token
from .api_key import ApiKey
__all__ = ["Flow"]
__all__ = ["Flow", "User", "Token", "ApiKey"]

View file

@ -0,0 +1,3 @@
from .api_key import ApiKey, ApiKeyCreate, UnmaskedApiKeyRead, ApiKeyRead
__all__ = ["ApiKey", "ApiKeyCreate", "UnmaskedApiKeyRead", "ApiKeyRead"]

View file

@ -0,0 +1,45 @@
from pydantic import validator
from sqlmodel import Field, Relationship
from uuid import UUID, uuid4
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from langflow.services.database.models.base import SQLModelSerializable
if TYPE_CHECKING:
from langflow.services.database.models.user import User
class ApiKeyBase(SQLModelSerializable):
name: Optional[str] = Field(index=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: Optional[datetime] = Field(default=None)
class ApiKey(ApiKeyBase, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
api_key: str = Field(index=True, unique=True)
# User relationship
user_id: UUID = Field(index=True, foreign_key="user.id")
user: "User" = Relationship(back_populates="api_keys")
class ApiKeyCreate(ApiKeyBase):
api_key: Optional[str] = None
user_id: Optional[UUID] = None
class UnmaskedApiKeyRead(ApiKeyBase):
id: UUID
api_key: str = Field(index=True, unique=True)
user_id: UUID = Field()
class ApiKeyRead(ApiKeyBase):
id: UUID
api_key: str = Field(index=True, unique=True)
user_id: UUID = Field()
@validator("api_key", always=True)
def mask_api_key(cls, v):
# This validator will always run, and will mask the API key
return f"{'*' * 8}{v[-4:]}"

View file

@ -0,0 +1,45 @@
import secrets
from uuid import UUID
from typing import List
from langflow.services.auth.utils import get_password_hash
from sqlmodel import Session, select
from langflow.services.database.models.api_key import (
ApiKey,
ApiKeyCreate,
UnmaskedApiKeyRead,
ApiKeyRead,
)
def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]:
query = select(ApiKey).where(ApiKey.user_id == user_id)
api_keys = session.exec(query).all()
return [ApiKeyRead.from_orm(api_key) for api_key in api_keys]
def create_api_key(
session: Session, api_key_create: ApiKeyCreate, user_id: UUID
) -> UnmaskedApiKeyRead:
# Generate a random API key with 32 bytes of randomness
generated_api_key = secrets.token_urlsafe(32)
# hash the API key
hashed_api_key = get_password_hash(generated_api_key)
# Use the generated key to create the ApiKey object
api_key = ApiKey(api_key=hashed_api_key, name=api_key_create.name, user_id=user_id)
session.add(api_key)
session.commit()
session.refresh(api_key)
unmasked = UnmaskedApiKeyRead.from_orm(api_key)
unmasked.api_key = generated_api_key
return unmasked
def delete_api_key(session: Session, api_key_id: UUID) -> None:
api_key = session.get(ApiKey, api_key_id)
if api_key is None:
raise ValueError("API Key not found")
session.delete(api_key)
session.commit()

View file

@ -0,0 +1,3 @@
from .component import Component, ComponentModel
__all__ = ["Component", "ComponentModel"]

View file

@ -0,0 +1,3 @@
from .flow import Flow, FlowCreate, FlowRead, FlowUpdate
__all__ = ["Flow", "FlowCreate", "FlowRead", "FlowUpdate"]

View file

@ -6,8 +6,6 @@ from sqlmodel import Field, JSON, Column
from uuid import UUID, uuid4
from typing import Dict, Optional
# if TYPE_CHECKING:
class FlowBase(SQLModelSerializable):
name: str = Field(index=True)
@ -16,7 +14,6 @@ class FlowBase(SQLModelSerializable):
@validator("data")
def validate_json(v):
# dict_keys(['description', 'name', 'id', 'data'])
if not v:
return v
if not isinstance(v, dict):

View file

@ -0,0 +1,5 @@
from .token import Token
__all__ = [
"Token",
]

View file

@ -0,0 +1,8 @@
from .user import User, UserCreate, UserRead, UserUpdate
__all__ = [
"User",
"UserCreate",
"UserRead",
"UserUpdate",
]

View file

@ -0,0 +1,53 @@
from datetime import datetime, timezone
from typing import Union
from uuid import UUID
from fastapi import Depends, HTTPException
from langflow.services.database.models.user.user import User, UserUpdate
from langflow.services.utils import get_session
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session
from sqlalchemy.orm.attributes import flag_modified
def get_user_by_username(db: Session, username: str) -> Union[User, None]:
return db.query(User).filter(User.username == username).first()
def get_user_by_id(db: Session, id: UUID) -> Union[User, None]:
return db.query(User).filter(User.id == id).first()
def update_user(
user_id: UUID, user: UserUpdate, db: Session = Depends(get_session)
) -> User:
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
user_db_by_username = get_user_by_username(db, user.username) # type: ignore
if user_db_by_username and user_db_by_username.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_data = user.dict(exclude_unset=True)
for attr, value in user_data.items():
if hasattr(user_db, attr) and value is not None:
setattr(user_db, attr, value)
user_db.updated_at = datetime.now(timezone.utc)
flag_modified(user_db, "updated_at")
try:
db.commit()
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e
return user_db
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
user_data = UserUpdate(last_login_at=datetime.now(timezone.utc)) # type: ignore
return update_user(user_id, user_data, db)

View file

@ -0,0 +1,44 @@
from langflow.services.database.models.base import SQLModel, SQLModelSerializable
from sqlmodel import Field, Relationship
from datetime import datetime
from typing import Optional, TYPE_CHECKING
from uuid import UUID, uuid4
if TYPE_CHECKING:
from langflow.services.database.models.api_key import ApiKey
class User(SQLModelSerializable, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
username: str = Field(index=True, unique=True)
password: str = Field()
is_active: bool = Field(default=False)
is_superuser: bool = Field(default=False)
create_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
last_login_at: Optional[datetime] = Field()
api_keys: list["ApiKey"] = Relationship(back_populates="user")
class UserCreate(SQLModel):
username: str = Field()
password: str = Field()
class UserRead(SQLModel):
id: UUID = Field(default_factory=uuid4)
username: str = Field()
is_active: bool = Field()
is_superuser: bool = Field()
create_at: datetime = Field()
updated_at: datetime = Field()
last_login_at: Optional[datetime] = Field()
class UserUpdate(SQLModel):
username: Optional[str] = Field()
is_active: Optional[bool] = Field()
is_superuser: Optional[bool] = Field()
last_login_at: Optional[datetime] = Field()

View file

@ -1,5 +1,5 @@
from langflow.services.schema import ServiceType
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from langflow.services.factory import ServiceFactory
@ -13,13 +13,21 @@ class ServiceManager:
def __init__(self):
self.services = {}
self.factories = {}
self.dependencies = {}
def register_factory(self, service_factory: "ServiceFactory"):
def register_factory(
self,
service_factory: "ServiceFactory",
dependencies: Optional[List[ServiceType]] = None,
):
"""
Registers a new factory.
Registers a new factory with dependencies.
"""
if service_factory.service_class.name not in self.factories:
self.factories[service_factory.service_class.name] = service_factory
if dependencies is None:
dependencies = []
service_name = service_factory.service_class.name
self.factories[service_name] = service_factory
self.dependencies[service_name] = dependencies
def get(self, service_name: ServiceType):
"""
@ -32,17 +40,25 @@ class ServiceManager:
def _create_service(self, service_name: ServiceType):
"""
Create a new service given its name.
Create a new service given its name, handling dependencies.
"""
self._validate_service_creation(service_name)
if service_name == ServiceType.SETTINGS_MANAGER:
self.services[service_name] = self.factories[service_name].create()
else:
settings_service = self.get(ServiceType.SETTINGS_MANAGER)
self.services[service_name] = self.factories[service_name].create(
settings_service
)
# Create dependencies first
for dependency in self.dependencies.get(service_name, []):
if dependency not in self.services:
self._create_service(dependency)
# Collect the dependent services
dependent_services = {
dep.value: self.services[dep]
for dep in self.dependencies.get(service_name, [])
}
# Create the actual service
self.services[service_name] = self.factories[service_name].create(
**dependent_services
)
def _validate_service_creation(self, service_name: ServiceType):
"""
@ -53,14 +69,6 @@ class ServiceManager:
f"No factory registered for the service class '{service_name.name}'"
)
if (
ServiceType.SETTINGS_MANAGER not in self.factories
and service_name != ServiceType.SETTINGS_MANAGER
):
raise ValueError(
f"Cannot create service '{service_name.name}' before the settings service"
)
def update(self, service_name: ServiceType):
"""
Update a service by its name.
@ -81,12 +89,24 @@ def initialize_services():
from langflow.services.cache import factory as cache_factory
from langflow.services.chat import factory as chat_factory
from langflow.services.settings import factory as settings_factory
from langflow.services.auth import factory as auth_factory
service_manager.register_factory(settings_factory.SettingsManagerFactory())
service_manager.register_factory(database_factory.DatabaseManagerFactory())
service_manager.register_factory(
auth_factory.AuthManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER]
)
service_manager.register_factory(
database_factory.DatabaseManagerFactory(),
dependencies=[ServiceType.SETTINGS_MANAGER],
)
service_manager.register_factory(cache_factory.CacheManagerFactory())
service_manager.register_factory(chat_factory.ChatManagerFactory())
# Test cache connection
service_manager.get(ServiceType.CACHE_MANAGER)
# Test database connection
service_manager.get(ServiceType.DATABASE_MANAGER)
def initialize_settings_manager():
"""
@ -95,3 +115,22 @@ def initialize_settings_manager():
from langflow.services.settings import factory as settings_factory
service_manager.register_factory(settings_factory.SettingsManagerFactory())
def initialize_session_manager():
"""
Initialize the session manager.
"""
from langflow.services.session import factory as session_manager_factory
from langflow.services.cache import factory as cache_factory
initialize_settings_manager()
service_manager.register_factory(
cache_factory.CacheManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER]
)
service_manager.register_factory(
session_manager_factory.SessionManagerFactory(),
dependencies=[ServiceType.CACHE_MANAGER],
)

View file

@ -7,6 +7,7 @@ class ServiceType(str, Enum):
registered with the service manager.
"""
AUTH_MANAGER = "auth_manager"
CACHE_MANAGER = "cache_manager"
SETTINGS_MANAGER = "settings_manager"
DATABASE_MANAGER = "database_manager"

View file

@ -0,0 +1,35 @@
from typing import Optional
import secrets
from pydantic import BaseSettings
from passlib.context import CryptContext
from fastapi.security import OAuth2PasswordBearer
class AuthSettings(BaseSettings):
# Login settings
SECRET_KEY: str = secrets.token_hex(32)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
REFRESH_TOKEN_EXPIRE_MINUTES: int = 70
# API Key to execute /process endpoint
API_KEY_SECRET_KEY: Optional[
str
] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a"
API_KEY_ALGORITHM: str = "HS256"
API_V1_STR: str = "/api/v1"
# If AUTO_LOGIN = True
# > The application does not request login and logs in automatically as a super user.
AUTO_LOGIN: bool = True
FIRST_SUPERUSER: str = "langflow"
FIRST_SUPERUSER_PASSWORD: str = "langflow"
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{API_V1_STR}/login")
class Config:
validate_assignment = True
extra = "ignore"
env_prefix = "LANGFLOW_"

View file

@ -3,7 +3,6 @@ import json
import orjson
import os
from shutil import copy2
import secrets
from typing import Optional, List
from pathlib import Path
@ -42,24 +41,6 @@ class Settings(BaseSettings):
REMOVE_API_KEYS: bool = False
COMPONENTS_PATH: List[str] = []
# Login settings
SECRET_KEY: str = secrets.token_hex(32)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
REFRESH_TOKEN_EXPIRE_MINUTES: int = 70
# API Key to execute /process endpoint
API_KEY_SECRET_KEY: Optional[
str
] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a"
API_KEY_ALGORITHM: str = "HS256"
# If AUTO_LOGIN = True
# > The application does not request login and logs in automatically as a super user.
AUTO_LOGIN: bool = True
FIRST_SUPERUSER: str = "langflow"
FIRST_SUPERUSER_PASSWORD: str = "langflow"
@validator("CONFIG_DIR", pre=True, allow_reuse=True)
def set_langflow_dir(cls, value):
if not value:

View file

@ -1,4 +1,5 @@
from langflow.services.base import Service
from langflow.services.settings.auth import AuthSettings
from langflow.services.settings.base import Settings
from langflow.utils.logger import logger
import os
@ -8,9 +9,10 @@ import yaml
class SettingsManager(Service):
name = "settings_manager"
def __init__(self, settings: Settings):
def __init__(self, settings: Settings, auth_settings: AuthSettings):
super().__init__()
self.settings = settings
self.auth_settings = auth_settings
@classmethod
def load_settings_from_yaml(cls, file_path: str) -> "SettingsManager":
@ -33,4 +35,5 @@ class SettingsManager(Service):
)
settings = Settings(**settings_dict)
return cls(settings)
auth_settings = AuthSettings()
return cls(settings, auth_settings)

View file

@ -5,6 +5,8 @@ from typing import AsyncGenerator, TYPE_CHECKING
from langflow.api.v1.flows import get_session
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.user.user import User, UserCreate
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
@ -155,3 +157,38 @@ def session_getter_fixture(client):
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture
def test_user(client):
user_data = UserCreate(
username="testuser",
password="testpassword",
)
response = client.post("/api/v1/user", json=user_data.dict())
return response.json()
@pytest.fixture(scope="function")
def active_user(session):
user = User(
username="activeuser",
password=get_password_hash(
"testpassword"
), # Assuming password needs to be hashed
is_active=True,
is_superuser=False,
)
session.add(user)
session.commit()
return user
@pytest.fixture
def logged_in_headers(client, active_user):
login_data = {"username": active_user.username, "password": "testpassword"}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 200
tokens = response.json()
a_token = tokens["access_token"]
return {"Authorization": f"Bearer {a_token}"}

50
tests/test_api_key.py Normal file
View file

@ -0,0 +1,50 @@
import pytest
from langflow.services.database.models.api_key import ApiKeyCreate
@pytest.fixture
def api_key(client, logged_in_headers, active_user):
api_key = ApiKeyCreate(name="test-api-key")
response = client.post(
"api/v1/api_key", data=api_key.json(), headers=logged_in_headers
)
assert response.status_code == 200, response.text
return response.json()
def test_get_api_keys(client, logged_in_headers, api_key):
response = client.get("api/v1/api_key", headers=logged_in_headers)
assert response.status_code == 200, response.text
data = response.json()
assert "total_count" in data
assert "user_id" in data
assert "api_keys" in data
assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"])
# assert all api keys in data["api_keys"] are masked
assert all("**" in api_key["api_key"] for api_key in data["api_keys"])
# Add more assertions as needed based on the expected data structure and content
def test_create_api_key(client, logged_in_headers):
api_key_name = "test-api-key"
response = client.post(
"api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers
)
assert response.status_code == 200
data = response.json()
assert "name" in data and data["name"] == api_key_name
assert "api_key" in data
# When creating the API key is returned which is
# the only time the API key is unmasked
assert "**" not in data["api_key"]
def test_delete_api_key(client, logged_in_headers, active_user, api_key):
# Assuming a function to create a test API key, returning the key ID
api_key_id = api_key["id"]
response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers)
assert response.status_code == 200
data = response.json()
assert data["detail"] == "API Key deleted"
# Optionally, add a follow-up check to ensure that the key is actually removed from the database

View file

@ -27,4 +27,4 @@ def test_components_path(runner, client, default_settings):
)
assert result.exit_code == 0, result.stdout
settings_manager = utils.get_settings_manager()
assert temp_dir in settings_manager.settings.COMPONENTS_PATH
assert str(temp_dir) in settings_manager.settings.COMPONENTS_PATH

View file

@ -3,7 +3,7 @@ import orjson
import pytest
from uuid import UUID, uuid4
from sqlalchemy.orm import Session
from sqlmodel import Session
from fastapi.testclient import TestClient

47
tests/test_login.py Normal file
View file

@ -0,0 +1,47 @@
import pytest
from langflow.services.database.models.user import User
from langflow.services.auth.utils import get_password_hash
@pytest.fixture
def test_user():
return User(
username="testuser",
password=get_password_hash(
"testpassword"
), # Assuming password needs to be hashed
is_active=True,
is_superuser=False,
)
def test_login_successful(client, test_user, session):
# Adding the test user to the database
session.add(test_user)
session.commit()
response = client.post(
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
)
assert response.status_code == 200
assert "access_token" in response.json()
def test_login_unsuccessful_wrong_username(client):
response = client.post(
"api/v1/login", data={"username": "wrongusername", "password": "testpassword"}
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"
def test_login_unsuccessful_wrong_password(client, test_user, session):
# Adding the test user to the database
session.add(test_user)
session.commit()
response = client.post(
"api/v1/login", data={"username": "testuser", "password": "wrongpassword"}
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"

213
tests/test_user.py Normal file
View file

@ -0,0 +1,213 @@
from datetime import datetime
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user.user import User
from langflow.services.utils import get_settings_manager
import pytest
from langflow.services.database.models.user import UserUpdate
@pytest.fixture
def super_user(client, session):
return create_super_user(session)
@pytest.fixture
def super_user_headers(client, super_user):
settings_manager = get_settings_manager()
auth_settings = settings_manager.auth_settings
login_data = {
"username": auth_settings.FIRST_SUPERUSER,
"password": auth_settings.FIRST_SUPERUSER_PASSWORD,
}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 200
tokens = response.json()
a_token = tokens["access_token"]
return {"Authorization": f"Bearer {a_token}"}
@pytest.fixture
def deactivated_user(session):
user = User(
username="deactivateduser",
password=get_password_hash("testpassword"),
is_active=False,
is_superuser=False,
last_login_at=datetime.now(),
)
session.add(user)
session.commit()
return user
def test_user_waiting_for_approval(client, session):
# Create a user that is not active and has never logged in
user = User(
username="waitingforapproval",
password=get_password_hash("testpassword"),
is_active=False,
last_login_at=None,
)
session.add(user)
session.commit()
login_data = {"username": "waitingforapproval", "password": "testpassword"}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 400
assert response.json()["detail"] == "Waiting for approval"
def test_deactivated_user_cannot_login(client, deactivated_user):
login_data = {"username": deactivated_user.username, "password": "testpassword"}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 400, response.json()
assert response.json()["detail"] == "Inactive user"
def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_headers):
# Assuming the headers for deactivated_user
response = client.get("/api/v1/users", headers=logged_in_headers)
assert response.status_code == 400, response.json()
assert response.json()["detail"] == "The user doesn't have enough privileges"
def test_data_consistency_after_update(client, active_user, logged_in_headers):
user_id = active_user.id
update_data = UserUpdate(username="newname")
response = client.patch(
f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers
)
assert response.status_code == 200
# Fetch the updated user from the database
response = client.get("/api/v1/user", headers=logged_in_headers)
assert response.json()["username"] == "newname", response.json()
def test_data_consistency_after_delete(client, test_user, super_user_headers):
user_id = test_user["id"]
response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers)
assert response.status_code == 200
# Attempt to fetch the deleted user from the database
response = client.get("/api/v1/users", headers=super_user_headers)
assert response.status_code == 200
assert all(user["id"] != user_id for user in response.json()["users"])
def test_inactive_user(client, session):
# Create a user that is not active and has a last_login_at value
user = User(
username="inactiveuser",
password=get_password_hash("testpassword"),
is_active=False,
last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string
)
session.add(user)
session.commit()
login_data = {"username": "inactiveuser", "password": "testpassword"}
response = client.post("/api/v1/login", data=login_data)
assert response.status_code == 400
assert response.json()["detail"] == "Inactive user"
def test_add_user(client, test_user):
assert test_user["username"] == "testuser"
# This is not used in the Frontend at the moment
# def test_read_current_user(client: TestClient, active_user):
# # First we need to login to get the access token
# login_data = {"username": "testuser", "password": "testpassword"}
# response = client.post("/api/v1/login", data=login_data)
# assert response.status_code == 200
# headers = {"Authorization": f"Bearer {response.json()['access_token']}"}
# response = client.get("/api/v1/user", headers=headers)
# assert response.status_code == 200, response.json()
# assert response.json()["username"] == "testuser"
def test_read_all_users(client, super_user_headers):
response = client.get("/api/v1/users", headers=super_user_headers)
assert response.status_code == 200, response.json()
assert isinstance(response.json()["users"], list)
def test_normal_user_cant_read_all_users(client, logged_in_headers):
response = client.get("/api/v1/users", headers=logged_in_headers)
assert response.status_code == 400, response.json()
assert response.json() == {"detail": "The user doesn't have enough privileges"}
def test_patch_user(client, active_user, logged_in_headers):
user_id = active_user.id
update_data = UserUpdate(
username="newname",
)
response = client.patch(
f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers
)
assert response.status_code == 200, response.json()
def test_patch_user_wrong_id(client, active_user, logged_in_headers):
user_id = "wrong_id"
update_data = UserUpdate(
username="newname",
)
response = client.patch(
f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers
)
assert response.status_code == 422, response.json()
assert response.json() == {
"detail": [
{
"loc": ["path", "user_id"],
"msg": "value is not a valid uuid",
"type": "type_error.uuid",
}
]
}
def test_delete_user(client, test_user, super_user_headers):
user_id = test_user["id"]
response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers)
assert response.status_code == 200
assert response.json() == {"detail": "User deleted"}
def test_delete_user_wrong_id(client, test_user, super_user_headers):
user_id = "wrong_id"
response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers)
assert response.status_code == 422
assert response.json() == {
"detail": [
{
"loc": ["path", "user_id"],
"msg": "value is not a valid uuid",
"type": "type_error.uuid",
}
]
}
def test_normal_user_cant_delete_user(client, test_user, logged_in_headers):
user_id = test_user["id"]
response = client.delete(f"/api/v1/user/{user_id}", headers=logged_in_headers)
assert response.status_code == 400
assert response.json() == {"detail": "The user doesn't have enough privileges"}
# If you still want to test the superuser endpoint
def test_add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(client):
response = client.post("/api/v1/super_user")
assert response.status_code == 200
assert response.json()["username"] == "superuser"