vocode-python/vocode/streaming/transcriber/deepgram_transcriber.py

228 lines
8.5 KiB
Python

import asyncio
import json
import logging
import websockets
from websockets.client import WebSocketClientProtocol
import audioop
from urllib.parse import urlencode
from vocode import getenv
from vocode.streaming.transcriber.base_transcriber import (
BaseTranscriber,
Transcription,
)
from vocode.streaming.models.transcriber import (
DeepgramTranscriberConfig,
EndpointingConfig,
EndpointingType,
)
from vocode.streaming.models.audio_encoding import AudioEncoding
PUNCTUATION_TERMINATORS = [".", "!", "?"]
NUM_RESTARTS = 5
class DeepgramTranscriber(BaseTranscriber):
def __init__(
self,
transcriber_config: DeepgramTranscriberConfig,
logger: logging.Logger = None,
):
super().__init__(transcriber_config)
self.api_key = getenv("DEEPGRAM_API_KEY")
self.transcriber_config = transcriber_config
self._ended = False
self.warmed_up = False
self.is_ready = False
self.logger = logger or logging.getLogger(__name__)
def create_warmup_chunks(self):
warmup_chunks = []
warmup_bytes = self.get_warmup_bytes()
chunk_size = self.transcriber_config.chunk_size
for i in range(len(warmup_bytes) // chunk_size):
warmup_chunks.append(warmup_bytes[i * chunk_size : (i + 1) * chunk_size])
return warmup_chunks
async def ready(self):
while not self.warmed_up:
await asyncio.sleep(0.1)
return self.is_ready
async def run(self):
# warmup_chunks = await self.create_warmup_chunks()
restarts = 0
while not self._ended and restarts < NUM_RESTARTS:
await self.process(self.transcriber_config.should_warmup_model)
restarts += 1
self.logger.debug(
"Deepgram connection died, restarting, num_restarts: %s", restarts
)
def send_audio(self, chunk):
if (
self.transcriber_config.downsampling
and self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16
):
chunk, _ = audioop.ratecv(
chunk,
2,
1,
self.transcriber_config.sampling_rate
* self.transcriber_config.downsampling,
self.transcriber_config.sampling_rate,
None,
)
self.audio_queue.put_nowait(chunk)
def terminate(self):
terminate_msg = json.dumps({"type": "CloseStream"})
self.audio_queue.put_nowait(terminate_msg)
self._ended = True
def get_deepgram_url(self):
if self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16:
encoding = "linear16"
elif self.transcriber_config.audio_encoding == AudioEncoding.MULAW:
encoding = "mulaw"
url_params = {
"encoding": encoding,
"sample_rate": self.transcriber_config.sampling_rate,
"channels": 1,
"interim_results": "true",
}
extra_params = {}
if self.transcriber_config.model:
extra_params["model"] = self.transcriber_config.model
if self.transcriber_config.tier:
extra_params["tier"] = self.transcriber_config.tier
if self.transcriber_config.version:
extra_params["version"] = self.transcriber_config.version
if (
self.transcriber_config.endpointing_config
and self.transcriber_config.endpointing_config.type
== EndpointingType.PUNCTUATION_BASED
):
extra_params["punctuate"] = "true"
url_params.update(extra_params)
return f"wss://api.deepgram.com/v1/listen?{urlencode(url_params)}"
def is_speech_final(
self, current_buffer: str, deepgram_response: dict, time_silent: float
):
transcript = deepgram_response["channel"]["alternatives"][0]["transcript"]
# if it is not time based, then return true if speech is final and there is a transcript
if not self.transcriber_config.endpointing_config:
return transcript and deepgram_response["speech_final"]
elif (
self.transcriber_config.endpointing_config.type
== EndpointingType.TIME_BASED
):
# if it is time based, then return true if there is no transcript
# and there is some speech to send
# and the time_silent is greater than the cutoff
return (
not transcript
and current_buffer
and (time_silent + deepgram_response["duration"])
> self.transcriber_config.endpointing_config.time_cutoff_seconds
)
elif (
self.transcriber_config.endpointing_config.type
== EndpointingType.PUNCTUATION_BASED
):
return (
transcript
and deepgram_response["speech_final"]
and transcript.strip()[-1] in PUNCTUATION_TERMINATORS
) or (
not transcript
and current_buffer
and (time_silent + deepgram_response["duration"])
> self.transcriber_config.endpointing_config.time_cutoff_seconds
)
raise Exception("Endpointing config not supported")
def calculate_time_silent(self, data: dict):
end = data["start"] + data["duration"]
words = data["channel"]["alternatives"][0]["words"]
if words:
return end - words[-1]["end"]
return data["duration"]
async def process(self, warmup=True):
extra_headers = {"Authorization": f"Token {self.api_key}"}
self.audio_queue = asyncio.Queue()
async with websockets.connect(
self.get_deepgram_url(), extra_headers=extra_headers
) as ws:
async def warmup_sender(ws: WebSocketClientProtocol):
if warmup:
warmup_chunks = self.create_warmup_chunks()
for chunk in warmup_chunks:
await ws.send(chunk)
await asyncio.sleep(5)
self.warmed_up = True
self.is_ready = True
async def sender(ws: WebSocketClientProtocol): # sends audio to websocket
while not self._ended:
try:
data = await asyncio.wait_for(self.audio_queue.get(), 5)
except asyncio.exceptions.TimeoutError:
break
await ws.send(data)
self.logger.debug("Terminating Deepgram transcriber sender")
async def receiver(ws: WebSocketClientProtocol):
buffer = ""
time_silent = 0
while not self._ended:
try:
msg = await ws.recv()
except Exception as e:
self.logger.debug(f"Got error {e} in Deepgram receiver")
break
data = json.loads(msg)
if (
not "is_final" in data
): # means we've finished receiving transcriptions
break
is_final = data["is_final"]
speech_final = self.is_speech_final(buffer, data, time_silent)
top_choice = data["channel"]["alternatives"][0]
confidence = top_choice["confidence"]
if (
top_choice["transcript"]
and confidence > 0.0
and self.warmed_up
and is_final
):
buffer = f"{buffer} {top_choice['transcript']}"
if speech_final:
await self.on_response(Transcription(buffer, confidence, True))
buffer = ""
time_silent = 0
elif (
top_choice["transcript"] and confidence > 0.0 and self.warmed_up
):
await self.on_response(
Transcription(
buffer,
confidence,
False,
)
)
time_silent = self.calculate_time_silent(data)
else:
time_silent += data["duration"]
self.logger.debug("Terminating Deepgram transcriber receiver")
await asyncio.gather(warmup_sender(ws), sender(ws), receiver(ws))