diff --git a/src/backend/langflow/alembic/versions/5512e39b4012_add_apikey_table.py b/src/backend/langflow/alembic/versions/5512e39b4012_add_apikey_table.py new file mode 100644 index 000000000..02db82e71 --- /dev/null +++ b/src/backend/langflow/alembic/versions/5512e39b4012_add_apikey_table.py @@ -0,0 +1,100 @@ +"""Add ApiKey table + +Revision ID: 5512e39b4012 +Revises: 0a534bdfd84b +Create Date: 2023-08-23 21:05:51.042203 + +""" + +import contextlib +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = "5512e39b4012" +down_revision: Union[str, None] = "0a534bdfd84b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with contextlib.suppress(sa.exc.OperationalError): + op.create_table( + "apikey", + sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("create_at", sa.DateTime(), nullable=False), + sa.Column("last_used_at", sa.DateTime(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_index(op.f("ix_apikey_api_key"), "apikey", ["api_key"], unique=True) + + with contextlib.suppress(sa.exc.OperationalError): + op.create_table( + "user", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_superuser", sa.Boolean(), nullable=False), + sa.Column("create_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.Column("last_login_at", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_index(op.f("ix_user_username"), "user", ["username"], unique=True) + with contextlib.suppress(sa.exc.OperationalError): + op.drop_table("flowstyle") + with contextlib.suppress(sa.exc.OperationalError): + op.drop_index("ix_component_frontend_node_id", table_name="component") + op.drop_index("ix_component_name", table_name="component") + op.drop_table("component") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "component", + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("frontend_node_id", sa.CHAR(length=32), nullable=False), + sa.Column("name", sa.VARCHAR(), nullable=False), + sa.Column("description", sa.VARCHAR(), nullable=True), + sa.Column("python_code", sa.VARCHAR(), nullable=True), + sa.Column("return_type", sa.VARCHAR(), nullable=True), + sa.Column("is_disabled", sa.BOOLEAN(), nullable=False), + sa.Column("is_read_only", sa.BOOLEAN(), nullable=False), + sa.Column("create_at", sa.DATETIME(), nullable=False), + sa.Column("update_at", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_component_name", "component", ["name"], unique=False) + op.create_index( + "ix_component_frontend_node_id", "component", ["frontend_node_id"], unique=False + ) + op.create_table( + "flowstyle", + sa.Column("color", sa.VARCHAR(), nullable=False), + sa.Column("emoji", sa.VARCHAR(), nullable=False), + sa.Column("flow_id", sa.CHAR(length=32), nullable=True), + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.ForeignKeyConstraint( + ["flow_id"], + ["flow.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.drop_index(op.f("ix_user_username"), table_name="user") + op.drop_table("user") + op.drop_index(op.f("ix_apikey_api_key"), table_name="apikey") + op.drop_table("apikey") + # ### end Alembic commands ### diff --git a/src/backend/langflow/api/router.py b/src/backend/langflow/api/router.py index ea1938a75..dbaf20e75 100644 --- a/src/backend/langflow/api/router.py +++ b/src/backend/langflow/api/router.py @@ -6,6 +6,9 @@ from langflow.api.v1 import ( validate_router, flows_router, component_router, + users_router, + api_key_router, + login_router, ) router = APIRouter( @@ -16,3 +19,6 @@ router.include_router(endpoints_router) router.include_router(validate_router) router.include_router(component_router) router.include_router(flows_router) +router.include_router(users_router) +router.include_router(api_key_router) +router.include_router(login_router) diff --git a/src/backend/langflow/api/v1/__init__.py b/src/backend/langflow/api/v1/__init__.py index b6e7b36d8..9335a4607 100644 --- a/src/backend/langflow/api/v1/__init__.py +++ b/src/backend/langflow/api/v1/__init__.py @@ -3,6 +3,9 @@ from langflow.api.v1.validate import router as validate_router from langflow.api.v1.chat import router as chat_router from langflow.api.v1.flows import router as flows_router from langflow.api.v1.components import router as component_router +from langflow.api.v1.users import router as users_router +from langflow.api.v1.api_key import router as api_key_router +from langflow.api.v1.login import router as login_router __all__ = [ "chat_router", @@ -10,4 +13,7 @@ __all__ = [ "component_router", "validate_router", "flows_router", + "users_router", + "api_key_router", + "login_router", ] diff --git a/src/backend/langflow/api/v1/api_key.py b/src/backend/langflow/api/v1/api_key.py new file mode 100644 index 000000000..df2d3e420 --- /dev/null +++ b/src/backend/langflow/api/v1/api_key.py @@ -0,0 +1,61 @@ +from uuid import UUID +from fastapi import APIRouter, HTTPException, Depends +from langflow.api.v1.schemas import ApiKeysResponse +from langflow.services.auth.utils import get_current_active_user +from langflow.services.database.models.api_key.api_key import ( + ApiKeyCreate, + UnmaskedApiKeyRead, +) + +# Assuming you have these methods in your service layer +from langflow.services.database.models.api_key.crud import ( + get_api_keys, + create_api_key, + delete_api_key, +) +from langflow.services.database.models.user.user import User +from langflow.services.utils import get_session +from sqlmodel import Session + + +router = APIRouter(tags=["APIKey"]) + + +@router.get("/api_key", response_model=ApiKeysResponse) +def get_api_keys_route( + db: Session = Depends(get_session), + current_user: User = Depends(get_current_active_user), +): + try: + user_id = current_user.id + keys = get_api_keys(db, user_id) + + return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/api_key", response_model=UnmaskedApiKeyRead) +def create_api_key_route( + req: ApiKeyCreate, + current_user: User = Depends(get_current_active_user), + db: Session = Depends(get_session), +): + try: + user_id = current_user.id + return create_api_key(db, req, user_id=user_id) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + +@router.delete("/api_key/{api_key_id}") +def delete_api_key_route( + api_key_id: UUID, + current_user=Depends(get_current_active_user), + db: Session = Depends(get_session), +): + try: + delete_api_key(db, api_key_id) + return {"detail": "API Key deleted"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index a39b6bc20..46acd4683 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -58,8 +58,12 @@ def get_all(): logger.info(f"Loading {len(custom_component_dicts)} category(ies)") for custom_component_dict in custom_component_dicts: - logger.debug( - {key: len(value) for key, value in custom_component_dict.items()} + # custom_component_dict is a dict of dicts + if not custom_component_dict: + continue + category = list(custom_component_dict.keys())[0] + logger.info( + f"Loading {len(custom_component_dict[category])} component(s) from category {category}" ) custom_components_from_file = merge_nested_dicts_with_renaming( custom_components_from_file, custom_component_dict diff --git a/src/backend/langflow/routers/login.py b/src/backend/langflow/api/v1/login.py similarity index 88% rename from src/backend/langflow/routers/login.py rename to src/backend/langflow/api/v1/login.py index de255a0d5..a11167a40 100644 --- a/src/backend/langflow/routers/login.py +++ b/src/backend/langflow/api/v1/login.py @@ -1,10 +1,10 @@ -from sqlalchemy.orm import Session +from sqlmodel import Session from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from langflow.services.utils import get_session -from langflow.database.models.token import Token -from langflow.auth.auth import ( +from langflow.services.database.models import Token +from langflow.services.auth.utils import ( authenticate_user, create_user_tokens, create_refresh_token, @@ -13,7 +13,7 @@ from langflow.auth.auth import ( from langflow.services.utils import get_settings_manager -router = APIRouter() +router = APIRouter(tags=["Login"]) @router.post("/login", response_model=Token) @@ -36,7 +36,7 @@ async def login_to_get_access_token( async def auto_login(db: Session = Depends(get_session)): settings_manager = get_settings_manager() - if settings_manager.settings.AUTO_LOGIN: + if settings_manager.auth_settings.AUTO_LOGIN: return create_user_longterm_token(db) raise HTTPException( diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 28fa40389..47f58d830 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -1,7 +1,10 @@ from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union +from uuid import UUID +from langflow.services.database.models.api_key.api_key import ApiKeyRead from langflow.services.database.models.flow import FlowCreate, FlowRead +from langflow.services.database.models.user import UserRead from langflow.services.database.models.base import orjson_dumps from pydantic import BaseModel, Field, validator @@ -137,3 +140,26 @@ class ComponentListCreate(BaseModel): class ComponentListRead(BaseModel): flows: List[FlowRead] + + +class UsersResponse(BaseModel): + total_count: int + users: List[UserRead] + + +class ApiKeyResponse(BaseModel): + id: str + api_key: str + name: str + created_at: str + last_used_at: str + + +class ApiKeysResponse(BaseModel): + total_count: int + user_id: UUID + api_keys: List[ApiKeyRead] + + +class CreateApiKeyRequest(BaseModel): + name: str diff --git a/src/backend/langflow/routers/users.py b/src/backend/langflow/api/v1/users.py similarity index 68% rename from src/backend/langflow/routers/users.py rename to src/backend/langflow/api/v1/users.py index 04972c976..140ee773f 100644 --- a/src/backend/langflow/routers/users.py +++ b/src/backend/langflow/api/v1/users.py @@ -1,4 +1,11 @@ from uuid import UUID +from langflow.api.v1.schemas import UsersResponse +from langflow.services.database.models.user import ( + User, + UserCreate, + UserRead, + UserUpdate, +) from sqlalchemy import func from sqlalchemy.exc import IntegrityError @@ -7,28 +14,27 @@ from sqlmodel import Session, select from fastapi import APIRouter, Depends, HTTPException from langflow.services.utils import get_session -from langflow.auth.auth import get_current_active_user, get_password_hash -from langflow.database.models.user import ( - User, - UserAddModel, - UserListModel, - UserPatchModel, - UsersResponse, +from langflow.services.auth.utils import ( + get_current_active_superuser, + get_current_active_user, + get_password_hash, +) +from langflow.services.database.models.user.crud import ( update_user, ) -router = APIRouter(tags=["Login"]) +router = APIRouter(tags=["Users"]) -@router.post("/user", response_model=UserListModel) +@router.post("/user", response_model=UserRead, status_code=201) def add_user( - user: UserAddModel, + user: UserCreate, db: Session = Depends(get_session), ) -> User: """ Add a new user to the database. """ - new_user = User(**user.dict()) + new_user = User.from_orm(user) try: new_user.password = get_password_hash(user.password) @@ -42,8 +48,10 @@ def add_user( return new_user -@router.get("/user", response_model=UserListModel) -def read_current_user(current_user: User = Depends(get_current_active_user)) -> User: +@router.get("/user", response_model=UserRead) +def read_current_user( + current_user: User = Depends(get_current_active_user), +) -> User: """ Retrieve the current user's data. """ @@ -54,7 +62,7 @@ def read_current_user(current_user: User = Depends(get_current_active_user)) -> def read_all_users( skip: int = 0, limit: int = 10, - _: Session = Depends(get_current_active_user), + current_user: Session = Depends(get_current_active_superuser), db: Session = Depends(get_session), ) -> UsersResponse: """ @@ -68,14 +76,14 @@ def read_all_users( return UsersResponse( total_count=total_count, # type: ignore - users=[UserListModel(**dict(user.User)) for user in users], + users=[UserRead(**dict(user.User)) for user in users], ) -@router.patch("/user/{user_id}", response_model=UserListModel) +@router.patch("/user/{user_id}", response_model=UserRead) def patch_user( user_id: UUID, - user: UserPatchModel, + user: UserUpdate, _: Session = Depends(get_current_active_user), db: Session = Depends(get_session), ) -> User: @@ -88,12 +96,21 @@ def patch_user( @router.delete("/user/{user_id}") def delete_user( user_id: UUID, - _: Session = Depends(get_current_active_user), + current_user: User = Depends(get_current_active_superuser), db: Session = Depends(get_session), ) -> dict: """ Delete a user from the database. """ + if current_user.id == user_id: + raise HTTPException( + status_code=400, detail="You can't delete your own user account" + ) + elif not current_user.is_superuser: + raise HTTPException( + status_code=403, detail="You don't have the permission to delete this user" + ) + user_db = db.query(User).filter(User.id == user_id).first() if not user_db: raise HTTPException(status_code=404, detail="User not found") diff --git a/src/backend/langflow/database/models/user.py b/src/backend/langflow/database/models/user.py deleted file mode 100644 index 94ceb4e15..000000000 --- a/src/backend/langflow/database/models/user.py +++ /dev/null @@ -1,94 +0,0 @@ -from sqlmodel import Field -from uuid import UUID, uuid4 -from pydantic import BaseModel -from typing import Optional, List -from sqlalchemy.orm import Session -from datetime import timezone, datetime -from sqlalchemy.exc import IntegrityError -from fastapi import HTTPException, Depends - -from langflow.services.utils import get_session -from langflow.services.database.models.base import SQLModelSerializable, SQLModel - - -class User(SQLModelSerializable, table=True): - id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) - username: str = Field(index=True, unique=True) - password: str = Field() - is_active: bool = Field(default=False) - is_superuser: bool = Field(default=False) - create_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow) - last_login_at: Optional[datetime] = Field() - - -class UserAddModel(SQLModel): - username: str = Field() - password: str = Field() - - -class UserListModel(SQLModel): - id: UUID = Field(default_factory=uuid4) - username: str = Field() - is_active: bool = Field() - is_superuser: bool = Field() - create_at: datetime = Field() - updated_at: datetime = Field() - last_login_at: Optional[datetime] = Field() - - -class UserPatchModel(SQLModel): - username: Optional[str] = Field() - is_active: Optional[bool] = Field() - is_superuser: Optional[bool] = Field() - last_login_at: Optional[datetime] = Field() - - -class UsersResponse(BaseModel): - total_count: int - users: List[UserListModel] - - -def get_user_by_username(db: Session, username: str) -> User: - db_user = db.query(User).filter(User.username == username).first() - return User.from_orm(db_user) if db_user else None # type: ignore - - -def get_user_by_id(db: Session, id: UUID) -> User: - db_user = db.query(User).filter(User.id == id).first() - return User.from_orm(db_user) if db_user else None # type: ignore - - -def update_user( - user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session) -) -> User: - user_db = get_user_by_username(db, user.username) # type: ignore - if user_db and user_db.id != user_id: - raise HTTPException(status_code=409, detail="Username already exists") - - user_db = get_user_by_id(db, user_id) - if not user_db: - raise HTTPException(status_code=404, detail="User not found") - - try: - user_data = user.dict(exclude_unset=True) - for key, value in user_data.items(): - setattr(user_db, key, value) - - user_db.updated_at = datetime.now(timezone.utc) - user_db = db.merge(user_db) - db.commit() - if db.identity_key(instance=user_db) is not None: - db.refresh(user_db) - - except IntegrityError as e: - db.rollback() - raise HTTPException(status_code=400, detail=str(e)) from e - - return user_db - - -def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)): - user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore - - return update_user(user_id, user_data, db) diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 7045ec99d..a383a2afa 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -6,7 +6,7 @@ from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from langflow.api import router -from langflow.routers import api_key, login, users, health + from langflow.interface.utils import setup_llm_caching from langflow.services.database.utils import initialize_database @@ -31,11 +31,6 @@ def create_app(): allow_headers=["*"], ) - app.include_router(login.router) - app.include_router(api_key.router) - app.include_router(users.router) - app.include_router(health.router) - app.include_router(router) app.on_event("startup")(initialize_services) diff --git a/src/backend/langflow/routers/__init__.py b/src/backend/langflow/routers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/backend/langflow/routers/api_key.py b/src/backend/langflow/routers/api_key.py deleted file mode 100644 index 9fa6acec5..000000000 --- a/src/backend/langflow/routers/api_key.py +++ /dev/null @@ -1,49 +0,0 @@ -from fastapi import APIRouter - - -router = APIRouter(tags=["APIKey"]) - - -@router.get("/api_key/{user_id}") -def get_api_key(user_id: str): - return { - "total_count": 3, - "user_id": user_id, - "api_keys": [ - { - "id": "4425707e-cce4-4d1b-a54e-bd2632064657", - "api_key": "lf-...abcd", - "name": "my api_key name - 01", - "created_at": "2023-08-15T19:28:40.019613", - "last_used_at": "2023-08-16T18:38:20.875210", - }, - { - "id": "6fb7282b-9f2e-4efe-9bda-0c3d8f899473", - "api_key": "lf-...abcd", - "name": "my api_key name - 02", - "created_at": "2023-08-15T19:41:30.077942", - "last_used_at": "2023-08-15T19:45:32.067899", - }, - { - "id": "c55f3b32-4920-42b6-a5cd-698b4251806e", - "api_key": "lf-...abcd", - "name": "my api_key name - 03", - "created_at": "2023-08-15T20:29:40.577808", - "last_used_at": "2023-08-15T20:29:40.577816", - }, - ], - } - - -@router.post("/api_key/{user_id}") -def create_api_key(user_id: str): - return { - "user_id": user_id, - "name": "my api-key 01", - "api_key": "lf-eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI1YTBmODM1ZS0yMTQxLTQ2YWItYmQ4NS0yMWEzMjQ1MTE2ZDAiLCJleHAiOjE2OTIyMTUwMTN9.c_s0ZPRtjSI9yUrhi8ACIwyXf0feRLYfaeIZEbRVKQg", - } - - -@router.delete("/api_key/{api_key_id}") -def delete_api_key(api_key_id: str): - return {"detail": "API Key deleted"} diff --git a/src/backend/langflow/routers/health.py b/src/backend/langflow/routers/health.py deleted file mode 100644 index 244ef001d..000000000 --- a/src/backend/langflow/routers/health.py +++ /dev/null @@ -1,8 +0,0 @@ -from fastapi import APIRouter - -router = APIRouter() - - -@router.get("/health") -def get_health(): - return {"status": "OK"} diff --git a/src/backend/langflow/auth/__init__.py b/src/backend/langflow/services/auth/__init__.py similarity index 100% rename from src/backend/langflow/auth/__init__.py rename to src/backend/langflow/services/auth/__init__.py diff --git a/src/backend/langflow/services/auth/factory.py b/src/backend/langflow/services/auth/factory.py new file mode 100644 index 000000000..4914ce645 --- /dev/null +++ b/src/backend/langflow/services/auth/factory.py @@ -0,0 +1,12 @@ +from langflow.services.factory import ServiceFactory +from langflow.services.auth.service import AuthManager + + +class AuthManagerFactory(ServiceFactory): + name = "auth_manager" + + def __init__(self): + super().__init__(AuthManager) + + def create(self, settings_manager): + return AuthManager(settings_manager) diff --git a/src/backend/langflow/services/auth/service.py b/src/backend/langflow/services/auth/service.py new file mode 100644 index 000000000..c80b984bb --- /dev/null +++ b/src/backend/langflow/services/auth/service.py @@ -0,0 +1,18 @@ +from fastapi import Request +from langflow.services.base import Service +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langflow.services.settings.manager import SettingsManager + + +class AuthManager(Service): + name = "auth_manager" + + def __init__(self, settings_manager: "SettingsManager"): + self.settings_manager = settings_manager + + # We need to define a function that can be passed to the Depends() function. + # This function will be called by FastAPI to run oauth2_scheme + def run_oauth2_scheme(self, request: Request): + return self.settings_manager.auth_settings.oauth2_scheme(request=request) diff --git a/src/backend/langflow/auth/auth.py b/src/backend/langflow/services/auth/utils.py similarity index 68% rename from src/backend/langflow/auth/auth.py rename to src/backend/langflow/services/auth/utils.py index 9d4f12862..8cc67d216 100644 --- a/src/backend/langflow/auth/auth.py +++ b/src/backend/langflow/services/auth/utils.py @@ -1,28 +1,30 @@ -from uuid import UUID -from typing import Annotated -from jose import JWTError, jwt -from sqlalchemy.orm import Session -from passlib.context import CryptContext -from fastapi.security import OAuth2PasswordBearer -from fastapi import Depends, HTTPException, status from datetime import datetime, timedelta, timezone - -from langflow.services.utils import get_settings_manager, get_session - -from langflow.database.models.user import ( - User, +from fastapi import Depends, HTTPException, Request, status +from jose import JWTError, jwt +from typing import Annotated, Coroutine +from uuid import UUID +from langflow.services.auth.service import AuthManager +from langflow.services.database.models.user.user import User +from langflow.services.database.models.user.crud import ( get_user_by_id, get_user_by_username, update_user_last_login_at, ) +from langflow.services.utils import get_session, get_settings_manager +from sqlmodel import Session -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") +def auth_scheme_dependency(request: Request): + settings_manager = ( + get_settings_manager() + ) # Assuming get_settings_manager is defined + + return AuthManager(settings_manager).run_oauth2_scheme(request) async def get_current_user( - token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session) + token: Annotated[str, Depends(auth_scheme_dependency)], + db: Session = Depends(get_session), ) -> User: settings_manager = get_settings_manager() @@ -32,11 +34,14 @@ async def get_current_user( headers={"WWW-Authenticate": "Bearer"}, ) + if isinstance(token, Coroutine): + token = await token + try: payload = jwt.decode( token, - settings_manager.settings.SECRET_KEY, - algorithms=[settings_manager.settings.ALGORITHM], + settings_manager.auth_settings.SECRET_KEY, + algorithms=[settings_manager.auth_settings.ALGORITHM], ) user_id: UUID = payload.get("sub") # type: ignore token_type: str = payload.get("type") # type: ignore @@ -47,25 +52,39 @@ async def get_current_user( raise credentials_exception from e user = get_user_by_id(db, user_id) # type: ignore - if user is None: + if user is None or not user.is_active: raise credentials_exception return user -async def get_current_active_user( - current_user: Annotated[User, Depends(get_current_user)] -): +def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user") return current_user +def get_current_active_superuser( + current_user: Annotated[User, Depends(get_current_user)] +) -> User: + if not current_user.is_active: + raise HTTPException(status_code=401, detail="Inactive user") + if not current_user.is_superuser: + raise HTTPException( + status_code=400, detail="The user doesn't have enough privileges" + ) + return current_user + + def verify_password(plain_password, hashed_password): - return pwd_context.verify(plain_password, hashed_password) + settings_manager = get_settings_manager() + return settings_manager.auth_settings.pwd_context.verify( + plain_password, hashed_password + ) def get_password_hash(password): - return pwd_context.hash(password) + settings_manager = get_settings_manager() + return settings_manager.auth_settings.pwd_context.hash(password) def create_token(data: dict, expires_delta: timedelta): @@ -77,21 +96,23 @@ def create_token(data: dict, expires_delta: timedelta): return jwt.encode( to_encode, - settings_manager.settings.SECRET_KEY, - algorithm=settings_manager.settings.ALGORITHM, + settings_manager.auth_settings.SECRET_KEY, + algorithm=settings_manager.auth_settings.ALGORITHM, ) def create_super_user(db: Session = Depends(get_session)) -> User: settings_manager = get_settings_manager() - super_user = get_user_by_username(db, settings_manager.settings.FIRST_SUPERUSER) + super_user = get_user_by_username( + db, settings_manager.auth_settings.FIRST_SUPERUSER + ) if not super_user: super_user = User( - username=settings_manager.settings.FIRST_SUPERUSER, + username=settings_manager.auth_settings.FIRST_SUPERUSER, password=get_password_hash( - settings_manager.settings.FIRST_SUPERUSER_PASSWORD + settings_manager.auth_settings.FIRST_SUPERUSER_PASSWORD ), is_superuser=True, is_active=True, @@ -147,7 +168,7 @@ def create_user_tokens( settings_manager = get_settings_manager() access_token_expires = timedelta( - minutes=settings_manager.settings.ACCESS_TOKEN_EXPIRE_MINUTES + minutes=settings_manager.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES ) access_token = create_token( data={"sub": str(user_id)}, @@ -155,7 +176,7 @@ def create_user_tokens( ) refresh_token_expires = timedelta( - minutes=settings_manager.settings.REFRESH_TOKEN_EXPIRE_MINUTES + minutes=settings_manager.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES ) refresh_token = create_token( data={"sub": str(user_id), "type": "rf"}, @@ -179,8 +200,8 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)) try: payload = jwt.decode( refresh_token, - settings_manager.settings.SECRET_KEY, - algorithms=[settings_manager.settings.ALGORITHM], + settings_manager.auth_settings.SECRET_KEY, + algorithms=[settings_manager.auth_settings.ALGORITHM], ) user_id: UUID = payload.get("sub") # type: ignore token_type: str = payload.get("type") # type: ignore diff --git a/src/backend/langflow/services/cache/factory.py b/src/backend/langflow/services/cache/factory.py index 77f8d58d1..f180f67c0 100644 --- a/src/backend/langflow/services/cache/factory.py +++ b/src/backend/langflow/services/cache/factory.py @@ -6,6 +6,6 @@ class CacheManagerFactory(ServiceFactory): def __init__(self): super().__init__(CacheManager) - def create(self, settings_service): + def create(self): # Here you would have logic to create and configure a CacheManager return CacheManager() diff --git a/src/backend/langflow/services/chat/factory.py b/src/backend/langflow/services/chat/factory.py index 03597ed11..ca844893a 100644 --- a/src/backend/langflow/services/chat/factory.py +++ b/src/backend/langflow/services/chat/factory.py @@ -6,6 +6,6 @@ class ChatManagerFactory(ServiceFactory): def __init__(self): super().__init__(ChatManager) - def create(self, settings_service): + def create(self): # Here you would have logic to create and configure a ChatManager return ChatManager() diff --git a/src/backend/langflow/services/database/factory.py b/src/backend/langflow/services/database/factory.py index fecf24543..25427b7b9 100644 --- a/src/backend/langflow/services/database/factory.py +++ b/src/backend/langflow/services/database/factory.py @@ -10,8 +10,8 @@ class DatabaseManagerFactory(ServiceFactory): def __init__(self): super().__init__(DatabaseManager) - def create(self, settings_service: "SettingsManager"): + def create(self, settings_manager: "SettingsManager"): # Here you would have logic to create and configure a DatabaseManager - if not settings_service.settings.DATABASE_URL: + if not settings_manager.settings.DATABASE_URL: raise ValueError("No database URL provided") - return DatabaseManager(settings_service.settings.DATABASE_URL) + return DatabaseManager(settings_manager.settings.DATABASE_URL) diff --git a/src/backend/langflow/services/database/models/__init__.py b/src/backend/langflow/services/database/models/__init__.py index da47bc5fe..01e81e277 100644 --- a/src/backend/langflow/services/database/models/__init__.py +++ b/src/backend/langflow/services/database/models/__init__.py @@ -1,4 +1,6 @@ from .flow import Flow +from .user import User +from .token import Token +from .api_key import ApiKey - -__all__ = ["Flow"] +__all__ = ["Flow", "User", "Token", "ApiKey"] diff --git a/src/backend/langflow/services/database/models/api_key/__init__.py b/src/backend/langflow/services/database/models/api_key/__init__.py new file mode 100644 index 000000000..fbb8265b9 --- /dev/null +++ b/src/backend/langflow/services/database/models/api_key/__init__.py @@ -0,0 +1,3 @@ +from .api_key import ApiKey, ApiKeyCreate, UnmaskedApiKeyRead, ApiKeyRead + +__all__ = ["ApiKey", "ApiKeyCreate", "UnmaskedApiKeyRead", "ApiKeyRead"] diff --git a/src/backend/langflow/services/database/models/api_key/api_key.py b/src/backend/langflow/services/database/models/api_key/api_key.py new file mode 100644 index 000000000..601d060b5 --- /dev/null +++ b/src/backend/langflow/services/database/models/api_key/api_key.py @@ -0,0 +1,45 @@ +from pydantic import validator +from sqlmodel import Field, Relationship +from uuid import UUID, uuid4 +from typing import Optional, TYPE_CHECKING +from datetime import datetime +from langflow.services.database.models.base import SQLModelSerializable + +if TYPE_CHECKING: + from langflow.services.database.models.user import User + + +class ApiKeyBase(SQLModelSerializable): + name: Optional[str] = Field(index=True) + created_at: datetime = Field(default_factory=datetime.utcnow) + last_used_at: Optional[datetime] = Field(default=None) + + +class ApiKey(ApiKeyBase, table=True): + id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) + api_key: str = Field(index=True, unique=True) + # User relationship + user_id: UUID = Field(index=True, foreign_key="user.id") + user: "User" = Relationship(back_populates="api_keys") + + +class ApiKeyCreate(ApiKeyBase): + api_key: Optional[str] = None + user_id: Optional[UUID] = None + + +class UnmaskedApiKeyRead(ApiKeyBase): + id: UUID + api_key: str = Field(index=True, unique=True) + user_id: UUID = Field() + + +class ApiKeyRead(ApiKeyBase): + id: UUID + api_key: str = Field(index=True, unique=True) + user_id: UUID = Field() + + @validator("api_key", always=True) + def mask_api_key(cls, v): + # This validator will always run, and will mask the API key + return f"{'*' * 8}{v[-4:]}" diff --git a/src/backend/langflow/services/database/models/api_key/crud.py b/src/backend/langflow/services/database/models/api_key/crud.py new file mode 100644 index 000000000..af697b6d5 --- /dev/null +++ b/src/backend/langflow/services/database/models/api_key/crud.py @@ -0,0 +1,45 @@ +import secrets +from uuid import UUID +from typing import List +from langflow.services.auth.utils import get_password_hash +from sqlmodel import Session, select +from langflow.services.database.models.api_key import ( + ApiKey, + ApiKeyCreate, + UnmaskedApiKeyRead, + ApiKeyRead, +) + + +def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]: + query = select(ApiKey).where(ApiKey.user_id == user_id) + api_keys = session.exec(query).all() + return [ApiKeyRead.from_orm(api_key) for api_key in api_keys] + + +def create_api_key( + session: Session, api_key_create: ApiKeyCreate, user_id: UUID +) -> UnmaskedApiKeyRead: + # Generate a random API key with 32 bytes of randomness + generated_api_key = secrets.token_urlsafe(32) + + # hash the API key + hashed_api_key = get_password_hash(generated_api_key) + # Use the generated key to create the ApiKey object + + api_key = ApiKey(api_key=hashed_api_key, name=api_key_create.name, user_id=user_id) + + session.add(api_key) + session.commit() + session.refresh(api_key) + unmasked = UnmaskedApiKeyRead.from_orm(api_key) + unmasked.api_key = generated_api_key + return unmasked + + +def delete_api_key(session: Session, api_key_id: UUID) -> None: + api_key = session.get(ApiKey, api_key_id) + if api_key is None: + raise ValueError("API Key not found") + session.delete(api_key) + session.commit() diff --git a/src/backend/langflow/services/database/models/component/__init__.py b/src/backend/langflow/services/database/models/component/__init__.py new file mode 100644 index 000000000..c787c3e04 --- /dev/null +++ b/src/backend/langflow/services/database/models/component/__init__.py @@ -0,0 +1,3 @@ +from .component import Component, ComponentModel + +__all__ = ["Component", "ComponentModel"] diff --git a/src/backend/langflow/services/database/models/component.py b/src/backend/langflow/services/database/models/component/component.py similarity index 100% rename from src/backend/langflow/services/database/models/component.py rename to src/backend/langflow/services/database/models/component/component.py diff --git a/src/backend/langflow/services/database/models/flow/__init__.py b/src/backend/langflow/services/database/models/flow/__init__.py new file mode 100644 index 000000000..7c7cc0172 --- /dev/null +++ b/src/backend/langflow/services/database/models/flow/__init__.py @@ -0,0 +1,3 @@ +from .flow import Flow, FlowCreate, FlowRead, FlowUpdate + +__all__ = ["Flow", "FlowCreate", "FlowRead", "FlowUpdate"] diff --git a/src/backend/langflow/services/database/models/flow.py b/src/backend/langflow/services/database/models/flow/flow.py similarity index 94% rename from src/backend/langflow/services/database/models/flow.py rename to src/backend/langflow/services/database/models/flow/flow.py index 2bc83f9dc..a05de5791 100644 --- a/src/backend/langflow/services/database/models/flow.py +++ b/src/backend/langflow/services/database/models/flow/flow.py @@ -6,8 +6,6 @@ from sqlmodel import Field, JSON, Column from uuid import UUID, uuid4 from typing import Dict, Optional -# if TYPE_CHECKING: - class FlowBase(SQLModelSerializable): name: str = Field(index=True) @@ -16,7 +14,6 @@ class FlowBase(SQLModelSerializable): @validator("data") def validate_json(v): - # dict_keys(['description', 'name', 'id', 'data']) if not v: return v if not isinstance(v, dict): diff --git a/src/backend/langflow/services/database/models/token/__init__.py b/src/backend/langflow/services/database/models/token/__init__.py new file mode 100644 index 000000000..9b9fa397d --- /dev/null +++ b/src/backend/langflow/services/database/models/token/__init__.py @@ -0,0 +1,5 @@ +from .token import Token + +__all__ = [ + "Token", +] diff --git a/src/backend/langflow/database/models/token.py b/src/backend/langflow/services/database/models/token/token.py similarity index 100% rename from src/backend/langflow/database/models/token.py rename to src/backend/langflow/services/database/models/token/token.py diff --git a/src/backend/langflow/services/database/models/user/__init__.py b/src/backend/langflow/services/database/models/user/__init__.py new file mode 100644 index 000000000..da9170eb7 --- /dev/null +++ b/src/backend/langflow/services/database/models/user/__init__.py @@ -0,0 +1,8 @@ +from .user import User, UserCreate, UserRead, UserUpdate + +__all__ = [ + "User", + "UserCreate", + "UserRead", + "UserUpdate", +] diff --git a/src/backend/langflow/services/database/models/user/crud.py b/src/backend/langflow/services/database/models/user/crud.py new file mode 100644 index 000000000..3dc02a499 --- /dev/null +++ b/src/backend/langflow/services/database/models/user/crud.py @@ -0,0 +1,53 @@ +from datetime import datetime, timezone +from typing import Union +from uuid import UUID +from fastapi import Depends, HTTPException +from langflow.services.database.models.user.user import User, UserUpdate +from langflow.services.utils import get_session +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session + + +from sqlalchemy.orm.attributes import flag_modified + + +def get_user_by_username(db: Session, username: str) -> Union[User, None]: + return db.query(User).filter(User.username == username).first() + + +def get_user_by_id(db: Session, id: UUID) -> Union[User, None]: + return db.query(User).filter(User.id == id).first() + + +def update_user( + user_id: UUID, user: UserUpdate, db: Session = Depends(get_session) +) -> User: + user_db = get_user_by_id(db, user_id) + if not user_db: + raise HTTPException(status_code=404, detail="User not found") + + user_db_by_username = get_user_by_username(db, user.username) # type: ignore + if user_db_by_username and user_db_by_username.id != user_id: + raise HTTPException(status_code=409, detail="Username already exists") + + user_data = user.dict(exclude_unset=True) + for attr, value in user_data.items(): + if hasattr(user_db, attr) and value is not None: + setattr(user_db, attr, value) + + user_db.updated_at = datetime.now(timezone.utc) + flag_modified(user_db, "updated_at") + + try: + db.commit() + except IntegrityError as e: + db.rollback() + raise HTTPException(status_code=400, detail=str(e)) from e + + return user_db + + +def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)): + user_data = UserUpdate(last_login_at=datetime.now(timezone.utc)) # type: ignore + + return update_user(user_id, user_data, db) diff --git a/src/backend/langflow/services/database/models/user/user.py b/src/backend/langflow/services/database/models/user/user.py new file mode 100644 index 000000000..b6c27c2dc --- /dev/null +++ b/src/backend/langflow/services/database/models/user/user.py @@ -0,0 +1,44 @@ +from langflow.services.database.models.base import SQLModel, SQLModelSerializable +from sqlmodel import Field, Relationship + + +from datetime import datetime +from typing import Optional, TYPE_CHECKING +from uuid import UUID, uuid4 + +if TYPE_CHECKING: + from langflow.services.database.models.api_key import ApiKey + + +class User(SQLModelSerializable, table=True): + id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) + username: str = Field(index=True, unique=True) + password: str = Field() + is_active: bool = Field(default=False) + is_superuser: bool = Field(default=False) + create_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + last_login_at: Optional[datetime] = Field() + api_keys: list["ApiKey"] = Relationship(back_populates="user") + + +class UserCreate(SQLModel): + username: str = Field() + password: str = Field() + + +class UserRead(SQLModel): + id: UUID = Field(default_factory=uuid4) + username: str = Field() + is_active: bool = Field() + is_superuser: bool = Field() + create_at: datetime = Field() + updated_at: datetime = Field() + last_login_at: Optional[datetime] = Field() + + +class UserUpdate(SQLModel): + username: Optional[str] = Field() + is_active: Optional[bool] = Field() + is_superuser: Optional[bool] = Field() + last_login_at: Optional[datetime] = Field() diff --git a/src/backend/langflow/services/manager.py b/src/backend/langflow/services/manager.py index f05102d0e..e9895adab 100644 --- a/src/backend/langflow/services/manager.py +++ b/src/backend/langflow/services/manager.py @@ -1,5 +1,5 @@ from langflow.services.schema import ServiceType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: from langflow.services.factory import ServiceFactory @@ -13,13 +13,21 @@ class ServiceManager: def __init__(self): self.services = {} self.factories = {} + self.dependencies = {} - def register_factory(self, service_factory: "ServiceFactory"): + def register_factory( + self, + service_factory: "ServiceFactory", + dependencies: Optional[List[ServiceType]] = None, + ): """ - Registers a new factory. + Registers a new factory with dependencies. """ - if service_factory.service_class.name not in self.factories: - self.factories[service_factory.service_class.name] = service_factory + if dependencies is None: + dependencies = [] + service_name = service_factory.service_class.name + self.factories[service_name] = service_factory + self.dependencies[service_name] = dependencies def get(self, service_name: ServiceType): """ @@ -32,17 +40,25 @@ class ServiceManager: def _create_service(self, service_name: ServiceType): """ - Create a new service given its name. + Create a new service given its name, handling dependencies. """ self._validate_service_creation(service_name) - if service_name == ServiceType.SETTINGS_MANAGER: - self.services[service_name] = self.factories[service_name].create() - else: - settings_service = self.get(ServiceType.SETTINGS_MANAGER) - self.services[service_name] = self.factories[service_name].create( - settings_service - ) + # Create dependencies first + for dependency in self.dependencies.get(service_name, []): + if dependency not in self.services: + self._create_service(dependency) + + # Collect the dependent services + dependent_services = { + dep.value: self.services[dep] + for dep in self.dependencies.get(service_name, []) + } + + # Create the actual service + self.services[service_name] = self.factories[service_name].create( + **dependent_services + ) def _validate_service_creation(self, service_name: ServiceType): """ @@ -53,14 +69,6 @@ class ServiceManager: f"No factory registered for the service class '{service_name.name}'" ) - if ( - ServiceType.SETTINGS_MANAGER not in self.factories - and service_name != ServiceType.SETTINGS_MANAGER - ): - raise ValueError( - f"Cannot create service '{service_name.name}' before the settings service" - ) - def update(self, service_name: ServiceType): """ Update a service by its name. @@ -81,12 +89,24 @@ def initialize_services(): from langflow.services.cache import factory as cache_factory from langflow.services.chat import factory as chat_factory from langflow.services.settings import factory as settings_factory + from langflow.services.auth import factory as auth_factory service_manager.register_factory(settings_factory.SettingsManagerFactory()) - service_manager.register_factory(database_factory.DatabaseManagerFactory()) + service_manager.register_factory( + auth_factory.AuthManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER] + ) + service_manager.register_factory( + database_factory.DatabaseManagerFactory(), + dependencies=[ServiceType.SETTINGS_MANAGER], + ) service_manager.register_factory(cache_factory.CacheManagerFactory()) service_manager.register_factory(chat_factory.ChatManagerFactory()) + # Test cache connection + service_manager.get(ServiceType.CACHE_MANAGER) + # Test database connection + service_manager.get(ServiceType.DATABASE_MANAGER) + def initialize_settings_manager(): """ @@ -95,3 +115,22 @@ def initialize_settings_manager(): from langflow.services.settings import factory as settings_factory service_manager.register_factory(settings_factory.SettingsManagerFactory()) + + +def initialize_session_manager(): + """ + Initialize the session manager. + """ + from langflow.services.session import factory as session_manager_factory + from langflow.services.cache import factory as cache_factory + + initialize_settings_manager() + + service_manager.register_factory( + cache_factory.CacheManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER] + ) + + service_manager.register_factory( + session_manager_factory.SessionManagerFactory(), + dependencies=[ServiceType.CACHE_MANAGER], + ) diff --git a/src/backend/langflow/services/schema.py b/src/backend/langflow/services/schema.py index 695763afc..6291a0d0b 100644 --- a/src/backend/langflow/services/schema.py +++ b/src/backend/langflow/services/schema.py @@ -7,6 +7,7 @@ class ServiceType(str, Enum): registered with the service manager. """ + AUTH_MANAGER = "auth_manager" CACHE_MANAGER = "cache_manager" SETTINGS_MANAGER = "settings_manager" DATABASE_MANAGER = "database_manager" diff --git a/src/backend/langflow/services/settings/auth.py b/src/backend/langflow/services/settings/auth.py new file mode 100644 index 000000000..2aa4e17bc --- /dev/null +++ b/src/backend/langflow/services/settings/auth.py @@ -0,0 +1,35 @@ +from typing import Optional +import secrets + +from pydantic import BaseSettings +from passlib.context import CryptContext +from fastapi.security import OAuth2PasswordBearer + + +class AuthSettings(BaseSettings): + # Login settings + SECRET_KEY: str = secrets.token_hex(32) + ALGORITHM: str = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 + REFRESH_TOKEN_EXPIRE_MINUTES: int = 70 + + # API Key to execute /process endpoint + API_KEY_SECRET_KEY: Optional[ + str + ] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a" + API_KEY_ALGORITHM: str = "HS256" + API_V1_STR: str = "/api/v1" + + # If AUTO_LOGIN = True + # > The application does not request login and logs in automatically as a super user. + AUTO_LOGIN: bool = True + FIRST_SUPERUSER: str = "langflow" + FIRST_SUPERUSER_PASSWORD: str = "langflow" + + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{API_V1_STR}/login") + + class Config: + validate_assignment = True + extra = "ignore" + env_prefix = "LANGFLOW_" diff --git a/src/backend/langflow/services/settings/base.py b/src/backend/langflow/services/settings/base.py index c32c25809..00cd2085f 100644 --- a/src/backend/langflow/services/settings/base.py +++ b/src/backend/langflow/services/settings/base.py @@ -3,7 +3,6 @@ import json import orjson import os from shutil import copy2 -import secrets from typing import Optional, List from pathlib import Path @@ -42,24 +41,6 @@ class Settings(BaseSettings): REMOVE_API_KEYS: bool = False COMPONENTS_PATH: List[str] = [] - # Login settings - SECRET_KEY: str = secrets.token_hex(32) - ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 - REFRESH_TOKEN_EXPIRE_MINUTES: int = 70 - - # API Key to execute /process endpoint - API_KEY_SECRET_KEY: Optional[ - str - ] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a" - API_KEY_ALGORITHM: str = "HS256" - - # If AUTO_LOGIN = True - # > The application does not request login and logs in automatically as a super user. - AUTO_LOGIN: bool = True - FIRST_SUPERUSER: str = "langflow" - FIRST_SUPERUSER_PASSWORD: str = "langflow" - @validator("CONFIG_DIR", pre=True, allow_reuse=True) def set_langflow_dir(cls, value): if not value: diff --git a/src/backend/langflow/services/settings/manager.py b/src/backend/langflow/services/settings/manager.py index a357c4804..1a6c0feeb 100644 --- a/src/backend/langflow/services/settings/manager.py +++ b/src/backend/langflow/services/settings/manager.py @@ -1,4 +1,5 @@ from langflow.services.base import Service +from langflow.services.settings.auth import AuthSettings from langflow.services.settings.base import Settings from langflow.utils.logger import logger import os @@ -8,9 +9,10 @@ import yaml class SettingsManager(Service): name = "settings_manager" - def __init__(self, settings: Settings): + def __init__(self, settings: Settings, auth_settings: AuthSettings): super().__init__() self.settings = settings + self.auth_settings = auth_settings @classmethod def load_settings_from_yaml(cls, file_path: str) -> "SettingsManager": @@ -33,4 +35,5 @@ class SettingsManager(Service): ) settings = Settings(**settings_dict) - return cls(settings) + auth_settings = AuthSettings() + return cls(settings, auth_settings) diff --git a/tests/conftest.py b/tests/conftest.py index e90d03d0a..9abe89d49 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ from typing import AsyncGenerator, TYPE_CHECKING from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.user.user import User, UserCreate import pytest from fastapi.testclient import TestClient from httpx import AsyncClient @@ -155,3 +157,38 @@ def session_getter_fixture(client): @pytest.fixture def runner(): return CliRunner() + + +@pytest.fixture +def test_user(client): + user_data = UserCreate( + username="testuser", + password="testpassword", + ) + response = client.post("/api/v1/user", json=user_data.dict()) + return response.json() + + +@pytest.fixture(scope="function") +def active_user(session): + user = User( + username="activeuser", + password=get_password_hash( + "testpassword" + ), # Assuming password needs to be hashed + is_active=True, + is_superuser=False, + ) + session.add(user) + session.commit() + return user + + +@pytest.fixture +def logged_in_headers(client, active_user): + login_data = {"username": active_user.username, "password": "testpassword"} + response = client.post("/api/v1/login", data=login_data) + assert response.status_code == 200 + tokens = response.json() + a_token = tokens["access_token"] + return {"Authorization": f"Bearer {a_token}"} diff --git a/tests/test_api_key.py b/tests/test_api_key.py new file mode 100644 index 000000000..43b91fa43 --- /dev/null +++ b/tests/test_api_key.py @@ -0,0 +1,50 @@ +import pytest +from langflow.services.database.models.api_key import ApiKeyCreate + + +@pytest.fixture +def api_key(client, logged_in_headers, active_user): + api_key = ApiKeyCreate(name="test-api-key") + + response = client.post( + "api/v1/api_key", data=api_key.json(), headers=logged_in_headers + ) + assert response.status_code == 200, response.text + return response.json() + + +def test_get_api_keys(client, logged_in_headers, api_key): + response = client.get("api/v1/api_key", headers=logged_in_headers) + assert response.status_code == 200, response.text + data = response.json() + assert "total_count" in data + assert "user_id" in data + assert "api_keys" in data + assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"]) + # assert all api keys in data["api_keys"] are masked + assert all("**" in api_key["api_key"] for api_key in data["api_keys"]) + # Add more assertions as needed based on the expected data structure and content + + +def test_create_api_key(client, logged_in_headers): + api_key_name = "test-api-key" + response = client.post( + "api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers + ) + assert response.status_code == 200 + data = response.json() + assert "name" in data and data["name"] == api_key_name + assert "api_key" in data + # When creating the API key is returned which is + # the only time the API key is unmasked + assert "**" not in data["api_key"] + + +def test_delete_api_key(client, logged_in_headers, active_user, api_key): + # Assuming a function to create a test API key, returning the key ID + api_key_id = api_key["id"] + response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers) + assert response.status_code == 200 + data = response.json() + assert data["detail"] == "API Key deleted" + # Optionally, add a follow-up check to ensure that the key is actually removed from the database diff --git a/tests/test_cli.py b/tests/test_cli.py index 408500d7a..c990ef9e8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,4 +27,4 @@ def test_components_path(runner, client, default_settings): ) assert result.exit_code == 0, result.stdout settings_manager = utils.get_settings_manager() - assert temp_dir in settings_manager.settings.COMPONENTS_PATH + assert str(temp_dir) in settings_manager.settings.COMPONENTS_PATH diff --git a/tests/test_database.py b/tests/test_database.py index 6976f963a..48e253026 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,7 +3,7 @@ import orjson import pytest from uuid import UUID, uuid4 -from sqlalchemy.orm import Session +from sqlmodel import Session from fastapi.testclient import TestClient diff --git a/tests/test_login.py b/tests/test_login.py new file mode 100644 index 000000000..07abb35ab --- /dev/null +++ b/tests/test_login.py @@ -0,0 +1,47 @@ +import pytest +from langflow.services.database.models.user import User +from langflow.services.auth.utils import get_password_hash + + +@pytest.fixture +def test_user(): + return User( + username="testuser", + password=get_password_hash( + "testpassword" + ), # Assuming password needs to be hashed + is_active=True, + is_superuser=False, + ) + + +def test_login_successful(client, test_user, session): + # Adding the test user to the database + session.add(test_user) + session.commit() + + response = client.post( + "api/v1/login", data={"username": "testuser", "password": "testpassword"} + ) + assert response.status_code == 200 + assert "access_token" in response.json() + + +def test_login_unsuccessful_wrong_username(client): + response = client.post( + "api/v1/login", data={"username": "wrongusername", "password": "testpassword"} + ) + assert response.status_code == 401 + assert response.json()["detail"] == "Incorrect username or password" + + +def test_login_unsuccessful_wrong_password(client, test_user, session): + # Adding the test user to the database + session.add(test_user) + session.commit() + + response = client.post( + "api/v1/login", data={"username": "testuser", "password": "wrongpassword"} + ) + assert response.status_code == 401 + assert response.json()["detail"] == "Incorrect username or password" diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 000000000..d734e4d61 --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,213 @@ +from datetime import datetime +from langflow.services.auth.utils import create_super_user, get_password_hash + +from langflow.services.database.models.user.user import User +from langflow.services.utils import get_settings_manager +import pytest +from langflow.services.database.models.user import UserUpdate + + +@pytest.fixture +def super_user(client, session): + return create_super_user(session) + + +@pytest.fixture +def super_user_headers(client, super_user): + settings_manager = get_settings_manager() + auth_settings = settings_manager.auth_settings + login_data = { + "username": auth_settings.FIRST_SUPERUSER, + "password": auth_settings.FIRST_SUPERUSER_PASSWORD, + } + response = client.post("/api/v1/login", data=login_data) + assert response.status_code == 200 + tokens = response.json() + a_token = tokens["access_token"] + return {"Authorization": f"Bearer {a_token}"} + + +@pytest.fixture +def deactivated_user(session): + user = User( + username="deactivateduser", + password=get_password_hash("testpassword"), + is_active=False, + is_superuser=False, + last_login_at=datetime.now(), + ) + session.add(user) + session.commit() + return user + + +def test_user_waiting_for_approval(client, session): + # Create a user that is not active and has never logged in + user = User( + username="waitingforapproval", + password=get_password_hash("testpassword"), + is_active=False, + last_login_at=None, + ) + session.add(user) + session.commit() + + login_data = {"username": "waitingforapproval", "password": "testpassword"} + response = client.post("/api/v1/login", data=login_data) + assert response.status_code == 400 + assert response.json()["detail"] == "Waiting for approval" + + +def test_deactivated_user_cannot_login(client, deactivated_user): + login_data = {"username": deactivated_user.username, "password": "testpassword"} + response = client.post("/api/v1/login", data=login_data) + assert response.status_code == 400, response.json() + assert response.json()["detail"] == "Inactive user" + + +def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_headers): + # Assuming the headers for deactivated_user + response = client.get("/api/v1/users", headers=logged_in_headers) + assert response.status_code == 400, response.json() + assert response.json()["detail"] == "The user doesn't have enough privileges" + + +def test_data_consistency_after_update(client, active_user, logged_in_headers): + user_id = active_user.id + update_data = UserUpdate(username="newname") + + response = client.patch( + f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers + ) + assert response.status_code == 200 + + # Fetch the updated user from the database + response = client.get("/api/v1/user", headers=logged_in_headers) + assert response.json()["username"] == "newname", response.json() + + +def test_data_consistency_after_delete(client, test_user, super_user_headers): + user_id = test_user["id"] + response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers) + assert response.status_code == 200 + + # Attempt to fetch the deleted user from the database + response = client.get("/api/v1/users", headers=super_user_headers) + assert response.status_code == 200 + assert all(user["id"] != user_id for user in response.json()["users"]) + + +def test_inactive_user(client, session): + # Create a user that is not active and has a last_login_at value + user = User( + username="inactiveuser", + password=get_password_hash("testpassword"), + is_active=False, + last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string + ) + session.add(user) + session.commit() + + login_data = {"username": "inactiveuser", "password": "testpassword"} + response = client.post("/api/v1/login", data=login_data) + assert response.status_code == 400 + assert response.json()["detail"] == "Inactive user" + + +def test_add_user(client, test_user): + assert test_user["username"] == "testuser" + + +# This is not used in the Frontend at the moment +# def test_read_current_user(client: TestClient, active_user): +# # First we need to login to get the access token +# login_data = {"username": "testuser", "password": "testpassword"} +# response = client.post("/api/v1/login", data=login_data) +# assert response.status_code == 200 + +# headers = {"Authorization": f"Bearer {response.json()['access_token']}"} + +# response = client.get("/api/v1/user", headers=headers) +# assert response.status_code == 200, response.json() +# assert response.json()["username"] == "testuser" + + +def test_read_all_users(client, super_user_headers): + response = client.get("/api/v1/users", headers=super_user_headers) + assert response.status_code == 200, response.json() + assert isinstance(response.json()["users"], list) + + +def test_normal_user_cant_read_all_users(client, logged_in_headers): + response = client.get("/api/v1/users", headers=logged_in_headers) + assert response.status_code == 400, response.json() + assert response.json() == {"detail": "The user doesn't have enough privileges"} + + +def test_patch_user(client, active_user, logged_in_headers): + user_id = active_user.id + update_data = UserUpdate( + username="newname", + ) + + response = client.patch( + f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers + ) + assert response.status_code == 200, response.json() + + +def test_patch_user_wrong_id(client, active_user, logged_in_headers): + user_id = "wrong_id" + update_data = UserUpdate( + username="newname", + ) + + response = client.patch( + f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers + ) + assert response.status_code == 422, response.json() + assert response.json() == { + "detail": [ + { + "loc": ["path", "user_id"], + "msg": "value is not a valid uuid", + "type": "type_error.uuid", + } + ] + } + + +def test_delete_user(client, test_user, super_user_headers): + user_id = test_user["id"] + response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers) + assert response.status_code == 200 + assert response.json() == {"detail": "User deleted"} + + +def test_delete_user_wrong_id(client, test_user, super_user_headers): + user_id = "wrong_id" + response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "loc": ["path", "user_id"], + "msg": "value is not a valid uuid", + "type": "type_error.uuid", + } + ] + } + + +def test_normal_user_cant_delete_user(client, test_user, logged_in_headers): + user_id = test_user["id"] + response = client.delete(f"/api/v1/user/{user_id}", headers=logged_in_headers) + assert response.status_code == 400 + assert response.json() == {"detail": "The user doesn't have enough privileges"} + + +# If you still want to test the superuser endpoint +def test_add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(client): + response = client.post("/api/v1/super_user") + assert response.status_code == 200 + assert response.json()["username"] == "superuser"