adds ability to pass in api key for every transcriber agent and synthesizer
This commit is contained in:
parent
a669d3f535
commit
ecebe4c1a5
11 changed files with 71 additions and 17 deletions
|
|
@ -21,11 +21,17 @@ class BotSentiment(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class BotSentimentAnalyser:
|
class BotSentimentAnalyser:
|
||||||
def __init__(self, emotions: list[str], model_name: str = "text-davinci-003"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
emotions: list[str],
|
||||||
|
model_name: str = "text-davinci-003",
|
||||||
|
openai_api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.llm = OpenAI(
|
openai_api_key = openai_api_key or getenv("OPENAI_API_KEY")
|
||||||
model_name=self.model_name, openai_api_key=getenv("OPENAI_API_KEY")
|
if not openai_api_key:
|
||||||
)
|
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
|
||||||
|
self.llm = OpenAI(model_name=self.model_name, openai_api_key=openai_api_key)
|
||||||
assert len(emotions) > 0
|
assert len(emotions) > 0
|
||||||
self.emotions = [e.lower() for e in emotions]
|
self.emotions = [e.lower() for e in emotions]
|
||||||
self.prompt = PromptTemplate(
|
self.prompt = PromptTemplate(
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,16 @@ from vocode.streaming.agent.utils import stream_llm_response
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTAgent(BaseAgent):
|
class ChatGPTAgent(BaseAgent):
|
||||||
def __init__(self, agent_config: ChatGPTAgentConfig, logger: logging.Logger = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_config: ChatGPTAgentConfig,
|
||||||
|
logger: logging.Logger = None,
|
||||||
|
openai_api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
super().__init__(agent_config)
|
super().__init__(agent_config)
|
||||||
openai.api_key = getenv("OPENAI_API_KEY")
|
openai.api_key = openai_api_key or getenv("OPENAI_API_KEY")
|
||||||
|
if not openai.api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
self.logger.setLevel(logging.DEBUG)
|
self.logger.setLevel(logging.DEBUG)
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ class LLMAgent(BaseAgent):
|
||||||
logger: logging.Logger = None,
|
logger: logging.Logger = None,
|
||||||
sender="AI",
|
sender="AI",
|
||||||
recipient="Human",
|
recipient="Human",
|
||||||
|
openai_api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(agent_config)
|
super().__init__(agent_config)
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
|
|
@ -40,11 +41,14 @@ class LLMAgent(BaseAgent):
|
||||||
if agent_config.initial_message
|
if agent_config.initial_message
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
openai_api_key = openai_api_key or getenv("OPENAI_API_KEY")
|
||||||
|
if not openai_api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
|
||||||
self.llm = OpenAI(
|
self.llm = OpenAI(
|
||||||
model_name=self.agent_config.model_name,
|
model_name=self.agent_config.model_name,
|
||||||
temperature=self.agent_config.temperature,
|
temperature=self.agent_config.temperature,
|
||||||
max_tokens=self.agent_config.max_tokens,
|
max_tokens=self.agent_config.max_tokens,
|
||||||
openai_api_key=getenv("OPENAI_API_KEY"),
|
openai_api_key=openai_api_key,
|
||||||
)
|
)
|
||||||
self.stop_tokens = [f"{recipient}:"]
|
self.stop_tokens = [f"{recipient}:"]
|
||||||
self.first_response = (
|
self.first_response = (
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ ELEVEN_LABS_ADAM_VOICE_ID = "pNInz6obpgDQGcFmaJgB"
|
||||||
|
|
||||||
|
|
||||||
class ElevenLabsSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.ELEVEN_LABS):
|
class ElevenLabsSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.ELEVEN_LABS):
|
||||||
api_key: str
|
api_key: Optional[str] = None
|
||||||
voice_id: Optional[str] = ELEVEN_LABS_ADAM_VOICE_ID
|
voice_id: Optional[str] = ELEVEN_LABS_ADAM_VOICE_ID
|
||||||
|
|
||||||
@validator("voice_id")
|
@validator("voice_id")
|
||||||
|
|
@ -127,7 +127,7 @@ class ElevenLabsSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.ELEVEN
|
||||||
def from_output_device(
|
def from_output_device(
|
||||||
cls,
|
cls,
|
||||||
output_device: BaseOutputDevice,
|
output_device: BaseOutputDevice,
|
||||||
api_key: str,
|
api_key: Optional[str] = None,
|
||||||
voice_id: Optional[str] = None,
|
voice_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
|
|
@ -140,7 +140,7 @@ class ElevenLabsSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.ELEVEN
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_telephone_output_device(
|
def from_telephone_output_device(
|
||||||
cls,
|
cls,
|
||||||
api_key: str,
|
api_key: Optional[str] = None,
|
||||||
voice_id: Optional[str] = None,
|
voice_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
|
|
|
||||||
|
|
@ -52,14 +52,27 @@ class AzureSynthesizer(BaseSynthesizer):
|
||||||
OFFSET_MS = 100
|
OFFSET_MS = 100
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, synthesizer_config: AzureSynthesizerConfig, logger: logging.Logger = None
|
self,
|
||||||
|
synthesizer_config: AzureSynthesizerConfig,
|
||||||
|
logger: logging.Logger = None,
|
||||||
|
azure_speech_key: str = None,
|
||||||
|
azure_speech_region: str = None,
|
||||||
):
|
):
|
||||||
super().__init__(synthesizer_config)
|
super().__init__(synthesizer_config)
|
||||||
self.synthesizer_config = synthesizer_config
|
self.synthesizer_config = synthesizer_config
|
||||||
# Instantiates a client
|
# Instantiates a client
|
||||||
|
azure_speech_key = azure_speech_key or getenv("AZURE_SPEECH_KEY")
|
||||||
|
azure_speech_region = azure_speech_region or getenv("AZURE_SPEECH_REGION")
|
||||||
|
if not azure_speech_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set AZURE_SPEECH_KEY environment variable or pass it as a parameter"
|
||||||
|
)
|
||||||
|
if not azure_speech_region:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set AZURE_SPEECH_REGION environment variable or pass it as a parameter"
|
||||||
|
)
|
||||||
speech_config = speechsdk.SpeechConfig(
|
speech_config = speechsdk.SpeechConfig(
|
||||||
subscription=getenv("AZURE_SPEECH_KEY"),
|
subscription=azure_speech_key, region=azure_speech_region
|
||||||
region=getenv("AZURE_SPEECH_REGION"),
|
|
||||||
)
|
)
|
||||||
if self.synthesizer_config.audio_encoding == AudioEncoding.LINEAR16:
|
if self.synthesizer_config.audio_encoding == AudioEncoding.LINEAR16:
|
||||||
if self.synthesizer_config.sampling_rate == 44100:
|
if self.synthesizer_config.sampling_rate == 44100:
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ ADAM_VOICE_ID = "pNInz6obpgDQGcFmaJgB"
|
||||||
class ElevenLabsSynthesizer(BaseSynthesizer):
|
class ElevenLabsSynthesizer(BaseSynthesizer):
|
||||||
def __init__(self, config: ElevenLabsSynthesizerConfig):
|
def __init__(self, config: ElevenLabsSynthesizerConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.api_key = getenv("ELEVEN_LABS_API_KEY")
|
self.api_key = config.api_key or getenv("ELEVEN_LABS_API_KEY")
|
||||||
self.voice_id = config.voice_id or ADAM_VOICE_ID
|
self.voice_id = config.voice_id or ADAM_VOICE_ID
|
||||||
self.words_per_minute = 150
|
self.words_per_minute = 150
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import wave
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from google.cloud import texttospeech_v1beta1 as tts
|
from google.cloud import texttospeech_v1beta1 as tts
|
||||||
|
from vocode import getenv
|
||||||
|
|
||||||
from vocode.streaming.agent.bot_sentiment_analyser import BotSentiment
|
from vocode.streaming.agent.bot_sentiment_analyser import BotSentiment
|
||||||
from vocode.streaming.models.message import BaseMessage
|
from vocode.streaming.models.message import BaseMessage
|
||||||
|
|
@ -22,6 +23,10 @@ class GoogleSynthesizer(BaseSynthesizer):
|
||||||
def __init__(self, synthesizer_config: GoogleSynthesizerConfig):
|
def __init__(self, synthesizer_config: GoogleSynthesizerConfig):
|
||||||
super().__init__(synthesizer_config)
|
super().__init__(synthesizer_config)
|
||||||
# Instantiates a client
|
# Instantiates a client
|
||||||
|
if not getenv("GOOGLE_APPLICATION_CREDENTIALS"):
|
||||||
|
raise Exception(
|
||||||
|
"GOOGLE_APPLICATION_CREDENTIALS environment variable must be set"
|
||||||
|
)
|
||||||
self.client = tts.TextToSpeechClient()
|
self.client = tts.TextToSpeechClient()
|
||||||
|
|
||||||
# Build the voice request, select the language code ("en-US") and the ssml
|
# Build the voice request, select the language code ("en-US") and the ssml
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,14 @@ class AssemblyAITranscriber(BaseTranscriber):
|
||||||
self,
|
self,
|
||||||
transcriber_config: AssemblyAITranscriberConfig,
|
transcriber_config: AssemblyAITranscriberConfig,
|
||||||
logger: logging.Logger = None,
|
logger: logging.Logger = None,
|
||||||
|
api_key: str = None,
|
||||||
):
|
):
|
||||||
super().__init__(transcriber_config)
|
super().__init__(transcriber_config)
|
||||||
self.api_key = getenv("ASSEMBLY_AI_API_KEY")
|
self.api_key = api_key or getenv("ASSEMBLY_AI_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise Exception(
|
||||||
|
"Please set ASSEMBLY_AI_API_KEY environment variable or pass it as a parameter"
|
||||||
|
)
|
||||||
self._ended = False
|
self._ended = False
|
||||||
self.is_ready = False
|
self.is_ready = False
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,14 @@ class DeepgramTranscriber(BaseTranscriber):
|
||||||
self,
|
self,
|
||||||
transcriber_config: DeepgramTranscriberConfig,
|
transcriber_config: DeepgramTranscriberConfig,
|
||||||
logger: logging.Logger = None,
|
logger: logging.Logger = None,
|
||||||
|
api_key: str = None,
|
||||||
):
|
):
|
||||||
super().__init__(transcriber_config)
|
super().__init__(transcriber_config)
|
||||||
self.api_key = getenv("DEEPGRAM_API_KEY")
|
self.api_key = api_key or getenv("DEEPGRAM_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise Exception(
|
||||||
|
"Please set DEEPGRAM_API_KEY environment variable or pass it as a parameter"
|
||||||
|
)
|
||||||
self.transcriber_config = transcriber_config
|
self.transcriber_config = transcriber_config
|
||||||
self._ended = False
|
self._ended = False
|
||||||
self.warmed_up = False
|
self.warmed_up = False
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import time
|
||||||
import queue
|
import queue
|
||||||
from google.cloud import speech
|
from google.cloud import speech
|
||||||
import threading
|
import threading
|
||||||
|
from vocode import getenv
|
||||||
|
|
||||||
from vocode.streaming.models.audio_encoding import AudioEncoding
|
from vocode.streaming.models.audio_encoding import AudioEncoding
|
||||||
from vocode.streaming.transcriber.base_transcriber import (
|
from vocode.streaming.transcriber.base_transcriber import (
|
||||||
|
|
@ -18,6 +19,10 @@ class GoogleTranscriber(BaseTranscriber):
|
||||||
super().__init__(transcriber_config)
|
super().__init__(transcriber_config)
|
||||||
self._queue = queue.Queue()
|
self._queue = queue.Queue()
|
||||||
self._ended = False
|
self._ended = False
|
||||||
|
if not getenv("GOOGLE_APPLICATION_CREDENTIALS"):
|
||||||
|
raise Exception(
|
||||||
|
"Please set GOOGLE_APPLICATION_CREDENTIALS environment variable"
|
||||||
|
)
|
||||||
self.google_streaming_config = self.create_google_streaming_config()
|
self.google_streaming_config = self.create_google_streaming_config()
|
||||||
self.client = speech.SpeechClient()
|
self.client = speech.SpeechClient()
|
||||||
self.warmed_up = False
|
self.warmed_up = False
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
import openai
|
import openai
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -26,8 +27,11 @@ class GoodbyeModel:
|
||||||
embeddings_cache_path=os.path.join(
|
embeddings_cache_path=os.path.join(
|
||||||
os.path.dirname(__file__), "goodbye_embeddings"
|
os.path.dirname(__file__), "goodbye_embeddings"
|
||||||
),
|
),
|
||||||
|
openai_api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
openai.api_key = getenv("OPENAI_API_KEY")
|
openai.api_key = openai_api_key or getenv("OPENAI_API_KEY")
|
||||||
|
if not openai.api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
|
||||||
self.goodbye_embeddings = self.load_or_create_embeddings(
|
self.goodbye_embeddings = self.load_or_create_embeddings(
|
||||||
f"{embeddings_cache_path}/goodbye_embeddings.npy"
|
f"{embeddings_cache_path}/goodbye_embeddings.npy"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue