fix: updates file size limit to use middleware and add tests for uploads (#4883)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-11-28 09:25:26 -03:00 committed by GitHub
commit 712a43958c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 133 additions and 13 deletions

View file

@ -176,6 +176,7 @@ dev-dependencies = [
"pytest-github-actions-annotate-failures>=0.2.0",
"pytest-codspeed>=3.0.0",
"blockbuster>=1.1.1,<1.2",
"types-aiofiles>=24.1.0.20240626",
]

View file

@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.v1.schemas import UploadFileResponse
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_settings_service, get_storage_service
from langflow.services.deps import get_storage_service
from langflow.services.storage.service import StorageService
from langflow.services.storage.utils import build_content_type_from_extension
@ -46,16 +46,6 @@ async def upload_file(
session: AsyncDbSession,
storage_service: Annotated[StorageService, Depends(get_storage_service)],
) -> UploadFileResponse:
try:
max_file_size_upload = get_settings_service().settings.max_file_size_upload
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
if file.size > max_file_size_upload * 1024 * 1024:
raise HTTPException(
status_code=413, detail=f"File size is larger than the maximum file size {max_file_size_upload}MB."
)
try:
flow_id_str = str(flow_id)
flow = await session.get(Flow, flow_id_str)

View file

@ -30,7 +30,7 @@ class CloudflareWorkersAIEmbeddingsComponent(LCModelComponent):
display_name="Model Name",
info="List of supported models https://developers.cloudflare.com/workers-ai/models/#text-embeddings",
required=True,
value="@cf/baai/bge-base-en-v1.5"
value="@cf/baai/bge-base-en-v1.5",
),
BoolInput(
name="strip_new_lines",
@ -75,6 +75,6 @@ class CloudflareWorkersAIEmbeddingsComponent(LCModelComponent):
strip_new_lines=self.strip_new_lines,
)
except Exception as e:
raise ValueError(f"Could not connect to CloudflareWorkersAIEmbeddings API: {str(e)}") from e
raise ValueError(f"Could not connect to CloudflareWorkersAIEmbeddings API: {e!s}") from e
return embeddings

View file

@ -28,6 +28,7 @@ from langflow.initial_setup.setup import (
from langflow.interface.types import get_and_cache_all_types_dict
from langflow.interface.utils import setup_llm_caching
from langflow.logging.logger import configure
from langflow.middleware import ContentSizeLimitMiddleware
from langflow.services.deps import get_settings_service, get_telemetry_service
from langflow.services.utils import initialize_services, teardown_services
@ -132,6 +133,10 @@ def create_app():
configure()
lifespan = get_lifespan(version=__version__)
app = FastAPI(lifespan=lifespan, title="Langflow", version=__version__)
app.add_middleware(
ContentSizeLimitMiddleware,
)
setup_sentry(app)
origins = ["*"]

View file

@ -0,0 +1,58 @@
from fastapi import HTTPException
from loguru import logger
from langflow.services.deps import get_settings_service
class MaxFileSizeException(HTTPException):
def __init__(self, detail: str = "File size is larger than the maximum file size {}MB"):
super().__init__(status_code=413, detail=detail)
# Adapted from https://github.com/steinnes/content-size-limit-asgi/blob/master/content_size_limit_asgi/middleware.py#L26
class ContentSizeLimitMiddleware:
"""Content size limiting middleware for ASGI applications.
Args:
app (ASGI application): ASGI application
max_content_size (optional): the maximum content size allowed in bytes, None for no limit
exception_cls (optional): the class of exception to raise (ContentSizeExceeded is the default)
"""
def __init__(
self,
app,
):
self.app = app
self.logger = logger
def receive_wrapper(self, receive):
received = 0
async def inner():
max_file_size_upload = get_settings_service().settings.max_file_size_upload
nonlocal received
message = await receive()
if message["type"] != "http.request" or max_file_size_upload is None:
return message
body_len = len(message.get("body", b""))
received += body_len
if received > max_file_size_upload * 1024 * 1024:
# max_content_size is in bytes, convert to MB
received_in_mb = round(received / (1024 * 1024), 3)
msg = (
f"Content size limit exceeded. Maximum allowed is {max_file_size_upload}MB"
f" and got {received_in_mb}MB."
)
raise MaxFileSizeException(msg)
return message
return inner
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
wrapper = self.receive_wrapper(receive)
await self.app(scope, wrapper, send)

View file

@ -78,6 +78,7 @@ dev-dependencies = [
"asgi-lifespan>=2.1.0",
"pytest-codspeed>=3.0.0",
"pytest-github-actions-annotate-failures>=0.2.0",
"types-aiofiles>=24.1.0.20240626",
]
[build-system]

View file

@ -3,6 +3,7 @@ import re
import shutil
import tempfile
from contextlib import suppress
from io import BytesIO
from pathlib import Path
from unittest.mock import MagicMock
@ -23,6 +24,13 @@ def mock_storage_service():
service.get_file.return_value = b"file content" # Binary content for files
service.list_files.return_value = ["file1.txt", "file2.jpg"]
service.delete_file.return_value = None
# Mock the settings service with proper max_file_size_upload attribute
settings_mock = MagicMock()
settings_mock.settings = MagicMock()
settings_mock.settings.max_file_size_upload = 1 # Default 1MB limit
service.settings_service = settings_mock
return service
@ -71,6 +79,20 @@ async def files_client_fixture(
db_path.unlink()
@pytest.fixture
async def max_file_size_upload_fixture(monkeypatch):
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "1")
yield
monkeypatch.undo()
@pytest.fixture
async def max_file_size_upload_10mb_fixture(monkeypatch):
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "10")
yield
monkeypatch.undo()
async def test_upload_file(files_client, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
@ -154,3 +176,33 @@ async def test_file_operations(client, created_api_key, flow):
# Verify that the file is indeed deleted
response = await client.get(f"api/v1/files/list/{flow_id}", headers=headers)
assert full_file_name not in response.json()["files"]
@pytest.mark.usefixtures("max_file_size_upload_fixture")
async def test_upload_file_size_limit(files_client, created_api_key, flow):
headers = {"x-api-key": created_api_key.api_key}
# Test file under the limit (500KB)
small_content = b"x" * (500 * 1024)
small_file = ("small_file.txt", small_content, "application/octet-stream")
headers["Content-Length"] = str(len(small_content))
response = await files_client.post(
f"api/v1/files/upload/{flow.id}",
files={"file": small_file},
headers=headers,
)
assert response.status_code == 201, f"Expected 201, got {response.status_code}: {response.json()}"
# Test file over the limit (1MB + 1KB)
large_content = b"x" * (1024 * 1024 + 1024)
bio = BytesIO(large_content)
headers["Content-Length"] = str(len(large_content))
response = await files_client.post(
f"api/v1/files/upload/{flow.id}",
files={"file": ("large_file.txt", bio, "application/octet-stream")},
headers=headers,
)
assert response.status_code == 413, f"Expected 413, got {response.status_code}: {response.json()}"
assert "Content size limit exceeded. Maximum allowed is 1MB and got 1.001MB." in response.json()["detail"]

13
uv.lock generated
View file

@ -3649,6 +3649,7 @@ dev = [
{ name = "requests" },
{ name = "respx" },
{ name = "ruff" },
{ name = "types-aiofiles" },
{ name = "types-google-cloud-ndb" },
{ name = "types-markdown" },
{ name = "types-passlib" },
@ -3780,6 +3781,7 @@ dev = [
{ name = "requests", specifier = ">=2.32.0" },
{ name = "respx", specifier = ">=0.21.1" },
{ name = "ruff", specifier = ">=0.6.2,<0.7.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20240626" },
{ name = "types-google-cloud-ndb", specifier = ">=2.2.0.0" },
{ name = "types-markdown", specifier = ">=3.7.0.20240822" },
{ name = "types-passlib", specifier = ">=1.7.7.13" },
@ -3913,6 +3915,7 @@ dev = [
{ name = "asgi-lifespan" },
{ name = "pytest-codspeed" },
{ name = "pytest-github-actions-annotate-failures" },
{ name = "types-aiofiles" },
]
[package.metadata]
@ -4023,6 +4026,7 @@ dev = [
{ name = "asgi-lifespan", specifier = ">=2.1.0" },
{ name = "pytest-codspeed", specifier = ">=3.0.0" },
{ name = "pytest-github-actions-annotate-failures", specifier = ">=0.2.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20240626" },
]
[[package]]
@ -7761,6 +7765,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/18/7e/c8bfa8cbcd3ea1d25d2beb359b5c5a3f4339a7e2e5d9e3ef3e29ba3ab3b9/typer-0.13.0-py3-none-any.whl", hash = "sha256:d85fe0b777b2517cc99c8055ed735452f2659cd45e451507c76f48ce5c1d00e2", size = 44194 },
]
[[package]]
name = "types-aiofiles"
version = "24.1.0.20240626"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/13/e9/013940b017c313c2e15c64017268fdb0c25e0638621fb8a5d9ebe00fb0f4/types-aiofiles-24.1.0.20240626.tar.gz", hash = "sha256:48604663e24bc2d5038eac05ccc33e75799b0779e93e13d6a8f711ddc306ac08", size = 9357 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/ad/c4b3275d21c5be79487c4f6ed7cd13336997746fe099236cb29256a44a90/types_aiofiles-24.1.0.20240626-py3-none-any.whl", hash = "sha256:7939eca4a8b4f9c6491b6e8ef160caee9a21d32e18534a57d5ed90aee47c66b4", size = 9389 },
]
[[package]]
name = "types-cachetools"
version = "5.5.0.20240820"