fix: Fix db session used in different threads (#4381)
Fix db session used in different threads
This commit is contained in:
parent
e477782c46
commit
965271c6e3
4 changed files with 47 additions and 35 deletions
|
|
@ -21,7 +21,7 @@ router = APIRouter(tags=["Login"])
|
|||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login_to_get_access_token(
|
||||
def login_to_get_access_token(
|
||||
response: Response,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: DbSession,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from langflow.services.database.models.api_key.crud import check_key
|
|||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from langflow.services.deps import get_db_service, get_session, get_settings_service
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
|
||||
|
|
@ -36,41 +36,42 @@ MINIMUM_KEY_LENGTH = 32
|
|||
def api_key_security(
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
header_param: Annotated[str, Security(api_key_header)],
|
||||
db: Annotated[Session, Depends(get_session)],
|
||||
) -> UserRead | None:
|
||||
settings_service = get_settings_service()
|
||||
result: ApiKey | User | None = None
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
# Get the first user
|
||||
if not settings_service.auth_settings.SUPERUSER:
|
||||
|
||||
with get_db_service().with_session() as db:
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
# Get the first user
|
||||
if not settings_service.auth_settings.SUPERUSER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing first superuser credentials",
|
||||
)
|
||||
|
||||
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
|
||||
|
||||
elif not query_param and not header_param:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing first superuser credentials",
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="An API key must be passed as query or header",
|
||||
)
|
||||
|
||||
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
|
||||
elif query_param:
|
||||
result = check_key(db, query_param)
|
||||
|
||||
elif not query_param and not header_param:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="An API key must be passed as query or header",
|
||||
)
|
||||
else:
|
||||
result = check_key(db, header_param)
|
||||
|
||||
elif query_param:
|
||||
result = check_key(db, query_param)
|
||||
|
||||
else:
|
||||
result = check_key(db, header_param)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
if isinstance(result, ApiKey):
|
||||
return UserRead.model_validate(result.user, from_attributes=True)
|
||||
if isinstance(result, User):
|
||||
return UserRead.model_validate(result, from_attributes=True)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
if isinstance(result, ApiKey):
|
||||
return UserRead.model_validate(result.user, from_attributes=True)
|
||||
if isinstance(result, User):
|
||||
return UserRead.model_validate(result, from_attributes=True)
|
||||
msg = "Invalid result type"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ async def get_current_user(
|
|||
) -> User:
|
||||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
user = await asyncio.to_thread(api_key_security, query_param, header_param, db)
|
||||
user = await asyncio.to_thread(api_key_security, query_param, header_param)
|
||||
if user:
|
||||
return user
|
||||
|
||||
|
|
@ -164,17 +165,17 @@ async def get_current_user_for_websocket(
|
|||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
if api_key:
|
||||
return await asyncio.to_thread(api_key_security, api_key, query_param, db)
|
||||
return await asyncio.to_thread(api_key_security, api_key, query_param)
|
||||
return None
|
||||
|
||||
|
||||
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
|
||||
async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
|
||||
if not current_user.is_superuser:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -317,7 +318,7 @@ class DatabaseService(Service):
|
|||
|
||||
logger.debug("Database and tables created successfully")
|
||||
|
||||
async def teardown(self) -> None:
|
||||
def _teardown(self) -> None:
|
||||
logger.debug("Tearing down database")
|
||||
try:
|
||||
settings_service = get_settings_service()
|
||||
|
|
@ -330,3 +331,6 @@ class DatabaseService(Service):
|
|||
logger.exception("Error tearing down database")
|
||||
|
||||
self.engine.dispose()
|
||||
|
||||
async def teardown(self) -> None:
|
||||
await asyncio.to_thread(self._teardown)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, select
|
||||
|
||||
|
|
@ -110,10 +112,15 @@ def teardown_superuser(settings_service, session) -> None:
|
|||
raise RuntimeError(msg) from exc
|
||||
|
||||
|
||||
def _teardown_superuser():
|
||||
with get_db_service().with_session() as session:
|
||||
teardown_superuser(get_settings_service(), session)
|
||||
|
||||
|
||||
async def teardown_services() -> None:
|
||||
"""Teardown all the services."""
|
||||
try:
|
||||
teardown_superuser(get_settings_service(), next(get_session()))
|
||||
await asyncio.to_thread(_teardown_superuser)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(exc)
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue