From 9a0677e88d677d982579b80d52ac380b58db8de8 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Wed, 8 Mar 2023 20:41:03 -0800 Subject: [PATCH] add goodbye option --- simple_conversation.py | 45 ++++++++++++------ vocode/models/agent.py | 51 ++++++++++++++++----- vocode/user_implemented_agent/base_agent.py | 6 +-- 3 files changed, 74 insertions(+), 28 deletions(-) diff --git a/simple_conversation.py b/simple_conversation.py index 8cf4f3b..ab0bb55 100644 --- a/simple_conversation.py +++ b/simple_conversation.py @@ -4,8 +4,20 @@ import signal from vocode.conversation import Conversation from vocode.helpers import create_microphone_input_and_speaker_output -from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig, GoogleTranscriberConfig -from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig +from vocode.models.transcriber import ( + DeepgramTranscriberConfig, + PunctuationEndpointingConfig, + GoogleTranscriberConfig, +) +from vocode.models.agent import ( + ChatGPTAgentConfig, + RESTfulUserImplementedAgentConfig, + WebSocketUserImplementedAgentConfig, + EchoAgentConfig, + ChatGPTAlphaAgentConfig, + LLMAgentConfig, + ChatGPTAgentConfig, +) from vocode.models.synthesizer import AzureSynthesizerConfig from vocode.user_implemented_agent.restful_agent import RESTfulAgent @@ -14,25 +26,30 @@ logging.root.setLevel(logging.INFO) if __name__ == "__main__": - microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_default_devices=False) + microphone_input, speaker_output = create_microphone_input_and_speaker_output( + use_default_devices=False + ) conversation = Conversation( input_device=microphone_input, output_device=speaker_output, transcriber_config=DeepgramTranscriberConfig.from_input_device( - microphone_input, - endpointing_config=PunctuationEndpointingConfig() + microphone_input, endpointing_config=PunctuationEndpointingConfig() ), - agent_config=WebSocketUserImplementedAgentConfig( - initial_message="Hello!", - respond=WebSocketUserImplementedAgentConfig.RouteConfig( - url="wss://8b7425d5b2ab.ngrok.io/respond", - ) + # agent_config=WebSocketUserImplementedAgentConfig( + # initial_message="Hello!", + # respond=WebSocketUserImplementedAgentConfig.RouteConfig( + # url="wss://8b7425d5b2ab.ngrok.io/respond", + # ) + # ), + # id="ajay", + agent_config=ChatGPTAgentConfig( + initial_message="goodbye", + prompt_preamble="you are an expert on the NBA", + generate_responses=True, + end_conversation_on_goodbye=True, ), - id="ajay", - # agent_config=ChatGPTAgentConfig(initial_message="hello", prompt_preamble="you are an expert on the NBA"), - synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output) + synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output), ) signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate()) asyncio.run(conversation.start()) - diff --git a/vocode/models/agent.py b/vocode/models/agent.py index f7832bb..4cf1487 100644 --- a/vocode/models/agent.py +++ b/vocode/models/agent.py @@ -18,20 +18,25 @@ class AgentConfig(TypedModel, type=AgentType.BASE): initial_message: Optional[str] = None generate_responses: bool = True allowed_idle_time_seconds: Optional[float] = None + end_conversation_on_goodbye: bool = False + class LLMAgentConfig(AgentConfig, type=AgentType.LLM): prompt_preamble: str expected_first_prompt: Optional[str] = None + class ChatGPTAlphaAgentConfig(AgentConfig, type=AgentType.CHAT_GPT_ALPHA): prompt_preamble: str expected_first_prompt: Optional[str] = None + class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT): prompt_preamble: str expected_first_prompt: Optional[str] = None generate_responses: bool = False + class InformationRetrievalAgentConfig( AgentConfig, type=AgentType.INFORMATION_RETRIEVAL ): @@ -45,7 +50,10 @@ class InformationRetrievalAgentConfig( class EchoAgentConfig(AgentConfig, type=AgentType.ECHO): pass -class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED): + +class RESTfulUserImplementedAgentConfig( + AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED +): class EndpointConfig(BaseModel): url: str method: str = "POST" @@ -55,25 +63,33 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER # generate_response: Optional[EndpointConfig] # update_last_bot_message_on_cut_off: Optional[EndpointConfig] + class RESTfulAgentInput(BaseModel): conversation_id: str human_input: str + class RESTfulAgentOutputType(str, Enum): BASE = "restful_agent_base" TEXT = "restful_agent_text" END = "restful_agent_end" + class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE): pass + class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT): response: str + class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END): pass -class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED): + +class WebSocketUserImplementedAgentConfig( + AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED +): class RouteConfig(BaseModel): url: str @@ -82,17 +98,22 @@ class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_ # generate_response: Optional[RouteConfig] # send_message_on_cut_off: bool = False + class WebSocketAgentMessageType(str, Enum): - BASE = 'websocket_agent_base' - START = 'websocket_agent_start' - TEXT = 'websocket_agent_text' - READY = 'websocket_agent_ready' - STOP = 'websocket_agent_stop' + BASE = "websocket_agent_base" + START = "websocket_agent_start" + TEXT = "websocket_agent_text" + READY = "websocket_agent_ready" + STOP = "websocket_agent_stop" + class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): conversation_id: Optional[str] = None -class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT): + +class WebSocketAgentTextMessage( + WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT +): class Payload(BaseModel): text: str @@ -103,11 +124,19 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag return cls(data=cls.Payload(text=text), conversation_id=conversation_id) -class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START): +class WebSocketAgentStartMessage( + WebSocketAgentMessage, type=WebSocketAgentMessageType.START +): pass -class WebSocketAgentReadyMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.READY): + +class WebSocketAgentReadyMessage( + WebSocketAgentMessage, type=WebSocketAgentMessageType.READY +): pass -class WebSocketAgentStopMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP): + +class WebSocketAgentStopMessage( + WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP +): pass diff --git a/vocode/user_implemented_agent/base_agent.py b/vocode/user_implemented_agent/base_agent.py index 4ddee94..8009be2 100644 --- a/vocode/user_implemented_agent/base_agent.py +++ b/vocode/user_implemented_agent/base_agent.py @@ -1,10 +1,10 @@ from fastapi import FastAPI import uvicorn -class BaseAgent(): +class BaseAgent: def __init__(self): self.app = FastAPI() - + def run(self, host="localhost", port=3000): - uvicorn.run(self.app, host=host, port=port) \ No newline at end of file + uvicorn.run(self.app, host=host, port=port)