From fffc7fb73ffaba7bef257a71f7533d84a04535b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dtalo=20Johnny?= Date: Wed, 18 Dec 2024 15:51:48 -0300 Subject: [PATCH] fix: file path handling for cross-os compatibility (#5342) * test: add more unit tests * fix: correct file path splitting to handle OS differences * [autofix.ci] apply automated fixes * fix: ruff error --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/backend/base/langflow/schema/image.py | 9 ++- src/backend/tests/unit/schema/test_image.py | 84 +++++++++++++++++++++ 2 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 src/backend/tests/unit/schema/test_image.py diff --git a/src/backend/base/langflow/schema/image.py b/src/backend/base/langflow/schema/image.py index c2b07225e..8a269a06a 100644 --- a/src/backend/base/langflow/schema/image.py +++ b/src/backend/base/langflow/schema/image.py @@ -1,4 +1,5 @@ import base64 +from pathlib import Path from PIL import Image as PILImage from pydantic import BaseModel @@ -21,7 +22,8 @@ def get_file_paths(files: list[str]): storage_service = get_storage_service() file_paths = [] for file in files: - flow_id, file_name = file.split("/") + file_path = Path(file) + flow_id, file_name = str(file_path.parent), file_path.name file_paths.append(storage_service.build_full_path(flow_id=flow_id, file_name=file_name)) return file_paths @@ -33,8 +35,9 @@ async def get_files( ): storage_service = get_storage_service() file_objects: list[str | bytes] = [] - for file_path in file_paths: - flow_id, file_name = file_path.split("/") + for file in file_paths: + file_path = Path(file) + flow_id, file_name = str(file_path.parent), file_path.name file_object = await storage_service.get_file(flow_id=flow_id, file_name=file_name) if convert_to_base64: file_base64 = base64.b64encode(file_object).decode("utf-8") diff --git a/src/backend/tests/unit/schema/test_image.py b/src/backend/tests/unit/schema/test_image.py new file mode 100644 index 000000000..ce7d840da --- /dev/null +++ b/src/backend/tests/unit/schema/test_image.py @@ -0,0 +1,84 @@ +import tempfile + +import aiofiles +import pytest +from langflow.schema.image import ( + get_file_paths, + get_files, + is_image_file, +) +from PIL import Image as PILImage + + +@pytest.fixture +def file_image(): + image = PILImage.new("RGB", (100, 100), (255, 0, 0)) + with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp_file: + image.save(temp_file.name) + yield temp_file.name + + +@pytest.fixture +def file_txt(): + content = """\ +line1: This is an example text file. +line2: It can be used for testing. +line3: End of file. +""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file: + temp_file.write(content.encode()) + temp_file.flush() + yield temp_file.name + + +def test_is_image_file(file_image): + assert is_image_file(file_image) is True + + +def test_is_image_file__not_image(file_txt): + assert is_image_file(file_txt) is False + + +def test_get_file_paths(file_image, file_txt): + files = [file_image, file_txt] + + result = get_file_paths(files) + + assert len(result) == 2 + assert result[0].endswith(".png") + assert result[1].endswith(".txt") + + +def test_get_file_paths__empty(): + result = get_file_paths([]) + + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_get_files(file_image, file_txt, caplog): # noqa: ARG001 + file_paths = [file_image, file_txt] + + result = await get_files(file_paths) + + for index, file in enumerate(result): + async with aiofiles.open(file_paths[index], "rb") as f: + assert file == await f.read() + + +@pytest.mark.asyncio +async def test_get_files__convert_to_base64(file_image, file_txt, caplog): # noqa: ARG001 + file_paths = [file_image, file_txt] + + result = await get_files(file_paths, convert_to_base64=True) + + for index, file in enumerate(result): + async with aiofiles.open(file_paths[index], "rb") as f: + assert file != await f.read() + + +@pytest.mark.asyncio +async def test_get_files__empty(): + result = await get_files([]) + + assert len(result) == 0