Don't redownload voice files if sha256 sum matches
This commit is contained in:
parent
fbd26a02d5
commit
c3a0bf3865
2 changed files with 98 additions and 14 deletions
|
|
@ -13,9 +13,11 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
"""A command-line tool for downloading Mimic 3 voices"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import typing
|
||||
import urllib.request
|
||||
|
|
@ -25,9 +27,12 @@ from urllib.error import HTTPError
|
|||
|
||||
from ._resources import _PACKAGE, _VOICES
|
||||
from .const import DEFAULT_VOICES_DOWNLOAD_DIR, DEFAULT_VOICES_URL_FORMAT
|
||||
from .utils import file_sha256_sum, wildcard_to_regex
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_WILDCARD = "*"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -50,6 +55,7 @@ def download_voice(
|
|||
voice_files: typing.Iterable[VoiceFile],
|
||||
voices_dir: typing.Union[str, Path],
|
||||
chunk_bytes: int = 4096,
|
||||
redownload: bool = False,
|
||||
):
|
||||
"""Downloads a voice to a directory"""
|
||||
from tqdm.auto import tqdm
|
||||
|
|
@ -67,7 +73,19 @@ def download_voice(
|
|||
file_url = f"{url_base}/{voice_file.relative_path}"
|
||||
file_path = voice_dir / voice_file.relative_path
|
||||
|
||||
if (not redownload) and voice_file.sha256_sum and file_path.is_file():
|
||||
# Check if file exists and has correct sha256
|
||||
expected_sha256 = voice_file.sha256_sum
|
||||
|
||||
with open(file_path, "rb") as check_file:
|
||||
actual_sha256 = file_sha256_sum(check_file)
|
||||
|
||||
if actual_sha256 == expected_sha256:
|
||||
_LOGGER.debug("Skipping download of %s (sha256 match)", file_path)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file, show progress with tqdm
|
||||
with urllib.request.urlopen(file_url) as response:
|
||||
with open(file_path, mode="wb") as dest_file:
|
||||
with tqdm(
|
||||
|
|
@ -101,7 +119,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"key",
|
||||
nargs="*",
|
||||
help="Keys of voices to download (e.g., en_US/vctk_low)",
|
||||
help="Keys of voices to download (e.g., en_US/vctk_low). May contain wildcards (*)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
|
|
@ -113,6 +131,11 @@ def main():
|
|||
default=DEFAULT_VOICES_URL_FORMAT,
|
||||
help="URL format string for voices (contains {key}, {lang}, {name})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redownload",
|
||||
action="store_true",
|
||||
help="Force re-downloading of files if they already exist",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Print DEBUG messages to console"
|
||||
)
|
||||
|
|
@ -133,21 +156,45 @@ def main():
|
|||
json.dump(_VOICES, sys.stdout, indent=4, ensure_ascii=False)
|
||||
sys.exit(0)
|
||||
|
||||
args.key = [
|
||||
wildcard_to_regex(key, wildcard=_WILDCARD) if _WILDCARD in key else key
|
||||
for key in args.key
|
||||
]
|
||||
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for voice_key in args.key:
|
||||
voice_lang, voice_name = voice_key.split("/", maxsplit=1)
|
||||
voice_info = _VOICES[voice_key]
|
||||
voice_url = str.format(
|
||||
args.url_format, key=voice_key, lang=voice_lang, name=voice_name
|
||||
)
|
||||
voice_files = voice_info["files"]
|
||||
download_voice(
|
||||
voice_key=voice_key,
|
||||
url_base=voice_url,
|
||||
voice_files=[VoiceFile(file_key) for file_key in voice_files.keys()],
|
||||
voices_dir=args.output_dir,
|
||||
)
|
||||
for key_or_pattern in args.key:
|
||||
if isinstance(key_or_pattern, re.Pattern):
|
||||
# Wildcards
|
||||
voice_keys = []
|
||||
for maybe_key in _VOICES.keys():
|
||||
if key_or_pattern.match(maybe_key):
|
||||
voice_keys.append(maybe_key)
|
||||
|
||||
_LOGGER.debug("%s matched %s", key_or_pattern, voice_keys)
|
||||
else:
|
||||
# No wildcards
|
||||
voice_keys = [key_or_pattern]
|
||||
|
||||
for voice_key in voice_keys:
|
||||
voice_lang, voice_name = voice_key.split("/", maxsplit=1)
|
||||
voice_info = _VOICES[voice_key]
|
||||
voice_url = str.format(
|
||||
args.url_format, key=voice_key, lang=voice_lang, name=voice_name
|
||||
)
|
||||
voice_files = voice_info["files"]
|
||||
|
||||
_LOGGER.info("Downloading %s", voice_key)
|
||||
download_voice(
|
||||
voice_key=voice_key,
|
||||
url_base=voice_url,
|
||||
voice_files=[
|
||||
VoiceFile(file_key, sha256_sum=file_info.get("sha256_sum"))
|
||||
for file_key, file_info in voice_files.items()
|
||||
],
|
||||
voices_dir=args.output_dir,
|
||||
redownload=args.redownload,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
"""Utility methods for Mimic 3"""
|
||||
import hashlib
|
||||
import re
|
||||
import typing
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
@ -24,3 +29,35 @@ def audio_float_to_int16(
|
|||
audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
|
||||
audio_norm = audio_norm.astype("int16")
|
||||
return audio_norm
|
||||
|
||||
|
||||
def wildcard_to_regex(template: str, wildcard: str = "*") -> re.Pattern:
|
||||
"""Convert a string with wildcards into a regex pattern"""
|
||||
wildcard_escaped = re.escape(wildcard)
|
||||
|
||||
pattern_parts = ["^"]
|
||||
for i, template_part in enumerate(re.split(f"({wildcard_escaped})", template)):
|
||||
if (i % 2) == 0:
|
||||
# Fixed string
|
||||
pattern_parts.append(re.escape(template_part))
|
||||
else:
|
||||
# Wildcard separator
|
||||
pattern_parts.append(".*")
|
||||
|
||||
pattern_parts.append("$")
|
||||
pattern_str = "".join(pattern_parts)
|
||||
|
||||
return re.compile(pattern_str)
|
||||
|
||||
|
||||
def file_sha256_sum(fp: typing.BinaryIO, block_bytes: int = 4096) -> str:
|
||||
"""Return the sha256 sum of a (possibly large) file"""
|
||||
current_hash = hashlib.sha256()
|
||||
|
||||
# Read in blocks in case file is very large
|
||||
block = fp.read(block_bytes)
|
||||
while len(block) > 0:
|
||||
current_hash.update(block)
|
||||
block = fp.read(block_bytes)
|
||||
|
||||
return current_hash.hexdigest()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue