Add deterministic mode for testing (same text, same audio guaranteed)

This commit is contained in:
Michael Hansen 2022-04-20 10:44:35 -04:00
commit 503126da4a
8 changed files with 38 additions and 2 deletions

View file

@ -1 +1 @@
0.1.0
0.1.1

View file

@ -70,6 +70,11 @@ def get_app(args: argparse.Namespace, request_queue: Queue, temp_dir: str):
Returns: WAV bytes
"""
if args.deterministic:
# Disable noise
_LOGGER.debug("Disabling noise in deterministic mode")
params.noise_scale = 0.0
params.noise_w = 0.0
_LOGGER.debug(params)

View file

@ -70,6 +70,11 @@ def get_args(argv=None) -> argparse.Namespace:
action="store_true",
help="Use Onnx CUDA execution provider (requires onnxruntime-gpu)",
)
parser.add_argument(
"--deterministic",
action="store_true",
help="Ensure that the same audio is always synthesized from the same text",
)
parser.add_argument(
"--num-threads",
type=int,

View file

@ -85,6 +85,7 @@ def do_synthesis_proc(args: argparse.Namespace, request_queue: Queue):
noise_w=args.noise_w,
use_cuda=args.cuda,
voices_directories=args.voices_dir,
use_deterministic_compute=args.deterministic,
)
)

View file

@ -1 +1 @@
0.1.6
0.1.7

View file

@ -221,6 +221,12 @@ def initialize_args(state: CommandLineInterfaceState):
# Split apart voice
args.voice, args.speaker = args.voice.split("#", maxsplit=1)
if args.deterministic:
# Disable noise
_LOGGER.debug("Disabling noise in deterministic mode")
args.noise_scale = 0.0
args.noise_w = 0.0
def initialize_tts(state: CommandLineInterfaceState):
"""Create Mimic 3 TTS from command-line arguments"""
@ -232,9 +238,13 @@ def initialize_tts(state: CommandLineInterfaceState):
# Local TTS
state.tts = Mimic3TextToSpeechSystem(
Mimic3Settings(
length_scale=args.length_scale,
noise_scale=args.noise_scale,
noise_w=args.noise_w,
voices_directories=args.voices_dir,
speaker=args.speaker,
use_cuda=args.cuda,
use_deterministic_compute=args.deterministic,
)
)
@ -694,6 +704,11 @@ def get_args(argv=None):
action="store_true",
help="Use Onnx CUDA execution provider (requires onnxruntime-gpu)",
)
parser.add_argument(
"--deterministic",
action="store_true",
help="Ensure that the same audio is always synthesized from the same text",
)
parser.add_argument("--seed", type=int, help="Set random seed (default: not set)")
parser.add_argument("--version", action="store_true", help="Print version and exit")
parser.add_argument(

View file

@ -117,6 +117,9 @@ class Mimic3Settings:
rate: float = DEFAULT_RATE
"""Voice speaking rate (< 1 is slower, > 1 is faster)"""
use_deterministic_compute: bool = False
"""Force onnxruntime to use deterministic compute mode. For fully deterministic synthesis, also set noise_scale and noise_w to 0."""
@dataclass
class Mimic3Phonemes:
@ -543,6 +546,7 @@ class Mimic3TextToSpeechSystem(TextToSpeechSystem):
model_dir,
providers=providers,
share_models=self.settings.share_onnx_models_between_threads,
use_deterministic_compute=self.settings.use_deterministic_compute,
)
_LOGGER.info("Loaded voice from %s", model_dir)

View file

@ -252,6 +252,7 @@ class Mimic3Voice(metaclass=ABCMeta):
]
] = None,
share_models: bool = True,
use_deterministic_compute: bool = False,
) -> "Mimic3Voice":
"""Load a Mimic 3 voice from a directory"""
voice_dir = Path(voice_dir)
@ -283,6 +284,7 @@ class Mimic3Voice(metaclass=ABCMeta):
generator_path,
session_options=session_options,
providers=providers,
use_deterministic_compute=use_deterministic_compute,
)
Mimic3Voice._SHARED_MODELS[model_key] = onnx_model
@ -293,6 +295,7 @@ class Mimic3Voice(metaclass=ABCMeta):
generator_path,
session_options=session_options,
providers=providers,
use_deterministic_compute=use_deterministic_compute,
)
# phoneme -> phoneme, phoneme, ...
@ -381,6 +384,7 @@ class Mimic3Voice(metaclass=ABCMeta):
typing.Union[str, typing.Tuple[str, typing.Dict[str, typing.Any]]]
]
] = None,
use_deterministic_compute: bool = False,
) -> onnxruntime.InferenceSession:
_LOGGER.debug("Loading model from %s", generator_path)
@ -394,6 +398,8 @@ class Mimic3Voice(metaclass=ABCMeta):
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
)
session_options.use_deterministic_compute = use_deterministic_compute
onnx_model = onnxruntime.InferenceSession(
str(generator_path), sess_options=session_options, providers=providers
)