ref: Use AsyncSession in some tests (#5151)
This commit is contained in:
parent
f97a326023
commit
624a2dde5d
7 changed files with 53 additions and 47 deletions
|
|
@ -74,7 +74,7 @@ def session_getter(db_service: DatabaseService):
|
|||
@asynccontextmanager
|
||||
async def async_session_getter(db_service: DatabaseService):
|
||||
try:
|
||||
session = AsyncSession(db_service.async_engine)
|
||||
session = AsyncSession(db_service.async_engine, expire_on_commit=False)
|
||||
yield session
|
||||
except Exception:
|
||||
logger.exception("Session rollback because of exception")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from langflow.services.database.models.folder.model import Folder
|
|||
from langflow.services.database.models.transactions.model import TransactionTable
|
||||
from langflow.services.database.models.user.model import User, UserCreate, UserRead
|
||||
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
|
@ -157,7 +157,7 @@ async def async_session():
|
|||
engine = create_async_engine("sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
async with AsyncSession(engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
|
@ -469,7 +469,7 @@ async def logged_in_headers_super_user(client, active_super_user):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def flow(
|
||||
async def flow(
|
||||
client, # noqa: ARG001
|
||||
json_flow: str,
|
||||
active_user,
|
||||
|
|
@ -480,14 +480,14 @@ def flow(
|
|||
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
|
||||
|
||||
flow = Flow.model_validate(flow_data)
|
||||
with session_getter(get_db_service()) as session:
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
session.add(flow)
|
||||
session.commit()
|
||||
session.refresh(flow)
|
||||
await session.commit()
|
||||
await session.refresh(flow)
|
||||
yield flow
|
||||
# Clean up
|
||||
session.delete(flow)
|
||||
session.commit()
|
||||
await session.delete(flow)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -582,7 +582,7 @@ async def flow_component(client: AsyncClient, logged_in_headers):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def created_api_key(active_user):
|
||||
async def created_api_key(active_user):
|
||||
hashed = get_password_hash("random_key")
|
||||
api_key = ApiKey(
|
||||
name="test_api_key",
|
||||
|
|
@ -591,17 +591,18 @@ def created_api_key(active_user):
|
|||
hashed_api_key=hashed,
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
|
||||
async with async_session_getter(db_manager) as session:
|
||||
stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key)
|
||||
if existing_api_key := (await session.exec(stmt)).first():
|
||||
yield existing_api_key
|
||||
return
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
yield api_key
|
||||
# Clean up
|
||||
session.delete(api_key)
|
||||
session.commit()
|
||||
await session.delete(api_key)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(name="simple_api_test")
|
||||
|
|
@ -618,14 +619,15 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
|
|||
|
||||
|
||||
@pytest.fixture(name="starter_project")
|
||||
def get_starter_project(active_user):
|
||||
async def get_starter_project(active_user):
|
||||
# once the client is created, we can get the starter project
|
||||
with session_getter(get_db_service()) as session:
|
||||
flow = session.exec(
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
stmt = (
|
||||
select(Flow)
|
||||
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))
|
||||
.where(Flow.name == "Basic Prompting (Hello, World)")
|
||||
).first()
|
||||
)
|
||||
flow = (await session.exec(stmt)).first()
|
||||
if not flow:
|
||||
msg = "No starter project found"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -640,10 +642,10 @@ def get_starter_project(active_user):
|
|||
)
|
||||
new_flow = Flow.model_validate(new_flow_create, from_attributes=True)
|
||||
session.add(new_flow)
|
||||
session.commit()
|
||||
session.refresh(new_flow)
|
||||
await session.commit()
|
||||
await session.refresh(new_flow)
|
||||
new_flow_dict = new_flow.model_dump()
|
||||
yield new_flow_dict
|
||||
# Clean up
|
||||
session.delete(new_flow)
|
||||
session.commit()
|
||||
await session.delete(new_flow)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ async def session():
|
|||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
async with AsyncSession(engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from langflow.initial_setup.setup import load_starter_projects
|
|||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
||||
from langflow.services.database.models.folder.model import FolderCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
|
||||
|
|
@ -530,14 +530,14 @@ async def test_download_file(
|
|||
]
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as _session:
|
||||
async with async_session_getter(db_manager) as _session:
|
||||
saved_flows = []
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.model_validate(flow, from_attributes=True)
|
||||
_session.add(db_flow)
|
||||
saved_flows.append(db_flow)
|
||||
_session.commit()
|
||||
await _session.commit()
|
||||
# Make request to endpoint inside the session context
|
||||
flow_ids = [str(db_flow.id) for db_flow in saved_flows] # Convert UUIDs to strings
|
||||
flow_ids_json = json.dumps(flow_ids)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from langflow.initial_setup.setup import (
|
|||
)
|
||||
from langflow.interface.types import aget_all_types_dict
|
||||
from langflow.services.database.models.folder.model import Folder
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.deps import async_session_scope
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
|
|
@ -52,12 +53,13 @@ def test_get_project_data():
|
|||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_create_or_update_starter_projects():
|
||||
with session_scope() as session:
|
||||
async with async_session_scope() as session:
|
||||
# Get the number of projects returned by load_starter_projects
|
||||
num_projects = len(await asyncio.to_thread(load_starter_projects))
|
||||
|
||||
# Get the number of projects in the database
|
||||
folder = session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first()
|
||||
stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.name == STARTER_FOLDER_NAME)
|
||||
folder = (await session.exec(stmt)).first()
|
||||
assert folder is not None
|
||||
num_db_projects = len(folder.flows)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.deps import async_session_scope
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
|
||||
|
|
@ -18,9 +18,9 @@ def test_user():
|
|||
async def test_login_successful(client, test_user):
|
||||
# Adding the test user to the database
|
||||
try:
|
||||
with session_scope() as session:
|
||||
async with async_session_scope() as session:
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from httpx import AsyncClient
|
|||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.utils import async_session_getter, session_getter
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
@ -41,8 +41,8 @@ async def super_user_headers(
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def deactivated_user(client): # noqa: ARG001
|
||||
with session_getter(get_db_service()) as session:
|
||||
async def deactivated_user(client): # noqa: ARG001
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="deactivateduser",
|
||||
password=get_password_hash("testpassword"),
|
||||
|
|
@ -51,8 +51,8 @@ def deactivated_user(client): # noqa: ARG001
|
|||
last_login_at=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
|
|
@ -61,15 +61,16 @@ async def test_user_waiting_for_approval(client):
|
|||
password = "testpassword" # noqa: S105
|
||||
|
||||
# Debug: Check if the user already exists
|
||||
with session_getter(get_db_service()) as session:
|
||||
existing_user = session.exec(select(User).where(User.username == username)).first()
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
stmt = select(User).where(User.username == username)
|
||||
existing_user = (await session.exec(stmt)).first()
|
||||
if existing_user:
|
||||
pytest.fail(
|
||||
f"User {username} already exists before the test. Database URL: {get_db_service().database_url}"
|
||||
)
|
||||
|
||||
# Create a user that is not active and has never logged in
|
||||
with session_getter(get_db_service()) as session:
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username=username,
|
||||
password=get_password_hash(password),
|
||||
|
|
@ -77,7 +78,7 @@ async def test_user_waiting_for_approval(client):
|
|||
last_login_at=None,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
login_data = {"username": "waitingforapproval", "password": "testpassword"}
|
||||
response = await client.post("api/v1/login", data=login_data)
|
||||
|
|
@ -85,8 +86,9 @@ async def test_user_waiting_for_approval(client):
|
|||
assert response.json()["detail"] == "Waiting for approval"
|
||||
|
||||
# Debug: Check if the user still exists after the test
|
||||
with session_getter(get_db_service()) as session:
|
||||
existing_user = session.exec(select(User).where(User.username == username)).first()
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
stmt = select(User).where(User.username == username)
|
||||
existing_user = (await session.exec(stmt)).first()
|
||||
if existing_user:
|
||||
pass
|
||||
else:
|
||||
|
|
@ -138,7 +140,7 @@ async def test_data_consistency_after_delete(client: AsyncClient, test_user, sup
|
|||
@pytest.mark.api_key_required
|
||||
async def test_inactive_user(client: AsyncClient):
|
||||
# Create a user that is not active and has a last_login_at value
|
||||
with session_getter(get_db_service()) as session:
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
|
|
@ -146,7 +148,7 @@ async def test_inactive_user(client: AsyncClient):
|
|||
last_login_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
login_data = {"username": "inactiveuser", "password": "testpassword"}
|
||||
response = await client.post("api/v1/login", data=login_data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue