Don't redownload voice files if sha256 sum matches

This commit is contained in:
Michael Hansen 2022-03-31 11:27:35 -04:00
commit c3a0bf3865
2 changed files with 98 additions and 14 deletions

View file

@ -13,9 +13,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
"""A command-line tool for downloading Mimic 3 voices"""
import argparse
import json
import logging
import re
import sys
import typing
import urllib.request
@ -25,9 +27,12 @@ from urllib.error import HTTPError
from ._resources import _PACKAGE, _VOICES
from .const import DEFAULT_VOICES_DOWNLOAD_DIR, DEFAULT_VOICES_URL_FORMAT
from .utils import file_sha256_sum, wildcard_to_regex
_LOGGER = logging.getLogger(__name__)
_WILDCARD = "*"
# -----------------------------------------------------------------------------
@ -50,6 +55,7 @@ def download_voice(
voice_files: typing.Iterable[VoiceFile],
voices_dir: typing.Union[str, Path],
chunk_bytes: int = 4096,
redownload: bool = False,
):
"""Downloads a voice to a directory"""
from tqdm.auto import tqdm
@ -67,7 +73,19 @@ def download_voice(
file_url = f"{url_base}/{voice_file.relative_path}"
file_path = voice_dir / voice_file.relative_path
if (not redownload) and voice_file.sha256_sum and file_path.is_file():
# Check if file exists and has correct sha256
expected_sha256 = voice_file.sha256_sum
with open(file_path, "rb") as check_file:
actual_sha256 = file_sha256_sum(check_file)
if actual_sha256 == expected_sha256:
_LOGGER.debug("Skipping download of %s (sha256 match)", file_path)
continue
try:
# Download file, show progress with tqdm
with urllib.request.urlopen(file_url) as response:
with open(file_path, mode="wb") as dest_file:
with tqdm(
@ -101,7 +119,7 @@ def main():
parser.add_argument(
"key",
nargs="*",
help="Keys of voices to download (e.g., en_US/vctk_low)",
help="Keys of voices to download (e.g., en_US/vctk_low). May contain wildcards (*)",
)
parser.add_argument(
"--output-dir",
@ -113,6 +131,11 @@ def main():
default=DEFAULT_VOICES_URL_FORMAT,
help="URL format string for voices (contains {key}, {lang}, {name})",
)
parser.add_argument(
"--redownload",
action="store_true",
help="Force re-downloading of files if they already exist",
)
parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to console"
)
@ -133,21 +156,45 @@ def main():
json.dump(_VOICES, sys.stdout, indent=4, ensure_ascii=False)
sys.exit(0)
args.key = [
wildcard_to_regex(key, wildcard=_WILDCARD) if _WILDCARD in key else key
for key in args.key
]
args.output_dir.mkdir(parents=True, exist_ok=True)
for voice_key in args.key:
voice_lang, voice_name = voice_key.split("/", maxsplit=1)
voice_info = _VOICES[voice_key]
voice_url = str.format(
args.url_format, key=voice_key, lang=voice_lang, name=voice_name
)
voice_files = voice_info["files"]
download_voice(
voice_key=voice_key,
url_base=voice_url,
voice_files=[VoiceFile(file_key) for file_key in voice_files.keys()],
voices_dir=args.output_dir,
)
for key_or_pattern in args.key:
if isinstance(key_or_pattern, re.Pattern):
# Wildcards
voice_keys = []
for maybe_key in _VOICES.keys():
if key_or_pattern.match(maybe_key):
voice_keys.append(maybe_key)
_LOGGER.debug("%s matched %s", key_or_pattern, voice_keys)
else:
# No wildcards
voice_keys = [key_or_pattern]
for voice_key in voice_keys:
voice_lang, voice_name = voice_key.split("/", maxsplit=1)
voice_info = _VOICES[voice_key]
voice_url = str.format(
args.url_format, key=voice_key, lang=voice_lang, name=voice_name
)
voice_files = voice_info["files"]
_LOGGER.info("Downloading %s", voice_key)
download_voice(
voice_key=voice_key,
url_base=voice_url,
voice_files=[
VoiceFile(file_key, sha256_sum=file_info.get("sha256_sum"))
for file_key, file_info in voice_files.items()
],
voices_dir=args.output_dir,
redownload=args.redownload,
)
# -----------------------------------------------------------------------------

View file

@ -13,6 +13,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
"""Utility methods for Mimic 3"""
import hashlib
import re
import typing
import numpy as np
@ -24,3 +29,35 @@ def audio_float_to_int16(
audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
audio_norm = audio_norm.astype("int16")
return audio_norm
def wildcard_to_regex(template: str, wildcard: str = "*") -> re.Pattern:
"""Convert a string with wildcards into a regex pattern"""
wildcard_escaped = re.escape(wildcard)
pattern_parts = ["^"]
for i, template_part in enumerate(re.split(f"({wildcard_escaped})", template)):
if (i % 2) == 0:
# Fixed string
pattern_parts.append(re.escape(template_part))
else:
# Wildcard separator
pattern_parts.append(".*")
pattern_parts.append("$")
pattern_str = "".join(pattern_parts)
return re.compile(pattern_str)
def file_sha256_sum(fp: typing.BinaryIO, block_bytes: int = 4096) -> str:
"""Return the sha256 sum of a (possibly large) file"""
current_hash = hashlib.sha256()
# Read in blocks in case file is very large
block = fp.read(block_bytes)
while len(block) > 0:
current_hash.update(block)
block = fp.read(block_bytes)
return current_hash.hexdigest()