From 282e6b0c186bbeebe8500195c249d49c9c9ece22 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 25 Aug 2023 20:02:53 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(endpoints.py):=20import=20va?= =?UTF-8?q?lidate=5Fapi=5Fkey=20function=20from=20auth.utils=20to=20fix=20?= =?UTF-8?q?missing=20dependency=20=E2=9C=A8=20feat(endpoints.py):=20add=20?= =?UTF-8?q?validation=20of=20API=20key=20in=20process=5Fflow=20endpoint=20?= =?UTF-8?q?to=20ensure=20only=20valid=20requests=20are=20processed=20?= =?UTF-8?q?=F0=9F=94=A7=20chore(utils.py):=20add=20validate=5Fapi=5Fkey=20?= =?UTF-8?q?function=20to=20validate=20API=20key=20against=20database=20?= =?UTF-8?q?=F0=9F=94=A7=20chore(test=5Fendpoints.py):=20add=20test=20case?= =?UTF-8?q?=20for=20process=5Fflow=20endpoint=20to=20test=20API=20key=20va?= =?UTF-8?q?lidation=20and=20processing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/endpoints.py | 5 ++- src/backend/langflow/services/auth/utils.py | 11 ++++- tests/test_endpoints.py | 45 +++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 088c00b13..dd98f77d9 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,6 +1,6 @@ from http import HTTPStatus from typing import Annotated, Optional, Union -from langflow.services.auth.utils import get_current_active_user +from langflow.services.auth.utils import get_current_active_user, validate_api_key from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow @@ -86,10 +86,13 @@ async def process_flow( clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 session: Session = Depends(get_session), + valid: bool = Depends(validate_api_key), ): """ Endpoint to process an input with a given flow_id. """ + if not valid: + raise HTTPException(status_code=401, detail="Invalid API key") try: flow = session.get(Flow, flow_id) diff --git a/src/backend/langflow/services/auth/utils.py b/src/backend/langflow/services/auth/utils.py index 3e9f0f582..4904a179b 100644 --- a/src/backend/langflow/services/auth/utils.py +++ b/src/backend/langflow/services/auth/utils.py @@ -4,6 +4,7 @@ from jose import JWTError, jwt from typing import Annotated, Coroutine from uuid import UUID from langflow.services.auth.service import AuthManager +from langflow.services.database.models.api_key.api_key import ApiKey from langflow.services.database.models.user.user import User from langflow.services.database.models.user.crud import ( get_user_by_id, @@ -11,7 +12,7 @@ from langflow.services.database.models.user.crud import ( update_user_last_login_at, ) from langflow.services.utils import get_session, get_settings_manager -from sqlmodel import Session +from sqlmodel import Session, select async def auth_scheme_dependency(request: Request): @@ -57,6 +58,14 @@ async def get_current_user( return user +async def validate_api_key( + token: Annotated[str, Depends(auth_scheme_dependency)], + db: Session = Depends(get_session), +) -> bool: + hashed_api_key = get_password_hash(token) + return db.exec(select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)).one_or_none() # type: ignore + + def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user") diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 431999a0c..22b10cddd 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,3 +1,5 @@ +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.api_key.api_key import ApiKey import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS @@ -83,6 +85,49 @@ PROMPT_REQUEST = { } +@pytest.fixture +def created_api_key(active_user): + hashed = get_password_hash("random_key") + return ApiKey( + name="test_api_key", + user_id=active_user.id, + api_key="random_key", + hashed_api_key=hashed, + ) + + +def test_process_flow(client, mocker, created_api_key): + # Mock de process_graph_cached + mock_process_graph_cached = mocker.patch( + "langflow.processing.process.process_graph_cached", autospec=True + ) + + # Defina o valor de retorno para o mock + mock_process_graph_cached.return_value = ("result_mock", "session_id_mock") + + api_key = created_api_key.api_key + headers = {"Authorization": f"Bearer {api_key}"} + + # Dummy POST data + post_data = { + "inputs": {"key": "value"}, + "tweaks": None, + "clear_cache": False, + "session_id": None, + } + + # Make the request to the FastAPI TestClient + response = client.post("api/v1/process/flow_test", headers=headers, json=post_data) + + # Check the response + assert response.status_code == 200 + assert response.json()["result"] == "result_mock" + assert response.json()["session_id"] == "session_id_mock" + + # Ensure mock was called once + mock_process_graph_cached.assert_called_once() + + def test_get_all(client: TestClient, logged_in_headers): response = client.get("api/v1/all", headers=logged_in_headers) assert response.status_code == 200