vocode-python/vocode/streaming/synthesizer/azure_synthesizer.py
2023-03-28 10:29:00 -07:00

250 lines
9.8 KiB
Python

import logging
import os
import re
from typing import Any, Optional
from xml.etree import ElementTree
import azure.cognitiveservices.speech as speechsdk
from dotenv import load_dotenv
from vocode.streaming.agent.bot_sentiment_analyser import BotSentiment
from vocode.streaming.models.message import BaseMessage, SSMLMessage
from vocode.streaming.synthesizer.base_synthesizer import (
BaseSynthesizer,
SynthesisResult,
FILLER_PHRASES,
FILLER_AUDIO_PATH,
FillerAudio,
encode_as_wav,
)
from vocode.streaming.models.synthesizer import AzureSynthesizerConfig
from vocode.streaming.models.audio_encoding import AudioEncoding
load_dotenv()
NAMESPACES = {
"mstts": "https://www.w3.org/2001/mstts",
"": "https://www.w3.org/2001/10/synthesis",
}
ElementTree.register_namespace("", NAMESPACES.get(""))
ElementTree.register_namespace("mstts", NAMESPACES.get("mstts"))
class WordBoundaryEventPool:
def __init__(self):
self.events = []
def add(self, event):
self.events.append(
{
"text": event.text,
"text_offset": event.text_offset,
"audio_offset": (event.audio_offset + 5000) / (10000 * 1000),
"boudary_type": event.boundary_type,
}
)
def get_events_sorted(self):
return sorted(self.events, key=lambda event: event["audio_offset"])
class AzureSynthesizer(BaseSynthesizer):
OFFSET_MS = 100
def __init__(
self, synthesizer_config: AzureSynthesizerConfig, logger: logging.Logger = None
):
super().__init__(synthesizer_config)
self.synthesizer_config = synthesizer_config
# Instantiates a client
speech_config = speechsdk.SpeechConfig(
subscription=os.environ.get("AZURE_SPEECH_KEY"),
region=os.environ.get("AZURE_SPEECH_REGION"),
)
if self.synthesizer_config.audio_encoding == AudioEncoding.LINEAR16:
if self.synthesizer_config.sampling_rate == 44100:
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm
)
if self.synthesizer_config.sampling_rate == 48000:
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm
)
if self.synthesizer_config.sampling_rate == 24000:
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm
)
elif self.synthesizer_config.sampling_rate == 16000:
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm
)
elif self.synthesizer_config.sampling_rate == 8000:
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm
)
elif self.synthesizer_config.audio_encoding == AudioEncoding.MULAW:
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Raw8Khz8BitMonoMULaw
)
self.synthesizer = speechsdk.SpeechSynthesizer(
speech_config=speech_config, audio_config=None
)
self.voice_name = self.synthesizer_config.voice_name
self.pitch = self.synthesizer_config.pitch
self.rate = self.synthesizer_config.rate
self.logger = logger or logging.getLogger(__name__)
def get_phrase_filler_audios(self) -> list[FillerAudio]:
filler_phrase_audios = []
for filler_phrase in FILLER_PHRASES:
cache_key = "-".join(
(
str(filler_phrase.text),
str(self.synthesizer_config.type),
str(self.synthesizer_config.audio_encoding),
str(self.synthesizer_config.sampling_rate),
str(self.voice_name),
str(self.pitch),
str(self.rate),
)
)
filler_audio_path = os.path.join(FILLER_AUDIO_PATH, f"{cache_key}.bytes")
if os.path.exists(filler_audio_path):
audio_data = open(filler_audio_path, "rb").read()
else:
self.logger.debug(f"Generating filler audio for {filler_phrase.text}")
ssml = self.create_ssml(filler_phrase.text)
result = self.synthesizer.speak_ssml(ssml)
offset = self.synthesizer_config.sampling_rate * self.OFFSET_MS // 1000
audio_data = result.audio_data[offset:]
with open(filler_audio_path, "wb") as f:
f.write(audio_data)
filler_phrase_audios.append(
FillerAudio(
filler_phrase,
audio_data,
self.synthesizer_config,
)
)
return filler_phrase_audios
def add_marks(self, message: str, index=0) -> str:
search_result = re.search(r"([\.\,\:\;\-\—]+)", message)
if search_result is None:
return message
start, end = search_result.span()
with_mark = message[:start] + f'<mark name="{index}" />' + message[start:end]
rest = message[end:]
rest_stripped = re.sub(r"^(.+)([\.\,\:\;\-\—]+)$", r"\1", rest)
if len(rest_stripped) == 0:
return with_mark
return with_mark + self.add_marks(rest_stripped, index + 1)
def word_boundary_cb(self, evt, pool):
pool.add(evt)
def create_ssml(
self, message: str, bot_sentiment: Optional[BotSentiment] = None
) -> str:
ssml_root = ElementTree.fromstring(
'<speak version="1.0" xmlns="https://www.w3.org/2001/10/synthesis" xml:lang="en-US"></speak>'
)
voice = ElementTree.SubElement(ssml_root, "voice")
voice.set("name", self.voice_name)
voice_root = voice
if bot_sentiment and bot_sentiment.emotion:
styled = ElementTree.SubElement(
voice, "{%s}express-as" % NAMESPACES.get("mstts")
)
styled.set("style", bot_sentiment.emotion)
styled.set(
"styledegree", str(bot_sentiment.degree * 2)
) # Azure specific, it's a scale of 0-2
voice_root = styled
prosody = ElementTree.SubElement(voice_root, "prosody")
prosody.set("pitch", f"{self.pitch}%")
prosody.set("rate", f"{self.rate}%")
prosody.text = message.strip()
return ElementTree.tostring(ssml_root, encoding="unicode")
def synthesize_ssml(self, ssml: str) -> tuple[speechsdk.AudioDataStream, str]:
result = self.synthesizer.start_speaking_ssml_async(ssml).get()
return speechsdk.AudioDataStream(result)
def ready_synthesizer(self):
connection = speechsdk.Connection.from_speech_synthesizer(self.synthesizer)
connection.open(True)
# given the number of seconds the message was allowed to go until, where did we get in the message?
def get_message_up_to(
self,
message: str,
ssml: str,
seconds: int,
word_boundary_event_pool: WordBoundaryEventPool,
) -> str:
events = word_boundary_event_pool.get_events_sorted()
for event in events:
if event["audio_offset"] > seconds:
ssml_fragment = ssml[: event["text_offset"]]
return ssml_fragment.split(">")[-1]
return message
def create_speech(
self,
message: BaseMessage,
chunk_size: int,
bot_sentiment: Optional[BotSentiment] = None,
) -> SynthesisResult:
# offset = int(self.OFFSET_MS * (self.synthesizer_config.sampling_rate / 1000))
offset = 0
self.logger.debug(f"Synthesizing message: {message}")
def chunk_generator(
audio_data_stream: speechsdk.AudioDataStream, chunk_transform=lambda x: x
):
audio_buffer = bytes(chunk_size)
filled_size = audio_data_stream.read_data(audio_buffer)
if filled_size != chunk_size:
yield SynthesisResult.ChunkResult(
chunk_transform(audio_buffer[offset:]), True
)
return
else:
yield SynthesisResult.ChunkResult(
chunk_transform(audio_buffer[offset:]), False
)
while True:
filled_size = audio_data_stream.read_data(audio_buffer)
if filled_size != chunk_size:
yield SynthesisResult.ChunkResult(
chunk_transform(audio_buffer[: filled_size - offset]), True
)
break
yield SynthesisResult.ChunkResult(chunk_transform(audio_buffer), False)
word_boundary_event_pool = WordBoundaryEventPool()
self.synthesizer.synthesis_word_boundary.connect(
lambda event: self.word_boundary_cb(event, word_boundary_event_pool)
)
ssml = (
message.ssml
if isinstance(message, SSMLMessage)
else self.create_ssml(message.text, bot_sentiment=bot_sentiment)
)
audio_data_stream = self.synthesize_ssml(ssml)
if self.synthesizer_config.should_encode_as_wav:
output_generator = chunk_generator(
audio_data_stream,
lambda chunk: encode_as_wav(chunk, self.synthesizer_config),
)
else:
output_generator = chunk_generator(audio_data_stream)
return SynthesisResult(
output_generator,
lambda seconds: self.get_message_up_to(
message, ssml, seconds, word_boundary_event_pool
),
)