open source
This commit is contained in:
parent
70b6e17c69
commit
a93bfc1ec9
61 changed files with 4013 additions and 126 deletions
145
vocode/streaming/transcriber/google_transcriber.py
Normal file
145
vocode/streaming/transcriber/google_transcriber.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
import asyncio
|
||||
import time
|
||||
import queue
|
||||
from google.cloud import speech
|
||||
import threading
|
||||
|
||||
from vocode.streaming.models.audio_encoding import AudioEncoding
|
||||
from vocode.streaming.transcriber.base_transcriber import (
|
||||
BaseTranscriber,
|
||||
Transcription,
|
||||
)
|
||||
from vocode.streaming.models.transcriber import GoogleTranscriberConfig
|
||||
from vocode.streaming.utils import create_loop_in_thread
|
||||
|
||||
|
||||
class GoogleTranscriber(BaseTranscriber):
|
||||
def __init__(self, transcriber_config: GoogleTranscriberConfig):
|
||||
super().__init__(transcriber_config)
|
||||
self._queue = queue.Queue()
|
||||
self._ended = False
|
||||
self.google_streaming_config = self.create_google_streaming_config()
|
||||
self.client = speech.SpeechClient()
|
||||
self.warmed_up = False
|
||||
self.is_ready = False
|
||||
if self.transcriber_config.endpointing_config:
|
||||
raise Exception("Google endpointing config not supported yet")
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
self.thread = threading.Thread(
|
||||
name="google_transcriber",
|
||||
target=create_loop_in_thread,
|
||||
args=(self.event_loop, self.process()),
|
||||
)
|
||||
|
||||
def create_google_streaming_config(self):
|
||||
extra_params = {}
|
||||
if self.transcriber_config.model:
|
||||
extra_params["model"] = self.transcriber_config.model
|
||||
extra_params["use_enhanced"] = True
|
||||
|
||||
if self.transcriber_config.audio_encoding == AudioEncoding.LINEAR16:
|
||||
google_audio_encoding = speech.RecognitionConfig.AudioEncoding.LINEAR16
|
||||
elif self.transcriber_config.audio_encoding == AudioEncoding.MULAW:
|
||||
google_audio_encoding = speech.RecognitionConfig.AudioEncoding.MULAW
|
||||
|
||||
return speech.StreamingRecognitionConfig(
|
||||
config=speech.RecognitionConfig(
|
||||
encoding=google_audio_encoding,
|
||||
sample_rate_hertz=self.transcriber_config.sampling_rate,
|
||||
language_code="en-US",
|
||||
**extra_params
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
async def ready(self):
|
||||
if not self.transcriber_config.should_warmup_model:
|
||||
return True
|
||||
while not self.warmed_up:
|
||||
await asyncio.sleep(0.1)
|
||||
return self.is_ready
|
||||
|
||||
def warmup(self):
|
||||
warmup_bytes = self.get_warmup_bytes()
|
||||
|
||||
def stream():
|
||||
chunk_size = self.transcriber_config.sampling_rate * 2
|
||||
for i in range(len(warmup_bytes) // chunk_size):
|
||||
yield speech.StreamingRecognizeRequest(
|
||||
audio_content=warmup_bytes[i * chunk_size : (i + 1) * chunk_size]
|
||||
)
|
||||
time.sleep(0.01)
|
||||
|
||||
for _ in self.client.streaming_recognize(
|
||||
self.google_streaming_config, stream()
|
||||
):
|
||||
pass
|
||||
self.warmed_up = True
|
||||
self.is_ready = True
|
||||
|
||||
async def run(self):
|
||||
self.thread.start()
|
||||
|
||||
async def process(self):
|
||||
if self.transcriber_config.should_warmup_model:
|
||||
self.warmup()
|
||||
stream = self.generator()
|
||||
requests = (
|
||||
speech.StreamingRecognizeRequest(audio_content=content)
|
||||
for content in stream
|
||||
)
|
||||
responses = self.client.streaming_recognize(
|
||||
self.google_streaming_config, requests
|
||||
)
|
||||
await self.process_responses_loop(responses)
|
||||
|
||||
def terminate(self):
|
||||
self._ended = True
|
||||
|
||||
def send_audio(self, chunk: bytes):
|
||||
self._queue.put(chunk, block=False)
|
||||
|
||||
async def process_responses_loop(self, responses):
|
||||
for response in responses:
|
||||
await self._on_response(response)
|
||||
|
||||
if self._ended:
|
||||
break
|
||||
|
||||
async def _on_response(self, response):
|
||||
if not response.results:
|
||||
return
|
||||
|
||||
result = response.results[0]
|
||||
if not result.alternatives:
|
||||
return
|
||||
|
||||
top_choice = result.alternatives[0]
|
||||
message = top_choice.transcript
|
||||
confidence = top_choice.confidence
|
||||
|
||||
return await self.on_response(
|
||||
Transcription(message, confidence, result.is_final)
|
||||
)
|
||||
|
||||
def generator(self):
|
||||
while not self._ended:
|
||||
# Use a blocking get() to ensure there's at least one chunk of
|
||||
# data, and stop iteration if the chunk is None, indicating the
|
||||
# end of the audio stream.
|
||||
chunk = self._queue.get()
|
||||
if chunk is None:
|
||||
return
|
||||
data = [chunk]
|
||||
|
||||
# Now consume whatever other data's still buffered.
|
||||
while True:
|
||||
try:
|
||||
chunk = self._queue.get(block=False)
|
||||
if chunk is None:
|
||||
return
|
||||
data.append(chunk)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
yield b"".join(data)
|
||||
Loading…
Add table
Add a link
Reference in a new issue