fix: auth current user on code validation (#6911)
* Add user auth for /code endpoint * import * revert * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * ruff * fix tests * [autofix.ci] apply automated fixes * ruff --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Nadir J <31660040+NadirJ@users.noreply.github.com>
This commit is contained in:
parent
e1ee081d32
commit
faac4db133
3 changed files with 36 additions and 12 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.utils import CurrentActiveUser
|
||||
from langflow.api.v1.base import Code, CodeValidationResponse, PromptValidationResponse, ValidatePromptRequest
|
||||
from langflow.base.prompts.api_utils import process_prompt_template
|
||||
from langflow.utils.validate import validate_code
|
||||
|
|
@ -10,7 +11,7 @@ router = APIRouter(prefix="/validate", tags=["Validate"])
|
|||
|
||||
|
||||
@router.post("/code", status_code=200)
|
||||
async def post_validate_code(code: Code) -> CodeValidationResponse:
|
||||
async def post_validate_code(code: Code, _current_user: CurrentActiveUser) -> CodeValidationResponse:
|
||||
try:
|
||||
errors = validate_code(code.code)
|
||||
return CodeValidationResponse(
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
import pytest
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
async def test_post_validate_code(client: AsyncClient):
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_post_validate_code(client: AsyncClient, logged_in_headers):
|
||||
good_code = """
|
||||
from pprint import pprint
|
||||
var = {"a": 1, "b": 2}
|
||||
pprint(var)
|
||||
"""
|
||||
response = await client.post("api/v1/validate/code", json={"code": good_code})
|
||||
response = await client.post("api/v1/validate/code", json={"code": good_code}, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
|
@ -17,7 +19,8 @@ pprint(var)
|
|||
assert "function" in result, "The result must have a 'function' key"
|
||||
|
||||
|
||||
async def test_post_validate_prompt(client: AsyncClient):
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_post_validate_prompt(client: AsyncClient, logged_in_headers):
|
||||
basic_case = {
|
||||
"name": "string",
|
||||
"template": "string",
|
||||
|
|
@ -48,10 +51,29 @@ async def test_post_validate_prompt(client: AsyncClient):
|
|||
"metadata": {},
|
||||
},
|
||||
}
|
||||
response = await client.post("api/v1/validate/prompt", json=basic_case)
|
||||
response = await client.post("api/v1/validate/prompt", json=basic_case, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert isinstance(result, dict), "The result must be a dictionary"
|
||||
assert "frontend_node" in result, "The result must have a 'frontend_node' key"
|
||||
assert "input_variables" in result, "The result must have an 'input_variables' key"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_post_validate_prompt_with_invalid_data(client: AsyncClient, logged_in_headers):
|
||||
invalid_case = {
|
||||
"name": "string",
|
||||
# Missing required fields
|
||||
"frontend_node": {"template": {}, "is_input": True},
|
||||
}
|
||||
response = await client.post("api/v1/validate/prompt", json=invalid_case, headers=logged_in_headers)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
async def test_post_validate_code_with_unauthenticated_user(client: AsyncClient):
|
||||
code = """
|
||||
print("Hello World")
|
||||
"""
|
||||
response = await client.post("api/v1/validate/code", json={"code": code}, headers={"Authorization": "Bearer fake"})
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
|
|
|||
|
|
@ -127,7 +127,8 @@ async def test_get_all(client: AsyncClient, logged_in_headers):
|
|||
assert "ChatOutput" in json_response["outputs"]
|
||||
|
||||
|
||||
async def test_post_validate_code(client: AsyncClient):
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_post_validate_code(client: AsyncClient, logged_in_headers):
|
||||
# Test case with a valid import and function
|
||||
code1 = """
|
||||
import math
|
||||
|
|
@ -135,7 +136,7 @@ import math
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response1 = await client.post("api/v1/validate/code", json={"code": code1})
|
||||
response1 = await client.post("api/v1/validate/code", json={"code": code1}, headers=logged_in_headers)
|
||||
assert response1.status_code == 200
|
||||
assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -146,7 +147,7 @@ import non_existent_module
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response2 = await client.post("api/v1/validate/code", json={"code": code2})
|
||||
response2 = await client.post("api/v1/validate/code", json={"code": code2}, headers=logged_in_headers)
|
||||
assert response2.status_code == 200
|
||||
assert response2.json() == {
|
||||
"imports": {"errors": ["No module named 'non_existent_module'"]},
|
||||
|
|
@ -160,7 +161,7 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response3 = await client.post("api/v1/validate/code", json={"code": code3})
|
||||
response3 = await client.post("api/v1/validate/code", json={"code": code3}, headers=logged_in_headers)
|
||||
assert response3.status_code == 200
|
||||
assert response3.json() == {
|
||||
"imports": {"errors": []},
|
||||
|
|
@ -168,11 +169,11 @@ def square(x)
|
|||
}
|
||||
|
||||
# Test case with invalid JSON payload
|
||||
response4 = await client.post("api/v1/validate/code", json={"invalid_key": code1})
|
||||
response4 = await client.post("api/v1/validate/code", json={"invalid_key": code1}, headers=logged_in_headers)
|
||||
assert response4.status_code == 422
|
||||
|
||||
# Test case with an empty code string
|
||||
response5 = await client.post("api/v1/validate/code", json={"code": ""})
|
||||
response5 = await client.post("api/v1/validate/code", json={"code": ""}, headers=logged_in_headers)
|
||||
assert response5.status_code == 200
|
||||
assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -183,7 +184,7 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response6 = await client.post("api/v1/validate/code", json={"code": code6})
|
||||
response6 = await client.post("api/v1/validate/code", json={"code": code6}, headers=logged_in_headers)
|
||||
assert response6.status_code == 200
|
||||
assert response6.json() == {
|
||||
"imports": {"errors": []},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue