233 lines
8.7 KiB
Python
233 lines
8.7 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import websockets
|
|
from websockets.client import WebSocketClientProtocol
|
|
import audioop
|
|
from urllib.parse import urlencode
|
|
from vocode import getenv
|
|
|
|
from vocode.streaming.transcriber.base_transcriber import (
|
|
BaseTranscriber,
|
|
Transcription,
|
|
)
|
|
from vocode.streaming.models.transcriber import (
|
|
DeepgramTranscriberConfig,
|
|
EndpointingConfig,
|
|
EndpointingType,
|
|
)
|
|
from vocode.streaming.models.audio_encoding import AudioEncoding
|
|
|
|
|
|
PUNCTUATION_TERMINATORS = [".", "!", "?"]
|
|
NUM_RESTARTS = 5
|
|
|
|
|
|
class DeepgramTranscriber(BaseTranscriber):
|
|
def __init__(
|
|
self,
|
|
transcriber_config: DeepgramTranscriberConfig,
|
|
logger: logging.Logger = None,
|
|
api_key: str = None,
|
|
):
|
|
super().__init__(transcriber_config)
|
|
self.api_key = api_key or getenv("DEEPGRAM_API_KEY")
|
|
if not self.api_key:
|
|
raise Exception(
|
|
"Please set DEEPGRAM_API_KEY environment variable or pass it as a parameter"
|
|
)
|
|
self.transcriber_config = transcriber_config
|
|
self._ended = False
|
|
self.warmed_up = False
|
|
self.is_ready = False
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
|
|
def create_warmup_chunks(self):
|
|
warmup_chunks = []
|
|
warmup_bytes = self.get_warmup_bytes()
|
|
chunk_size = self.transcriber_config.chunk_size
|
|
for i in range(len(warmup_bytes) // chunk_size):
|
|
warmup_chunks.append(warmup_bytes[i * chunk_size : (i + 1) * chunk_size])
|
|
return warmup_chunks
|
|
|
|
async def ready(self):
|
|
while not self.warmed_up:
|
|
await asyncio.sleep(0.1)
|
|
return self.is_ready
|
|
|
|
async def run(self):
|
|
# warmup_chunks = await self.create_warmup_chunks()
|
|
restarts = 0
|
|
while not self._ended and restarts < NUM_RESTARTS:
|
|
await self.process(self.transcriber_config.should_warmup_model)
|
|
restarts += 1
|
|
self.logger.debug(
|
|
"Deepgram connection died, restarting, num_restarts: %s", restarts
|
|
)
|
|
|
|
def send_audio(self, chunk):
|
|
if (
|
|
self.transcriber_config.downsampling
|
|
and self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16
|
|
):
|
|
chunk, _ = audioop.ratecv(
|
|
chunk,
|
|
2,
|
|
1,
|
|
self.transcriber_config.sampling_rate
|
|
* self.transcriber_config.downsampling,
|
|
self.transcriber_config.sampling_rate,
|
|
None,
|
|
)
|
|
self.audio_queue.put_nowait(chunk)
|
|
|
|
def terminate(self):
|
|
terminate_msg = json.dumps({"type": "CloseStream"})
|
|
self.audio_queue.put_nowait(terminate_msg)
|
|
self._ended = True
|
|
|
|
def get_deepgram_url(self):
|
|
if self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16:
|
|
encoding = "linear16"
|
|
elif self.transcriber_config.audio_encoding == AudioEncoding.MULAW:
|
|
encoding = "mulaw"
|
|
url_params = {
|
|
"encoding": encoding,
|
|
"sample_rate": self.transcriber_config.sampling_rate,
|
|
"channels": 1,
|
|
"interim_results": "true",
|
|
}
|
|
extra_params = {}
|
|
if self.transcriber_config.model:
|
|
extra_params["model"] = self.transcriber_config.model
|
|
if self.transcriber_config.tier:
|
|
extra_params["tier"] = self.transcriber_config.tier
|
|
if self.transcriber_config.version:
|
|
extra_params["version"] = self.transcriber_config.version
|
|
if (
|
|
self.transcriber_config.endpointing_config
|
|
and self.transcriber_config.endpointing_config.type
|
|
== EndpointingType.PUNCTUATION_BASED
|
|
):
|
|
extra_params["punctuate"] = "true"
|
|
url_params.update(extra_params)
|
|
return f"wss://api.deepgram.com/v1/listen?{urlencode(url_params)}"
|
|
|
|
def is_speech_final(
|
|
self, current_buffer: str, deepgram_response: dict, time_silent: float
|
|
):
|
|
transcript = deepgram_response["channel"]["alternatives"][0]["transcript"]
|
|
|
|
# if it is not time based, then return true if speech is final and there is a transcript
|
|
if not self.transcriber_config.endpointing_config:
|
|
return transcript and deepgram_response["speech_final"]
|
|
elif (
|
|
self.transcriber_config.endpointing_config.type
|
|
== EndpointingType.TIME_BASED
|
|
):
|
|
# if it is time based, then return true if there is no transcript
|
|
# and there is some speech to send
|
|
# and the time_silent is greater than the cutoff
|
|
return (
|
|
not transcript
|
|
and current_buffer
|
|
and (time_silent + deepgram_response["duration"])
|
|
> self.transcriber_config.endpointing_config.time_cutoff_seconds
|
|
)
|
|
elif (
|
|
self.transcriber_config.endpointing_config.type
|
|
== EndpointingType.PUNCTUATION_BASED
|
|
):
|
|
return (
|
|
transcript
|
|
and deepgram_response["speech_final"]
|
|
and transcript.strip()[-1] in PUNCTUATION_TERMINATORS
|
|
) or (
|
|
not transcript
|
|
and current_buffer
|
|
and (time_silent + deepgram_response["duration"])
|
|
> self.transcriber_config.endpointing_config.time_cutoff_seconds
|
|
)
|
|
raise Exception("Endpointing config not supported")
|
|
|
|
def calculate_time_silent(self, data: dict):
|
|
end = data["start"] + data["duration"]
|
|
words = data["channel"]["alternatives"][0]["words"]
|
|
if words:
|
|
return end - words[-1]["end"]
|
|
return data["duration"]
|
|
|
|
async def process(self, warmup=True):
|
|
extra_headers = {"Authorization": f"Token {self.api_key}"}
|
|
self.audio_queue = asyncio.Queue()
|
|
|
|
async with websockets.connect(
|
|
self.get_deepgram_url(), extra_headers=extra_headers
|
|
) as ws:
|
|
|
|
async def warmup_sender(ws: WebSocketClientProtocol):
|
|
if warmup:
|
|
warmup_chunks = self.create_warmup_chunks()
|
|
for chunk in warmup_chunks:
|
|
await ws.send(chunk)
|
|
await asyncio.sleep(5)
|
|
self.warmed_up = True
|
|
self.is_ready = True
|
|
|
|
async def sender(ws: WebSocketClientProtocol): # sends audio to websocket
|
|
while not self._ended:
|
|
try:
|
|
data = await asyncio.wait_for(self.audio_queue.get(), 5)
|
|
except asyncio.exceptions.TimeoutError:
|
|
break
|
|
await ws.send(data)
|
|
self.logger.debug("Terminating Deepgram transcriber sender")
|
|
|
|
async def receiver(ws: WebSocketClientProtocol):
|
|
buffer = ""
|
|
time_silent = 0
|
|
while not self._ended:
|
|
try:
|
|
msg = await ws.recv()
|
|
except Exception as e:
|
|
self.logger.debug(f"Got error {e} in Deepgram receiver")
|
|
break
|
|
data = json.loads(msg)
|
|
if (
|
|
not "is_final" in data
|
|
): # means we've finished receiving transcriptions
|
|
break
|
|
is_final = data["is_final"]
|
|
speech_final = self.is_speech_final(buffer, data, time_silent)
|
|
top_choice = data["channel"]["alternatives"][0]
|
|
confidence = top_choice["confidence"]
|
|
|
|
if (
|
|
top_choice["transcript"]
|
|
and confidence > 0.0
|
|
and self.warmed_up
|
|
and is_final
|
|
):
|
|
buffer = f"{buffer} {top_choice['transcript']}"
|
|
|
|
if speech_final:
|
|
await self.on_response(Transcription(buffer, confidence, True))
|
|
buffer = ""
|
|
time_silent = 0
|
|
elif (
|
|
top_choice["transcript"] and confidence > 0.0 and self.warmed_up
|
|
):
|
|
await self.on_response(
|
|
Transcription(
|
|
buffer,
|
|
confidence,
|
|
False,
|
|
)
|
|
)
|
|
time_silent = self.calculate_time_silent(data)
|
|
else:
|
|
time_silent += data["duration"]
|
|
|
|
self.logger.debug("Terminating Deepgram transcriber receiver")
|
|
|
|
await asyncio.gather(warmup_sender(ws), sender(ws), receiver(ws))
|