ref: make superuser creation more multi-user friendly (#9019)
* Fix superuser creationg race condition * remove now unnecessary race condition check * Add tests * [autofix.ci] apply automated fixes * ruff * [autofix.ci] apply automated fixes * clean up tests * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e6ca374d07
commit
115b0f535f
5 changed files with 129 additions and 15 deletions
|
|
@ -454,7 +454,7 @@ class MCPToolsComponent(ComponentWithCache):
|
|||
kwargs = {}
|
||||
for arg in tool_args:
|
||||
value = getattr(self, arg.name, None)
|
||||
if value:
|
||||
if value is not None:
|
||||
if isinstance(value, Message):
|
||||
kwargs[arg.name] = value.text
|
||||
else:
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -11,6 +11,7 @@ from fastapi import Depends, HTTPException, Security, WebSocketException, status
|
|||
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from loguru import logger
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
|
@ -299,8 +300,17 @@ async def create_super_user(
|
|||
)
|
||||
|
||||
db.add(super_user)
|
||||
await db.commit()
|
||||
await db.refresh(super_user)
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(super_user)
|
||||
except IntegrityError:
|
||||
# Race condition - another worker created the user
|
||||
await db.rollback()
|
||||
super_user = await get_user_by_username(db, username)
|
||||
if not super_user:
|
||||
raise # Re-raise if it's not a race condition
|
||||
except Exception: # noqa: BLE001
|
||||
logger.opt(exception=True).debug("Error creating superuser.")
|
||||
|
||||
return super_user
|
||||
|
||||
|
|
|
|||
|
|
@ -65,16 +65,7 @@ async def get_or_create_super_user(session: AsyncSession, username, password, is
|
|||
logger.debug("Creating default superuser.")
|
||||
else:
|
||||
logger.debug("Creating superuser.")
|
||||
try:
|
||||
return await create_super_user(username, password, db=session)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if "UNIQUE constraint failed: user.username" in str(exc):
|
||||
# This is to deal with workers running this
|
||||
# at startup and trying to create the superuser
|
||||
# at the same time.
|
||||
logger.opt(exception=True).debug("Superuser already exists.")
|
||||
return None
|
||||
logger.opt(exception=True).debug("Error creating superuser.")
|
||||
return await create_super_user(username, password, db=session)
|
||||
|
||||
|
||||
async def setup_superuser(settings_service, session: AsyncSession) -> None:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.services.auth.utils import create_super_user
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
DEFAULT_SUPERUSER_PASSWORD,
|
||||
)
|
||||
from langflow.services.utils import teardown_superuser
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
# @patch("langflow.services.deps.get_session")
|
||||
# @patch("langflow.services.utils.create_super_user")
|
||||
|
|
@ -130,3 +134,112 @@ async def test_teardown_superuser_no_default_superuser():
|
|||
|
||||
mock_session.delete.assert_not_awaited()
|
||||
mock_session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_super_user_race_condition():
|
||||
"""Test create_super_user handles race conditions gracefully when multiple workers try to create the same user."""
|
||||
# Mock the database session
|
||||
mock_session = AsyncMock()
|
||||
|
||||
# Create a mock user that will be "created" by the first worker
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.username = "testuser"
|
||||
mock_user.is_superuser = True
|
||||
|
||||
# Mock get_password_hash to return a fixed value
|
||||
mock_get_password_hash = MagicMock(return_value="hashed_password")
|
||||
|
||||
# Set up the race condition scenario:
|
||||
# 1. First call to get_user_by_username returns None (user doesn't exist)
|
||||
# 2. commit() raises IntegrityError (simulating race condition)
|
||||
# 3. After rollback, second call to get_user_by_username returns the existing user
|
||||
mock_get_user_by_username = AsyncMock()
|
||||
mock_get_user_by_username.side_effect = [None, mock_user] # None first, then existing user
|
||||
|
||||
mock_session.commit.side_effect = IntegrityError("statement", "params", Exception("orig"))
|
||||
with (
|
||||
patch("langflow.services.auth.utils.get_user_by_username", mock_get_user_by_username),
|
||||
patch("langflow.services.auth.utils.get_password_hash", mock_get_password_hash),
|
||||
patch("langflow.services.database.models.user.model.User") as mock_user_class,
|
||||
):
|
||||
# Configure the User class mock to return our mock_user when instantiated
|
||||
mock_user_class.return_value = mock_user
|
||||
|
||||
result = await create_super_user("testuser", "password", mock_session)
|
||||
|
||||
# Verify that the function handled the race condition correctly
|
||||
assert result == mock_user
|
||||
assert mock_session.add.call_count == 1 # User was added to session
|
||||
assert mock_session.commit.call_count == 1 # Commit was attempted once (and failed)
|
||||
assert mock_session.rollback.call_count == 1 # Session was rolled back after IntegrityError
|
||||
assert mock_get_user_by_username.call_count == 2 # Called twice: initial check + after rollback
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_super_user_race_condition_no_user_found():
|
||||
"""Test that create_super_user re-raises exception if no user is found after IntegrityError."""
|
||||
# Mock the database session
|
||||
mock_session = AsyncMock()
|
||||
|
||||
# Mock get_user_by_username to always return None (even after rollback)
|
||||
mock_get_user_by_username = AsyncMock()
|
||||
mock_get_user_by_username.side_effect = [None, None] # None for initial check and after rollback
|
||||
|
||||
# Mock other dependencies
|
||||
mock_get_password_hash = MagicMock(return_value="hashed_password")
|
||||
mock_user = MagicMock(spec=User)
|
||||
|
||||
# Set up scenario where IntegrityError occurs but no user is found afterward
|
||||
integrity_error = IntegrityError("statement", "params", Exception("orig"))
|
||||
mock_session.commit.side_effect = integrity_error
|
||||
|
||||
with (
|
||||
patch("langflow.services.auth.utils.get_user_by_username", mock_get_user_by_username),
|
||||
patch("langflow.services.auth.utils.get_password_hash", mock_get_password_hash),
|
||||
patch("langflow.services.database.models.user.model.User", return_value=mock_user),
|
||||
pytest.raises(IntegrityError),
|
||||
):
|
||||
await create_super_user("testuser", "password", mock_session)
|
||||
|
||||
# Verify rollback was called but exception was re-raised
|
||||
assert mock_session.rollback.call_count == 1
|
||||
assert mock_get_user_by_username.call_count == 2 # Initial + after rollback
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_super_user_concurrent_workers():
|
||||
"""Test multiple concurrent calls to create_super_user with the same username."""
|
||||
# This would require a real database to properly test, but we can simulate
|
||||
# the behavior with mocks to verify the logic works correctly
|
||||
|
||||
mock_session1 = AsyncMock()
|
||||
mock_session2 = AsyncMock()
|
||||
|
||||
# Create mock users
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.username = "admin"
|
||||
mock_user.is_superuser = True
|
||||
|
||||
mock_get_user_by_username = AsyncMock()
|
||||
|
||||
# Worker 1 succeeds, Worker 2 gets IntegrityError then finds existing user
|
||||
mock_session1.commit.return_value = None # Success
|
||||
mock_session2.commit.side_effect = IntegrityError("statement", "params", Exception("orig")) # Race condition
|
||||
|
||||
# get_user_by_username returns None initially, then the created user for worker 2
|
||||
mock_get_user_by_username.side_effect = [None, None, mock_user]
|
||||
|
||||
with patch("langflow.services.auth.utils.get_user_by_username", mock_get_user_by_username):
|
||||
# Simulate concurrent execution using asyncio.gather
|
||||
result1, result2 = await asyncio.gather(
|
||||
create_super_user("admin", "password", mock_session1),
|
||||
create_super_user("admin", "password", mock_session2),
|
||||
)
|
||||
|
||||
# Both workers should end up with a user (worker 1 creates, worker 2 finds existing)
|
||||
assert result1 is not None
|
||||
assert result2 == mock_user
|
||||
|
||||
# Worker 2 should have rolled back and fetched existing user
|
||||
assert mock_session2.rollback.call_count == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue