diff --git a/src/backend/base/langflow/__main__.py b/src/backend/base/langflow/__main__.py index 99fa7ceab..b825d4cb0 100644 --- a/src/backend/base/langflow/__main__.py +++ b/src/backend/base/langflow/__main__.py @@ -26,7 +26,7 @@ from sqlmodel import select from langflow.logging.logger import configure, logger from langflow.main import setup_app from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist -from langflow.services.database.utils import async_session_getter +from langflow.services.database.utils import session_getter from langflow.services.deps import async_session_scope, get_db_service, get_settings_service from langflow.services.settings.constants import DEFAULT_SUPERUSER from langflow.services.utils import initialize_services @@ -423,7 +423,7 @@ def superuser( async def _create_superuser(): await initialize_services() - async with async_session_getter(db_service) as session: + async with session_getter(db_service) as session: from langflow.services.auth.utils import create_super_user if await create_super_user(db=session, username=username, password=password): @@ -485,6 +485,15 @@ def copy_db() -> None: typer.echo("Pre-release database not found in the cache directory.") +async def _migration(*, test: bool, fix: bool) -> None: + await initialize_services(fix_migration=fix) + db_service = get_db_service() + if not test: + await db_service.run_migrations() + results = db_service.run_migrations_test() + display_results(results) + + @app.command() def migration( test: bool = typer.Option(default=True, help="Run migrations in test mode."), # noqa: FBT001 @@ -499,12 +508,7 @@ def migration( ): raise typer.Abort - asyncio.run(initialize_services(fix_migration=fix)) - db_service = get_db_service() - if not test: - db_service.run_migrations() - results = db_service.run_migrations_test() - display_results(results) + asyncio.run(_migration(test=test, fix=fix)) @app.command() diff --git a/src/backend/base/langflow/graph/utils.py b/src/backend/base/langflow/graph/utils.py index b1a69efc1..bfb4fadb1 100644 --- a/src/backend/base/langflow/graph/utils.py +++ b/src/backend/base/langflow/graph/utils.py @@ -18,7 +18,7 @@ from langflow.services.database.models.transactions.crud import log_transaction from langflow.services.database.models.transactions.model import TransactionBase from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build from langflow.services.database.models.vertex_builds.model import VertexBuildBase -from langflow.services.database.utils import async_session_getter +from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service, get_settings_service if TYPE_CHECKING: @@ -157,7 +157,7 @@ async def log_transaction( error=error, flow_id=flow_id if isinstance(flow_id, UUID) else UUID(flow_id), ) - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: inserted = await crud_log_transaction(session, transaction) logger.debug(f"Logged transaction: {inserted.id}") except Exception: # noqa: BLE001 @@ -186,7 +186,7 @@ async def log_vertex_build( # ugly hack to get the model dump with weird datatypes artifacts=json.loads(json.dumps(artifacts, default=str)), ) - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: inserted = await crud_log_vertex_build(session, vertex_build) logger.debug(f"Logged vertex build: {inserted.build_id}") except Exception: # noqa: BLE001 diff --git a/src/backend/base/langflow/services/database/models/api_key/crud.py b/src/backend/base/langflow/services/database/models/api_key/crud.py index 7c3063357..dfc8fd6e7 100644 --- a/src/backend/base/langflow/services/database/models/api_key/crud.py +++ b/src/backend/base/langflow/services/database/models/api_key/crud.py @@ -10,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.database.models import User from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead -from langflow.services.database.utils import async_session_getter +from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service if TYPE_CHECKING: @@ -68,7 +68,7 @@ async def check_key(session: AsyncSession, api_key: str) -> User | None: async def update_total_uses(api_key_id: UUID): """Update the total uses and last used at.""" - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: new_api_key = await session.get(ApiKey, api_key_id) if new_api_key is None: msg = "API Key not found" diff --git a/src/backend/base/langflow/services/database/models/message/crud.py b/src/backend/base/langflow/services/database/models/message/crud.py deleted file mode 100644 index 840d015ba..000000000 --- a/src/backend/base/langflow/services/database/models/message/crud.py +++ /dev/null @@ -1,20 +0,0 @@ -from uuid import UUID - -from langflow.services.database.models.message.model import MessageTable, MessageUpdate -from langflow.services.deps import session_scope - - -def update_message(message_id: UUID | str, message: MessageUpdate | dict): - if not isinstance(message, MessageUpdate): - message = MessageUpdate(**message) - with session_scope() as session: - db_message = session.get(MessageTable, message_id) - if not db_message: - msg = "Message not found" - raise ValueError(msg) - message_dict = message.model_dump(exclude_unset=True, exclude_none=True) - db_message.sqlmodel_update(message_dict) - session.add(db_message) - session.commit() - session.refresh(db_message) - return db_message diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 33087a750..d319c94d4 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -218,8 +218,8 @@ class DatabaseService(Service): return new_name - def check_schema_health(self) -> bool: - inspector = inspect(self.engine) + def _check_schema_health(self, connection) -> bool: + inspector = inspect(connection) model_mapping: dict[str, type[SQLModel]] = { "flow": models.Flow, @@ -251,6 +251,10 @@ class DatabaseService(Service): return True + async def check_schema_health(self) -> None: + async with self.with_async_session() as session, session.bind.connect() as conn: + await conn.run_sync(self._check_schema_health) + def init_alembic(self, alembic_cfg) -> None: logger.info("Initializing alembic") command.ensure_version(alembic_cfg) @@ -258,7 +262,7 @@ class DatabaseService(Service): command.upgrade(alembic_cfg, "head") logger.info("Alembic initialized") - def run_migrations(self, *, fix=False) -> None: + def _run_migrations(self, should_initialize_alembic, fix) -> None: # First we need to check if alembic has been initialized # If not, we need to initialize it # if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized @@ -274,16 +278,6 @@ class DatabaseService(Service): alembic_cfg.set_main_option("script_location", str(self.script_location)) alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%")) - should_initialize_alembic = False - with self.with_session() as session: - # If the table does not exist it throws an error - # so we need to catch it - try: - session.exec(text("SELECT * FROM alembic_version")) - except Exception: # noqa: BLE001 - logger.debug("Alembic not initialized") - should_initialize_alembic = True - if should_initialize_alembic: try: self.init_alembic(alembic_cfg) @@ -317,6 +311,18 @@ class DatabaseService(Service): if fix: self.try_downgrade_upgrade_until_success(alembic_cfg) + async def run_migrations(self, *, fix=False) -> None: + should_initialize_alembic = False + async with self.with_async_session() as session: + # If the table does not exist it throws an error + # so we need to catch it + try: + await session.exec(text("SELECT * FROM alembic_version")) + except Exception: # noqa: BLE001 + logger.debug("Alembic not initialized") + should_initialize_alembic = True + await asyncio.to_thread(self._run_migrations, should_initialize_alembic, fix) + def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None: # Try -1 then head, if it fails, try -2 then head, etc. # until we reach the number of retries @@ -363,10 +369,11 @@ class DatabaseService(Service): results.append(Result(name=column, type="column", success=True)) return results - def create_db_and_tables(self) -> None: + @staticmethod + def _create_db_and_tables(connection) -> None: from sqlalchemy import inspect - inspector = inspect(self.engine) + inspector = inspect(connection) table_names = inspector.get_table_names() current_tables = ["flow", "user", "apikey", "folder", "message", "variable", "transaction", "vertex_build"] @@ -378,7 +385,7 @@ class DatabaseService(Service): for table in SQLModel.metadata.sorted_tables: try: - table.create(self.engine, checkfirst=True) + table.create(connection, checkfirst=True) except OperationalError as oe: logger.warning(f"Table {table} already exists, skipping. Exception: {oe}") except Exception as exc: @@ -387,7 +394,7 @@ class DatabaseService(Service): raise RuntimeError(msg) from exc # Now check if the required tables exist, if not, something went wrong. - inspector = inspect(self.engine) + inspector = inspect(connection) table_names = inspector.get_table_names() for table in current_tables: if table not in table_names: @@ -398,6 +405,10 @@ class DatabaseService(Service): logger.debug("Database and tables created successfully") + async def create_db_and_tables(self) -> None: + async with self.with_async_session() as session, session.bind.connect() as conn: + await conn.run_sync(self._create_db_and_tables) + async def teardown(self) -> None: logger.debug("Tearing down database") try: diff --git a/src/backend/base/langflow/services/database/utils.py b/src/backend/base/langflow/services/database/utils.py index 97c800f08..f5f2201aa 100644 --- a/src/backend/base/langflow/services/database/utils.py +++ b/src/backend/base/langflow/services/database/utils.py @@ -1,25 +1,25 @@ from __future__ import annotations -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import TYPE_CHECKING from alembic.util.exc import CommandError from loguru import logger -from sqlmodel import Session, text +from sqlmodel import text from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from langflow.services.database.service import DatabaseService -def initialize_database(*, fix_migration: bool = False) -> None: +async def initialize_database(*, fix_migration: bool = False) -> None: logger.debug("Initializing database") from langflow.services.deps import get_db_service database_service: DatabaseService = get_db_service() try: - database_service.create_db_and_tables() + await database_service.create_db_and_tables() except Exception as exc: # if the exception involves tables already existing # we can ignore it @@ -28,13 +28,13 @@ def initialize_database(*, fix_migration: bool = False) -> None: logger.exception(msg) raise RuntimeError(msg) from exc try: - database_service.check_schema_health() + await database_service.check_schema_health() except Exception as exc: msg = "Error checking schema health" logger.exception(msg) raise RuntimeError(msg) from exc try: - database_service.run_migrations(fix=fix_migration) + await database_service.run_migrations(fix=fix_migration) except CommandError as exc: # if "overlaps with other requested revisions" or "Can't locate revision identified by" # are not in the exception, we can't handle it @@ -46,9 +46,9 @@ def initialize_database(*, fix_migration: bool = False) -> None: # We need to delete the alembic_version table # and run the migrations again logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again") - with session_getter(database_service) as session: - session.exec(text("DROP TABLE alembic_version")) - database_service.run_migrations(fix=fix_migration) + async with session_getter(database_service) as session: + await session.exec(text("DROP TABLE alembic_version")) + await database_service.run_migrations(fix=fix_migration) except Exception as exc: # if the exception involves tables already existing # we can ignore it @@ -58,21 +58,8 @@ def initialize_database(*, fix_migration: bool = False) -> None: logger.debug("Database initialized") -@contextmanager -def session_getter(db_service: DatabaseService): - try: - session = Session(db_service.engine) - yield session - except Exception: - logger.exception("Session rollback because of exception") - session.rollback() - raise - finally: - session.close() - - @asynccontextmanager -async def async_session_getter(db_service: DatabaseService): +async def session_getter(db_service: DatabaseService): try: session = AsyncSession(db_service.async_engine, expire_on_commit=False) yield session diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index 974c77795..50bf6a376 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -237,7 +237,7 @@ async def initialize_services(*, fix_migration: bool = False) -> None: # Test cache connection get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory()) # Setup the superuser - await asyncio.to_thread(initialize_database, fix_migration=fix_migration) + await initialize_database(fix_migration=fix_migration) async with get_db_service().with_async_session() as session: settings_service = get_service(ServiceType.SETTINGS_SERVICE) await setup_superuser(settings_service, session) diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 391a86999..ede70343c 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -5,9 +5,8 @@ import shutil # we need to import tmpdir import tempfile from collections.abc import AsyncGenerator -from contextlib import contextmanager, suppress +from contextlib import suppress from pathlib import Path -from typing import TYPE_CHECKING from uuid import UUID import anyio @@ -27,7 +26,7 @@ from langflow.services.database.models.folder.model import Folder from langflow.services.database.models.transactions.model import TransactionTable from langflow.services.database.models.user.model import User, UserCreate, UserRead from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id -from langflow.services.database.utils import async_session_getter +from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service from loguru import logger from sqlalchemy.ext.asyncio import create_async_engine @@ -39,10 +38,6 @@ from typer.testing import CliRunner from tests.api_keys import get_openai_api_key -if TYPE_CHECKING: - from langflow.services.database.service import DatabaseService - - load_dotenv() @@ -369,17 +364,6 @@ async def client_fixture( await anyio.Path(db_path).unlink() -# create a fixture for session_getter above -@pytest.fixture(name="session_getter") -def session_getter_fixture(client): # noqa: ARG001 - @contextmanager - def blank_session_getter(db_service: "DatabaseService"): - with Session(db_service.engine) as session: - yield session - - return blank_session_getter - - @pytest.fixture def runner(): return CliRunner() @@ -489,7 +473,7 @@ async def flow( flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id) flow = Flow.model_validate(flow_data) - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: session.add(flow) await session.commit() await session.refresh(flow) @@ -600,7 +584,7 @@ async def created_api_key(active_user): hashed_api_key=hashed, ) db_manager = get_db_service() - async with async_session_getter(db_manager) as session: + async with session_getter(db_manager) as session: stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key) if existing_api_key := (await session.exec(stmt)).first(): yield existing_api_key @@ -630,7 +614,7 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test): @pytest.fixture(name="starter_project") async def get_starter_project(active_user): # once the client is created, we can get the starter project - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: stmt = ( select(Flow) .where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME)) diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index f32cf2edb..d228d3e05 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -12,7 +12,7 @@ from langflow.initial_setup.setup import load_starter_projects from langflow.services.database.models.base import orjson_dumps from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate from langflow.services.database.models.folder.model import FolderCreate -from langflow.services.database.utils import async_session_getter +from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service @@ -530,7 +530,7 @@ async def test_download_file( ] ) db_manager = get_db_service() - async with async_session_getter(db_manager) as _session: + async with session_getter(db_manager) as _session: saved_flows = [] for flow in flow_list.flows: flow.user_id = active_user.id diff --git a/src/backend/tests/unit/test_user.py b/src/backend/tests/unit/test_user.py index d2122109c..c9eb9cfd9 100644 --- a/src/backend/tests/unit/test_user.py +++ b/src/backend/tests/unit/test_user.py @@ -5,7 +5,7 @@ from httpx import AsyncClient from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user import UserUpdate from langflow.services.database.models.user.model import User -from langflow.services.database.utils import async_session_getter +from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service, get_settings_service from sqlmodel import select @@ -14,7 +14,7 @@ from sqlmodel import select async def super_user(client): # noqa: ARG001 settings_manager = get_settings_service() auth_settings = settings_manager.auth_settings - async with async_session_getter(get_db_service()) as db: + async with session_getter(get_db_service()) as db: return await create_super_user( db=db, username=auth_settings.SUPERUSER, @@ -42,7 +42,7 @@ async def super_user_headers( @pytest.fixture async def deactivated_user(client): # noqa: ARG001 - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: user = User( username="deactivateduser", password=get_password_hash("testpassword"), @@ -61,7 +61,7 @@ async def test_user_waiting_for_approval(client): password = "testpassword" # noqa: S105 # Debug: Check if the user already exists - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: stmt = select(User).where(User.username == username) existing_user = (await session.exec(stmt)).first() if existing_user: @@ -70,7 +70,7 @@ async def test_user_waiting_for_approval(client): ) # Create a user that is not active and has never logged in - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: user = User( username=username, password=get_password_hash(password), @@ -86,7 +86,7 @@ async def test_user_waiting_for_approval(client): assert response.json()["detail"] == "Waiting for approval" # Debug: Check if the user still exists after the test - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: stmt = select(User).where(User.username == username) existing_user = (await session.exec(stmt)).first() if existing_user: @@ -140,7 +140,7 @@ async def test_data_consistency_after_delete(client: AsyncClient, test_user, sup @pytest.mark.api_key_required async def test_inactive_user(client: AsyncClient): # Create a user that is not active and has a last_login_at value - async with async_session_getter(get_db_service()) as session: + async with session_getter(get_db_service()) as session: user = User( username="inactiveuser", password=get_password_hash("testpassword"),