From 20e14d49b45e66267c236706f40250cdee6a6bc1 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 20 Sep 2023 18:40:19 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(test=5Fdatabase.py):=20remov?= =?UTF-8?q?e=20unused=20imports=20to=20improve=20code=20readability=20?= =?UTF-8?q?=E2=9C=A8=20feat(test=5Fdatabase.py):=20add=20support=20for=20s?= =?UTF-8?q?ession=20management=20using=20session=5Fgetter=20to=20improve?= =?UTF-8?q?=20code=20organization=20and=20maintainability=20=F0=9F=90=9B?= =?UTF-8?q?=20fix(test=5Fendpoints.py):=20remove=20unused=20imports=20to?= =?UTF-8?q?=20improve=20code=20readability=20=E2=9C=A8=20feat(test=5Fendpo?= =?UTF-8?q?ints.py):=20add=20support=20for=20session=20management=20using?= =?UTF-8?q?=20session=5Fgetter=20to=20improve=20code=20organization=20and?= =?UTF-8?q?=20maintainability=20=F0=9F=90=9B=20fix(test=5Flogin.py):=20rem?= =?UTF-8?q?ove=20unused=20imports=20to=20improve=20code=20readability=20?= =?UTF-8?q?=E2=9C=A8=20feat(test=5Flogin.py):=20add=20support=20for=20sess?= =?UTF-8?q?ion=20management=20using=20session=5Fgetter=20to=20improve=20co?= =?UTF-8?q?de=20organization=20and=20maintainability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_database.py | 22 ++++++++++++---------- tests/test_endpoints.py | 14 ++++++++------ tests/test_login.py | 9 ++++++--- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/tests/test_database.py b/tests/test_database.py index e4f68ca56..7641f1e65 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,4 +1,6 @@ from langflow.services.database.models.base import orjson_dumps +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import orjson import pytest @@ -178,9 +180,7 @@ def test_upload_file( assert response_data[1]["data"] == data -def test_download_file( - client: TestClient, session: Session, json_flow, active_user, logged_in_headers -): +def test_download_file(client: TestClient, json_flow, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -190,18 +190,20 @@ def test_download_file( FlowCreate(name="Flow 2", description="description", data=data), ] ) - for flow in flow_list.flows: - flow.user_id = active_user.id - db_flow = Flow.from_orm(flow) - session.add(db_flow) - session.commit() + db_manager = get_db_manager() + with session_getter(db_manager) as session: + for flow in flow_list.flows: + flow.user_id = active_user.id + db_flow = Flow.from_orm(flow) + session.add(db_flow) + session.commit() # Make request to endpoint response = client.get("api/v1/flows/download/", headers=logged_in_headers) # Check response status code - assert response.status_code == 200 + assert response.status_code == 200, response.json() # Check response data response_data = response.json()["flows"] - assert len(response_data) == 2 + assert len(response_data) == 2, response_data assert response_data[0]["name"] == "Flow 1" assert response_data[0]["description"] == "description" assert response_data[0]["data"] == data diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 2b706ba31..474a72e31 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,7 +1,8 @@ import uuid from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.api_key.api_key import ApiKey -from langflow.services.getters import get_settings_manager +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager, get_settings_manager import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS @@ -88,7 +89,7 @@ PROMPT_REQUEST = { @pytest.fixture -def created_api_key(session, active_user): +def created_api_key(active_user): hashed = get_password_hash("random_key") api_key = ApiKey( name="test_api_key", @@ -96,10 +97,11 @@ def created_api_key(session, active_user): api_key="random_key", hashed_api_key=hashed, ) - - session.add(api_key) - session.commit() - session.refresh(api_key) + db_manager = get_db_manager() + with session_getter(db_manager) as session: + session.add(api_key) + session.commit() + session.refresh(api_key) return api_key diff --git a/tests/test_login.py b/tests/test_login.py index 07abb35ab..651e2264b 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -1,3 +1,5 @@ +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import pytest from langflow.services.database.models.user import User from langflow.services.auth.utils import get_password_hash @@ -15,10 +17,11 @@ def test_user(): ) -def test_login_successful(client, test_user, session): +def test_login_successful(client, test_user): # Adding the test user to the database - session.add(test_user) - session.commit() + with session_getter(get_db_manager()) as session: + session.add(test_user) + session.commit() response = client.post( "api/v1/login", data={"username": "testuser", "password": "testpassword"}