feat: Unified File Management API (#6100)
* feat: FIrst pass at file management API * [autofix.ci] apply automated fixes * Add delete and edit endpoints * [autofix.ci] apply automated fixes * Add file size and duplicate name handling * Ensure the File model has a unique name * Ensure count is before extension * [autofix.ci] apply automated fixes * Add the correct path to the return * Added function to handle list of paths in File component * [autofix.ci] apply automated fixes * Update input_mixin.py * Refactor to a v2 endpoint * Add unit tests * Update test_files.py * Update frontend.ts * [autofix.ci] apply automated fixes * Remove extension from name * Cast the string type for like * Update files.py * Update base.py * Update base.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Lucas Oliveira <lucas.edu.oli@hotmail.com>
This commit is contained in:
parent
cc3417bec7
commit
28e07be870
18 changed files with 635 additions and 15 deletions
|
|
@ -92,6 +92,7 @@ export class Web extends Construct {
|
|||
defaultBehavior: { origin: s3SpaOrigin },
|
||||
additionalBehaviors: {
|
||||
'/api/v1/*': albBehaviorOptions,
|
||||
'/api/v2/*': albBehaviorOptions,
|
||||
'/health' : albBehaviorOptions,
|
||||
},
|
||||
enableLogging: true, // ログ出力設定
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
"""Add V2 File Table
|
||||
|
||||
Revision ID: dd9e0804ebd1
|
||||
Revises: e3162c1804e6
|
||||
Create Date: 2025-02-03 11:47:16.101523
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from langflow.utils import migration
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'dd9e0804ebd1'
|
||||
down_revision: Union[str, None] = 'e3162c1804e6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
if not migration.table_exists("file", conn):
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"file",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.types.Uuid(), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.types.Uuid(), nullable=False),
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False, unique=True),
|
||||
sa.Column("path", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("size", sa.Integer(), nullable=False),
|
||||
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], name="fk_file_user_id_user"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if migration.table_exists("file", conn):
|
||||
op.drop_table("file")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.api.health_check_router import health_check_router
|
||||
from langflow.api.log_router import log_router
|
||||
from langflow.api.router import router
|
||||
from langflow.api.router import router, router_v2
|
||||
|
||||
__all__ = ["health_check_router", "log_router", "router"]
|
||||
__all__ = ["health_check_router", "log_router", "router", "router_v2"]
|
||||
|
|
|
|||
|
|
@ -16,10 +16,16 @@ from langflow.api.v1 import (
|
|||
validate_router,
|
||||
variables_router,
|
||||
)
|
||||
from langflow.api.v2 import files_router as files_router_v2
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1",
|
||||
)
|
||||
|
||||
router_v2 = APIRouter(
|
||||
prefix="/api/v2",
|
||||
)
|
||||
|
||||
router.include_router(chat_router)
|
||||
router.include_router(endpoints_router)
|
||||
router.include_router(validate_router)
|
||||
|
|
@ -33,3 +39,5 @@ router.include_router(files_router)
|
|||
router.include_router(monitor_router)
|
||||
router.include_router(folders_router)
|
||||
router.include_router(starter_projects_router)
|
||||
|
||||
router_v2.include_router(files_router_v2)
|
||||
|
|
|
|||
14
src/backend/base/langflow/api/schemas.py
Normal file
14
src/backend/base/langflow/api/schemas.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
"""File upload response schema."""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
path: Path
|
||||
size: int
|
||||
provider: str | None = None
|
||||
5
src/backend/base/langflow/api/v2/__init__.py
Normal file
5
src/backend/base/langflow/api/v2/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from langflow.api.v2.files import router as files_router
|
||||
|
||||
__all__ = [
|
||||
"files_router",
|
||||
]
|
||||
228
src/backend/base/langflow/api/v2/files.py
Normal file
228
src/backend/base/langflow/api/v2/files.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlmodel import String, cast, select
|
||||
|
||||
from langflow.api.schemas import UploadFileResponse
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.services.database.models.file import File as UserFile
|
||||
from langflow.services.deps import get_settings_service, get_storage_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
|
||||
router = APIRouter(tags=["Files"], prefix="/files")
|
||||
|
||||
|
||||
async def byte_stream_generator(file_bytes: bytes, chunk_size: int = 8192) -> AsyncGenerator[bytes, None]:
|
||||
"""Convert bytes object into an async generator that yields chunks."""
|
||||
for i in range(0, len(file_bytes), chunk_size):
|
||||
yield file_bytes[i : i + chunk_size]
|
||||
|
||||
|
||||
async def fetch_file_object(file_id: uuid.UUID, current_user: CurrentActiveUser, session: DbSession):
|
||||
# Fetch the file from the DB
|
||||
stmt = select(UserFile).where(UserFile.id == file_id)
|
||||
results = await session.exec(stmt)
|
||||
file = results.first()
|
||||
|
||||
# Check if the file exists
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Make sure the user has access to the file
|
||||
if file.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="You don't have access to this file")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
@router.post("", status_code=HTTPStatus.CREATED)
|
||||
async def upload_user_file(
|
||||
file: Annotated[UploadFile, File(...)],
|
||||
session: DbSession,
|
||||
current_user: CurrentActiveUser,
|
||||
storage_service=Depends(get_storage_service),
|
||||
settings_service=Depends(get_settings_service),
|
||||
) -> UploadFileResponse:
|
||||
"""Upload a file for the current user and track it in the database."""
|
||||
# Get the max allowed file size from settings (in MB)
|
||||
try:
|
||||
max_file_size_upload = settings_service.settings.max_file_size_upload
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Settings error: {e}") from e
|
||||
|
||||
# Validate that a file is actually provided
|
||||
if not file or not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
# Validate file size (convert MB to bytes)
|
||||
if file.size > max_file_size_upload * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File size is larger than the maximum file size {max_file_size_upload}MB.",
|
||||
)
|
||||
|
||||
# Read file content and create a unique file name
|
||||
try:
|
||||
# Create a unique file name
|
||||
file_id = uuid.uuid4()
|
||||
file_content = await file.read()
|
||||
|
||||
# Get file extension of the file
|
||||
file_extension = "." + file.filename.split(".")[-1] if file.filename and "." in file.filename else ""
|
||||
anonymized_file_name = f"{file_id!s}{file_extension}"
|
||||
|
||||
# Here we use the current user's id as the folder name
|
||||
folder = str(current_user.id)
|
||||
# Save the file using the storage service.
|
||||
await storage_service.save_file(flow_id=folder, file_name=anonymized_file_name, data=file_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error saving file: {e}") from e
|
||||
|
||||
# Create a new database record for the uploaded file.
|
||||
try:
|
||||
# Enforce unique constraint on name
|
||||
# Name it as filename (1), (2), etc.
|
||||
# Check if the file name already exists
|
||||
new_filename = file.filename
|
||||
try:
|
||||
root_filename, _ = new_filename.rsplit(".", 1)
|
||||
except ValueError:
|
||||
root_filename, _ = new_filename, ""
|
||||
|
||||
# Check if there are files with the same name
|
||||
stmt = select(UserFile).where(cast(UserFile.name, String).like(f"{root_filename}%"))
|
||||
existing_files = await session.exec(stmt)
|
||||
files = existing_files.all() # Fetch all matching records
|
||||
|
||||
# If there are files with the same name, append a count to the filename
|
||||
if files:
|
||||
count = len(files) # Count occurrences
|
||||
|
||||
# Split the extension from the filename
|
||||
root_filename = f"{root_filename} ({count})"
|
||||
|
||||
# Compute the file size based on the path
|
||||
file_size = await storage_service.get_file_size(flow_id=folder, file_name=anonymized_file_name)
|
||||
|
||||
# Compute the file path
|
||||
file_path = f"{folder}/{anonymized_file_name}"
|
||||
|
||||
# Create a new file record
|
||||
new_file = UserFile(
|
||||
id=file_id,
|
||||
user_id=current_user.id,
|
||||
name=root_filename,
|
||||
path=file_path,
|
||||
size=file_size,
|
||||
)
|
||||
session.add(new_file)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(new_file)
|
||||
except Exception as e:
|
||||
# Optionally, you could also delete the file from disk if the DB insert fails.
|
||||
raise HTTPException(status_code=500, detail=f"Database error: {e}") from e
|
||||
|
||||
return UploadFileResponse(id=new_file.id, name=new_file.name, path=Path(new_file.path), size=new_file.size)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_files(
|
||||
current_user: CurrentActiveUser,
|
||||
session: DbSession,
|
||||
) -> list[UserFile]:
|
||||
"""List the files available to the current user."""
|
||||
try:
|
||||
# Fetch from the UserFile table
|
||||
stmt = select(UserFile).where(UserFile.user_id == current_user.id)
|
||||
results = await session.exec(stmt)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error listing files: {e}") from e
|
||||
|
||||
|
||||
@router.get("/{file_id}")
|
||||
async def download_file(
|
||||
file_id: uuid.UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
session: DbSession,
|
||||
storage_service: Annotated[StorageService, Depends(get_storage_service)],
|
||||
):
|
||||
"""Download a file by its ID."""
|
||||
try:
|
||||
# Fetch the file from the DB
|
||||
file = await fetch_file_object(file_id, current_user, session)
|
||||
|
||||
# Get the basename of the file path
|
||||
file_name = file.path.split("/")[-1]
|
||||
|
||||
# Get file stream
|
||||
file_stream = await storage_service.get_file(flow_id=str(current_user.id), file_name=file_name)
|
||||
|
||||
# Ensure file_stream is an async iterator returning bytes
|
||||
byte_stream = byte_stream_generator(file_stream)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") from e
|
||||
|
||||
# Return the file as a streaming response
|
||||
return StreamingResponse(
|
||||
byte_stream,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{file.name}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{file_id}")
|
||||
async def edit_file_name(
|
||||
file_id: uuid.UUID,
|
||||
name: str,
|
||||
current_user: CurrentActiveUser,
|
||||
session: DbSession,
|
||||
) -> UploadFileResponse:
|
||||
"""Edit the name of a file by its ID."""
|
||||
try:
|
||||
# Fetch the file from the DB
|
||||
file = await fetch_file_object(file_id, current_user, session)
|
||||
|
||||
# Update the file name
|
||||
file.name = name
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error editing file: {e}") from e
|
||||
|
||||
return UploadFileResponse(id=file.id, name=file.name, path=file.path, size=file.size)
|
||||
|
||||
|
||||
@router.delete("/{file_id}")
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
session: DbSession,
|
||||
storage_service: Annotated[StorageService, Depends(get_storage_service)],
|
||||
):
|
||||
"""Delete a file by its ID."""
|
||||
try:
|
||||
# Fetch the file from the DB
|
||||
file = await fetch_file_object(file_id, current_user, session)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Delete the file from the storage service
|
||||
await storage_service.delete_file(flow_id=str(current_user.id), file_name=file.path)
|
||||
|
||||
# Delete from the database
|
||||
await session.delete(file)
|
||||
await session.flush() # Ensures delete is staged
|
||||
await session.commit() # Commit deletion
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback() # Rollback on failure
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting file: {e}") from e
|
||||
|
||||
return {"message": "File deleted successfully"}
|
||||
|
|
@ -100,7 +100,10 @@ class BaseFileComponent(Component, ABC):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Dynamically update FileInput to include valid extensions and bundles
|
||||
self._base_inputs[0].file_types = [*self.valid_extensions, *self.SUPPORTED_BUNDLE_EXTENSIONS]
|
||||
self._base_inputs[0].file_types = [
|
||||
*self.valid_extensions,
|
||||
*self.SUPPORTED_BUNDLE_EXTENSIONS,
|
||||
]
|
||||
|
||||
file_types = ", ".join(self.valid_extensions)
|
||||
bundles = ", ".join(self.SUPPORTED_BUNDLE_EXTENSIONS)
|
||||
|
|
@ -342,8 +345,13 @@ class BaseFileComponent(Component, ABC):
|
|||
|
||||
if self.path and not file_path:
|
||||
# Wrap self.path into a Data object
|
||||
data_obj = Data(data={self.SERVER_FILE_PATH_FIELDNAME: self.path})
|
||||
add_file(data=data_obj, path=self.path, delete_after_processing=False)
|
||||
if isinstance(self.path, list):
|
||||
for path in self.path:
|
||||
data_obj = Data(data={self.SERVER_FILE_PATH_FIELDNAME: path})
|
||||
add_file(data=data_obj, path=path, delete_after_processing=False)
|
||||
else:
|
||||
data_obj = Data(data={self.SERVER_FILE_PATH_FIELDNAME: self.path})
|
||||
add_file(data=data_obj, path=self.path, delete_after_processing=False)
|
||||
elif file_path:
|
||||
for obj in file_path:
|
||||
server_file_path = obj.data.get(self.SERVER_FILE_PATH_FIELDNAME)
|
||||
|
|
@ -384,7 +392,11 @@ class BaseFileComponent(Component, ABC):
|
|||
# Recurse into directories
|
||||
collected_files.extend(
|
||||
[
|
||||
BaseFileComponent.BaseFile(data, sub_path, delete_after_processing=delete_after_processing)
|
||||
BaseFileComponent.BaseFile(
|
||||
data,
|
||||
sub_path,
|
||||
delete_after_processing=delete_after_processing,
|
||||
)
|
||||
for sub_path in path.rglob("*")
|
||||
if sub_path.is_file()
|
||||
]
|
||||
|
|
@ -399,7 +411,11 @@ class BaseFileComponent(Component, ABC):
|
|||
self.log(f"Unpacked bundle {path.name} into {subpaths}")
|
||||
collected_files.extend(
|
||||
[
|
||||
BaseFileComponent.BaseFile(data, sub_path, delete_after_processing=delete_after_processing)
|
||||
BaseFileComponent.BaseFile(
|
||||
data,
|
||||
sub_path,
|
||||
delete_after_processing=delete_after_processing,
|
||||
)
|
||||
for sub_path in subpaths
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,12 @@ import pandas as pd
|
|||
from loguru import logger
|
||||
|
||||
from langflow.exceptions.component import ComponentBuildError
|
||||
from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData
|
||||
from langflow.graph.schema import (
|
||||
INPUT_COMPONENTS,
|
||||
OUTPUT_COMPONENTS,
|
||||
InterfaceComponentTypes,
|
||||
ResultData,
|
||||
)
|
||||
from langflow.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction
|
||||
from langflow.interface import initialize
|
||||
from langflow.interface.listing import lazy_load_dict
|
||||
|
|
@ -355,8 +360,16 @@ class Vertex:
|
|||
if file_path := field.get("file_path"):
|
||||
storage_service = get_storage_service()
|
||||
try:
|
||||
flow_id, file_name = os.path.split(file_path)
|
||||
full_path = storage_service.build_full_path(flow_id, file_name)
|
||||
full_path: str | list[str] = ""
|
||||
if field.get("list"):
|
||||
full_path = []
|
||||
for p in file_path:
|
||||
flow_id, file_name = os.path.split(p)
|
||||
path = storage_service.build_full_path(flow_id, file_name)
|
||||
full_path.append(path)
|
||||
else:
|
||||
flow_id, file_name = os.path.split(file_path)
|
||||
full_path = storage_service.build_full_path(flow_id, file_name)
|
||||
except ValueError as e:
|
||||
if "too many values to unpack" in str(e):
|
||||
full_path = file_path
|
||||
|
|
@ -621,7 +634,12 @@ class Vertex:
|
|||
return await self._get_result(requester, target_handle_name)
|
||||
|
||||
async def _log_transaction_async(
|
||||
self, flow_id: str | UUID, source: Vertex, status, target: Vertex | None = None, error=None
|
||||
self,
|
||||
flow_id: str | UUID,
|
||||
source: Vertex,
|
||||
status,
|
||||
target: Vertex | None = None,
|
||||
error=None,
|
||||
) -> None:
|
||||
"""Log a transaction asynchronously with proper task handling and cancellation.
|
||||
|
||||
|
|
@ -723,7 +741,12 @@ class Vertex:
|
|||
self.params[key].extend(result)
|
||||
|
||||
async def _build_results(
|
||||
self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False
|
||||
self,
|
||||
custom_component,
|
||||
custom_params,
|
||||
base_type: str,
|
||||
*,
|
||||
fallback_to_env_vars=False,
|
||||
) -> None:
|
||||
try:
|
||||
result = await initialize.loading.get_instance_results(
|
||||
|
|
|
|||
|
|
@ -132,9 +132,27 @@ class DatabaseLoadMixin(BaseModel):
|
|||
|
||||
# Specific mixin for fields needing file interaction
|
||||
class FileMixin(BaseModel):
|
||||
file_path: str | None = Field(default="")
|
||||
file_path: list[str] | str | None = Field(default="")
|
||||
file_types: list[str] = Field(default=[], alias="fileTypes")
|
||||
|
||||
@field_validator("file_path")
|
||||
@classmethod
|
||||
def validate_file_path(cls, v):
|
||||
if v is None or v == "":
|
||||
return v
|
||||
# If it's already a list, validate each element is a string
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if not isinstance(item, str):
|
||||
msg = "All file paths must be strings"
|
||||
raise TypeError(msg)
|
||||
return v
|
||||
# If it's a single string, that's also valid
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
msg = "file_path must be a string, list of strings, or None"
|
||||
raise ValueError(msg)
|
||||
|
||||
@field_validator("file_types")
|
||||
@classmethod
|
||||
def validate_file_types(cls, v):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from pydantic_core import PydanticSerializationError
|
|||
from rich import print as rprint
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
|
||||
from langflow.api import health_check_router, log_router, router
|
||||
from langflow.api import health_check_router, log_router, router, router_v2
|
||||
from langflow.initial_setup.setup import (
|
||||
create_or_update_starter_projects,
|
||||
initialize_super_user_if_needed,
|
||||
|
|
@ -239,6 +239,7 @@ def create_app():
|
|||
router.include_router(mcp_router)
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(router_v2)
|
||||
app.include_router(health_check_router)
|
||||
app.include_router(log_router)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from .model import File
|
||||
|
||||
__all__ = [
|
||||
"File",
|
||||
]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.database.models.file.model import File
|
||||
|
||||
|
||||
async def get_file_by_id(db: AsyncSession, file_id: UUID) -> File | None:
|
||||
if isinstance(file_id, str):
|
||||
file_id = UUID(file_id)
|
||||
stmt = select(File).where(File.id == file_id)
|
||||
|
||||
return (await db.exec(stmt)).first()
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from langflow.schema.serialize import UUIDstr
|
||||
|
||||
|
||||
class File(SQLModel, table=True): # type: ignore[call-arg]
|
||||
id: UUIDstr = Field(default_factory=uuid4, primary_key=True)
|
||||
user_id: UUID = Field(foreign_key="user.id")
|
||||
name: str = Field(unique=True, nullable=False)
|
||||
path: str = Field(nullable=False)
|
||||
size: int = Field(nullable=False)
|
||||
provider: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
|
@ -108,3 +108,15 @@ class LocalStorageService(StorageService):
|
|||
async def teardown(self) -> None:
|
||||
"""Perform any cleanup operations when the service is being torn down."""
|
||||
# No specific teardown actions required for local
|
||||
|
||||
async def get_file_size(self, flow_id: str, file_name: str) -> None:
|
||||
"""Get the size of a file in the local storage."""
|
||||
# Get the file size from the file path
|
||||
file_path = self.data_dir / flow_id / file_name
|
||||
if not await file_path.exists():
|
||||
logger.warning(f"File {file_name} not found in flow {flow_id}.")
|
||||
msg = f"File {file_name} not found in flow {flow_id}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
file_size_stat = await file_path.stat()
|
||||
return file_size_stat.st_size
|
||||
|
|
|
|||
0
src/backend/tests/unit/api/v2/__init__.py
Normal file
0
src/backend/tests/unit/api/v2/__init__.py
Normal file
209
src/backend/tests/unit/api/v2/test_files.py
Normal file
209
src/backend/tests/unit/api/v2/test_files.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
import asyncio
|
||||
import tempfile
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
|
||||
# we need to import tmpdir
|
||||
import anyio
|
||||
import pytest
|
||||
from asgi_lifespan import LifespanManager
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from langflow.main import create_app
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from tests.conftest import _delete_transactions_and_vertex_builds
|
||||
|
||||
|
||||
@pytest.fixture(name="files_created_api_key")
|
||||
async def files_created_api_key(files_client, files_active_user): # noqa: ARG001
|
||||
hashed = get_password_hash("random_key")
|
||||
api_key = ApiKey(
|
||||
name="files_created_api_key",
|
||||
user_id=files_active_user.id,
|
||||
api_key="random_key",
|
||||
hashed_api_key=hashed,
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
async with session_getter(db_manager) as session:
|
||||
stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key)
|
||||
if existing_api_key := (await session.exec(stmt)).first():
|
||||
yield existing_api_key
|
||||
return
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
yield api_key
|
||||
# Clean up
|
||||
await session.delete(api_key)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(name="files_active_user")
|
||||
async def files_active_user(files_client): # noqa: ARG001
|
||||
db_manager = get_db_service()
|
||||
async with db_manager.with_session() as session:
|
||||
user = User(
|
||||
username="files_active_user",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
stmt = select(User).where(User.username == user.username)
|
||||
if active_user := (await session.exec(stmt)).first():
|
||||
user = active_user
|
||||
else:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
user = UserRead.model_validate(user, from_attributes=True)
|
||||
yield user
|
||||
# Clean up
|
||||
# Now cleanup transactions, vertex_build
|
||||
async with db_manager.with_session() as session:
|
||||
user = await session.get(User, user.id, options=[selectinload(User.flows)])
|
||||
await _delete_transactions_and_vertex_builds(session, user.flows)
|
||||
await session.delete(user)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def max_file_size_upload_fixture(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "1")
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def max_file_size_upload_10mb_fixture(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "10")
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(name="files_client")
|
||||
async def files_client_fixture(
|
||||
monkeypatch,
|
||||
request,
|
||||
):
|
||||
# Set the database url to a test database
|
||||
if "noclient" in request.keywords:
|
||||
yield
|
||||
else:
|
||||
|
||||
def init_app():
|
||||
db_dir = tempfile.mkdtemp()
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
from langflow.services.manager import service_manager
|
||||
|
||||
service_manager.factories.clear()
|
||||
service_manager.services.clear() # Clear the services cache
|
||||
app = create_app()
|
||||
return app, db_path
|
||||
|
||||
app, db_path = await asyncio.to_thread(init_app)
|
||||
|
||||
async with (
|
||||
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
|
||||
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client,
|
||||
):
|
||||
yield client
|
||||
# app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
# clear the temp db
|
||||
with suppress(FileNotFoundError):
|
||||
await anyio.Path(db_path).unlink()
|
||||
|
||||
|
||||
async def test_upload_file(files_client, files_created_api_key):
|
||||
headers = {"x-api-key": files_created_api_key.api_key}
|
||||
|
||||
response = await files_client.post(
|
||||
"api/v2/files",
|
||||
files={"file": ("test.txt", b"test content")},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}: {response.json()}"
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
|
||||
|
||||
async def test_download_file(files_client, files_created_api_key):
|
||||
headers = {"x-api-key": files_created_api_key.api_key}
|
||||
|
||||
# First upload a file
|
||||
response = await files_client.post(
|
||||
"api/v2/files",
|
||||
files={"file": ("test.txt", b"test content")},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
upload_response = response.json()
|
||||
|
||||
# Then try to download it
|
||||
response = await files_client.get(f"api/v2/files/{upload_response['id']}", headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"test content"
|
||||
|
||||
|
||||
async def test_list_files(files_client, files_created_api_key):
|
||||
headers = {"x-api-key": files_created_api_key.api_key}
|
||||
|
||||
# First upload a file
|
||||
response = await files_client.post(
|
||||
"api/v2/files",
|
||||
files={"file": ("test.txt", b"test content")},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Then list the files
|
||||
response = await files_client.get("api/v2/files", headers=headers)
|
||||
assert response.status_code == 200
|
||||
files = response.json()
|
||||
assert len(files) == 1
|
||||
|
||||
|
||||
async def test_delete_file(files_client, files_created_api_key):
|
||||
headers = {"x-api-key": files_created_api_key.api_key}
|
||||
|
||||
response = await files_client.post(
|
||||
"api/v2/files",
|
||||
files={"file": ("test.txt", b"test content")},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
upload_response = response.json()
|
||||
|
||||
response = await files_client.delete(f"api/v2/files/{upload_response['id']}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "File deleted successfully"}
|
||||
|
||||
|
||||
async def test_edit_file(files_client, files_created_api_key):
|
||||
headers = {"x-api-key": files_created_api_key.api_key}
|
||||
|
||||
# First upload a file
|
||||
response = await files_client.post(
|
||||
"api/v2/files",
|
||||
files={"file": ("test.txt", b"test content")},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
upload_response = response.json()
|
||||
|
||||
# Then list the files
|
||||
response = await files_client.put(f"api/v2/files/{upload_response['id']}?name=potato.txt", headers=headers)
|
||||
assert response.status_code == 200
|
||||
file = response.json()
|
||||
assert file["name"] == "potato.txt"
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
export const BASENAME = "";
|
||||
export const PORT = 3000;
|
||||
export const PROXY_TARGET = "http://127.0.0.1:7860";
|
||||
export const API_ROUTES = ["^/api/v1/", "/health"];
|
||||
export const API_ROUTES = ["^/api/v1/", "/api/v2/", "/health"];
|
||||
export const BASE_URL_API = "/api/v1/";
|
||||
export const HEALTH_CHECK_URL = "/health_check";
|
||||
export const DOCS_LINK = "https://docs.langflow.org";
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue