From c3a0bf38650f120c2c497f6f079484332f7bdcc8 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 31 Mar 2022 11:27:35 -0400 Subject: [PATCH] Don't redownload voice files if sha256 sum matches --- mimic3-tts/mimic3_tts/download.py | 75 +++++++++++++++++++++++++------ mimic3-tts/mimic3_tts/utils.py | 37 +++++++++++++++ 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/mimic3-tts/mimic3_tts/download.py b/mimic3-tts/mimic3_tts/download.py index bf05bff..693ea8d 100644 --- a/mimic3-tts/mimic3_tts/download.py +++ b/mimic3-tts/mimic3_tts/download.py @@ -13,9 +13,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +"""A command-line tool for downloading Mimic 3 voices""" import argparse import json import logging +import re import sys import typing import urllib.request @@ -25,9 +27,12 @@ from urllib.error import HTTPError from ._resources import _PACKAGE, _VOICES from .const import DEFAULT_VOICES_DOWNLOAD_DIR, DEFAULT_VOICES_URL_FORMAT +from .utils import file_sha256_sum, wildcard_to_regex _LOGGER = logging.getLogger(__name__) +_WILDCARD = "*" + # ----------------------------------------------------------------------------- @@ -50,6 +55,7 @@ def download_voice( voice_files: typing.Iterable[VoiceFile], voices_dir: typing.Union[str, Path], chunk_bytes: int = 4096, + redownload: bool = False, ): """Downloads a voice to a directory""" from tqdm.auto import tqdm @@ -67,7 +73,19 @@ def download_voice( file_url = f"{url_base}/{voice_file.relative_path}" file_path = voice_dir / voice_file.relative_path + if (not redownload) and voice_file.sha256_sum and file_path.is_file(): + # Check if file exists and has correct sha256 + expected_sha256 = voice_file.sha256_sum + + with open(file_path, "rb") as check_file: + actual_sha256 = file_sha256_sum(check_file) + + if actual_sha256 == expected_sha256: + _LOGGER.debug("Skipping download of %s (sha256 match)", file_path) + continue + try: + # Download file, show progress with tqdm with urllib.request.urlopen(file_url) as response: with open(file_path, mode="wb") as dest_file: with tqdm( @@ -101,7 +119,7 @@ def main(): parser.add_argument( "key", nargs="*", - help="Keys of voices to download (e.g., en_US/vctk_low)", + help="Keys of voices to download (e.g., en_US/vctk_low). May contain wildcards (*)", ) parser.add_argument( "--output-dir", @@ -113,6 +131,11 @@ def main(): default=DEFAULT_VOICES_URL_FORMAT, help="URL format string for voices (contains {key}, {lang}, {name})", ) + parser.add_argument( + "--redownload", + action="store_true", + help="Force re-downloading of files if they already exist", + ) parser.add_argument( "--debug", action="store_true", help="Print DEBUG messages to console" ) @@ -133,21 +156,45 @@ def main(): json.dump(_VOICES, sys.stdout, indent=4, ensure_ascii=False) sys.exit(0) + args.key = [ + wildcard_to_regex(key, wildcard=_WILDCARD) if _WILDCARD in key else key + for key in args.key + ] + args.output_dir.mkdir(parents=True, exist_ok=True) - for voice_key in args.key: - voice_lang, voice_name = voice_key.split("/", maxsplit=1) - voice_info = _VOICES[voice_key] - voice_url = str.format( - args.url_format, key=voice_key, lang=voice_lang, name=voice_name - ) - voice_files = voice_info["files"] - download_voice( - voice_key=voice_key, - url_base=voice_url, - voice_files=[VoiceFile(file_key) for file_key in voice_files.keys()], - voices_dir=args.output_dir, - ) + for key_or_pattern in args.key: + if isinstance(key_or_pattern, re.Pattern): + # Wildcards + voice_keys = [] + for maybe_key in _VOICES.keys(): + if key_or_pattern.match(maybe_key): + voice_keys.append(maybe_key) + + _LOGGER.debug("%s matched %s", key_or_pattern, voice_keys) + else: + # No wildcards + voice_keys = [key_or_pattern] + + for voice_key in voice_keys: + voice_lang, voice_name = voice_key.split("/", maxsplit=1) + voice_info = _VOICES[voice_key] + voice_url = str.format( + args.url_format, key=voice_key, lang=voice_lang, name=voice_name + ) + voice_files = voice_info["files"] + + _LOGGER.info("Downloading %s", voice_key) + download_voice( + voice_key=voice_key, + url_base=voice_url, + voice_files=[ + VoiceFile(file_key, sha256_sum=file_info.get("sha256_sum")) + for file_key, file_info in voice_files.items() + ], + voices_dir=args.output_dir, + redownload=args.redownload, + ) # ----------------------------------------------------------------------------- diff --git a/mimic3-tts/mimic3_tts/utils.py b/mimic3-tts/mimic3_tts/utils.py index 72863f6..96d4c4a 100644 --- a/mimic3-tts/mimic3_tts/utils.py +++ b/mimic3-tts/mimic3_tts/utils.py @@ -13,6 +13,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +"""Utility methods for Mimic 3""" +import hashlib +import re +import typing + import numpy as np @@ -24,3 +29,35 @@ def audio_float_to_int16( audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value) audio_norm = audio_norm.astype("int16") return audio_norm + + +def wildcard_to_regex(template: str, wildcard: str = "*") -> re.Pattern: + """Convert a string with wildcards into a regex pattern""" + wildcard_escaped = re.escape(wildcard) + + pattern_parts = ["^"] + for i, template_part in enumerate(re.split(f"({wildcard_escaped})", template)): + if (i % 2) == 0: + # Fixed string + pattern_parts.append(re.escape(template_part)) + else: + # Wildcard separator + pattern_parts.append(".*") + + pattern_parts.append("$") + pattern_str = "".join(pattern_parts) + + return re.compile(pattern_str) + + +def file_sha256_sum(fp: typing.BinaryIO, block_bytes: int = 4096) -> str: + """Return the sha256 sum of a (possibly large) file""" + current_hash = hashlib.sha256() + + # Read in blocks in case file is very large + block = fp.read(block_bytes) + while len(block) > 0: + current_hash.update(block) + block = fp.read(block_bytes) + + return current_hash.hexdigest()