diff --git a/tests/conftest.py b/tests/conftest.py index d0829d0c8..1d1fb9ac7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.flow.flow import Flow from langflow.services.database.models.user.user import User, UserCreate import pytest from fastapi.testclient import TestClient @@ -192,3 +193,18 @@ def logged_in_headers(client, active_user): tokens = response.json() a_token = tokens["access_token"] return {"Authorization": f"Bearer {a_token}"} + + +@pytest.fixture +def flow(client, json_flow: str, session, active_user): + from langflow.services.database.models.flow.flow import FlowCreate + + loaded_json = json.loads(json_flow) + flow_data = FlowCreate( + name="test_flow", data=loaded_json.get("data"), user_id=active_user.id + ) + flow = Flow(**flow_data.dict()) + session.add(flow) + session.commit() + + return flow diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 22b10cddd..f8596f9c9 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,5 +1,6 @@ from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.api_key.api_key import ApiKey +from langflow.services.utils import get_settings_manager import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS @@ -86,27 +87,35 @@ PROMPT_REQUEST = { @pytest.fixture -def created_api_key(active_user): +def created_api_key(session, active_user): hashed = get_password_hash("random_key") - return ApiKey( + api_key = ApiKey( name="test_api_key", user_id=active_user.id, api_key="random_key", hashed_api_key=hashed, ) + session.add(api_key) + session.commit() + session.refresh(api_key) + return api_key -def test_process_flow(client, mocker, created_api_key): + +def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key): # Mock de process_graph_cached - mock_process_graph_cached = mocker.patch( - "langflow.processing.process.process_graph_cached", autospec=True - ) + from langflow.api.v1 import endpoints - # Defina o valor de retorno para o mock - mock_process_graph_cached.return_value = ("result_mock", "session_id_mock") + settings_manager = get_settings_manager() + settings_manager.auth_settings.AUTO_LOGIN = False + + def mock_process_graph_cached(*args, **kwargs): + return {}, "session_id_mock" + + monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) api_key = created_api_key.api_key - headers = {"Authorization": f"Bearer {api_key}"} + headers = {"api-key": api_key} # Dummy POST data post_data = { @@ -117,15 +126,44 @@ def test_process_flow(client, mocker, created_api_key): } # Make the request to the FastAPI TestClient - response = client.post("api/v1/process/flow_test", headers=headers, json=post_data) + + response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) # Check the response - assert response.status_code == 200 - assert response.json()["result"] == "result_mock" + assert response.status_code == 200, response.json() + assert response.json()["result"] == {} assert response.json()["session_id"] == "session_id_mock" - # Ensure mock was called once - mock_process_graph_cached.assert_called_once() + +def test_process_flow_fails_autologin_off(client, flow, monkeypatch): + # Mock de process_graph_cached + from langflow.api.v1 import endpoints + + settings_manager = get_settings_manager() + settings_manager.auth_settings.AUTO_LOGIN = False + + def mock_process_graph_cached(*args, **kwargs): + return {}, "session_id_mock" + + monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) + + headers = {"api-key": "api_key"} + + # Dummy POST data + post_data = { + "inputs": {"key": "value"}, + "tweaks": None, + "clear_cache": False, + "session_id": None, + } + + # Make the request to the FastAPI TestClient + + response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) + + # Check the response + assert response.status_code == 403, response.json() + assert response.json() == {"detail": "Invalid or missing API key"} def test_get_all(client: TestClient, logged_in_headers):