diff --git a/tests/test_database.py b/tests/test_database.py index e4f68ca56..7641f1e65 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,4 +1,6 @@ from langflow.services.database.models.base import orjson_dumps +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import orjson import pytest @@ -178,9 +180,7 @@ def test_upload_file( assert response_data[1]["data"] == data -def test_download_file( - client: TestClient, session: Session, json_flow, active_user, logged_in_headers -): +def test_download_file(client: TestClient, json_flow, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -190,18 +190,20 @@ def test_download_file( FlowCreate(name="Flow 2", description="description", data=data), ] ) - for flow in flow_list.flows: - flow.user_id = active_user.id - db_flow = Flow.from_orm(flow) - session.add(db_flow) - session.commit() + db_manager = get_db_manager() + with session_getter(db_manager) as session: + for flow in flow_list.flows: + flow.user_id = active_user.id + db_flow = Flow.from_orm(flow) + session.add(db_flow) + session.commit() # Make request to endpoint response = client.get("api/v1/flows/download/", headers=logged_in_headers) # Check response status code - assert response.status_code == 200 + assert response.status_code == 200, response.json() # Check response data response_data = response.json()["flows"] - assert len(response_data) == 2 + assert len(response_data) == 2, response_data assert response_data[0]["name"] == "Flow 1" assert response_data[0]["description"] == "description" assert response_data[0]["data"] == data diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 2b706ba31..474a72e31 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,7 +1,8 @@ import uuid from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.api_key.api_key import ApiKey -from langflow.services.getters import get_settings_manager +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager, get_settings_manager import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS @@ -88,7 +89,7 @@ PROMPT_REQUEST = { @pytest.fixture -def created_api_key(session, active_user): +def created_api_key(active_user): hashed = get_password_hash("random_key") api_key = ApiKey( name="test_api_key", @@ -96,10 +97,11 @@ def created_api_key(session, active_user): api_key="random_key", hashed_api_key=hashed, ) - - session.add(api_key) - session.commit() - session.refresh(api_key) + db_manager = get_db_manager() + with session_getter(db_manager) as session: + session.add(api_key) + session.commit() + session.refresh(api_key) return api_key diff --git a/tests/test_login.py b/tests/test_login.py index 07abb35ab..651e2264b 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -1,3 +1,5 @@ +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import pytest from langflow.services.database.models.user import User from langflow.services.auth.utils import get_password_hash @@ -15,10 +17,11 @@ def test_user(): ) -def test_login_successful(client, test_user, session): +def test_login_successful(client, test_user): # Adding the test user to the database - session.add(test_user) - session.commit() + with session_getter(get_db_manager()) as session: + session.add(test_user) + session.commit() response = client.post( "api/v1/login", data={"username": "testuser", "password": "testpassword"}