🐛 fix(endpoints.py): import validate_api_key function from auth.utils to fix missing dependency

 feat(endpoints.py): add validation of API key in process_flow endpoint to ensure only valid requests are processed
🔧 chore(utils.py): add validate_api_key function to validate API key against database
🔧 chore(test_endpoints.py): add test case for process_flow endpoint to test API key validation and processing
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-25 20:02:53 -03:00
commit 282e6b0c18
3 changed files with 59 additions and 2 deletions

View file

@ -1,6 +1,6 @@
from http import HTTPStatus
from typing import Annotated, Optional, Union
from langflow.services.auth.utils import get_current_active_user
from langflow.services.auth.utils import get_current_active_user, validate_api_key
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
@ -86,10 +86,13 @@ async def process_flow(
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
session: Session = Depends(get_session),
valid: bool = Depends(validate_api_key),
):
"""
Endpoint to process an input with a given flow_id.
"""
if not valid:
raise HTTPException(status_code=401, detail="Invalid API key")
try:
flow = session.get(Flow, flow_id)

View file

@ -4,6 +4,7 @@ from jose import JWTError, jwt
from typing import Annotated, Coroutine
from uuid import UUID
from langflow.services.auth.service import AuthManager
from langflow.services.database.models.api_key.api_key import ApiKey
from langflow.services.database.models.user.user import User
from langflow.services.database.models.user.crud import (
get_user_by_id,
@ -11,7 +12,7 @@ from langflow.services.database.models.user.crud import (
update_user_last_login_at,
)
from langflow.services.utils import get_session, get_settings_manager
from sqlmodel import Session
from sqlmodel import Session, select
async def auth_scheme_dependency(request: Request):
@ -57,6 +58,14 @@ async def get_current_user(
return user
async def validate_api_key(
token: Annotated[str, Depends(auth_scheme_dependency)],
db: Session = Depends(get_session),
) -> bool:
hashed_api_key = get_password_hash(token)
return db.exec(select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)).one_or_none() # type: ignore
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")

View file

@ -1,3 +1,5 @@
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.api_key import ApiKey
import pytest
from fastapi.testclient import TestClient
from langflow.interface.tools.constants import CUSTOM_TOOLS
@ -83,6 +85,49 @@ PROMPT_REQUEST = {
}
@pytest.fixture
def created_api_key(active_user):
hashed = get_password_hash("random_key")
return ApiKey(
name="test_api_key",
user_id=active_user.id,
api_key="random_key",
hashed_api_key=hashed,
)
def test_process_flow(client, mocker, created_api_key):
# Mock de process_graph_cached
mock_process_graph_cached = mocker.patch(
"langflow.processing.process.process_graph_cached", autospec=True
)
# Defina o valor de retorno para o mock
mock_process_graph_cached.return_value = ("result_mock", "session_id_mock")
api_key = created_api_key.api_key
headers = {"Authorization": f"Bearer {api_key}"}
# Dummy POST data
post_data = {
"inputs": {"key": "value"},
"tweaks": None,
"clear_cache": False,
"session_id": None,
}
# Make the request to the FastAPI TestClient
response = client.post("api/v1/process/flow_test", headers=headers, json=post_data)
# Check the response
assert response.status_code == 200
assert response.json()["result"] == "result_mock"
assert response.json()["session_id"] == "session_id_mock"
# Ensure mock was called once
mock_process_graph_cached.assert_called_once()
def test_get_all(client: TestClient, logged_in_headers):
response = client.get("api/v1/all", headers=logged_in_headers)
assert response.status_code == 200