diff --git a/src/backend/base/langflow/api/v1/validate.py b/src/backend/base/langflow/api/v1/validate.py index 06617152d..1bc8219ab 100644 --- a/src/backend/base/langflow/api/v1/validate.py +++ b/src/backend/base/langflow/api/v1/validate.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, HTTPException from loguru import logger +from langflow.api.utils import CurrentActiveUser from langflow.api.v1.base import Code, CodeValidationResponse, PromptValidationResponse, ValidatePromptRequest from langflow.base.prompts.api_utils import process_prompt_template from langflow.utils.validate import validate_code @@ -10,7 +11,7 @@ router = APIRouter(prefix="/validate", tags=["Validate"]) @router.post("/code", status_code=200) -async def post_validate_code(code: Code) -> CodeValidationResponse: +async def post_validate_code(code: Code, _current_user: CurrentActiveUser) -> CodeValidationResponse: try: errors = validate_code(code.code) return CodeValidationResponse( diff --git a/src/backend/tests/unit/api/v1/test_validate.py b/src/backend/tests/unit/api/v1/test_validate.py index 752d381b7..3957f65d1 100644 --- a/src/backend/tests/unit/api/v1/test_validate.py +++ b/src/backend/tests/unit/api/v1/test_validate.py @@ -1,14 +1,16 @@ +import pytest from fastapi import status from httpx import AsyncClient -async def test_post_validate_code(client: AsyncClient): +@pytest.mark.usefixtures("active_user") +async def test_post_validate_code(client: AsyncClient, logged_in_headers): good_code = """ from pprint import pprint var = {"a": 1, "b": 2} pprint(var) """ - response = await client.post("api/v1/validate/code", json={"code": good_code}) + response = await client.post("api/v1/validate/code", json={"code": good_code}, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_200_OK @@ -17,7 +19,8 @@ pprint(var) assert "function" in result, "The result must have a 'function' key" -async def test_post_validate_prompt(client: AsyncClient): +@pytest.mark.usefixtures("active_user") +async def test_post_validate_prompt(client: AsyncClient, logged_in_headers): basic_case = { "name": "string", "template": "string", @@ -48,10 +51,29 @@ async def test_post_validate_prompt(client: AsyncClient): "metadata": {}, }, } - response = await client.post("api/v1/validate/prompt", json=basic_case) + response = await client.post("api/v1/validate/prompt", json=basic_case, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_200_OK assert isinstance(result, dict), "The result must be a dictionary" assert "frontend_node" in result, "The result must have a 'frontend_node' key" assert "input_variables" in result, "The result must have an 'input_variables' key" + + +@pytest.mark.usefixtures("active_user") +async def test_post_validate_prompt_with_invalid_data(client: AsyncClient, logged_in_headers): + invalid_case = { + "name": "string", + # Missing required fields + "frontend_node": {"template": {}, "is_input": True}, + } + response = await client.post("api/v1/validate/prompt", json=invalid_case, headers=logged_in_headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +async def test_post_validate_code_with_unauthenticated_user(client: AsyncClient): + code = """ + print("Hello World") + """ + response = await client.post("api/v1/validate/code", json={"code": code}, headers={"Authorization": "Bearer fake"}) + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/src/backend/tests/unit/test_endpoints.py b/src/backend/tests/unit/test_endpoints.py index 99073f60b..81a4b13c8 100644 --- a/src/backend/tests/unit/test_endpoints.py +++ b/src/backend/tests/unit/test_endpoints.py @@ -127,7 +127,8 @@ async def test_get_all(client: AsyncClient, logged_in_headers): assert "ChatOutput" in json_response["outputs"] -async def test_post_validate_code(client: AsyncClient): +@pytest.mark.usefixtures("active_user") +async def test_post_validate_code(client: AsyncClient, logged_in_headers): # Test case with a valid import and function code1 = """ import math @@ -135,7 +136,7 @@ import math def square(x): return x ** 2 """ - response1 = await client.post("api/v1/validate/code", json={"code": code1}) + response1 = await client.post("api/v1/validate/code", json={"code": code1}, headers=logged_in_headers) assert response1.status_code == 200 assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -146,7 +147,7 @@ import non_existent_module def square(x): return x ** 2 """ - response2 = await client.post("api/v1/validate/code", json={"code": code2}) + response2 = await client.post("api/v1/validate/code", json={"code": code2}, headers=logged_in_headers) assert response2.status_code == 200 assert response2.json() == { "imports": {"errors": ["No module named 'non_existent_module'"]}, @@ -160,7 +161,7 @@ import math def square(x) return x ** 2 """ - response3 = await client.post("api/v1/validate/code", json={"code": code3}) + response3 = await client.post("api/v1/validate/code", json={"code": code3}, headers=logged_in_headers) assert response3.status_code == 200 assert response3.json() == { "imports": {"errors": []}, @@ -168,11 +169,11 @@ def square(x) } # Test case with invalid JSON payload - response4 = await client.post("api/v1/validate/code", json={"invalid_key": code1}) + response4 = await client.post("api/v1/validate/code", json={"invalid_key": code1}, headers=logged_in_headers) assert response4.status_code == 422 # Test case with an empty code string - response5 = await client.post("api/v1/validate/code", json={"code": ""}) + response5 = await client.post("api/v1/validate/code", json={"code": ""}, headers=logged_in_headers) assert response5.status_code == 200 assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -183,7 +184,7 @@ import math def square(x) return x ** 2 """ - response6 = await client.post("api/v1/validate/code", json={"code": code6}) + response6 = await client.post("api/v1/validate/code", json={"code": code6}, headers=logged_in_headers) assert response6.status_code == 200 assert response6.json() == { "imports": {"errors": []},