fix: Use AsyncSession for user management (#4491)
* Use AsyncSession for user management * Simplify check_key * Don't trigger blockbuster on settings service initialize * Fix mypy * Fix api key update_total_uses * Fix auto-login * Revert making CustomComponent.list_key_names async
This commit is contained in:
parent
2881346400
commit
6573ca14cc
24 changed files with 430 additions and 339 deletions
|
|
@ -90,6 +90,8 @@ def _wrap_file_read_blocking(func):
|
|||
"_read_pyc",
|
||||
}:
|
||||
return func(self, *args, **kwargs)
|
||||
if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize":
|
||||
return func(self, *args, **kwargs)
|
||||
raise _blocking_error(func)
|
||||
|
||||
return file_op
|
||||
|
|
@ -104,6 +106,8 @@ def _wrap_file_write_blocking(func):
|
|||
for frame_info in inspect.stack():
|
||||
if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function == "_write_pyc":
|
||||
return func(self, *args, **kwargs)
|
||||
if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize":
|
||||
return func(self, *args, **kwargs)
|
||||
if self not in {sys.stdout, sys.stderr}:
|
||||
raise _blocking_error(func)
|
||||
return func(self, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ async def test_initialize_services():
|
|||
"""Benchmark the initialization of services."""
|
||||
from langflow.services.utils import initialize_services
|
||||
|
||||
await asyncio.to_thread(initialize_services, fix_migration=False)
|
||||
await initialize_services(fix_migration=False)
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
||||
|
|
@ -45,8 +45,8 @@ async def test_initialize_super_user():
|
|||
from langflow.initial_setup.setup import initialize_super_user_if_needed
|
||||
from langflow.services.utils import initialize_services
|
||||
|
||||
await asyncio.to_thread(initialize_services, fix_migration=False)
|
||||
await asyncio.to_thread(initialize_super_user_if_needed)
|
||||
await initialize_services(fix_migration=False)
|
||||
await initialize_super_user_if_needed()
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ async def test_create_starter_projects():
|
|||
from langflow.interface.types import get_and_cache_all_types_dict
|
||||
from langflow.services.utils import initialize_services
|
||||
|
||||
await asyncio.to_thread(initialize_services, fix_migration=False)
|
||||
await initialize_services(fix_migration=False)
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
types_dict = await get_and_cache_all_types_dict(settings_service)
|
||||
await asyncio.to_thread(create_or_update_starter_projects, types_dict)
|
||||
|
|
@ -81,6 +81,6 @@ async def test_load_flows():
|
|||
"""Benchmark loading flows from directory."""
|
||||
from langflow.initial_setup.setup import load_flows_from_directory
|
||||
|
||||
await asyncio.to_thread(load_flows_from_directory)
|
||||
await load_flows_from_directory()
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from langflow.services.database.models.variable.model import VariableUpdate
|
||||
|
|
@ -8,7 +8,9 @@ from langflow.services.deps import get_settings_service
|
|||
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import Session, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -18,114 +20,125 @@ def service():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
async def session():
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def test_initialize_user_variables__create_and_update(service, session):
|
||||
def _get_variable(
|
||||
session: Session,
|
||||
service,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
):
|
||||
return service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
|
||||
async def test_initialize_user_variables__create_and_update(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
field = ""
|
||||
good_vars = {k: f"value{i}" for i, k in enumerate(VARIABLES_TO_GET_FROM_ENVIRONMENT)}
|
||||
bad_vars = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"}
|
||||
env_vars = {**good_vars, **bad_vars}
|
||||
|
||||
service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
|
||||
await service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
|
||||
env_vars["OPENAI_API_KEY"] = "updated_value"
|
||||
|
||||
with patch.dict("os.environ", env_vars, clear=True):
|
||||
service.initialize_user_variables(user_id=user_id, session=session)
|
||||
await service.initialize_user_variables(user_id=user_id, session=session)
|
||||
|
||||
variables = service.list_variables(user_id, session=session)
|
||||
variables = await service.list_variables(user_id, session=session)
|
||||
for name in variables:
|
||||
value = service.get_variable(user_id, name, field, session=session)
|
||||
value = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
assert value == env_vars[name]
|
||||
|
||||
assert all(i in variables for i in good_vars)
|
||||
assert all(i not in variables for i in bad_vars)
|
||||
|
||||
|
||||
def test_initialize_user_variables__not_found_variable(service, session):
|
||||
async def test_initialize_user_variables__not_found_variable(service, session: AsyncSession):
|
||||
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
|
||||
m.side_effect = Exception()
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
await service.initialize_user_variables(uuid4(), session=session)
|
||||
assert True
|
||||
|
||||
|
||||
def test_initialize_user_variables__skipping_environment_variable_storage(service, session):
|
||||
async def test_initialize_user_variables__skipping_environment_variable_storage(service, session: AsyncSession):
|
||||
service.settings_service.settings.store_environment_variables = False
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
await service.initialize_user_variables(uuid4(), session=session)
|
||||
assert True
|
||||
|
||||
|
||||
def test_get_variable(service, session):
|
||||
async def test_get_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = ""
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = service.get_variable(user_id, name, field, session=session)
|
||||
result = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert result == value
|
||||
|
||||
|
||||
def test_get_variable__valueerror(service, session):
|
||||
async def test_get_variable__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
field = ""
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
|
||||
def test_get_variable__typeerror(service, session):
|
||||
async def test_get_variable__typeerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = "session_id"
|
||||
_type = CREDENTIAL_TYPE
|
||||
service.create_variable(user_id, name, value, _type=_type, session=session)
|
||||
await service.create_variable(user_id, name, value, _type=_type, session=session)
|
||||
|
||||
with pytest.raises(TypeError) as exc:
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "purpose is to prevent the exposure of value" in str(exc.value)
|
||||
|
||||
|
||||
def test_list_variables(service, session):
|
||||
async def test_list_variables(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
names = ["name1", "name2", "name3"]
|
||||
value = "value"
|
||||
for name in names:
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = service.list_variables(user_id, session=session)
|
||||
result = await service.list_variables(user_id, session=session)
|
||||
|
||||
assert all(name in result for name in names)
|
||||
|
||||
|
||||
def test_list_variables__empty(service, session):
|
||||
result = service.list_variables(uuid4(), session=session)
|
||||
async def test_list_variables__empty(service, session: AsyncSession):
|
||||
result = await service.list_variables(uuid4(), session=session)
|
||||
|
||||
assert not result
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
def test_update_variable(service, session):
|
||||
async def test_update_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
old_value = "old_value"
|
||||
new_value = "new_value"
|
||||
field = ""
|
||||
service.create_variable(user_id, name, old_value, session=session)
|
||||
await service.create_variable(user_id, name, old_value, session=session)
|
||||
|
||||
old_recovered = service.get_variable(user_id, name, field, session=session)
|
||||
result = service.update_variable(user_id, name, new_value, session=session)
|
||||
new_recovered = service.get_variable(user_id, name, field, session=session)
|
||||
old_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
result = await service.update_variable(user_id, name, new_value, session=session)
|
||||
new_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert old_value == old_recovered
|
||||
assert new_value == new_recovered
|
||||
|
|
@ -139,26 +152,26 @@ def test_update_variable(service, session):
|
|||
assert isinstance(result.updated_at, datetime)
|
||||
|
||||
|
||||
def test_update_variable__valueerror(service, session):
|
||||
async def test_update_variable__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.update_variable(user_id, name, value, session=session)
|
||||
await service.update_variable(user_id, name, value, session=session)
|
||||
|
||||
|
||||
def test_update_variable_fields(service, session):
|
||||
async def test_update_variable_fields(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
new_name = new_value = "donkey"
|
||||
variable = service.create_variable(user_id, "old_name", "old_value", session=session)
|
||||
variable = await service.create_variable(user_id, "old_name", "old_value", session=session)
|
||||
saved = variable.model_dump()
|
||||
variable = VariableUpdate(**saved)
|
||||
variable.name = new_name
|
||||
variable.value = new_value
|
||||
variable.default_fields = ["new_field"]
|
||||
|
||||
result = service.update_variable_fields(
|
||||
result = await service.update_variable_fields(
|
||||
user_id=user_id,
|
||||
variable_id=saved.get("id"),
|
||||
variable=variable,
|
||||
|
|
@ -177,58 +190,58 @@ def test_update_variable_fields(service, session):
|
|||
assert saved.get("updated_at") != result.updated_at
|
||||
|
||||
|
||||
def test_delete_variable(service, session):
|
||||
async def test_delete_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = ""
|
||||
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.delete_variable(user_id, name, session=session)
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert recovered == value
|
||||
|
||||
|
||||
def test_delete_variable__valueerror(service, session):
|
||||
async def test_delete_variable__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
await service.delete_variable(user_id, name, session=session)
|
||||
|
||||
|
||||
def test_delete_variable_by_id(service, session):
|
||||
async def test_delete_variable_by_id(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = "field"
|
||||
|
||||
saved = service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable_by_id(user_id, saved.id, session=session)
|
||||
saved = await service.create_variable(user_id, name, value, session=session)
|
||||
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.delete_variable_by_id(user_id, saved.id, session=session)
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert recovered == value
|
||||
|
||||
|
||||
def test_delete_variable_by_id__valueerror(service, session):
|
||||
async def test_delete_variable_by_id__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
variable_id = uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match=f"{variable_id} variable not found."):
|
||||
service.delete_variable_by_id(user_id, variable_id, session=session)
|
||||
await service.delete_variable_by_id(user_id, variable_id, session=session)
|
||||
|
||||
|
||||
def test_create_variable(service, session):
|
||||
async def test_create_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
result = service.create_variable(user_id, name, value, session=session)
|
||||
result = await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
assert result.user_id == user_id
|
||||
assert result.name == name
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
|
|
@ -91,7 +92,7 @@ from langflow.services.utils import teardown_superuser
|
|||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
async def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
|
|
@ -104,29 +105,28 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
|
|||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = iter([mock_session])
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
await teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
async def test_teardown_superuser_no_default_superuser():
|
||||
admin_user_name = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = admin_user_name
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session = AsyncMock(return_value=asyncio.Future())
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = False
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = [mock_session]
|
||||
mock_user.last_login_at = None
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = mock_user
|
||||
mock_session.exec.return_value = mock_result
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
await teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.delete.assert_not_awaited()
|
||||
mock_session.commit.assert_not_awaited()
|
||||
|
|
|
|||
|
|
@ -5,18 +5,18 @@ from httpx import AsyncClient
|
|||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.utils import async_session_getter, session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user(client): # noqa: ARG001
|
||||
async def super_user(client): # noqa: ARG001
|
||||
settings_manager = get_settings_service()
|
||||
auth_settings = settings_manager.auth_settings
|
||||
with session_getter(get_db_service()) as session:
|
||||
return create_super_user(
|
||||
db=session,
|
||||
async with async_session_getter(get_db_service()) as db:
|
||||
return await create_super_user(
|
||||
db=db,
|
||||
username=auth_settings.SUPERUSER,
|
||||
password=auth_settings.SUPERUSER_PASSWORD,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue