diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index 210995051..47f922adf 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -529,14 +529,12 @@ def load_flows_from_directory(): existing.updated_at = datetime.utcnow() existing.user_id = user_id session.add(existing) - session.commit() else: logger.info(f"Creating new flow: {flow_id} with endpoint name {flow_endpoint_name}") flow["user_id"] = user_id flow = Flow.model_validate(flow, from_attributes=True) flow.updated_at = datetime.utcnow() session.add(flow) - session.commit() def find_existing_flow(session, flow_id, flow_endpoint_name): @@ -614,7 +612,6 @@ def initialize_super_user_if_needed(): super_user = create_super_user(db=session, username=username, password=password) get_variable_service().initialize_user_variables(super_user.id, session) create_default_folder_if_it_doesnt_exist(session, super_user.id) - session.commit() logger.info("Super user initialized") diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 7561e22c6..aeda7e9a0 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -108,7 +108,6 @@ def delete_messages(session_id: str): .where(col(MessageTable.session_id) == session_id) .execution_options(synchronize_session="fetch") ) - session.commit() def store_message( diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 2befcd3c7..7003abd85 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -1,4 +1,5 @@ import time +from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Optional, Type @@ -97,19 +98,8 @@ class DatabaseService(Service): finally: cursor.close() - def __enter__(self): - self._session = Session(self.engine) - return self._session - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is not None: # If an exception has been raised - logger.error(f"Session rollback because of exception: {exc_type.__name__} {exc_value}") - self._session.rollback() - else: - self._session.commit() - self._session.close() - - def get_session(self): + @contextmanager + def with_session(self): with Session(self.engine) as session: yield session @@ -119,7 +109,7 @@ class DatabaseService(Service): # associated with them settings_service = get_settings_service() if settings_service.auth_settings.AUTO_LOGIN: - with Session(self.engine) as session: + with self.with_session() as session: flows = session.exec(select(models.Flow).where(models.Flow.user_id is None)).all() if flows: logger.debug("Migrating flows to default superuser") @@ -190,7 +180,7 @@ class DatabaseService(Service): alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%")) should_initialize_alembic = False - with Session(self.engine) as session: + with self.with_session() as session: # If the table does not exist it throws an error # so we need to catch it try: @@ -322,11 +312,10 @@ class DatabaseService(Service): settings_service = get_settings_service() # remove the default superuser if auto_login is enabled # using the SUPERUSER to get the user - with Session(self.engine) as session: + with self.with_session() as session: teardown_superuser(settings_service, session) except Exception as exc: logger.error(f"Error tearing down database: {exc}") self.engine.dispose() - self.engine.dispose() diff --git a/src/backend/base/langflow/services/deps.py b/src/backend/base/langflow/services/deps.py index 588d0ced5..217104160 100644 --- a/src/backend/base/langflow/services/deps.py +++ b/src/backend/base/langflow/services/deps.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Generator - +from loguru import logger from langflow.services.schema import ServiceType if TYPE_CHECKING: @@ -162,12 +162,12 @@ def get_session() -> Generator["Session", None, None]: Session: A session object. """ - db_service = get_db_service() - yield from db_service.get_session() + with get_db_service().with_session() as session: + yield session @contextmanager -def session_scope(): +def session_scope() -> Generator["Session", None, None]: """ Context manager for managing a session scope. @@ -182,15 +182,15 @@ def session_scope(): Exception: If an error occurs during the session scope. """ - session = next(get_session()) - try: - yield session - session.commit() - except: - session.rollback() - raise - finally: - session.close() + db_service = get_db_service() + with db_service.with_session() as session: + try: + yield session + session.commit() + except Exception as e: + logger.exception("An error occurred during the session scope.", e) + session.rollback() + raise def get_cache_service() -> "CacheService": diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index 558eaf028..6c3e6e7b1 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -106,6 +106,7 @@ def teardown_superuser(settings_service, session): except Exception as exc: logger.exception(exc) + session.rollback() raise RuntimeError("Could not remove default superuser.") from exc diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index 44e347f1e..f3f10bed4 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -407,7 +407,7 @@ def test_migrate_transactions_no_duckdb(client: TestClient): def test_sqlite_pragmas(): db_service = get_db_service() - with db_service as session: + with db_service.with_session() as session: from sqlalchemy import text assert "wal" == session.execute(text("PRAGMA journal_mode;")).fetchone()[0]