🐛 fix(test_database.py): remove unused imports to improve code readability
✨ feat(test_database.py): add support for session management using session_getter to improve code organization and maintainability 🐛 fix(test_endpoints.py): remove unused imports to improve code readability ✨ feat(test_endpoints.py): add support for session management using session_getter to improve code organization and maintainability 🐛 fix(test_login.py): remove unused imports to improve code readability ✨ feat(test_login.py): add support for session management using session_getter to improve code organization and maintainability
This commit is contained in:
parent
b441d42559
commit
20e14d49b4
3 changed files with 26 additions and 19 deletions
|
|
@ -1,4 +1,6 @@
|
|||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_manager
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
|
|
@ -178,9 +180,7 @@ def test_upload_file(
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_download_file(
|
||||
client: TestClient, session: Session, json_flow, active_user, logged_in_headers
|
||||
):
|
||||
def test_download_file(client: TestClient, json_flow, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
|
|
@ -190,18 +190,20 @@ def test_download_file(
|
|||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.from_orm(flow)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
db_manager = get_db_manager()
|
||||
with session_getter(db_manager) as session:
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.from_orm(flow)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
# Make request to endpoint
|
||||
response = client.get("api/v1/flows/download/", headers=logged_in_headers)
|
||||
# Check response status code
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 200, response.json()
|
||||
# Check response data
|
||||
response_data = response.json()["flows"]
|
||||
assert len(response_data) == 2
|
||||
assert len(response_data) == 2, response_data
|
||||
assert response_data[0]["name"] == "Flow 1"
|
||||
assert response_data[0]["description"] == "description"
|
||||
assert response_data[0]["data"] == data
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import uuid
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
from langflow.services.getters import get_settings_manager
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_manager, get_settings_manager
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
|
|
@ -88,7 +89,7 @@ PROMPT_REQUEST = {
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def created_api_key(session, active_user):
|
||||
def created_api_key(active_user):
|
||||
hashed = get_password_hash("random_key")
|
||||
api_key = ApiKey(
|
||||
name="test_api_key",
|
||||
|
|
@ -96,10 +97,11 @@ def created_api_key(session, active_user):
|
|||
api_key="random_key",
|
||||
hashed_api_key=hashed,
|
||||
)
|
||||
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
db_manager = get_db_manager()
|
||||
with session_getter(db_manager) as session:
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_manager
|
||||
import pytest
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
|
@ -15,10 +17,11 @@ def test_user():
|
|||
)
|
||||
|
||||
|
||||
def test_login_successful(client, test_user, session):
|
||||
def test_login_successful(client, test_user):
|
||||
# Adding the test user to the database
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
with session_getter(get_db_manager()) as session:
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue