ref: Make initialize_database async (#5163)
Make initialize_database async
This commit is contained in:
parent
e545d12c40
commit
63bdcb9d03
10 changed files with 70 additions and 104 deletions
|
|
@ -26,7 +26,7 @@ from sqlmodel import select
|
|||
from langflow.logging.logger import configure, logger
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER
|
||||
from langflow.services.utils import initialize_services
|
||||
|
|
@ -423,7 +423,7 @@ def superuser(
|
|||
|
||||
async def _create_superuser():
|
||||
await initialize_services()
|
||||
async with async_session_getter(db_service) as session:
|
||||
async with session_getter(db_service) as session:
|
||||
from langflow.services.auth.utils import create_super_user
|
||||
|
||||
if await create_super_user(db=session, username=username, password=password):
|
||||
|
|
@ -485,6 +485,15 @@ def copy_db() -> None:
|
|||
typer.echo("Pre-release database not found in the cache directory.")
|
||||
|
||||
|
||||
async def _migration(*, test: bool, fix: bool) -> None:
|
||||
await initialize_services(fix_migration=fix)
|
||||
db_service = get_db_service()
|
||||
if not test:
|
||||
await db_service.run_migrations()
|
||||
results = db_service.run_migrations_test()
|
||||
display_results(results)
|
||||
|
||||
|
||||
@app.command()
|
||||
def migration(
|
||||
test: bool = typer.Option(default=True, help="Run migrations in test mode."), # noqa: FBT001
|
||||
|
|
@ -499,12 +508,7 @@ def migration(
|
|||
):
|
||||
raise typer.Abort
|
||||
|
||||
asyncio.run(initialize_services(fix_migration=fix))
|
||||
db_service = get_db_service()
|
||||
if not test:
|
||||
db_service.run_migrations()
|
||||
results = db_service.run_migrations_test()
|
||||
display_results(results)
|
||||
asyncio.run(_migration(test=test, fix=fix))
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from langflow.services.database.models.transactions.crud import log_transaction
|
|||
from langflow.services.database.models.transactions.model import TransactionBase
|
||||
from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build
|
||||
from langflow.services.database.models.vertex_builds.model import VertexBuildBase
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -157,7 +157,7 @@ async def log_transaction(
|
|||
error=error,
|
||||
flow_id=flow_id if isinstance(flow_id, UUID) else UUID(flow_id),
|
||||
)
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
inserted = await crud_log_transaction(session, transaction)
|
||||
logger.debug(f"Logged transaction: {inserted.id}")
|
||||
except Exception: # noqa: BLE001
|
||||
|
|
@ -186,7 +186,7 @@ async def log_vertex_build(
|
|||
# ugly hack to get the model dump with weird datatypes
|
||||
artifacts=json.loads(json.dumps(artifacts, default=str)),
|
||||
)
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
inserted = await crud_log_vertex_build(session, vertex_build)
|
||||
logger.debug(f"Logged vertex build: {inserted.build_id}")
|
||||
except Exception: # noqa: BLE001
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||
|
||||
from langflow.services.database.models import User
|
||||
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -68,7 +68,7 @@ async def check_key(session: AsyncSession, api_key: str) -> User | None:
|
|||
|
||||
async def update_total_uses(api_key_id: UUID):
|
||||
"""Update the total uses and last used at."""
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
new_api_key = await session.get(ApiKey, api_key_id)
|
||||
if new_api_key is None:
|
||||
msg = "API Key not found"
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
from uuid import UUID
|
||||
|
||||
from langflow.services.database.models.message.model import MessageTable, MessageUpdate
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
def update_message(message_id: UUID | str, message: MessageUpdate | dict):
|
||||
if not isinstance(message, MessageUpdate):
|
||||
message = MessageUpdate(**message)
|
||||
with session_scope() as session:
|
||||
db_message = session.get(MessageTable, message_id)
|
||||
if not db_message:
|
||||
msg = "Message not found"
|
||||
raise ValueError(msg)
|
||||
message_dict = message.model_dump(exclude_unset=True, exclude_none=True)
|
||||
db_message.sqlmodel_update(message_dict)
|
||||
session.add(db_message)
|
||||
session.commit()
|
||||
session.refresh(db_message)
|
||||
return db_message
|
||||
|
|
@ -218,8 +218,8 @@ class DatabaseService(Service):
|
|||
|
||||
return new_name
|
||||
|
||||
def check_schema_health(self) -> bool:
|
||||
inspector = inspect(self.engine)
|
||||
def _check_schema_health(self, connection) -> bool:
|
||||
inspector = inspect(connection)
|
||||
|
||||
model_mapping: dict[str, type[SQLModel]] = {
|
||||
"flow": models.Flow,
|
||||
|
|
@ -251,6 +251,10 @@ class DatabaseService(Service):
|
|||
|
||||
return True
|
||||
|
||||
async def check_schema_health(self) -> None:
|
||||
async with self.with_async_session() as session, session.bind.connect() as conn:
|
||||
await conn.run_sync(self._check_schema_health)
|
||||
|
||||
def init_alembic(self, alembic_cfg) -> None:
|
||||
logger.info("Initializing alembic")
|
||||
command.ensure_version(alembic_cfg)
|
||||
|
|
@ -258,7 +262,7 @@ class DatabaseService(Service):
|
|||
command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Alembic initialized")
|
||||
|
||||
def run_migrations(self, *, fix=False) -> None:
|
||||
def _run_migrations(self, should_initialize_alembic, fix) -> None:
|
||||
# First we need to check if alembic has been initialized
|
||||
# If not, we need to initialize it
|
||||
# if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized
|
||||
|
|
@ -274,16 +278,6 @@ class DatabaseService(Service):
|
|||
alembic_cfg.set_main_option("script_location", str(self.script_location))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%"))
|
||||
|
||||
should_initialize_alembic = False
|
||||
with self.with_session() as session:
|
||||
# If the table does not exist it throws an error
|
||||
# so we need to catch it
|
||||
try:
|
||||
session.exec(text("SELECT * FROM alembic_version"))
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("Alembic not initialized")
|
||||
should_initialize_alembic = True
|
||||
|
||||
if should_initialize_alembic:
|
||||
try:
|
||||
self.init_alembic(alembic_cfg)
|
||||
|
|
@ -317,6 +311,18 @@ class DatabaseService(Service):
|
|||
if fix:
|
||||
self.try_downgrade_upgrade_until_success(alembic_cfg)
|
||||
|
||||
async def run_migrations(self, *, fix=False) -> None:
|
||||
should_initialize_alembic = False
|
||||
async with self.with_async_session() as session:
|
||||
# If the table does not exist it throws an error
|
||||
# so we need to catch it
|
||||
try:
|
||||
await session.exec(text("SELECT * FROM alembic_version"))
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("Alembic not initialized")
|
||||
should_initialize_alembic = True
|
||||
await asyncio.to_thread(self._run_migrations, should_initialize_alembic, fix)
|
||||
|
||||
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None:
|
||||
# Try -1 then head, if it fails, try -2 then head, etc.
|
||||
# until we reach the number of retries
|
||||
|
|
@ -363,10 +369,11 @@ class DatabaseService(Service):
|
|||
results.append(Result(name=column, type="column", success=True))
|
||||
return results
|
||||
|
||||
def create_db_and_tables(self) -> None:
|
||||
@staticmethod
|
||||
def _create_db_and_tables(connection) -> None:
|
||||
from sqlalchemy import inspect
|
||||
|
||||
inspector = inspect(self.engine)
|
||||
inspector = inspect(connection)
|
||||
table_names = inspector.get_table_names()
|
||||
current_tables = ["flow", "user", "apikey", "folder", "message", "variable", "transaction", "vertex_build"]
|
||||
|
||||
|
|
@ -378,7 +385,7 @@ class DatabaseService(Service):
|
|||
|
||||
for table in SQLModel.metadata.sorted_tables:
|
||||
try:
|
||||
table.create(self.engine, checkfirst=True)
|
||||
table.create(connection, checkfirst=True)
|
||||
except OperationalError as oe:
|
||||
logger.warning(f"Table {table} already exists, skipping. Exception: {oe}")
|
||||
except Exception as exc:
|
||||
|
|
@ -387,7 +394,7 @@ class DatabaseService(Service):
|
|||
raise RuntimeError(msg) from exc
|
||||
|
||||
# Now check if the required tables exist, if not, something went wrong.
|
||||
inspector = inspect(self.engine)
|
||||
inspector = inspect(connection)
|
||||
table_names = inspector.get_table_names()
|
||||
for table in current_tables:
|
||||
if table not in table_names:
|
||||
|
|
@ -398,6 +405,10 @@ class DatabaseService(Service):
|
|||
|
||||
logger.debug("Database and tables created successfully")
|
||||
|
||||
async def create_db_and_tables(self) -> None:
|
||||
async with self.with_async_session() as session, session.bind.connect() as conn:
|
||||
await conn.run_sync(self._create_db_and_tables)
|
||||
|
||||
async def teardown(self) -> None:
|
||||
logger.debug("Tearing down database")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,25 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from alembic.util.exc import CommandError
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, text
|
||||
from sqlmodel import text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
|
||||
def initialize_database(*, fix_migration: bool = False) -> None:
|
||||
async def initialize_database(*, fix_migration: bool = False) -> None:
|
||||
logger.debug("Initializing database")
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
database_service: DatabaseService = get_db_service()
|
||||
try:
|
||||
database_service.create_db_and_tables()
|
||||
await database_service.create_db_and_tables()
|
||||
except Exception as exc:
|
||||
# if the exception involves tables already existing
|
||||
# we can ignore it
|
||||
|
|
@ -28,13 +28,13 @@ def initialize_database(*, fix_migration: bool = False) -> None:
|
|||
logger.exception(msg)
|
||||
raise RuntimeError(msg) from exc
|
||||
try:
|
||||
database_service.check_schema_health()
|
||||
await database_service.check_schema_health()
|
||||
except Exception as exc:
|
||||
msg = "Error checking schema health"
|
||||
logger.exception(msg)
|
||||
raise RuntimeError(msg) from exc
|
||||
try:
|
||||
database_service.run_migrations(fix=fix_migration)
|
||||
await database_service.run_migrations(fix=fix_migration)
|
||||
except CommandError as exc:
|
||||
# if "overlaps with other requested revisions" or "Can't locate revision identified by"
|
||||
# are not in the exception, we can't handle it
|
||||
|
|
@ -46,9 +46,9 @@ def initialize_database(*, fix_migration: bool = False) -> None:
|
|||
# We need to delete the alembic_version table
|
||||
# and run the migrations again
|
||||
logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again")
|
||||
with session_getter(database_service) as session:
|
||||
session.exec(text("DROP TABLE alembic_version"))
|
||||
database_service.run_migrations(fix=fix_migration)
|
||||
async with session_getter(database_service) as session:
|
||||
await session.exec(text("DROP TABLE alembic_version"))
|
||||
await database_service.run_migrations(fix=fix_migration)
|
||||
except Exception as exc:
|
||||
# if the exception involves tables already existing
|
||||
# we can ignore it
|
||||
|
|
@ -58,21 +58,8 @@ def initialize_database(*, fix_migration: bool = False) -> None:
|
|||
logger.debug("Database initialized")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_getter(db_service: DatabaseService):
|
||||
try:
|
||||
session = Session(db_service.engine)
|
||||
yield session
|
||||
except Exception:
|
||||
logger.exception("Session rollback because of exception")
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_session_getter(db_service: DatabaseService):
|
||||
async def session_getter(db_service: DatabaseService):
|
||||
try:
|
||||
session = AsyncSession(db_service.async_engine, expire_on_commit=False)
|
||||
yield session
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ async def initialize_services(*, fix_migration: bool = False) -> None:
|
|||
# Test cache connection
|
||||
get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory())
|
||||
# Setup the superuser
|
||||
await asyncio.to_thread(initialize_database, fix_migration=fix_migration)
|
||||
await initialize_database(fix_migration=fix_migration)
|
||||
async with get_db_service().with_async_session() as session:
|
||||
settings_service = get_service(ServiceType.SETTINGS_SERVICE)
|
||||
await setup_superuser(settings_service, session)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ import shutil
|
|||
# we need to import tmpdir
|
||||
import tempfile
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import contextmanager, suppress
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
|
|
@ -27,7 +26,7 @@ from langflow.services.database.models.folder.model import Folder
|
|||
from langflow.services.database.models.transactions.model import TransactionTable
|
||||
from langflow.services.database.models.user.model import User, UserCreate, UserRead
|
||||
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
|
@ -39,10 +38,6 @@ from typer.testing import CliRunner
|
|||
|
||||
from tests.api_keys import get_openai_api_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
|
@ -369,17 +364,6 @@ async def client_fixture(
|
|||
await anyio.Path(db_path).unlink()
|
||||
|
||||
|
||||
# create a fixture for session_getter above
|
||||
@pytest.fixture(name="session_getter")
|
||||
def session_getter_fixture(client): # noqa: ARG001
|
||||
@contextmanager
|
||||
def blank_session_getter(db_service: "DatabaseService"):
|
||||
with Session(db_service.engine) as session:
|
||||
yield session
|
||||
|
||||
return blank_session_getter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
|
@ -489,7 +473,7 @@ async def flow(
|
|||
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
|
||||
|
||||
flow = Flow.model_validate(flow_data)
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
session.add(flow)
|
||||
await session.commit()
|
||||
await session.refresh(flow)
|
||||
|
|
@ -600,7 +584,7 @@ async def created_api_key(active_user):
|
|||
hashed_api_key=hashed,
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
async with async_session_getter(db_manager) as session:
|
||||
async with session_getter(db_manager) as session:
|
||||
stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key)
|
||||
if existing_api_key := (await session.exec(stmt)).first():
|
||||
yield existing_api_key
|
||||
|
|
@ -630,7 +614,7 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
|
|||
@pytest.fixture(name="starter_project")
|
||||
async def get_starter_project(active_user):
|
||||
# once the client is created, we can get the starter project
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
stmt = (
|
||||
select(Flow)
|
||||
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from langflow.initial_setup.setup import load_starter_projects
|
|||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
||||
from langflow.services.database.models.folder.model import FolderCreate
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
|
||||
|
|
@ -530,7 +530,7 @@ async def test_download_file(
|
|||
]
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
async with async_session_getter(db_manager) as _session:
|
||||
async with session_getter(db_manager) as _session:
|
||||
saved_flows = []
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from httpx import AsyncClient
|
|||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ from sqlmodel import select
|
|||
async def super_user(client): # noqa: ARG001
|
||||
settings_manager = get_settings_service()
|
||||
auth_settings = settings_manager.auth_settings
|
||||
async with async_session_getter(get_db_service()) as db:
|
||||
async with session_getter(get_db_service()) as db:
|
||||
return await create_super_user(
|
||||
db=db,
|
||||
username=auth_settings.SUPERUSER,
|
||||
|
|
@ -42,7 +42,7 @@ async def super_user_headers(
|
|||
|
||||
@pytest.fixture
|
||||
async def deactivated_user(client): # noqa: ARG001
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="deactivateduser",
|
||||
password=get_password_hash("testpassword"),
|
||||
|
|
@ -61,7 +61,7 @@ async def test_user_waiting_for_approval(client):
|
|||
password = "testpassword" # noqa: S105
|
||||
|
||||
# Debug: Check if the user already exists
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
stmt = select(User).where(User.username == username)
|
||||
existing_user = (await session.exec(stmt)).first()
|
||||
if existing_user:
|
||||
|
|
@ -70,7 +70,7 @@ async def test_user_waiting_for_approval(client):
|
|||
)
|
||||
|
||||
# Create a user that is not active and has never logged in
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username=username,
|
||||
password=get_password_hash(password),
|
||||
|
|
@ -86,7 +86,7 @@ async def test_user_waiting_for_approval(client):
|
|||
assert response.json()["detail"] == "Waiting for approval"
|
||||
|
||||
# Debug: Check if the user still exists after the test
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
stmt = select(User).where(User.username == username)
|
||||
existing_user = (await session.exec(stmt)).first()
|
||||
if existing_user:
|
||||
|
|
@ -140,7 +140,7 @@ async def test_data_consistency_after_delete(client: AsyncClient, test_user, sup
|
|||
@pytest.mark.api_key_required
|
||||
async def test_inactive_user(client: AsyncClient):
|
||||
# Create a user that is not active and has a last_login_at value
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
async with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue