mimic3/mimic3_http/app.py

361 lines
12 KiB
Python

# Copyright 2022 Mycroft AI Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import argparse
import asyncio
import dataclasses
import json
import logging
import re
import shlex
import subprocess
import typing
from pathlib import Path
from queue import Queue
from urllib.parse import parse_qs
from uuid import uuid4
import quart_cors
from quart import (
Quart,
Response,
jsonify,
render_template,
request,
send_from_directory,
)
from swagger_ui import api_doc
from mimic3_tts import DEFAULT_VOICE, Mimic3Settings, Mimic3TextToSpeechSystem
from mimic3_tts.download import is_voice_downloaded
from mimic3_tts.utils import LANG_NAMES, SAMPLE_SENTENCES
from ._resources import _DIR, _PACKAGE
from .args import _MISSING
from .const import SynthesisRequest, TextToWavParams
_LOGGER = logging.getLogger(__name__)
def get_app(args: argparse.Namespace, request_queue: Queue, temp_dir: str):
"""Create and return Quart application for Mimic 3 HTTP server"""
_TEMP_DIR: typing.Optional[Path] = None
_MIMIC3 = Mimic3TextToSpeechSystem(
Mimic3Settings(voices_directories=args.voices_dir)
)
if args.cache_dir != _MISSING:
if args.cache_dir is None:
# Use temporary directory
_TEMP_DIR = Path(temp_dir)
else:
# Use user-supplied cache directory
_TEMP_DIR = Path(args.cache_dir)
_TEMP_DIR.mkdir(parents=True, exist_ok=True)
if _TEMP_DIR:
_LOGGER.debug("Cache directory: %s", _TEMP_DIR)
async def text_to_wav(params: TextToWavParams, no_cache: bool = False) -> bytes:
"""Synthesize text into audio.
Returns: WAV bytes
"""
if args.deterministic:
# Disable noise
_LOGGER.debug("Disabling noise in deterministic mode")
params.noise_scale = 0.0
params.noise_w = 0.0
_LOGGER.debug(params)
if _TEMP_DIR and (not no_cache):
# Look up in cache
maybe_wav_path = _TEMP_DIR / f"{params.cache_key}.wav"
if maybe_wav_path.is_file():
_LOGGER.debug("Loading WAV from cache: %s", maybe_wav_path)
wav_bytes = maybe_wav_path.read_bytes()
return wav_bytes
loop = asyncio.get_running_loop()
future = loop.create_future()
request_queue.put_nowait(
SynthesisRequest(
params=params,
loop=loop,
future=future,
)
)
wav_bytes = await future
if _TEMP_DIR and (not no_cache):
# Store in cache
wav_path = _TEMP_DIR / f"{params.cache_key}.wav"
wav_path.parent.mkdir(parents=True, exist_ok=True)
wav_path.write_bytes(wav_bytes)
_LOGGER.debug("Cached WAV at %s", wav_path.absolute())
return wav_bytes
# -----------------------------------------------------------------------------
_TEMPLATES_DIR = _DIR / "templates"
app = Quart(_PACKAGE, template_folder=str(_TEMPLATES_DIR))
app.secret_key = str(uuid4())
if args.debug:
app.config["TEMPLATES_AUTO_RELOAD"] = True
app = quart_cors.cors(app)
# -----------------------------------------------------------------------------
_CSS_DIR = _DIR / "css"
_IMG_DIR = _DIR / "img"
def _to_bool(s: str) -> bool:
return s.strip().lower() in {"true", "1", "yes", "on"}
class VoiceEncoder(json.JSONEncoder):
"""Encode a voice to JSON"""
def default(self, o):
if isinstance(o, set):
return list(o)
return json.JSONEncoder.default(self, o)
app.json_encoder = VoiceEncoder # type: ignore
@app.route("/img/<path:filename>", methods=["GET"])
async def img(filename) -> Response:
"""Image static endpoint."""
return await send_from_directory(_IMG_DIR, filename)
@app.route("/css/<path:filename>", methods=["GET"])
async def css(filename) -> Response:
"""CSS static endpoint."""
return await send_from_directory(_CSS_DIR, filename)
show_openapi = True
@app.route("/")
async def app_index():
"""Main page."""
return await render_template(
"index.html",
show_openapi=show_openapi,
max_text_length=args.max_text_length,
default_voice=args.default_voice,
)
@app.route("/api/tts", methods=["GET", "POST"])
async def app_tts() -> typing.Union[Response, str]:
"""Speak text to WAV."""
tts_args: typing.Dict[str, typing.Any] = {
"length_scale": args.length_scale,
"noise_scale": args.noise_scale,
"noise_w": args.noise_w,
}
_LOGGER.debug("Request args: %s", request.args)
voice = request.args.get("voice") or args.voice or DEFAULT_VOICE
tts_args["voice"] = str(voice)
# TTS settings
noise_scale = request.args.get("noiseScale")
if noise_scale:
tts_args["noise_scale"] = float(noise_scale)
noise_w = request.args.get("noiseW")
if noise_w:
tts_args["noise_w"] = float(noise_w)
length_scale = request.args.get("lengthScale")
if length_scale:
tts_args["length_scale"] = float(length_scale)
# Set SSML flag either from arg or content type
ssml_str = request.args.get("ssml")
if ssml_str:
tts_args["ssml"] = _to_bool(ssml_str)
elif request.content_type == "application/ssml+xml":
tts_args["ssml"] = True
text_language = request.args.get("textLanguage")
if text_language:
tts_args["text_language"] = str(text_language)
# Id used for cache
cache_id = request.args.get("cacheId")
if cache_id:
tts_args["cache_id"] = str(cache_id)
# Text can come from POST body or GET ?text arg
if request.method == "POST":
text = (await request.data).decode()
else:
text = request.args.get("text", "")
assert text, "No text provided"
if args.max_text_length is not None:
text = text[: args.max_text_length]
# Cache settings
no_cache_str = request.args.get("noCache", "")
no_cache = _to_bool(no_cache_str)
wav_bytes = await text_to_wav(
TextToWavParams(text=text, **tts_args), no_cache=no_cache
)
audio_target = request.args.get("audioTarget", "client").strip().lower()
if audio_target == "client":
return Response(wav_bytes, mimetype="audio/wav")
# Play audio on server
play_cmd = shlex.split(args.play_program)
subprocess.run(play_cmd, input=wav_bytes, check=True)
return "OK"
@app.route("/api/voices", methods=["GET"])
async def api_voices():
voices_by_key = {v.key: v for v in _MIMIC3.get_voices()}
sorted_voices = sorted(voices_by_key.values(), key=lambda v: v.key)
voice_dicts = [dataclasses.asdict(v) for v in sorted_voices]
# Add more fields to voices
for voice_dict in voice_dicts:
voice_lang = voice_dict["language"]
# en_US => en
short_lang = voice_lang.split("_", maxsplit=1)[0]
# en_US => English (US)
lang_name = LANG_NAMES.get(voice_lang, voice_lang)
if isinstance(lang_name, str):
# Native and English language name are the same
native_lang, english_lang = lang_name, lang_name
else:
# Native and English language name are different
native_lang, english_lang = lang_name
voice_dict["language_native"] = native_lang
voice_dict["language_english"] = english_lang
sample_text = SAMPLE_SENTENCES.get(short_lang, "")
sample_text = re.sub(r"\s+", " ", sample_text)
voice_dict["sample_text"] = sample_text
return jsonify(voice_dicts)
@app.route("/process", methods=["GET", "POST"])
async def api_marytts_process():
"""MaryTTS-compatible /process endpoint"""
voice = args.voice
if request.method == "POST":
data = parse_qs((await request.data).decode())
text = data.get("INPUT_TEXT", [""])[0]
if "VOICE" in data:
voice = str(data.get("VOICE", [voice])[0]).strip()
else:
text = request.args.get("INPUT_TEXT", "")
voice = str(request.args.get("VOICE", voice)).strip()
if args.max_text_length is not None:
text = text[: args.max_text_length]
voice = voice or args.voice or DEFAULT_VOICE
# Assume SSML if text begins with an angle bracket
ssml = text.strip().startswith("<")
_LOGGER.debug("Speaking with voice '%s': %s", voice, text)
wav_bytes = await text_to_wav(
TextToWavParams(
text=text,
voice=voice,
ssml=ssml,
length_scale=args.length_scale,
noise_scale=args.noise_scale,
noise_w=args.noise_w,
)
)
return Response(wav_bytes, mimetype="audio/wav")
@app.route("/voices", methods=["GET"])
async def api_marytts_voices():
"""MaryTTS-compatible /voices endpoint"""
voices_by_key = {v.key: v for v in _MIMIC3.get_voices()}
sorted_voices = sorted(voices_by_key.values(), key=lambda v: v.key)
# [voice] [language] [gender] [tech=hmm]
lines = []
gender = "NA" # don't have this information for every speaker yet
tech = "vits"
for voice in sorted_voices:
if not is_voice_downloaded(voice.location):
# Skip voices that are not yet installed
continue
if voice.is_multispeaker:
# List each speaker separately
for speaker in voice.speakers:
lines.append(
f"{voice.key}#{speaker} {voice.language} {gender} {tech}"
)
else:
lines.append(f"{voice.key} {voice.language} {gender} {tech}")
return "\n".join(lines)
@app.route("/api/healthcheck", methods=["GET"])
async def api_healthcheck():
"""Endpoint to check health status"""
return "OK"
# Swagger UI
show_openapi = not args.no_show_openapi
if show_openapi:
try:
api_doc(
app,
config_path=_DIR / "swagger.yaml",
url_prefix="/openapi",
title="Mimic 3",
)
except Exception:
# Fails with PyInstaller for some reason
_LOGGER.exception("Error setting up swagger UI page")
show_openapi = False
@app.errorhandler(Exception)
async def handle_error(err) -> typing.Tuple[str, int]:
"""Return error as text."""
_LOGGER.exception(err)
return (f"{err.__class__.__name__}: {err}", 500)
return app