🐛 fix(endpoints.py): import validate_api_key function from auth.utils to fix missing dependency
✨ feat(endpoints.py): add validation of API key in process_flow endpoint to ensure only valid requests are processed 🔧 chore(utils.py): add validate_api_key function to validate API key against database 🔧 chore(test_endpoints.py): add test case for process_flow endpoint to test API key validation and processing
This commit is contained in:
parent
dc20de76b5
commit
282e6b0c18
3 changed files with 59 additions and 2 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, Optional, Union
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.auth.utils import get_current_active_user, validate_api_key
|
||||
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
|
|
@ -86,10 +86,13 @@ async def process_flow(
|
|||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
session: Session = Depends(get_session),
|
||||
valid: bool = Depends(validate_api_key),
|
||||
):
|
||||
"""
|
||||
Endpoint to process an input with a given flow_id.
|
||||
"""
|
||||
if not valid:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
try:
|
||||
flow = session.get(Flow, flow_id)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from jose import JWTError, jwt
|
|||
from typing import Annotated, Coroutine
|
||||
from uuid import UUID
|
||||
from langflow.services.auth.service import AuthManager
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.models.user.crud import (
|
||||
get_user_by_id,
|
||||
|
|
@ -11,7 +12,7 @@ from langflow.services.database.models.user.crud import (
|
|||
update_user_last_login_at,
|
||||
)
|
||||
from langflow.services.utils import get_session, get_settings_manager
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, select
|
||||
|
||||
|
||||
async def auth_scheme_dependency(request: Request):
|
||||
|
|
@ -57,6 +58,14 @@ async def get_current_user(
|
|||
return user
|
||||
|
||||
|
||||
async def validate_api_key(
|
||||
token: Annotated[str, Depends(auth_scheme_dependency)],
|
||||
db: Session = Depends(get_session),
|
||||
) -> bool:
|
||||
hashed_api_key = get_password_hash(token)
|
||||
return db.exec(select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)).one_or_none() # type: ignore
|
||||
|
||||
|
||||
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
|
|
@ -83,6 +85,49 @@ PROMPT_REQUEST = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_api_key(active_user):
|
||||
hashed = get_password_hash("random_key")
|
||||
return ApiKey(
|
||||
name="test_api_key",
|
||||
user_id=active_user.id,
|
||||
api_key="random_key",
|
||||
hashed_api_key=hashed,
|
||||
)
|
||||
|
||||
|
||||
def test_process_flow(client, mocker, created_api_key):
|
||||
# Mock de process_graph_cached
|
||||
mock_process_graph_cached = mocker.patch(
|
||||
"langflow.processing.process.process_graph_cached", autospec=True
|
||||
)
|
||||
|
||||
# Defina o valor de retorno para o mock
|
||||
mock_process_graph_cached.return_value = ("result_mock", "session_id_mock")
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
# Dummy POST data
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
|
||||
# Make the request to the FastAPI TestClient
|
||||
response = client.post("api/v1/process/flow_test", headers=headers, json=post_data)
|
||||
|
||||
# Check the response
|
||||
assert response.status_code == 200
|
||||
assert response.json()["result"] == "result_mock"
|
||||
assert response.json()["session_id"] == "session_id_mock"
|
||||
|
||||
# Ensure mock was called once
|
||||
mock_process_graph_cached.assert_called_once()
|
||||
|
||||
|
||||
def test_get_all(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue