open source
This commit is contained in:
parent
70b6e17c69
commit
a93bfc1ec9
61 changed files with 4013 additions and 126 deletions
230
vocode/streaming/transcriber/deepgram_transcriber.py
Normal file
230
vocode/streaming/transcriber/deepgram_transcriber.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import websockets
|
||||
from websockets.client import WebSocketClientProtocol
|
||||
import audioop
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from vocode.streaming.transcriber.base_transcriber import (
|
||||
BaseTranscriber,
|
||||
Transcription,
|
||||
)
|
||||
from vocode.streaming.models.transcriber import (
|
||||
DeepgramTranscriberConfig,
|
||||
EndpointingConfig,
|
||||
EndpointingType,
|
||||
)
|
||||
from vocode.streaming.models.audio_encoding import AudioEncoding
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DEEPGRAM_API_KEY = os.environ.get("DEEPGRAM_API_KEY")
|
||||
PUNCTUATION_TERMINATORS = [".", "!", "?"]
|
||||
NUM_RESTARTS = 5
|
||||
|
||||
|
||||
class DeepgramTranscriber(BaseTranscriber):
|
||||
def __init__(
|
||||
self,
|
||||
transcriber_config: DeepgramTranscriberConfig,
|
||||
logger: logging.Logger = None,
|
||||
):
|
||||
super().__init__(transcriber_config)
|
||||
self.transcriber_config = transcriber_config
|
||||
self._ended = False
|
||||
self.warmed_up = False
|
||||
self.is_ready = False
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
def create_warmup_chunks(self):
|
||||
warmup_chunks = []
|
||||
warmup_bytes = self.get_warmup_bytes()
|
||||
chunk_size = self.transcriber_config.chunk_size
|
||||
for i in range(len(warmup_bytes) // chunk_size):
|
||||
warmup_chunks.append(warmup_bytes[i * chunk_size : (i + 1) * chunk_size])
|
||||
return warmup_chunks
|
||||
|
||||
async def ready(self):
|
||||
while not self.warmed_up:
|
||||
await asyncio.sleep(0.1)
|
||||
return self.is_ready
|
||||
|
||||
async def run(self):
|
||||
# warmup_chunks = await self.create_warmup_chunks()
|
||||
restarts = 0
|
||||
while not self._ended and restarts < NUM_RESTARTS:
|
||||
await self.process(self.transcriber_config.should_warmup_model)
|
||||
restarts += 1
|
||||
self.logger.debug(
|
||||
"Deepgram connection died, restarting, num_restarts: %s", restarts
|
||||
)
|
||||
|
||||
def send_audio(self, chunk):
|
||||
if (
|
||||
self.transcriber_config.downsampling
|
||||
and self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16
|
||||
):
|
||||
chunk, _ = audioop.ratecv(
|
||||
chunk,
|
||||
2,
|
||||
1,
|
||||
self.transcriber_config.sampling_rate
|
||||
* self.transcriber_config.downsampling,
|
||||
self.transcriber_config.sampling_rate,
|
||||
None,
|
||||
)
|
||||
self.audio_queue.put_nowait(chunk)
|
||||
|
||||
def terminate(self):
|
||||
terminate_msg = json.dumps({"type": "CloseStream"})
|
||||
self.audio_queue.put_nowait(terminate_msg)
|
||||
self._ended = True
|
||||
|
||||
def get_deepgram_url(self):
|
||||
if self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16:
|
||||
encoding = "linear16"
|
||||
elif self.transcriber_config.audio_encoding == AudioEncoding.MULAW:
|
||||
encoding = "mulaw"
|
||||
url_params = {
|
||||
"encoding": encoding,
|
||||
"sample_rate": self.transcriber_config.sampling_rate,
|
||||
"channels": 1,
|
||||
"interim_results": "true",
|
||||
}
|
||||
extra_params = {}
|
||||
if self.transcriber_config.model:
|
||||
extra_params["model"] = self.transcriber_config.model
|
||||
if self.transcriber_config.tier:
|
||||
extra_params["tier"] = self.transcriber_config.tier
|
||||
if self.transcriber_config.version:
|
||||
extra_params["version"] = self.transcriber_config.version
|
||||
if (
|
||||
self.transcriber_config.endpointing_config
|
||||
and self.transcriber_config.endpointing_config.type
|
||||
== EndpointingType.PUNCTUATION_BASED
|
||||
):
|
||||
extra_params["punctuate"] = "true"
|
||||
url_params.update(extra_params)
|
||||
return f"wss://api.deepgram.com/v1/listen?{urlencode(url_params)}"
|
||||
|
||||
def is_speech_final(
|
||||
self, current_buffer: str, deepgram_response: dict, time_silent: float
|
||||
):
|
||||
transcript = deepgram_response["channel"]["alternatives"][0]["transcript"]
|
||||
|
||||
# if it is not time based, then return true if speech is final and there is a transcript
|
||||
if not self.transcriber_config.endpointing_config:
|
||||
return transcript and deepgram_response["speech_final"]
|
||||
elif (
|
||||
self.transcriber_config.endpointing_config.type
|
||||
== EndpointingType.TIME_BASED
|
||||
):
|
||||
# if it is time based, then return true if there is no transcript
|
||||
# and there is some speech to send
|
||||
# and the time_silent is greater than the cutoff
|
||||
return (
|
||||
not transcript
|
||||
and current_buffer
|
||||
and (time_silent + deepgram_response["duration"])
|
||||
> self.transcriber_config.endpointing_config.time_cutoff_seconds
|
||||
)
|
||||
elif (
|
||||
self.transcriber_config.endpointing_config.type
|
||||
== EndpointingType.PUNCTUATION_BASED
|
||||
):
|
||||
return (
|
||||
transcript
|
||||
and deepgram_response["speech_final"]
|
||||
and transcript.strip()[-1] in PUNCTUATION_TERMINATORS
|
||||
) or (
|
||||
not transcript
|
||||
and current_buffer
|
||||
and (time_silent + deepgram_response["duration"])
|
||||
> self.transcriber_config.endpointing_config.time_cutoff_seconds
|
||||
)
|
||||
raise Exception("Endpointing config not supported")
|
||||
|
||||
def calculate_time_silent(self, data: dict):
|
||||
end = data["start"] + data["duration"]
|
||||
words = data["channel"]["alternatives"][0]["words"]
|
||||
if words:
|
||||
return end - words[-1]["end"]
|
||||
return data["duration"]
|
||||
|
||||
async def process(self, warmup=True):
|
||||
extra_headers = {"Authorization": f"Token {DEEPGRAM_API_KEY}"}
|
||||
self.audio_queue = asyncio.Queue()
|
||||
|
||||
async with websockets.connect(
|
||||
self.get_deepgram_url(), extra_headers=extra_headers
|
||||
) as ws:
|
||||
|
||||
async def warmup_sender(ws: WebSocketClientProtocol):
|
||||
if warmup:
|
||||
warmup_chunks = self.create_warmup_chunks()
|
||||
for chunk in warmup_chunks:
|
||||
await ws.send(chunk)
|
||||
await asyncio.sleep(5)
|
||||
self.warmed_up = True
|
||||
self.is_ready = True
|
||||
|
||||
async def sender(ws: WebSocketClientProtocol): # sends audio to websocket
|
||||
while not self._ended:
|
||||
try:
|
||||
data = await asyncio.wait_for(self.audio_queue.get(), 5)
|
||||
except asyncio.exceptions.TimeoutError:
|
||||
break
|
||||
await ws.send(data)
|
||||
self.logger.debug("Terminating Deepgram transcriber sender")
|
||||
|
||||
async def receiver(ws: WebSocketClientProtocol):
|
||||
buffer = ""
|
||||
time_silent = 0
|
||||
while not self._ended:
|
||||
try:
|
||||
msg = await ws.recv()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Got error {e} in Deepgram receiver")
|
||||
break
|
||||
data = json.loads(msg)
|
||||
if (
|
||||
not "is_final" in data
|
||||
): # means we've finished receiving transcriptions
|
||||
break
|
||||
is_final = data["is_final"]
|
||||
speech_final = self.is_speech_final(buffer, data, time_silent)
|
||||
top_choice = data["channel"]["alternatives"][0]
|
||||
confidence = top_choice["confidence"]
|
||||
|
||||
if (
|
||||
top_choice["transcript"]
|
||||
and confidence > 0.0
|
||||
and self.warmed_up
|
||||
and is_final
|
||||
):
|
||||
buffer = f"{buffer} {top_choice['transcript']}"
|
||||
|
||||
if speech_final:
|
||||
await self.on_response(Transcription(buffer, confidence, True))
|
||||
buffer = ""
|
||||
time_silent = 0
|
||||
elif (
|
||||
top_choice["transcript"] and confidence > 0.0 and self.warmed_up
|
||||
):
|
||||
await self.on_response(
|
||||
Transcription(
|
||||
buffer,
|
||||
confidence,
|
||||
False,
|
||||
)
|
||||
)
|
||||
time_silent = self.calculate_time_silent(data)
|
||||
else:
|
||||
time_silent += data["duration"]
|
||||
|
||||
self.logger.debug("Terminating Deepgram transcriber receiver")
|
||||
|
||||
await asyncio.gather(warmup_sender(ws), sender(ws), receiver(ws))
|
||||
Loading…
Add table
Add a link
Reference in a new issue