chore: improve db session access (#3138)

* chore: improve db session access

* chore: improve db session access

* fix

* [autofix.ci] apply automated fixes

* Refactor session management in test_sqlite_pragmas to use with_session method

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Nicolò Boschi 2024-08-30 23:03:29 +02:00 committed by GitHub
commit 1c87a804bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 21 additions and 35 deletions

View file

@ -529,14 +529,12 @@ def load_flows_from_directory():
existing.updated_at = datetime.utcnow()
existing.user_id = user_id
session.add(existing)
session.commit()
else:
logger.info(f"Creating new flow: {flow_id} with endpoint name {flow_endpoint_name}")
flow["user_id"] = user_id
flow = Flow.model_validate(flow, from_attributes=True)
flow.updated_at = datetime.utcnow()
session.add(flow)
session.commit()
def find_existing_flow(session, flow_id, flow_endpoint_name):
@ -614,7 +612,6 @@ def initialize_super_user_if_needed():
super_user = create_super_user(db=session, username=username, password=password)
get_variable_service().initialize_user_variables(super_user.id, session)
create_default_folder_if_it_doesnt_exist(session, super_user.id)
session.commit()
logger.info("Super user initialized")

View file

@ -108,7 +108,6 @@ def delete_messages(session_id: str):
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
session.commit()
def store_message(

View file

@ -1,4 +1,5 @@
import time
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Type
@ -97,19 +98,8 @@ class DatabaseService(Service):
finally:
cursor.close()
def __enter__(self):
self._session = Session(self.engine)
return self._session
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None: # If an exception has been raised
logger.error(f"Session rollback because of exception: {exc_type.__name__} {exc_value}")
self._session.rollback()
else:
self._session.commit()
self._session.close()
def get_session(self):
@contextmanager
def with_session(self):
with Session(self.engine) as session:
yield session
@ -119,7 +109,7 @@ class DatabaseService(Service):
# associated with them
settings_service = get_settings_service()
if settings_service.auth_settings.AUTO_LOGIN:
with Session(self.engine) as session:
with self.with_session() as session:
flows = session.exec(select(models.Flow).where(models.Flow.user_id is None)).all()
if flows:
logger.debug("Migrating flows to default superuser")
@ -190,7 +180,7 @@ class DatabaseService(Service):
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%"))
should_initialize_alembic = False
with Session(self.engine) as session:
with self.with_session() as session:
# If the table does not exist it throws an error
# so we need to catch it
try:
@ -322,11 +312,10 @@ class DatabaseService(Service):
settings_service = get_settings_service()
# remove the default superuser if auto_login is enabled
# using the SUPERUSER to get the user
with Session(self.engine) as session:
with self.with_session() as session:
teardown_superuser(settings_service, session)
except Exception as exc:
logger.error(f"Error tearing down database: {exc}")
self.engine.dispose()
self.engine.dispose()

View file

@ -1,6 +1,6 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Generator
from loguru import logger
from langflow.services.schema import ServiceType
if TYPE_CHECKING:
@ -162,12 +162,12 @@ def get_session() -> Generator["Session", None, None]:
Session: A session object.
"""
db_service = get_db_service()
yield from db_service.get_session()
with get_db_service().with_session() as session:
yield session
@contextmanager
def session_scope():
def session_scope() -> Generator["Session", None, None]:
"""
Context manager for managing a session scope.
@ -182,15 +182,15 @@ def session_scope():
Exception: If an error occurs during the session scope.
"""
session = next(get_session())
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
db_service = get_db_service()
with db_service.with_session() as session:
try:
yield session
session.commit()
except Exception as e:
logger.exception("An error occurred during the session scope.", e)
session.rollback()
raise
def get_cache_service() -> "CacheService":

View file

@ -106,6 +106,7 @@ def teardown_superuser(settings_service, session):
except Exception as exc:
logger.exception(exc)
session.rollback()
raise RuntimeError("Could not remove default superuser.") from exc

View file

@ -407,7 +407,7 @@ def test_migrate_transactions_no_duckdb(client: TestClient):
def test_sqlite_pragmas():
db_service = get_db_service()
with db_service as session:
with db_service.with_session() as session:
from sqlalchemy import text
assert "wal" == session.execute(text("PRAGMA journal_mode;")).fetchone()[0]