diff --git a/src/backend/base/langflow/__main__.py b/src/backend/base/langflow/__main__.py index b825d4cb0..f07caac33 100644 --- a/src/backend/base/langflow/__main__.py +++ b/src/backend/base/langflow/__main__.py @@ -27,7 +27,7 @@ from langflow.logging.logger import configure, logger from langflow.main import setup_app from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist from langflow.services.database.utils import session_getter -from langflow.services.deps import async_session_scope, get_db_service, get_settings_service +from langflow.services.deps import get_db_service, get_settings_service, session_scope from langflow.services.settings.constants import DEFAULT_SUPERUSER from langflow.services.utils import initialize_services from langflow.utils.version import fetch_latest_version, get_version_info @@ -490,7 +490,7 @@ async def _migration(*, test: bool, fix: bool) -> None: db_service = get_db_service() if not test: await db_service.run_migrations() - results = db_service.run_migrations_test() + results = await db_service.run_migrations_test() display_results(results) @@ -533,7 +533,7 @@ def api_key( typer.echo("Auto login is disabled. API keys cannot be created through the CLI.") return None - async with async_session_scope() as session: + async with session_scope() as session: from langflow.services.database.models.user.model import User stmt = select(User).where(User.username == DEFAULT_SUPERUSER) diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index 61912ae1f..62bd81cc3 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -16,7 +16,7 @@ from langflow.services.database.models import User from langflow.services.database.models.flow import Flow from langflow.services.database.models.transactions.model import TransactionTable from langflow.services.database.models.vertex_builds.model import VertexBuildTable -from langflow.services.deps import async_session_scope, get_session +from langflow.services.deps import get_session, session_scope from langflow.services.store.utils import get_lf_version_from_pypi if TYPE_CHECKING: @@ -141,7 +141,7 @@ def format_elapsed_time(elapsed_time: float) -> str: async def _get_flow_name(flow_id: uuid.UUID) -> str: - async with async_session_scope() as session: + async with session_scope() as session: flow = await session.get(Flow, flow_id) if flow is None: msg = f"Flow {flow_id} not found" diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 6aaf2988b..f4dc156d1 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -44,7 +44,7 @@ from langflow.schema.schema import OutputValue from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService from langflow.services.database.models.flow.model import Flow -from langflow.services.deps import async_session_scope, get_chat_service, get_session, get_telemetry_service +from langflow.services.deps import get_chat_service, get_session, get_telemetry_service, session_scope from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload if TYPE_CHECKING: @@ -167,7 +167,7 @@ async def build_flow( if not data: graph = await build_graph_from_db(flow_id=flow_id, session=session, chat_service=chat_service) else: - async with async_session_scope() as new_session: + async with session_scope() as new_session: result = await new_session.exec(select(Flow.name).where(Flow.id == flow_id)) flow_name = result.first() graph = await build_graph_from_data( diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index fed586fb0..9cb9e1523 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from langflow.custom.custom_component.base_component import BaseComponent from langflow.helpers.flow import list_flows, load_flow, run_flow from langflow.schema import Data -from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service +from langflow.services.deps import get_storage_service, get_variable_service, session_scope from langflow.services.storage.service import StorageService from langflow.template.utils import update_frontend_node_with_template_values from langflow.type_extraction.type_extraction import post_process_type @@ -437,7 +437,7 @@ class CustomComponent(BaseComponent): else: msg = f"Invalid user id: {self.user_id}" raise TypeError(msg) - async with async_session_scope() as session: + async with session_scope() as session: return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) async def list_key_names(self): @@ -454,7 +454,7 @@ class CustomComponent(BaseComponent): raise ValueError(msg) variable_service = get_variable_service() - async with async_session_scope() as session: + async with session_scope() as session: return await variable_service.list_variables(user_id=self.user_id, session=session) def index(self, value: int = 0): diff --git a/src/backend/base/langflow/helpers/flow.py b/src/backend/base/langflow/helpers/flow.py index 1e379c1c6..9526c8aee 100644 --- a/src/backend/base/langflow/helpers/flow.py +++ b/src/backend/base/langflow/helpers/flow.py @@ -10,7 +10,7 @@ from sqlmodel import select from langflow.schema.schema import INPUT_FIELD_NAME from langflow.services.database.models.flow import Flow from langflow.services.database.models.flow.model import FlowRead -from langflow.services.deps import async_session_scope, get_settings_service +from langflow.services.deps import get_settings_service, session_scope if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -32,7 +32,7 @@ async def list_flows(*, user_id: str | None = None) -> list[Data]: msg = "Session is invalid" raise ValueError(msg) try: - async with async_session_scope() as session: + async with session_scope() as session: uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712 flows = (await session.exec(stmt)).all() @@ -58,7 +58,7 @@ async def load_flow( msg = f"Flow {flow_name} not found" raise ValueError(msg) - async with async_session_scope() as session: + async with session_scope() as session: graph_data = flow.data if (flow := await session.get(Flow, flow_id)) else None if not graph_data: msg = f"Flow {flow_id} not found" @@ -69,7 +69,7 @@ async def load_flow( async def find_flow(flow_name: str, user_id: str) -> str | None: - async with async_session_scope() as session: + async with session_scope() as session: stmt = select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id) flow = (await session.exec(stmt)).first() return flow.id if flow else None @@ -275,7 +275,7 @@ def get_arg_names(inputs: list[Vertex]) -> list[dict[str, str]]: async def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: UUID | None = None) -> FlowRead | None: - async with async_session_scope() as session: + async with session_scope() as session: endpoint_name = None try: flow_id = UUID(flow_id_or_name) diff --git a/src/backend/base/langflow/helpers/user.py b/src/backend/base/langflow/helpers/user.py index e5b956b59..df0e3a1c6 100644 --- a/src/backend/base/langflow/helpers/user.py +++ b/src/backend/base/langflow/helpers/user.py @@ -9,7 +9,7 @@ from langflow.services.deps import get_db_service async def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None: - async with get_db_service().with_async_session() as session: + async with get_db_service().with_session() as session: try: flow_id = UUID(flow_id_or_name) flow = await session.get(Flow, flow_id) diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index ddbdd86d1..88bd30138 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -28,10 +28,10 @@ from langflow.services.database.models.folder.utils import ( ) from langflow.services.database.models.user.crud import get_user_by_username from langflow.services.deps import ( - async_session_scope, get_settings_service, get_storage_service, get_variable_service, + session_scope, ) from langflow.template.field.prompt import DEFAULT_PROMPT_INTUT_TYPES from langflow.utils.util import escape_json_dump @@ -544,7 +544,7 @@ async def load_flows_from_directory() -> None: logger.warning("AUTO_LOGIN is disabled, not loading flows from directory") return - async with async_session_scope() as session: + async with session_scope() as session: user = await get_user_by_username(session, settings_service.auth_settings.SUPERUSER) if user is None: msg = "Superuser not found in the database" @@ -618,7 +618,7 @@ async def create_or_update_starter_projects(all_types_dict: dict, *, do_create: all_types_dict (dict): Dictionary containing all component types and their templates do_create (bool, optional): Whether to create new projects. Defaults to True. """ - async with async_session_scope() as session: + async with session_scope() as session: new_folder = await create_starter_folder(session) starter_projects = await load_starter_projects() await delete_start_projects(session, new_folder.id) @@ -674,7 +674,7 @@ async def initialize_super_user_if_needed() -> None: msg = "SUPERUSER and SUPERUSER_PASSWORD must be set in the settings if AUTO_LOGIN is true." raise ValueError(msg) - async with async_session_scope() as async_session: + async with session_scope() as async_session: super_user = await create_super_user(db=async_session, username=username, password=password) await get_variable_service().initialize_user_variables(super_user.id, async_session) await create_default_folder_if_it_doesnt_exist(async_session, super_user.id) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index ebf8fac3d..c6ec20348 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -11,7 +11,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from langflow.schema.message import Message from langflow.services.database.models.message.model import MessageRead, MessageTable -from langflow.services.deps import async_session_scope +from langflow.services.deps import session_scope from langflow.utils.async_helpers import run_until_complete @@ -92,7 +92,7 @@ async def aget_messages( Returns: List[Data]: A list of Data objects representing the retrieved messages. """ - async with async_session_scope() as session: + async with session_scope() as session: stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) messages = await session.exec(stmt) return [await Message.create(**d.model_dump()) for d in messages] @@ -118,7 +118,7 @@ async def aadd_messages(messages: Message | list[Message], flow_id: str | UUID | try: messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] - async with async_session_scope() as session: + async with session_scope() as session: messages_models = await aadd_messagetables(messages_models, session) return [await Message.create(**message.model_dump()) for message in messages_models] except Exception as e: @@ -130,7 +130,7 @@ async def aupdate_messages(messages: Message | list[Message]) -> list[Message]: if not isinstance(messages, list): messages = [messages] - async with async_session_scope() as session: + async with session_scope() as session: updated_messages: list[MessageTable] = [] for message in messages: msg = await session.get(MessageTable, message.id) @@ -186,7 +186,7 @@ async def adelete_messages(session_id: str) -> None: Args: session_id (str): The session ID associated with the messages to delete. """ - async with async_session_scope() as session: + async with session_scope() as session: stmt = ( delete(MessageTable) .where(col(MessageTable.session_id) == session_id) @@ -201,7 +201,7 @@ async def delete_message(id_: str) -> None: Args: id_ (str): The ID of the message to delete. """ - async with async_session_scope() as session: + async with session_scope() as session: message = await session.get(MessageTable, id_) if message: await session.delete(message) diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index 197c0cdc7..fc2d99b49 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -41,7 +41,7 @@ async def api_key_security( settings_service = get_settings_service() result: ApiKey | User | None - async with get_db_service().with_async_session() as db: + async with get_db_service().with_session() as db: if settings_service.auth_settings.AUTO_LOGIN: # Get the first user if not settings_service.auth_settings.SUPERUSER: diff --git a/src/backend/base/langflow/services/database/models/message/crud.py b/src/backend/base/langflow/services/database/models/message/crud.py index 5031960e4..577d97b56 100644 --- a/src/backend/base/langflow/services/database/models/message/crud.py +++ b/src/backend/base/langflow/services/database/models/message/crud.py @@ -1,14 +1,14 @@ from uuid import UUID from langflow.services.database.models.message.model import MessageTable, MessageUpdate -from langflow.services.deps import async_session_scope +from langflow.services.deps import session_scope from langflow.utils.async_helpers import run_until_complete async def _update_message(message_id: UUID | str, message: MessageUpdate | dict): if not isinstance(message, MessageUpdate): message = MessageUpdate(**message) - async with async_session_scope() as session: + async with session_scope() as session: db_message = await session.get(MessageTable, message_id) if not db_message: msg = "Message not found" diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 0256b8d2a..dc96aa692 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -4,7 +4,7 @@ import asyncio import re import sqlite3 import time -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING @@ -19,7 +19,7 @@ from sqlalchemy.dialects import sqlite as dialect_sqlite from sqlalchemy.engine import Engine from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlmodel import Session, SQLModel, create_engine, select, text +from sqlmodel import SQLModel, select, text from sqlmodel.ext.asyncio.session import AsyncSession from langflow.initial_setup.constants import STARTER_FOLDER_NAME @@ -55,7 +55,6 @@ class DatabaseService(Service): # Using decorator will make the method not able to use self event.listen(Engine, "connect", self.on_connection) self.engine = self._create_engine() - self.async_engine = self._create_async_engine() alembic_log_file = self.settings_service.settings.alembic_log_file # Check if the provided path is absolute, cross-platform. @@ -74,7 +73,6 @@ class DatabaseService(Service): def reload_engine(self) -> None: self._sanitize_database_url() self.engine = self._create_engine() - self.async_engine = self._create_async_engine() def _sanitize_database_url(self): if self.database_url.startswith("postgres://"): @@ -84,16 +82,7 @@ class DatabaseService(Service): "To avoid this warning, update the database URL." ) - def _create_engine(self) -> Engine: - """Create the engine for the database.""" - return create_engine( - self.database_url, - connect_args=self._get_connect_args(), - pool_size=self.settings_service.settings.pool_size, - max_overflow=self.settings_service.settings.max_overflow, - ) - - def _create_async_engine(self) -> AsyncEngine: + def _create_engine(self) -> AsyncEngine: """Create the engine for the database.""" url_components = self.database_url.split("://", maxsplit=1) if url_components[0].startswith("sqlite"): @@ -144,14 +133,9 @@ class DatabaseService(Service): finally: cursor.close() - @contextmanager - def with_session(self): - with Session(self.engine) as session: - yield session - @asynccontextmanager - async def with_async_session(self): - async with AsyncSession(self.async_engine, expire_on_commit=False) as session: + async def with_session(self): + async with AsyncSession(self.engine, expire_on_commit=False) as session: yield session async def assign_orphaned_flows_to_superuser(self) -> None: @@ -161,7 +145,7 @@ class DatabaseService(Service): if not settings_service.auth_settings.AUTO_LOGIN: return - async with self.with_async_session() as session: + async with self.with_session() as session: # Fetch orphaned flows stmt = ( select(models.Flow) @@ -262,7 +246,7 @@ class DatabaseService(Service): return True async def check_schema_health(self) -> None: - async with self.with_async_session() as session, session.bind.connect() as conn: + async with self.with_session() as session, session.bind.connect() as conn: await conn.run_sync(self._check_schema_health) def init_alembic(self, alembic_cfg) -> None: @@ -323,7 +307,7 @@ class DatabaseService(Service): async def run_migrations(self, *, fix=False) -> None: should_initialize_alembic = False - async with self.with_async_session() as session: + async with self.with_session() as session: # If the table does not exist it throws an error # so we need to catch it try: @@ -348,7 +332,7 @@ class DatabaseService(Service): time.sleep(3) command.upgrade(alembic_cfg, "head") - def run_migrations_test(self): + async def run_migrations_test(self): # This method is used for testing purposes only # We will check that all models are in the database # and that the database is up to date with all columns @@ -356,11 +340,16 @@ class DatabaseService(Service): sql_models = [ model for model in models.__dict__.values() if isinstance(model, type) and issubclass(model, SQLModel) ] - return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models] + async with self.with_session() as session, session.bind.connect() as conn: + return [ + TableResults(sql_model.__tablename__, conn.run_sync(self.check_table, sql_model)) + for sql_model in sql_models + ] - def check_table(self, model): + @staticmethod + def check_table(connection, model): results = [] - inspector = inspect(self.engine) + inspector = inspect(connection) table_name = model.__tablename__ expected_columns = list(model.__fields__.keys()) available_columns = [] @@ -416,7 +405,7 @@ class DatabaseService(Service): logger.debug("Database and tables created successfully") async def create_db_and_tables(self) -> None: - async with self.with_async_session() as session, session.bind.connect() as conn: + async with self.with_session() as session, session.bind.connect() as conn: await conn.run_sync(self._create_db_and_tables) async def teardown(self) -> None: @@ -425,9 +414,8 @@ class DatabaseService(Service): settings_service = get_settings_service() # remove the default superuser if auto_login is enabled # using the SUPERUSER to get the user - async with self.with_async_session() as session: + async with self.with_session() as session: await teardown_superuser(settings_service, session) except Exception: # noqa: BLE001 logger.exception("Error tearing down database") - await self.async_engine.dispose() - await asyncio.to_thread(self.engine.dispose) + await self.engine.dispose() diff --git a/src/backend/base/langflow/services/database/utils.py b/src/backend/base/langflow/services/database/utils.py index f5f2201aa..942a11cfa 100644 --- a/src/backend/base/langflow/services/database/utils.py +++ b/src/backend/base/langflow/services/database/utils.py @@ -61,7 +61,7 @@ async def initialize_database(*, fix_migration: bool = False) -> None: @asynccontextmanager async def session_getter(db_service: DatabaseService): try: - session = AsyncSession(db_service.async_engine, expire_on_commit=False) + session = AsyncSession(db_service.engine, expire_on_commit=False) yield session except Exception: logger.exception("Session rollback because of exception") diff --git a/src/backend/base/langflow/services/deps.py b/src/backend/base/langflow/services/deps.py index c18e09872..4dd4a2639 100644 --- a/src/backend/base/langflow/services/deps.py +++ b/src/backend/base/langflow/services/deps.py @@ -1,6 +1,6 @@ from __future__ import annotations -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from typing import TYPE_CHECKING from loguru import logger @@ -8,9 +8,8 @@ from loguru import logger from langflow.services.schema import ServiceType if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator + from collections.abc import AsyncGenerator - from sqlmodel import Session from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.cache.service import AsyncBaseCacheService, CacheService @@ -149,38 +148,12 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: AsyncSession: An async session object. """ - async with get_db_service().with_async_session() as session: + async with get_db_service().with_session() as session: yield session -@contextmanager -def session_scope() -> Generator[Session, None, None]: - """Context manager for managing a session scope. - - This context manager is used to manage a session scope for database operations. - It ensures that the session is properly committed if no exceptions occur, - and rolled back if an exception is raised. - - Yields: - Session: The session object. - - Raises: - Exception: If an error occurs during the session scope. - - """ - db_service = get_db_service() - with db_service.with_session() as session: - try: - yield session - session.commit() - except Exception: - logger.exception("An error occurred during the session scope.") - session.rollback() - raise - - @asynccontextmanager -async def async_session_scope() -> AsyncGenerator[AsyncSession, None]: +async def session_scope() -> AsyncGenerator[AsyncSession, None]: """Context manager for managing an async session scope. This context manager is used to manage an async session scope for database operations. @@ -195,7 +168,7 @@ async def async_session_scope() -> AsyncGenerator[AsyncSession, None]: """ db_service = get_db_service() - async with db_service.with_async_session() as session: + async with db_service.with_session() as session: try: yield session await session.commit() diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index 0bcca7a99..9bf41e353 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -131,7 +131,7 @@ async def teardown_superuser(settings_service, session: AsyncSession) -> None: async def teardown_services() -> None: """Teardown all the services.""" try: - async with get_db_service().with_async_session() as session: + async with get_db_service().with_session() as session: await teardown_superuser(get_settings_service(), session) except Exception as exc: # noqa: BLE001 logger.exception(exc) @@ -240,7 +240,7 @@ async def initialize_services(*, fix_migration: bool = False) -> None: await initialize_database(fix_migration=fix_migration) db_service = get_db_service() await db_service.initialize_alembic_log_file() - async with db_service.with_async_session() as session: + async with db_service.with_session() as session: settings_service = get_service(ServiceType.SETTINGS_SERVICE) await setup_superuser(settings_service, session) try: diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 6e0ce6b9a..f80410eed 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -400,7 +400,7 @@ async def test_user(client): @pytest.fixture async def active_user(client): # noqa: ARG001 db_manager = get_db_service() - async with db_manager.with_async_session() as session: + async with db_manager.with_session() as session: user = User( username="activeuser", password=get_password_hash("testpassword"), @@ -418,7 +418,7 @@ async def active_user(client): # noqa: ARG001 yield user # Clean up # Now cleanup transactions, vertex_build - async with db_manager.with_async_session() as session: + async with db_manager.with_session() as session: user = await session.get(User, user.id, options=[selectinload(User.flows)]) await _delete_transactions_and_vertex_builds(session, user.flows) await session.delete(user) @@ -439,7 +439,7 @@ async def logged_in_headers(client, active_user): @pytest.fixture async def active_super_user(client): # noqa: ARG001 db_manager = get_db_service() - async with db_manager.with_async_session() as session: + async with db_manager.with_session() as session: user = User( username="activeuser", password=get_password_hash("testpassword"), @@ -457,7 +457,7 @@ async def active_super_user(client): # noqa: ARG001 yield user # Clean up # Now cleanup transactions, vertex_build - async with db_manager.with_async_session() as session: + async with db_manager.with_session() as session: user = await session.get(User, user.id, options=[selectinload(User.flows)]) await _delete_transactions_and_vertex_builds(session, user.flows) await session.delete(user) diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index 7dcc0847c..adea3a63b 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -615,14 +615,14 @@ async def test_read_only_starter_projects(client: AsyncClient, logged_in_headers assert len(response.json()) == len(starter_projects) -def test_sqlite_pragmas(): +async def test_sqlite_pragmas(): db_service = get_db_service() - with db_service.with_session() as session: + async with db_service.with_session() as session: from sqlalchemy import text - assert session.exec(text("PRAGMA journal_mode;")).scalar() == "wal" - assert session.exec(text("PRAGMA synchronous;")).scalar() == 1 + assert (await session.exec(text("PRAGMA journal_mode;"))).scalar() == "wal" + assert (await session.exec(text("PRAGMA synchronous;"))).scalar() == 1 @pytest.mark.usefixtures("active_user") diff --git a/src/backend/tests/unit/test_initial_setup.py b/src/backend/tests/unit/test_initial_setup.py index 5236f0c7f..4fd736df3 100644 --- a/src/backend/tests/unit/test_initial_setup.py +++ b/src/backend/tests/unit/test_initial_setup.py @@ -11,7 +11,7 @@ from langflow.initial_setup.setup import ( ) from langflow.interface.types import aget_all_types_dict from langflow.services.database.models.folder.model import Folder -from langflow.services.deps import async_session_scope +from langflow.services.deps import session_scope from sqlalchemy.orm import selectinload from sqlmodel import select @@ -52,7 +52,7 @@ async def test_get_project_data(): @pytest.mark.usefixtures("client") async def test_create_or_update_starter_projects(): - async with async_session_scope() as session: + async with session_scope() as session: # Get the number of projects returned by load_starter_projects num_projects = len(await load_starter_projects()) diff --git a/src/backend/tests/unit/test_login.py b/src/backend/tests/unit/test_login.py index d045a0407..e22016267 100644 --- a/src/backend/tests/unit/test_login.py +++ b/src/backend/tests/unit/test_login.py @@ -1,7 +1,7 @@ import pytest from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.user import User -from langflow.services.deps import async_session_scope +from langflow.services.deps import session_scope from sqlalchemy.exc import IntegrityError @@ -18,7 +18,7 @@ def test_user(): async def test_login_successful(client, test_user): # Adding the test user to the database try: - async with async_session_scope() as session: + async with session_scope() as session: session.add(test_user) await session.commit() except IntegrityError: diff --git a/src/backend/tests/unit/test_messages.py b/src/backend/tests/unit/test_messages.py index 66942b392..b1eb7b606 100644 --- a/src/backend/tests/unit/test_messages.py +++ b/src/backend/tests/unit/test_messages.py @@ -22,13 +22,13 @@ from langflow.schema.properties import Properties, Source # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import async_session_scope +from langflow.services.deps import session_scope from langflow.services.tracing.utils import convert_to_langchain_type @pytest.fixture async def created_message(): - async with async_session_scope() as session: + async with session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) messagetables = await aadd_messagetables([messagetable], session) @@ -37,7 +37,7 @@ async def created_message(): @pytest.fixture async def created_messages(async_session): # noqa: ARG001 - async with async_session_scope() as _session: + async with session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), diff --git a/src/backend/tests/unit/test_messages_endpoints.py b/src/backend/tests/unit/test_messages_endpoints.py index 2031cb5ef..61cd6f133 100644 --- a/src/backend/tests/unit/test_messages_endpoints.py +++ b/src/backend/tests/unit/test_messages_endpoints.py @@ -8,12 +8,12 @@ from langflow.memory import aadd_messagetables # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import async_session_scope +from langflow.services.deps import session_scope @pytest.fixture async def created_message(): - async with async_session_scope() as session: + async with session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) messagetables = await aadd_messagetables([messagetable], session) @@ -22,7 +22,7 @@ async def created_message(): @pytest.fixture async def created_messages(session): # noqa: ARG001 - async with async_session_scope() as _session: + async with session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),