diff --git a/filesystem.py b/filesystem.py index b289b55..d56f70e 100644 --- a/filesystem.py +++ b/filesystem.py @@ -1221,60 +1221,111 @@ class Tools: async def compress_file( self, - file_name: str, + file_name: Union[str, List[str]], output_filename: str, format: str = "zip", base_dir: Optional[str] = None, ) -> Dict[str, Any]: """ - Compress a file into the specified format. - :param file_name: The name of the file to compress. + Compress a single file, an entire directory, or a list of files into the specified format. + :param file_name: A path to a file or directory, or a list of file paths to include. :param output_filename: The name of the output compressed file. :param format: The compression format ('zip', 'tar', 'gztar'). - :param base_dir: The base directory where the file is located. - :return: A success message if the file is compressed successfully. + :param base_dir: The base directory where the input path(s) are located. + :return: A success message if compression succeeds. """ try: base_path = base_dir if base_dir else "." - file_path = self._resolve_under_restriction(os.path.join(base_path, file_name)) output_path = self._resolve_under_restriction(os.path.join(base_path, output_filename)) - - if not await aiofiles.os.path.exists(file_path): - return self._result( - False, - action="compress", - subject_type="file", - error="Source file does not exist", - path=self._get_relative_path(file_path) - ) - - if not await aiofiles.os.path.isfile(file_path): - return self._result(False, action="compress", subject_type="file", error="Source path is not a file") - - # Skip symlinks for security - if await aiofiles.os.path.islink(file_path): - return self._result(False, action="compress", subject_type="file", error="Cannot compress symlinks") - + + # Build list of (absolute_file_path, arcname) to include in archive + def iter_items_for_source(abs_path: str) -> List[tuple[str, str]]: + items: List[tuple[str, str]] = [] + # Skip symlinks entirely + if os.path.islink(abs_path): + return items + if os.path.isfile(abs_path): + items.append((abs_path, os.path.basename(abs_path))) + elif os.path.isdir(abs_path): + root_name = os.path.basename(abs_path.rstrip(os.sep)) + for walk_root, dirnames, filenames in os.walk(abs_path, followlinks=False): + # Prevent following symlinked directories + dirnames[:] = [d for d in dirnames if not os.path.islink(os.path.join(walk_root, d))] + for fname in filenames: + file_abs = os.path.join(walk_root, fname) + if os.path.islink(file_abs): + continue + rel_inside = os.path.relpath(file_abs, abs_path) + arcname = os.path.join(root_name, rel_inside) + items.append((file_abs, arcname)) + return items + + resolved_inputs: List[str] = [] + # Normalize inputs to a list of absolute paths + if isinstance(file_name, list): + # In list mode, expect files only + for name in file_name: + candidate = self._resolve_under_restriction(os.path.join(base_path, name)) + if not os.path.exists(candidate): + return self._result( + False, + action="compress", + subject_type="file", + error=f"Source does not exist: {name}", + path=self._get_relative_path(candidate) + ) + if os.path.islink(candidate): + return self._result(False, action="compress", subject_type="file", error=f"Cannot compress symlink: {name}") + if os.path.isdir(candidate): + return self._result(False, action="compress", subject_type="file", error=f"List mode only supports files (got directory: {name})") + resolved_inputs.append(candidate) + else: + single_path = self._resolve_under_restriction(os.path.join(base_path, file_name)) + if not os.path.exists(single_path): + return self._result( + False, + action="compress", + subject_type="file", + error="Source path does not exist", + path=self._get_relative_path(single_path) + ) + if os.path.islink(single_path): + return self._result(False, action="compress", subject_type="file", error="Cannot compress symlinks") + resolved_inputs.append(single_path) + + # Gather items + items_to_add: List[tuple[str, str]] = [] + for src in resolved_inputs: + items_to_add.extend(iter_items_for_source(src)) + + if not items_to_add: + return self._result(False, action="compress", subject_type="file", error="No files to compress") + await self._ensure_parent_dir(output_path) - - # Use asyncio.to_thread for CPU-bound compression operations + + # Perform compression using CPU-bound threads if format == "zip": - await asyncio.to_thread(self._compress_zip, file_path, output_path) + await asyncio.to_thread(self._compress_zip_items, items_to_add, output_path) elif format == "tar": - await asyncio.to_thread(self._compress_tar, file_path, output_path) + await asyncio.to_thread(self._compress_tar_items, items_to_add, output_path, False) elif format == "gztar": - await asyncio.to_thread(self._compress_gztar, file_path, output_path) + await asyncio.to_thread(self._compress_tar_items, items_to_add, output_path, True) else: return self._result(False, action="compress", subject_type="file", error=f"Unsupported compression format: {format}") - - logger.info(f"File '{file_name}' compressed successfully to {output_path}") + + # For result metadata + display_sources = [self._get_relative_path(p) for p in resolved_inputs] + + logger.info(f"Compressed {len(items_to_add)} item(s) to {output_path}") return self._result( True, action="compress", subject_type="file", message="Compressed successfully", output=self._get_relative_path(output_path), - format=format + format=format, + sources=display_sources, + file_count=len(items_to_add) ) except ValueError as e: return self._result(False, action="compress", subject_type="file", error=str(e)) @@ -1283,20 +1334,19 @@ class Tools: except OSError as e: return self._result(False, action="compress", subject_type="file", error=f"Failed to compress file: {str(e)}") - def _compress_zip(self, file_path: str, output_path: str) -> None: - """Helper method for ZIP compression.""" + def _compress_zip_items(self, items: List[tuple[str, str]], output_path: str) -> None: + """Helper to create a ZIP archive from a list of (abs_path, arcname).""" with zipfile.ZipFile(output_path, "w", compression=zipfile.ZIP_DEFLATED) as zipf: - zipf.write(file_path, os.path.basename(file_path)) + for abs_path, arcname in items: + zipf.write(abs_path, arcname) - def _compress_tar(self, file_path: str, output_path: str) -> None: - """Helper method for TAR compression.""" - with tarfile.open(output_path, "w") as tarf: - tarf.add(file_path, os.path.basename(file_path)) - - def _compress_gztar(self, file_path: str, output_path: str) -> None: - """Helper method for GZTAR compression.""" - with tarfile.open(output_path, "w:gz") as tarf: - tarf.add(file_path, os.path.basename(file_path)) + def _compress_tar_items(self, items: List[tuple[str, str]], output_path: str, gzip: bool) -> None: + """Helper to create a TAR/TAR.GZ archive from a list of (abs_path, arcname).""" + mode = "w:gz" if gzip else "w" + with tarfile.open(output_path, mode) as tarf: + for abs_path, arcname in items: + tarf.add(abs_path, arcname) + async def decompress_file( self, file_name: str, output_directory: str, base_dir: Optional[str] = None