ref: Remove unused sync session_scope, with_session and engine (#5333)
Remove unused sync session_scope, with_session and engine
This commit is contained in:
parent
9bf372ee3e
commit
3f0e383135
20 changed files with 75 additions and 114 deletions
|
|
@ -27,7 +27,7 @@ from langflow.logging.logger import configure, logger
|
|||
from langflow.main import setup_app
|
||||
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
|
||||
from langflow.services.deps import get_db_service, get_settings_service, session_scope
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER
|
||||
from langflow.services.utils import initialize_services
|
||||
from langflow.utils.version import fetch_latest_version, get_version_info
|
||||
|
|
@ -490,7 +490,7 @@ async def _migration(*, test: bool, fix: bool) -> None:
|
|||
db_service = get_db_service()
|
||||
if not test:
|
||||
await db_service.run_migrations()
|
||||
results = db_service.run_migrations_test()
|
||||
results = await db_service.run_migrations_test()
|
||||
display_results(results)
|
||||
|
||||
|
||||
|
|
@ -533,7 +533,7 @@ def api_key(
|
|||
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
|
||||
return None
|
||||
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
stmt = select(User).where(User.username == DEFAULT_SUPERUSER)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from langflow.services.database.models import User
|
|||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.transactions.model import TransactionTable
|
||||
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
|
||||
from langflow.services.deps import async_session_scope, get_session
|
||||
from langflow.services.deps import get_session, session_scope
|
||||
from langflow.services.store.utils import get_lf_version_from_pypi
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -141,7 +141,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
|
|||
|
||||
|
||||
async def _get_flow_name(flow_id: uuid.UUID) -> str:
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
flow = await session.get(Flow, flow_id)
|
||||
if flow is None:
|
||||
msg = f"Flow {flow_id} not found"
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from langflow.schema.schema import OutputValue
|
|||
from langflow.services.cache.utils import CacheMiss
|
||||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
from langflow.services.deps import async_session_scope, get_chat_service, get_session, get_telemetry_service
|
||||
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service, session_scope
|
||||
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -167,7 +167,7 @@ async def build_flow(
|
|||
if not data:
|
||||
graph = await build_graph_from_db(flow_id=flow_id, session=session, chat_service=chat_service)
|
||||
else:
|
||||
async with async_session_scope() as new_session:
|
||||
async with session_scope() as new_session:
|
||||
result = await new_session.exec(select(Flow.name).where(Flow.id == flow_id))
|
||||
flow_name = result.first()
|
||||
graph = await build_graph_from_data(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|||
from langflow.custom.custom_component.base_component import BaseComponent
|
||||
from langflow.helpers.flow import list_flows, load_flow, run_flow
|
||||
from langflow.schema import Data
|
||||
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service
|
||||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.template.utils import update_frontend_node_with_template_values
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
|
|
@ -437,7 +437,7 @@ class CustomComponent(BaseComponent):
|
|||
else:
|
||||
msg = f"Invalid user id: {self.user_id}"
|
||||
raise TypeError(msg)
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
|
||||
|
||||
async def list_key_names(self):
|
||||
|
|
@ -454,7 +454,7 @@ class CustomComponent(BaseComponent):
|
|||
raise ValueError(msg)
|
||||
variable_service = get_variable_service()
|
||||
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
return await variable_service.list_variables(user_id=self.user_id, session=session)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlmodel import select
|
|||
from langflow.schema.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.flow.model import FlowRead
|
||||
from langflow.services.deps import async_session_scope, get_settings_service
|
||||
from langflow.services.deps import get_settings_service, session_scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
|
@ -32,7 +32,7 @@ async def list_flows(*, user_id: str | None = None) -> list[Data]:
|
|||
msg = "Session is invalid"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
|
||||
flows = (await session.exec(stmt)).all()
|
||||
|
|
@ -58,7 +58,7 @@ async def load_flow(
|
|||
msg = f"Flow {flow_name} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
graph_data = flow.data if (flow := await session.get(Flow, flow_id)) else None
|
||||
if not graph_data:
|
||||
msg = f"Flow {flow_id} not found"
|
||||
|
|
@ -69,7 +69,7 @@ async def load_flow(
|
|||
|
||||
|
||||
async def find_flow(flow_name: str, user_id: str) -> str | None:
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
stmt = select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id)
|
||||
flow = (await session.exec(stmt)).first()
|
||||
return flow.id if flow else None
|
||||
|
|
@ -275,7 +275,7 @@ def get_arg_names(inputs: list[Vertex]) -> list[dict[str, str]]:
|
|||
|
||||
|
||||
async def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: UUID | None = None) -> FlowRead | None:
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
endpoint_name = None
|
||||
try:
|
||||
flow_id = UUID(flow_id_or_name)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from langflow.services.deps import get_db_service
|
|||
|
||||
|
||||
async def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None:
|
||||
async with get_db_service().with_async_session() as session:
|
||||
async with get_db_service().with_session() as session:
|
||||
try:
|
||||
flow_id = UUID(flow_id_or_name)
|
||||
flow = await session.get(Flow, flow_id)
|
||||
|
|
|
|||
|
|
@ -28,10 +28,10 @@ from langflow.services.database.models.folder.utils import (
|
|||
)
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.deps import (
|
||||
async_session_scope,
|
||||
get_settings_service,
|
||||
get_storage_service,
|
||||
get_variable_service,
|
||||
session_scope,
|
||||
)
|
||||
from langflow.template.field.prompt import DEFAULT_PROMPT_INTUT_TYPES
|
||||
from langflow.utils.util import escape_json_dump
|
||||
|
|
@ -544,7 +544,7 @@ async def load_flows_from_directory() -> None:
|
|||
logger.warning("AUTO_LOGIN is disabled, not loading flows from directory")
|
||||
return
|
||||
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
user = await get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
|
||||
if user is None:
|
||||
msg = "Superuser not found in the database"
|
||||
|
|
@ -618,7 +618,7 @@ async def create_or_update_starter_projects(all_types_dict: dict, *, do_create:
|
|||
all_types_dict (dict): Dictionary containing all component types and their templates
|
||||
do_create (bool, optional): Whether to create new projects. Defaults to True.
|
||||
"""
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
new_folder = await create_starter_folder(session)
|
||||
starter_projects = await load_starter_projects()
|
||||
await delete_start_projects(session, new_folder.id)
|
||||
|
|
@ -674,7 +674,7 @@ async def initialize_super_user_if_needed() -> None:
|
|||
msg = "SUPERUSER and SUPERUSER_PASSWORD must be set in the settings if AUTO_LOGIN is true."
|
||||
raise ValueError(msg)
|
||||
|
||||
async with async_session_scope() as async_session:
|
||||
async with session_scope() as async_session:
|
||||
super_user = await create_super_user(db=async_session, username=username, password=password)
|
||||
await get_variable_service().initialize_user_variables(super_user.id, async_session)
|
||||
await create_default_folder_if_it_doesnt_exist(async_session, super_user.id)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.utils.async_helpers import run_until_complete
|
||||
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ async def aget_messages(
|
|||
Returns:
|
||||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
|
||||
messages = await session.exec(stmt)
|
||||
return [await Message.create(**d.model_dump()) for d in messages]
|
||||
|
|
@ -118,7 +118,7 @@ async def aadd_messages(messages: Message | list[Message], flow_id: str | UUID |
|
|||
|
||||
try:
|
||||
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
messages_models = await aadd_messagetables(messages_models, session)
|
||||
return [await Message.create(**message.model_dump()) for message in messages_models]
|
||||
except Exception as e:
|
||||
|
|
@ -130,7 +130,7 @@ async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
|
|||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
updated_messages: list[MessageTable] = []
|
||||
for message in messages:
|
||||
msg = await session.get(MessageTable, message.id)
|
||||
|
|
@ -186,7 +186,7 @@ async def adelete_messages(session_id: str) -> None:
|
|||
Args:
|
||||
session_id (str): The session ID associated with the messages to delete.
|
||||
"""
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
stmt = (
|
||||
delete(MessageTable)
|
||||
.where(col(MessageTable.session_id) == session_id)
|
||||
|
|
@ -201,7 +201,7 @@ async def delete_message(id_: str) -> None:
|
|||
Args:
|
||||
id_ (str): The ID of the message to delete.
|
||||
"""
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
message = await session.get(MessageTable, id_)
|
||||
if message:
|
||||
await session.delete(message)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ async def api_key_security(
|
|||
settings_service = get_settings_service()
|
||||
result: ApiKey | User | None
|
||||
|
||||
async with get_db_service().with_async_session() as db:
|
||||
async with get_db_service().with_session() as db:
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
# Get the first user
|
||||
if not settings_service.auth_settings.SUPERUSER:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from uuid import UUID
|
||||
|
||||
from langflow.services.database.models.message.model import MessageTable, MessageUpdate
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.utils.async_helpers import run_until_complete
|
||||
|
||||
|
||||
async def _update_message(message_id: UUID | str, message: MessageUpdate | dict):
|
||||
if not isinstance(message, MessageUpdate):
|
||||
message = MessageUpdate(**message)
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
db_message = await session.get(MessageTable, message_id)
|
||||
if not db_message:
|
||||
msg = "Message not found"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -19,7 +19,7 @@ from sqlalchemy.dialects import sqlite as dialect_sqlite
|
|||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel import Session, SQLModel, create_engine, select, text
|
||||
from sqlmodel import SQLModel, select, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.initial_setup.constants import STARTER_FOLDER_NAME
|
||||
|
|
@ -55,7 +55,6 @@ class DatabaseService(Service):
|
|||
# Using decorator will make the method not able to use self
|
||||
event.listen(Engine, "connect", self.on_connection)
|
||||
self.engine = self._create_engine()
|
||||
self.async_engine = self._create_async_engine()
|
||||
|
||||
alembic_log_file = self.settings_service.settings.alembic_log_file
|
||||
# Check if the provided path is absolute, cross-platform.
|
||||
|
|
@ -74,7 +73,6 @@ class DatabaseService(Service):
|
|||
def reload_engine(self) -> None:
|
||||
self._sanitize_database_url()
|
||||
self.engine = self._create_engine()
|
||||
self.async_engine = self._create_async_engine()
|
||||
|
||||
def _sanitize_database_url(self):
|
||||
if self.database_url.startswith("postgres://"):
|
||||
|
|
@ -84,16 +82,7 @@ class DatabaseService(Service):
|
|||
"To avoid this warning, update the database URL."
|
||||
)
|
||||
|
||||
def _create_engine(self) -> Engine:
|
||||
"""Create the engine for the database."""
|
||||
return create_engine(
|
||||
self.database_url,
|
||||
connect_args=self._get_connect_args(),
|
||||
pool_size=self.settings_service.settings.pool_size,
|
||||
max_overflow=self.settings_service.settings.max_overflow,
|
||||
)
|
||||
|
||||
def _create_async_engine(self) -> AsyncEngine:
|
||||
def _create_engine(self) -> AsyncEngine:
|
||||
"""Create the engine for the database."""
|
||||
url_components = self.database_url.split("://", maxsplit=1)
|
||||
if url_components[0].startswith("sqlite"):
|
||||
|
|
@ -144,14 +133,9 @@ class DatabaseService(Service):
|
|||
finally:
|
||||
cursor.close()
|
||||
|
||||
@contextmanager
|
||||
def with_session(self):
|
||||
with Session(self.engine) as session:
|
||||
yield session
|
||||
|
||||
@asynccontextmanager
|
||||
async def with_async_session(self):
|
||||
async with AsyncSession(self.async_engine, expire_on_commit=False) as session:
|
||||
async def with_session(self):
|
||||
async with AsyncSession(self.engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
async def assign_orphaned_flows_to_superuser(self) -> None:
|
||||
|
|
@ -161,7 +145,7 @@ class DatabaseService(Service):
|
|||
if not settings_service.auth_settings.AUTO_LOGIN:
|
||||
return
|
||||
|
||||
async with self.with_async_session() as session:
|
||||
async with self.with_session() as session:
|
||||
# Fetch orphaned flows
|
||||
stmt = (
|
||||
select(models.Flow)
|
||||
|
|
@ -262,7 +246,7 @@ class DatabaseService(Service):
|
|||
return True
|
||||
|
||||
async def check_schema_health(self) -> None:
|
||||
async with self.with_async_session() as session, session.bind.connect() as conn:
|
||||
async with self.with_session() as session, session.bind.connect() as conn:
|
||||
await conn.run_sync(self._check_schema_health)
|
||||
|
||||
def init_alembic(self, alembic_cfg) -> None:
|
||||
|
|
@ -323,7 +307,7 @@ class DatabaseService(Service):
|
|||
|
||||
async def run_migrations(self, *, fix=False) -> None:
|
||||
should_initialize_alembic = False
|
||||
async with self.with_async_session() as session:
|
||||
async with self.with_session() as session:
|
||||
# If the table does not exist it throws an error
|
||||
# so we need to catch it
|
||||
try:
|
||||
|
|
@ -348,7 +332,7 @@ class DatabaseService(Service):
|
|||
time.sleep(3)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
def run_migrations_test(self):
|
||||
async def run_migrations_test(self):
|
||||
# This method is used for testing purposes only
|
||||
# We will check that all models are in the database
|
||||
# and that the database is up to date with all columns
|
||||
|
|
@ -356,11 +340,16 @@ class DatabaseService(Service):
|
|||
sql_models = [
|
||||
model for model in models.__dict__.values() if isinstance(model, type) and issubclass(model, SQLModel)
|
||||
]
|
||||
return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models]
|
||||
async with self.with_session() as session, session.bind.connect() as conn:
|
||||
return [
|
||||
TableResults(sql_model.__tablename__, conn.run_sync(self.check_table, sql_model))
|
||||
for sql_model in sql_models
|
||||
]
|
||||
|
||||
def check_table(self, model):
|
||||
@staticmethod
|
||||
def check_table(connection, model):
|
||||
results = []
|
||||
inspector = inspect(self.engine)
|
||||
inspector = inspect(connection)
|
||||
table_name = model.__tablename__
|
||||
expected_columns = list(model.__fields__.keys())
|
||||
available_columns = []
|
||||
|
|
@ -416,7 +405,7 @@ class DatabaseService(Service):
|
|||
logger.debug("Database and tables created successfully")
|
||||
|
||||
async def create_db_and_tables(self) -> None:
|
||||
async with self.with_async_session() as session, session.bind.connect() as conn:
|
||||
async with self.with_session() as session, session.bind.connect() as conn:
|
||||
await conn.run_sync(self._create_db_and_tables)
|
||||
|
||||
async def teardown(self) -> None:
|
||||
|
|
@ -425,9 +414,8 @@ class DatabaseService(Service):
|
|||
settings_service = get_settings_service()
|
||||
# remove the default superuser if auto_login is enabled
|
||||
# using the SUPERUSER to get the user
|
||||
async with self.with_async_session() as session:
|
||||
async with self.with_session() as session:
|
||||
await teardown_superuser(settings_service, session)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error tearing down database")
|
||||
await self.async_engine.dispose()
|
||||
await asyncio.to_thread(self.engine.dispose)
|
||||
await self.engine.dispose()
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def initialize_database(*, fix_migration: bool = False) -> None:
|
|||
@asynccontextmanager
|
||||
async def session_getter(db_service: DatabaseService):
|
||||
try:
|
||||
session = AsyncSession(db_service.async_engine, expire_on_commit=False)
|
||||
session = AsyncSession(db_service.engine, expire_on_commit=False)
|
||||
yield session
|
||||
except Exception:
|
||||
logger.exception("Session rollback because of exception")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -8,9 +8,8 @@ from loguru import logger
|
|||
from langflow.services.schema import ServiceType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.cache.service import AsyncBaseCacheService, CacheService
|
||||
|
|
@ -149,38 +148,12 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
AsyncSession: An async session object.
|
||||
|
||||
"""
|
||||
async with get_db_service().with_async_session() as session:
|
||||
async with get_db_service().with_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope() -> Generator[Session, None, None]:
|
||||
"""Context manager for managing a session scope.
|
||||
|
||||
This context manager is used to manage a session scope for database operations.
|
||||
It ensures that the session is properly committed if no exceptions occur,
|
||||
and rolled back if an exception is raised.
|
||||
|
||||
Yields:
|
||||
Session: The session object.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the session scope.
|
||||
|
||||
"""
|
||||
db_service = get_db_service()
|
||||
with db_service.with_session() as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("An error occurred during the session scope.")
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_session_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
async def session_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Context manager for managing an async session scope.
|
||||
|
||||
This context manager is used to manage an async session scope for database operations.
|
||||
|
|
@ -195,7 +168,7 @@ async def async_session_scope() -> AsyncGenerator[AsyncSession, None]:
|
|||
|
||||
"""
|
||||
db_service = get_db_service()
|
||||
async with db_service.with_async_session() as session:
|
||||
async with db_service.with_session() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ async def teardown_superuser(settings_service, session: AsyncSession) -> None:
|
|||
async def teardown_services() -> None:
|
||||
"""Teardown all the services."""
|
||||
try:
|
||||
async with get_db_service().with_async_session() as session:
|
||||
async with get_db_service().with_session() as session:
|
||||
await teardown_superuser(get_settings_service(), session)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(exc)
|
||||
|
|
@ -240,7 +240,7 @@ async def initialize_services(*, fix_migration: bool = False) -> None:
|
|||
await initialize_database(fix_migration=fix_migration)
|
||||
db_service = get_db_service()
|
||||
await db_service.initialize_alembic_log_file()
|
||||
async with db_service.with_async_session() as session:
|
||||
async with db_service.with_session() as session:
|
||||
settings_service = get_service(ServiceType.SETTINGS_SERVICE)
|
||||
await setup_superuser(settings_service, session)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -400,7 +400,7 @@ async def test_user(client):
|
|||
@pytest.fixture
|
||||
async def active_user(client): # noqa: ARG001
|
||||
db_manager = get_db_service()
|
||||
async with db_manager.with_async_session() as session:
|
||||
async with db_manager.with_session() as session:
|
||||
user = User(
|
||||
username="activeuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
|
|
@ -418,7 +418,7 @@ async def active_user(client): # noqa: ARG001
|
|||
yield user
|
||||
# Clean up
|
||||
# Now cleanup transactions, vertex_build
|
||||
async with db_manager.with_async_session() as session:
|
||||
async with db_manager.with_session() as session:
|
||||
user = await session.get(User, user.id, options=[selectinload(User.flows)])
|
||||
await _delete_transactions_and_vertex_builds(session, user.flows)
|
||||
await session.delete(user)
|
||||
|
|
@ -439,7 +439,7 @@ async def logged_in_headers(client, active_user):
|
|||
@pytest.fixture
|
||||
async def active_super_user(client): # noqa: ARG001
|
||||
db_manager = get_db_service()
|
||||
async with db_manager.with_async_session() as session:
|
||||
async with db_manager.with_session() as session:
|
||||
user = User(
|
||||
username="activeuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
|
|
@ -457,7 +457,7 @@ async def active_super_user(client): # noqa: ARG001
|
|||
yield user
|
||||
# Clean up
|
||||
# Now cleanup transactions, vertex_build
|
||||
async with db_manager.with_async_session() as session:
|
||||
async with db_manager.with_session() as session:
|
||||
user = await session.get(User, user.id, options=[selectinload(User.flows)])
|
||||
await _delete_transactions_and_vertex_builds(session, user.flows)
|
||||
await session.delete(user)
|
||||
|
|
|
|||
|
|
@ -615,14 +615,14 @@ async def test_read_only_starter_projects(client: AsyncClient, logged_in_headers
|
|||
assert len(response.json()) == len(starter_projects)
|
||||
|
||||
|
||||
def test_sqlite_pragmas():
|
||||
async def test_sqlite_pragmas():
|
||||
db_service = get_db_service()
|
||||
|
||||
with db_service.with_session() as session:
|
||||
async with db_service.with_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
assert session.exec(text("PRAGMA journal_mode;")).scalar() == "wal"
|
||||
assert session.exec(text("PRAGMA synchronous;")).scalar() == 1
|
||||
assert (await session.exec(text("PRAGMA journal_mode;"))).scalar() == "wal"
|
||||
assert (await session.exec(text("PRAGMA synchronous;"))).scalar() == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from langflow.initial_setup.setup import (
|
|||
)
|
||||
from langflow.interface.types import aget_all_types_dict
|
||||
from langflow.services.database.models.folder.model import Folder
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.deps import session_scope
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ async def test_get_project_data():
|
|||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_create_or_update_starter_projects():
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
# Get the number of projects returned by load_starter_projects
|
||||
num_projects = len(await load_starter_projects())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.deps import session_scope
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ def test_user():
|
|||
async def test_login_successful(client, test_user):
|
||||
# Adding the test user to the database
|
||||
try:
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
session.add(test_user)
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ from langflow.schema.properties import Properties, Source
|
|||
# Assuming you have these imports available
|
||||
from langflow.services.database.models.message import MessageCreate, MessageRead
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.tracing.utils import convert_to_langchain_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def created_message():
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = await aadd_messagetables([messagetable], session)
|
||||
|
|
@ -37,7 +37,7 @@ async def created_message():
|
|||
|
||||
@pytest.fixture
|
||||
async def created_messages(async_session): # noqa: ARG001
|
||||
async with async_session_scope() as _session:
|
||||
async with session_scope() as _session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from langflow.memory import aadd_messagetables
|
|||
# Assuming you have these imports available
|
||||
from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def created_message():
|
||||
async with async_session_scope() as session:
|
||||
async with session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = await aadd_messagetables([messagetable], session)
|
||||
|
|
@ -22,7 +22,7 @@ async def created_message():
|
|||
|
||||
@pytest.fixture
|
||||
async def created_messages(session): # noqa: ARG001
|
||||
async with async_session_scope() as _session:
|
||||
async with session_scope() as _session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue