adds ability to pass in api key for every transcriber agent and synthesizer

This commit is contained in:
Ajay Raj 2023-03-28 23:33:36 -07:00
commit ecebe4c1a5
11 changed files with 71 additions and 17 deletions

View file

@ -21,11 +21,17 @@ class BotSentiment(BaseModel):
class BotSentimentAnalyser:
def __init__(self, emotions: list[str], model_name: str = "text-davinci-003"):
def __init__(
self,
emotions: list[str],
model_name: str = "text-davinci-003",
openai_api_key: Optional[str] = None,
):
self.model_name = model_name
self.llm = OpenAI(
model_name=self.model_name, openai_api_key=getenv("OPENAI_API_KEY")
)
openai_api_key = openai_api_key or getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
self.llm = OpenAI(model_name=self.model_name, openai_api_key=openai_api_key)
assert len(emotions) > 0
self.emotions = [e.lower() for e in emotions]
self.prompt = PromptTemplate(

View file

@ -26,9 +26,16 @@ from vocode.streaming.agent.utils import stream_llm_response
class ChatGPTAgent(BaseAgent):
def __init__(self, agent_config: ChatGPTAgentConfig, logger: logging.Logger = None):
def __init__(
self,
agent_config: ChatGPTAgentConfig,
logger: logging.Logger = None,
openai_api_key: Optional[str] = None,
):
super().__init__(agent_config)
openai.api_key = getenv("OPENAI_API_KEY")
openai.api_key = openai_api_key or getenv("OPENAI_API_KEY")
if not openai.api_key:
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
self.agent_config = agent_config
self.logger = logger or logging.getLogger(__name__)
self.logger.setLevel(logging.DEBUG)

View file

@ -23,6 +23,7 @@ class LLMAgent(BaseAgent):
logger: logging.Logger = None,
sender="AI",
recipient="Human",
openai_api_key: Optional[str] = None,
):
super().__init__(agent_config)
self.agent_config = agent_config
@ -40,11 +41,14 @@ class LLMAgent(BaseAgent):
if agent_config.initial_message
else []
)
openai_api_key = openai_api_key or getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
self.llm = OpenAI(
model_name=self.agent_config.model_name,
temperature=self.agent_config.temperature,
max_tokens=self.agent_config.max_tokens,
openai_api_key=getenv("OPENAI_API_KEY"),
openai_api_key=openai_api_key,
)
self.stop_tokens = [f"{recipient}:"]
self.first_response = (