diff --git a/src/backend/base/langflow/api/v1/login.py b/src/backend/base/langflow/api/v1/login.py index e04469c57..05d583e75 100644 --- a/src/backend/base/langflow/api/v1/login.py +++ b/src/backend/base/langflow/api/v1/login.py @@ -21,7 +21,7 @@ router = APIRouter(tags=["Login"]) @router.post("/login", response_model=Token) -async def login_to_get_access_token( +def login_to_get_access_token( response: Response, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: DbSession, diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index 6b61f7167..71dec85ec 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -19,7 +19,7 @@ from langflow.services.database.models.api_key.crud import check_key from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at from langflow.services.database.models.user.model import User, UserRead -from langflow.services.deps import get_session, get_settings_service +from langflow.services.deps import get_db_service, get_session, get_settings_service from langflow.services.settings.service import SettingsService oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) @@ -36,41 +36,42 @@ MINIMUM_KEY_LENGTH = 32 def api_key_security( query_param: Annotated[str, Security(api_key_query)], header_param: Annotated[str, Security(api_key_header)], - db: Annotated[Session, Depends(get_session)], ) -> UserRead | None: settings_service = get_settings_service() result: ApiKey | User | None = None - if settings_service.auth_settings.AUTO_LOGIN: - # Get the first user - if not settings_service.auth_settings.SUPERUSER: + + with get_db_service().with_session() as db: + if settings_service.auth_settings.AUTO_LOGIN: + # Get the first user + if not settings_service.auth_settings.SUPERUSER: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing first superuser credentials", + ) + + result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER) + + elif not query_param and not header_param: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Missing first superuser credentials", + status_code=status.HTTP_403_FORBIDDEN, + detail="An API key must be passed as query or header", ) - result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER) + elif query_param: + result = check_key(db, query_param) - elif not query_param and not header_param: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="An API key must be passed as query or header", - ) + else: + result = check_key(db, header_param) - elif query_param: - result = check_key(db, query_param) - - else: - result = check_key(db, header_param) - - if not result: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Invalid or missing API key", - ) - if isinstance(result, ApiKey): - return UserRead.model_validate(result.user, from_attributes=True) - if isinstance(result, User): - return UserRead.model_validate(result, from_attributes=True) + if not result: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid or missing API key", + ) + if isinstance(result, ApiKey): + return UserRead.model_validate(result.user, from_attributes=True) + if isinstance(result, User): + return UserRead.model_validate(result, from_attributes=True) msg = "Invalid result type" raise ValueError(msg) @@ -83,7 +84,7 @@ async def get_current_user( ) -> User: if token: return await get_current_user_by_jwt(token, db) - user = await asyncio.to_thread(api_key_security, query_param, header_param, db) + user = await asyncio.to_thread(api_key_security, query_param, header_param) if user: return user @@ -164,17 +165,17 @@ async def get_current_user_for_websocket( if token: return await get_current_user_by_jwt(token, db) if api_key: - return await asyncio.to_thread(api_key_security, api_key, query_param, db) + return await asyncio.to_thread(api_key_security, api_key, query_param) return None -def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): +async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): if not current_user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") return current_user -def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User: +async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User: if not current_user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") if not current_user.is_superuser: diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 5e8a33dc6..b5059eb79 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import time from contextlib import contextmanager from datetime import datetime, timezone @@ -317,7 +318,7 @@ class DatabaseService(Service): logger.debug("Database and tables created successfully") - async def teardown(self) -> None: + def _teardown(self) -> None: logger.debug("Tearing down database") try: settings_service = get_settings_service() @@ -330,3 +331,6 @@ class DatabaseService(Service): logger.exception("Error tearing down database") self.engine.dispose() + + async def teardown(self) -> None: + await asyncio.to_thread(self._teardown) diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index b1cb950da..a017e7934 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -1,3 +1,5 @@ +import asyncio + from loguru import logger from sqlmodel import Session, select @@ -110,10 +112,15 @@ def teardown_superuser(settings_service, session) -> None: raise RuntimeError(msg) from exc +def _teardown_superuser(): + with get_db_service().with_session() as session: + teardown_superuser(get_settings_service(), session) + + async def teardown_services() -> None: """Teardown all the services.""" try: - teardown_superuser(get_settings_service(), next(get_session())) + await asyncio.to_thread(_teardown_superuser) except Exception as exc: # noqa: BLE001 logger.exception(exc) try: