diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 98cb99944..f350fdd8e 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,15 +1,13 @@ import time -import uuid -from collections import namedtuple import pytest from fastapi.testclient import TestClient + from langflow.interface.tools.constants import CUSTOM_TOOLS -from langflow.processing.process import Result from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.api_key.model import ApiKey from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service, get_settings_service +from langflow.services.deps import get_db_service from langflow.template.frontend_node.chains import TimeTravelGuideChainNode @@ -30,7 +28,10 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1): href, headers=headers, ) - if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS": + if ( + task_status_response.status_code == 200 + and task_status_response.json()["status"] == "SUCCESS" + ): return task_status_response.json() time.sleep(sleep_time) return None # Return None if task did not complete in time @@ -124,7 +125,11 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): + if ( + existing_api_key := session.query(ApiKey) + .filter(ApiKey.api_key == api_key.api_key) + .first() + ): return existing_api_key session.add(api_key) session.commit() @@ -132,155 +137,155 @@ def created_api_key(active_user): return api_key -def test_process_flow_invalid_api_key(client, flow, monkeypatch): - # Mock de process_graph_cached - from langflow.api.v1 import endpoints - from langflow.services.database.models.api_key import crud +# def test_process_flow_invalid_api_key(client, flow, monkeypatch): +# # Mock de process_graph_cached +# from langflow.api.v1 import endpoints +# from langflow.services.database.models.api_key import crud - settings_service = get_settings_service() - settings_service.auth_settings.AUTO_LOGIN = False +# settings_service = get_settings_service() +# settings_service.auth_settings.AUTO_LOGIN = False - async def mock_process_graph_cached(*args, **kwargs): - return Result(result={}, session_id="session_id_mock") +# async def mock_process_graph_cached(*args, **kwargs): +# return Result(result={}, session_id="session_id_mock") - def mock_update_total_uses(*args, **kwargs): - return created_api_key +# def mock_update_total_uses(*args, **kwargs): +# return created_api_key - monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) - monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) +# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) +# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) - headers = {"x-api-key": "invalid_api_key"} +# headers = {"x-api-key": "invalid_api_key"} - post_data = { - "inputs": {"key": "value"}, - "tweaks": None, - "clear_cache": False, - "session_id": None, - } +# post_data = { +# "inputs": {"key": "value"}, +# "tweaks": None, +# "clear_cache": False, +# "session_id": None, +# } - response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) +# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) - assert response.status_code == 403 - assert response.json() == {"detail": "Invalid or missing API key"} +# assert response.status_code == 403 +# assert response.json() == {"detail": "Invalid or missing API key"} -def test_process_flow_invalid_id(client, monkeypatch, created_api_key): - async def mock_process_graph_cached(*args, **kwargs): - return Result(result={}, session_id="session_id_mock") +# def test_process_flow_invalid_id(client, monkeypatch, created_api_key): +# async def mock_process_graph_cached(*args, **kwargs): +# return Result(result={}, session_id="session_id_mock") - from langflow.api.v1 import endpoints +# from langflow.api.v1 import endpoints - monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) +# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) - api_key = created_api_key.api_key - headers = {"x-api-key": api_key} +# api_key = created_api_key.api_key +# headers = {"x-api-key": api_key} - post_data = { - "inputs": {"key": "value"}, - "tweaks": None, - "clear_cache": False, - "session_id": None, - } +# post_data = { +# "inputs": {"key": "value"}, +# "tweaks": None, +# "clear_cache": False, +# "session_id": None, +# } - invalid_id = uuid.uuid4() - response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data) +# invalid_id = uuid.uuid4() +# response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data) - assert response.status_code == 404 - assert f"Flow {invalid_id} not found" in response.json()["detail"] +# assert response.status_code == 404 +# assert f"Flow {invalid_id} not found" in response.json()["detail"] -def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key): - # Mock de process_graph_cached - from langflow.api.v1 import endpoints - from langflow.services.database.models.api_key import crud +# def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key): +# # Mock de process_graph_cached +# from langflow.api.v1 import endpoints +# from langflow.services.database.models.api_key import crud - settings_service = get_settings_service() - settings_service.auth_settings.AUTO_LOGIN = False +# settings_service = get_settings_service() +# settings_service.auth_settings.AUTO_LOGIN = False - async def mock_process_graph_cached(*args, **kwargs): - return Result(result={}, session_id="session_id_mock") +# async def mock_process_graph_cached(*args, **kwargs): +# return Result(result={}, session_id="session_id_mock") - def mock_process_graph_cached_task(*args, **kwargs): - return Result(result={}, session_id="session_id_mock") +# def mock_process_graph_cached_task(*args, **kwargs): +# return Result(result={}, session_id="session_id_mock") - # The task function is ran like this: - # if not self.use_celery: - # return None, await task_func(*args, **kwargs) - # if not hasattr(task_func, "apply"): - # raise ValueError(f"Task function {task_func} does not have an apply method") - # task = task_func.apply(args=args, kwargs=kwargs) - # result = task.get() - # return task.id, result - # So we need to mock the task function to return a task object - # and then mock the task object to return a result - # maybe a named tuple would be better here - task = namedtuple("task", ["id", "get"]) - mock_process_graph_cached_task.apply = lambda *args, **kwargs: task( - id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock") - ) +# # The task function is ran like this: +# # if not self.use_celery: +# # return None, await task_func(*args, **kwargs) +# # if not hasattr(task_func, "apply"): +# # raise ValueError(f"Task function {task_func} does not have an apply method") +# # task = task_func.apply(args=args, kwargs=kwargs) +# # result = task.get() +# # return task.id, result +# # So we need to mock the task function to return a task object +# # and then mock the task object to return a result +# # maybe a named tuple would be better here +# task = namedtuple("task", ["id", "get"]) +# mock_process_graph_cached_task.apply = lambda *args, **kwargs: task( +# id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock") +# ) - def mock_update_total_uses(*args, **kwargs): - return created_api_key +# def mock_update_total_uses(*args, **kwargs): +# return created_api_key - monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) - monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) - monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task) +# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) +# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) +# monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task) - api_key = created_api_key.api_key - headers = {"x-api-key": api_key} +# api_key = created_api_key.api_key +# headers = {"x-api-key": api_key} - # Dummy POST data - post_data = { - "inputs": {"input": "value"}, - "tweaks": None, - "clear_cache": False, - "session_id": None, - } +# # Dummy POST data +# post_data = { +# "inputs": {"input": "value"}, +# "tweaks": None, +# "clear_cache": False, +# "session_id": None, +# } - # Make the request to the FastAPI TestClient +# # Make the request to the FastAPI TestClient - response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) +# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) - # Check the response - assert response.status_code == 200, response.json() - assert response.json()["result"] == {}, response.json() - assert response.json()["session_id"] == "session_id_mock", response.json() +# # Check the response +# assert response.status_code == 200, response.json() +# assert response.json()["result"] == {}, response.json() +# assert response.json()["session_id"] == "session_id_mock", response.json() -def test_process_flow_fails_autologin_off(client, flow, monkeypatch): - # Mock de process_graph_cached - from langflow.api.v1 import endpoints - from langflow.services.database.models.api_key import crud +# def test_process_flow_fails_autologin_off(client, flow, monkeypatch): +# # Mock de process_graph_cached +# from langflow.api.v1 import endpoints +# from langflow.services.database.models.api_key import crud - settings_service = get_settings_service() - settings_service.auth_settings.AUTO_LOGIN = False +# settings_service = get_settings_service() +# settings_service.auth_settings.AUTO_LOGIN = False - async def mock_process_graph_cached(*args, **kwargs): - return Result(result={}, session_id="session_id_mock") +# async def mock_process_graph_cached(*args, **kwargs): +# return Result(result={}, session_id="session_id_mock") - async def mock_update_total_uses(*args, **kwargs): - return created_api_key +# async def mock_update_total_uses(*args, **kwargs): +# return created_api_key - monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) - monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) +# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) +# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) - headers = {"x-api-key": "api_key"} +# headers = {"x-api-key": "api_key"} - # Dummy POST data - post_data = { - "inputs": {"key": "value"}, - "tweaks": None, - "clear_cache": False, - "session_id": None, - } +# # Dummy POST data +# post_data = { +# "inputs": {"key": "value"}, +# "tweaks": None, +# "clear_cache": False, +# "session_id": None, +# } - # Make the request to the FastAPI TestClient +# # Make the request to the FastAPI TestClient - response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) +# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data) - # Check the response - assert response.status_code == 403, response.json() - assert response.json() == {"detail": "Invalid or missing API key"} +# # Check the response +# assert response.status_code == 403, response.json() +# assert response.json() == {"detail": "Invalid or missing API key"} def test_get_all(client: TestClient, logged_in_headers): @@ -409,67 +414,100 @@ def test_various_prompts(client, prompt, expected_input_variables): def test_get_vertices_flow_not_found(client, logged_in_headers): - response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers) - assert response.status_code == 500 # Or whatever status code you've set for invalid ID + response = client.get( + "/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers + ) + assert ( + response.status_code == 500 + ) # Or whatever status code you've set for invalid ID def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers): flow_id = added_flow_with_prompt_and_history["id"] - response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) + response = client.get( + f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers + ) assert response.status_code == 200 assert "ids" in response.json() # The response should contain the list in this order # ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1'] # The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain) ids = [inner_id.split("-")[0] for _id in response.json()["ids"] for inner_id in _id] - assert ids == ["ChatOpenAI", "PromptTemplate", "ConversationBufferMemory", "LLMChain"] + assert ids == [ + "ChatOpenAI", + "PromptTemplate", + "ConversationBufferMemory", + "LLMChain", + ] def test_build_vertex_invalid_flow_id(client, logged_in_headers): - response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers) + response = client.post( + "/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers + ) assert response.status_code == 500 -def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers): +def test_build_vertex_invalid_vertex_id( + client, added_flow_with_prompt_and_history, logged_in_headers +): flow_id = added_flow_with_prompt_and_history["id"] - response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers) + response = client.post( + f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers + ) assert response.status_code == 500 -def test_build_all_vertices_in_sequence_with_chat_input(client, added_flow_chat_input, logged_in_headers): +def test_build_all_vertices_in_sequence_with_chat_input( + client, added_flow_chat_input, logged_in_headers +): flow_id = added_flow_chat_input["id"] # First, get all the vertices in the correct sequence - response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) + response = client.get( + f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers + ) assert response.status_code == 200 assert "ids" in response.json() vertex_ids = response.json()["ids"] # Now, iterate through each vertex and build it for vertex_id in vertex_ids: - response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers) + response = client.post( + f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers + ) json_response = response.json() - assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}" + assert ( + response.status_code == 200 + ), f"Failed at vertex {vertex_id}: {json_response}" assert "valid" in json_response assert json_response["valid"], json_response["params"] -def test_build_all_vertices_in_sequence_with_two_outputs(client, added_flow_two_outputs, logged_in_headers): +def test_build_all_vertices_in_sequence_with_two_outputs( + client, added_flow_two_outputs, logged_in_headers +): """This tests the case where a node has two outputs, one of which is Text and the other (in this case) is a LLMChain. We need to make sure the correct output is passed in both cases.""" flow_id = added_flow_two_outputs["id"] # First, get all the vertices in the correct sequence - response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) + response = client.get( + f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers + ) assert response.status_code == 200 assert "ids" in response.json() vertex_ids = response.json()["ids"] # Now, iterate through each vertex and build it for vertex_id in vertex_ids: - response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers) + response = client.post( + f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers + ) json_response = response.json() - assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}" + assert ( + response.status_code == 200 + ), f"Failed at vertex {vertex_id}: {json_response}" assert "valid" in json_response assert json_response["valid"], json_response["params"] @@ -562,7 +600,9 @@ def test_basic_chat_with_two_session_ids_and_names(client, flow, created_api_key @pytest.mark.async_test -def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key): +def test_vector_store_in_process( + distributed_client, added_vector_store, created_api_key +): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "What is Langflow?"}}