diff --git a/mimic3-tts/mimic3_tts/download.py b/mimic3-tts/mimic3_tts/download.py
index bf05bff..693ea8d 100644
--- a/mimic3-tts/mimic3_tts/download.py
+++ b/mimic3-tts/mimic3_tts/download.py
@@ -13,9 +13,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
+"""A command-line tool for downloading Mimic 3 voices"""
import argparse
import json
import logging
+import re
import sys
import typing
import urllib.request
@@ -25,9 +27,12 @@ from urllib.error import HTTPError
from ._resources import _PACKAGE, _VOICES
from .const import DEFAULT_VOICES_DOWNLOAD_DIR, DEFAULT_VOICES_URL_FORMAT
+from .utils import file_sha256_sum, wildcard_to_regex
_LOGGER = logging.getLogger(__name__)
+_WILDCARD = "*"
+
# -----------------------------------------------------------------------------
@@ -50,6 +55,7 @@ def download_voice(
voice_files: typing.Iterable[VoiceFile],
voices_dir: typing.Union[str, Path],
chunk_bytes: int = 4096,
+ redownload: bool = False,
):
"""Downloads a voice to a directory"""
from tqdm.auto import tqdm
@@ -67,7 +73,19 @@ def download_voice(
file_url = f"{url_base}/{voice_file.relative_path}"
file_path = voice_dir / voice_file.relative_path
+ if (not redownload) and voice_file.sha256_sum and file_path.is_file():
+ # Check if file exists and has correct sha256
+ expected_sha256 = voice_file.sha256_sum
+
+ with open(file_path, "rb") as check_file:
+ actual_sha256 = file_sha256_sum(check_file)
+
+ if actual_sha256 == expected_sha256:
+ _LOGGER.debug("Skipping download of %s (sha256 match)", file_path)
+ continue
+
try:
+ # Download file, show progress with tqdm
with urllib.request.urlopen(file_url) as response:
with open(file_path, mode="wb") as dest_file:
with tqdm(
@@ -101,7 +119,7 @@ def main():
parser.add_argument(
"key",
nargs="*",
- help="Keys of voices to download (e.g., en_US/vctk_low)",
+ help="Keys of voices to download (e.g., en_US/vctk_low). May contain wildcards (*)",
)
parser.add_argument(
"--output-dir",
@@ -113,6 +131,11 @@ def main():
default=DEFAULT_VOICES_URL_FORMAT,
help="URL format string for voices (contains {key}, {lang}, {name})",
)
+ parser.add_argument(
+ "--redownload",
+ action="store_true",
+ help="Force re-downloading of files if they already exist",
+ )
parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to console"
)
@@ -133,21 +156,45 @@ def main():
json.dump(_VOICES, sys.stdout, indent=4, ensure_ascii=False)
sys.exit(0)
+ args.key = [
+ wildcard_to_regex(key, wildcard=_WILDCARD) if _WILDCARD in key else key
+ for key in args.key
+ ]
+
args.output_dir.mkdir(parents=True, exist_ok=True)
- for voice_key in args.key:
- voice_lang, voice_name = voice_key.split("/", maxsplit=1)
- voice_info = _VOICES[voice_key]
- voice_url = str.format(
- args.url_format, key=voice_key, lang=voice_lang, name=voice_name
- )
- voice_files = voice_info["files"]
- download_voice(
- voice_key=voice_key,
- url_base=voice_url,
- voice_files=[VoiceFile(file_key) for file_key in voice_files.keys()],
- voices_dir=args.output_dir,
- )
+ for key_or_pattern in args.key:
+ if isinstance(key_or_pattern, re.Pattern):
+ # Wildcards
+ voice_keys = []
+ for maybe_key in _VOICES.keys():
+ if key_or_pattern.match(maybe_key):
+ voice_keys.append(maybe_key)
+
+ _LOGGER.debug("%s matched %s", key_or_pattern, voice_keys)
+ else:
+ # No wildcards
+ voice_keys = [key_or_pattern]
+
+ for voice_key in voice_keys:
+ voice_lang, voice_name = voice_key.split("/", maxsplit=1)
+ voice_info = _VOICES[voice_key]
+ voice_url = str.format(
+ args.url_format, key=voice_key, lang=voice_lang, name=voice_name
+ )
+ voice_files = voice_info["files"]
+
+ _LOGGER.info("Downloading %s", voice_key)
+ download_voice(
+ voice_key=voice_key,
+ url_base=voice_url,
+ voice_files=[
+ VoiceFile(file_key, sha256_sum=file_info.get("sha256_sum"))
+ for file_key, file_info in voice_files.items()
+ ],
+ voices_dir=args.output_dir,
+ redownload=args.redownload,
+ )
# -----------------------------------------------------------------------------
diff --git a/mimic3-tts/mimic3_tts/utils.py b/mimic3-tts/mimic3_tts/utils.py
index 72863f6..96d4c4a 100644
--- a/mimic3-tts/mimic3_tts/utils.py
+++ b/mimic3-tts/mimic3_tts/utils.py
@@ -13,6 +13,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
+"""Utility methods for Mimic 3"""
+import hashlib
+import re
+import typing
+
import numpy as np
@@ -24,3 +29,35 @@ def audio_float_to_int16(
audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
audio_norm = audio_norm.astype("int16")
return audio_norm
+
+
+def wildcard_to_regex(template: str, wildcard: str = "*") -> re.Pattern:
+ """Convert a string with wildcards into a regex pattern"""
+ wildcard_escaped = re.escape(wildcard)
+
+ pattern_parts = ["^"]
+ for i, template_part in enumerate(re.split(f"({wildcard_escaped})", template)):
+ if (i % 2) == 0:
+ # Fixed string
+ pattern_parts.append(re.escape(template_part))
+ else:
+ # Wildcard separator
+ pattern_parts.append(".*")
+
+ pattern_parts.append("$")
+ pattern_str = "".join(pattern_parts)
+
+ return re.compile(pattern_str)
+
+
+def file_sha256_sum(fp: typing.BinaryIO, block_bytes: int = 4096) -> str:
+ """Return the sha256 sum of a (possibly large) file"""
+ current_hash = hashlib.sha256()
+
+ # Read in blocks in case file is very large
+ block = fp.read(block_bytes)
+ while len(block) > 0:
+ current_hash.update(block)
+ block = fp.read(block_bytes)
+
+ return current_hash.hexdigest()