ref: Make initialize_database async (#5163)

Make initialize_database async
This commit is contained in:
Christophe Bornet 2024-12-10 07:44:34 +01:00 committed by GitHub
commit 63bdcb9d03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 70 additions and 104 deletions

View file

@ -26,7 +26,7 @@ from sqlmodel import select
from langflow.logging.logger import configure, logger
from langflow.main import setup_app
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
from langflow.services.database.utils import async_session_getter
from langflow.services.database.utils import session_getter
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
from langflow.services.settings.constants import DEFAULT_SUPERUSER
from langflow.services.utils import initialize_services
@ -423,7 +423,7 @@ def superuser(
async def _create_superuser():
await initialize_services()
async with async_session_getter(db_service) as session:
async with session_getter(db_service) as session:
from langflow.services.auth.utils import create_super_user
if await create_super_user(db=session, username=username, password=password):
@ -485,6 +485,15 @@ def copy_db() -> None:
typer.echo("Pre-release database not found in the cache directory.")
async def _migration(*, test: bool, fix: bool) -> None:
await initialize_services(fix_migration=fix)
db_service = get_db_service()
if not test:
await db_service.run_migrations()
results = db_service.run_migrations_test()
display_results(results)
@app.command()
def migration(
test: bool = typer.Option(default=True, help="Run migrations in test mode."), # noqa: FBT001
@ -499,12 +508,7 @@ def migration(
):
raise typer.Abort
asyncio.run(initialize_services(fix_migration=fix))
db_service = get_db_service()
if not test:
db_service.run_migrations()
results = db_service.run_migrations_test()
display_results(results)
asyncio.run(_migration(test=test, fix=fix))
@app.command()

View file

@ -18,7 +18,7 @@ from langflow.services.database.models.transactions.crud import log_transaction
from langflow.services.database.models.transactions.model import TransactionBase
from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build
from langflow.services.database.models.vertex_builds.model import VertexBuildBase
from langflow.services.database.utils import async_session_getter
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
if TYPE_CHECKING:
@ -157,7 +157,7 @@ async def log_transaction(
error=error,
flow_id=flow_id if isinstance(flow_id, UUID) else UUID(flow_id),
)
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
inserted = await crud_log_transaction(session, transaction)
logger.debug(f"Logged transaction: {inserted.id}")
except Exception: # noqa: BLE001
@ -186,7 +186,7 @@ async def log_vertex_build(
# ugly hack to get the model dump with weird datatypes
artifacts=json.loads(json.dumps(artifacts, default=str)),
)
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
inserted = await crud_log_vertex_build(session, vertex_build)
logger.debug(f"Logged vertex build: {inserted.build_id}")
except Exception: # noqa: BLE001

View file

@ -10,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.database.models import User
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead
from langflow.services.database.utils import async_session_getter
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
if TYPE_CHECKING:
@ -68,7 +68,7 @@ async def check_key(session: AsyncSession, api_key: str) -> User | None:
async def update_total_uses(api_key_id: UUID):
"""Update the total uses and last used at."""
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
new_api_key = await session.get(ApiKey, api_key_id)
if new_api_key is None:
msg = "API Key not found"

View file

@ -1,20 +0,0 @@
from uuid import UUID
from langflow.services.database.models.message.model import MessageTable, MessageUpdate
from langflow.services.deps import session_scope
def update_message(message_id: UUID | str, message: MessageUpdate | dict):
if not isinstance(message, MessageUpdate):
message = MessageUpdate(**message)
with session_scope() as session:
db_message = session.get(MessageTable, message_id)
if not db_message:
msg = "Message not found"
raise ValueError(msg)
message_dict = message.model_dump(exclude_unset=True, exclude_none=True)
db_message.sqlmodel_update(message_dict)
session.add(db_message)
session.commit()
session.refresh(db_message)
return db_message

View file

@ -218,8 +218,8 @@ class DatabaseService(Service):
return new_name
def check_schema_health(self) -> bool:
inspector = inspect(self.engine)
def _check_schema_health(self, connection) -> bool:
inspector = inspect(connection)
model_mapping: dict[str, type[SQLModel]] = {
"flow": models.Flow,
@ -251,6 +251,10 @@ class DatabaseService(Service):
return True
async def check_schema_health(self) -> None:
async with self.with_async_session() as session, session.bind.connect() as conn:
await conn.run_sync(self._check_schema_health)
def init_alembic(self, alembic_cfg) -> None:
logger.info("Initializing alembic")
command.ensure_version(alembic_cfg)
@ -258,7 +262,7 @@ class DatabaseService(Service):
command.upgrade(alembic_cfg, "head")
logger.info("Alembic initialized")
def run_migrations(self, *, fix=False) -> None:
def _run_migrations(self, should_initialize_alembic, fix) -> None:
# First we need to check if alembic has been initialized
# If not, we need to initialize it
# if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized
@ -274,16 +278,6 @@ class DatabaseService(Service):
alembic_cfg.set_main_option("script_location", str(self.script_location))
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%"))
should_initialize_alembic = False
with self.with_session() as session:
# If the table does not exist it throws an error
# so we need to catch it
try:
session.exec(text("SELECT * FROM alembic_version"))
except Exception: # noqa: BLE001
logger.debug("Alembic not initialized")
should_initialize_alembic = True
if should_initialize_alembic:
try:
self.init_alembic(alembic_cfg)
@ -317,6 +311,18 @@ class DatabaseService(Service):
if fix:
self.try_downgrade_upgrade_until_success(alembic_cfg)
async def run_migrations(self, *, fix=False) -> None:
should_initialize_alembic = False
async with self.with_async_session() as session:
# If the table does not exist it throws an error
# so we need to catch it
try:
await session.exec(text("SELECT * FROM alembic_version"))
except Exception: # noqa: BLE001
logger.debug("Alembic not initialized")
should_initialize_alembic = True
await asyncio.to_thread(self._run_migrations, should_initialize_alembic, fix)
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None:
# Try -1 then head, if it fails, try -2 then head, etc.
# until we reach the number of retries
@ -363,10 +369,11 @@ class DatabaseService(Service):
results.append(Result(name=column, type="column", success=True))
return results
def create_db_and_tables(self) -> None:
@staticmethod
def _create_db_and_tables(connection) -> None:
from sqlalchemy import inspect
inspector = inspect(self.engine)
inspector = inspect(connection)
table_names = inspector.get_table_names()
current_tables = ["flow", "user", "apikey", "folder", "message", "variable", "transaction", "vertex_build"]
@ -378,7 +385,7 @@ class DatabaseService(Service):
for table in SQLModel.metadata.sorted_tables:
try:
table.create(self.engine, checkfirst=True)
table.create(connection, checkfirst=True)
except OperationalError as oe:
logger.warning(f"Table {table} already exists, skipping. Exception: {oe}")
except Exception as exc:
@ -387,7 +394,7 @@ class DatabaseService(Service):
raise RuntimeError(msg) from exc
# Now check if the required tables exist, if not, something went wrong.
inspector = inspect(self.engine)
inspector = inspect(connection)
table_names = inspector.get_table_names()
for table in current_tables:
if table not in table_names:
@ -398,6 +405,10 @@ class DatabaseService(Service):
logger.debug("Database and tables created successfully")
async def create_db_and_tables(self) -> None:
async with self.with_async_session() as session, session.bind.connect() as conn:
await conn.run_sync(self._create_db_and_tables)
async def teardown(self) -> None:
logger.debug("Tearing down database")
try:

View file

@ -1,25 +1,25 @@
from __future__ import annotations
from contextlib import asynccontextmanager, contextmanager
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING
from alembic.util.exc import CommandError
from loguru import logger
from sqlmodel import Session, text
from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
def initialize_database(*, fix_migration: bool = False) -> None:
async def initialize_database(*, fix_migration: bool = False) -> None:
logger.debug("Initializing database")
from langflow.services.deps import get_db_service
database_service: DatabaseService = get_db_service()
try:
database_service.create_db_and_tables()
await database_service.create_db_and_tables()
except Exception as exc:
# if the exception involves tables already existing
# we can ignore it
@ -28,13 +28,13 @@ def initialize_database(*, fix_migration: bool = False) -> None:
logger.exception(msg)
raise RuntimeError(msg) from exc
try:
database_service.check_schema_health()
await database_service.check_schema_health()
except Exception as exc:
msg = "Error checking schema health"
logger.exception(msg)
raise RuntimeError(msg) from exc
try:
database_service.run_migrations(fix=fix_migration)
await database_service.run_migrations(fix=fix_migration)
except CommandError as exc:
# if "overlaps with other requested revisions" or "Can't locate revision identified by"
# are not in the exception, we can't handle it
@ -46,9 +46,9 @@ def initialize_database(*, fix_migration: bool = False) -> None:
# We need to delete the alembic_version table
# and run the migrations again
logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again")
with session_getter(database_service) as session:
session.exec(text("DROP TABLE alembic_version"))
database_service.run_migrations(fix=fix_migration)
async with session_getter(database_service) as session:
await session.exec(text("DROP TABLE alembic_version"))
await database_service.run_migrations(fix=fix_migration)
except Exception as exc:
# if the exception involves tables already existing
# we can ignore it
@ -58,21 +58,8 @@ def initialize_database(*, fix_migration: bool = False) -> None:
logger.debug("Database initialized")
@contextmanager
def session_getter(db_service: DatabaseService):
try:
session = Session(db_service.engine)
yield session
except Exception:
logger.exception("Session rollback because of exception")
session.rollback()
raise
finally:
session.close()
@asynccontextmanager
async def async_session_getter(db_service: DatabaseService):
async def session_getter(db_service: DatabaseService):
try:
session = AsyncSession(db_service.async_engine, expire_on_commit=False)
yield session

View file

@ -237,7 +237,7 @@ async def initialize_services(*, fix_migration: bool = False) -> None:
# Test cache connection
get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory())
# Setup the superuser
await asyncio.to_thread(initialize_database, fix_migration=fix_migration)
await initialize_database(fix_migration=fix_migration)
async with get_db_service().with_async_session() as session:
settings_service = get_service(ServiceType.SETTINGS_SERVICE)
await setup_superuser(settings_service, session)

View file

@ -5,9 +5,8 @@ import shutil
# we need to import tmpdir
import tempfile
from collections.abc import AsyncGenerator
from contextlib import contextmanager, suppress
from contextlib import suppress
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID
import anyio
@ -27,7 +26,7 @@ from langflow.services.database.models.folder.model import Folder
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.user.model import User, UserCreate, UserRead
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
from langflow.services.database.utils import async_session_getter
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from loguru import logger
from sqlalchemy.ext.asyncio import create_async_engine
@ -39,10 +38,6 @@ from typer.testing import CliRunner
from tests.api_keys import get_openai_api_key
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
load_dotenv()
@ -369,17 +364,6 @@ async def client_fixture(
await anyio.Path(db_path).unlink()
# create a fixture for session_getter above
@pytest.fixture(name="session_getter")
def session_getter_fixture(client): # noqa: ARG001
@contextmanager
def blank_session_getter(db_service: "DatabaseService"):
with Session(db_service.engine) as session:
yield session
return blank_session_getter
@pytest.fixture
def runner():
return CliRunner()
@ -489,7 +473,7 @@ async def flow(
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
flow = Flow.model_validate(flow_data)
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
session.add(flow)
await session.commit()
await session.refresh(flow)
@ -600,7 +584,7 @@ async def created_api_key(active_user):
hashed_api_key=hashed,
)
db_manager = get_db_service()
async with async_session_getter(db_manager) as session:
async with session_getter(db_manager) as session:
stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key)
if existing_api_key := (await session.exec(stmt)).first():
yield existing_api_key
@ -630,7 +614,7 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
@pytest.fixture(name="starter_project")
async def get_starter_project(active_user):
# once the client is created, we can get the starter project
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
stmt = (
select(Flow)
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))

View file

@ -12,7 +12,7 @@ from langflow.initial_setup.setup import load_starter_projects
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.models.folder.model import FolderCreate
from langflow.services.database.utils import async_session_getter
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
@ -530,7 +530,7 @@ async def test_download_file(
]
)
db_manager = get_db_service()
async with async_session_getter(db_manager) as _session:
async with session_getter(db_manager) as _session:
saved_flows = []
for flow in flow_list.flows:
flow.user_id = active_user.id

View file

@ -5,7 +5,7 @@ from httpx import AsyncClient
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user import UserUpdate
from langflow.services.database.models.user.model import User
from langflow.services.database.utils import async_session_getter
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from sqlmodel import select
@ -14,7 +14,7 @@ from sqlmodel import select
async def super_user(client): # noqa: ARG001
settings_manager = get_settings_service()
auth_settings = settings_manager.auth_settings
async with async_session_getter(get_db_service()) as db:
async with session_getter(get_db_service()) as db:
return await create_super_user(
db=db,
username=auth_settings.SUPERUSER,
@ -42,7 +42,7 @@ async def super_user_headers(
@pytest.fixture
async def deactivated_user(client): # noqa: ARG001
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
user = User(
username="deactivateduser",
password=get_password_hash("testpassword"),
@ -61,7 +61,7 @@ async def test_user_waiting_for_approval(client):
password = "testpassword" # noqa: S105
# Debug: Check if the user already exists
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
stmt = select(User).where(User.username == username)
existing_user = (await session.exec(stmt)).first()
if existing_user:
@ -70,7 +70,7 @@ async def test_user_waiting_for_approval(client):
)
# Create a user that is not active and has never logged in
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
user = User(
username=username,
password=get_password_hash(password),
@ -86,7 +86,7 @@ async def test_user_waiting_for_approval(client):
assert response.json()["detail"] == "Waiting for approval"
# Debug: Check if the user still exists after the test
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
stmt = select(User).where(User.username == username)
existing_user = (await session.exec(stmt)).first()
if existing_user:
@ -140,7 +140,7 @@ async def test_data_consistency_after_delete(client: AsyncClient, test_user, sup
@pytest.mark.api_key_required
async def test_inactive_user(client: AsyncClient):
# Create a user that is not active and has a last_login_at value
async with async_session_getter(get_db_service()) as session:
async with session_getter(get_db_service()) as session:
user = User(
username="inactiveuser",
password=get_password_hash("testpassword"),