add goodbye option

This commit is contained in:
Ajay Raj 2023-03-08 20:41:03 -08:00
commit 9a0677e88d
3 changed files with 74 additions and 28 deletions

View file

@ -4,8 +4,20 @@ import signal
from vocode.conversation import Conversation
from vocode.helpers import create_microphone_input_and_speaker_output
from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig, GoogleTranscriberConfig
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig
from vocode.models.transcriber import (
DeepgramTranscriberConfig,
PunctuationEndpointingConfig,
GoogleTranscriberConfig,
)
from vocode.models.agent import (
ChatGPTAgentConfig,
RESTfulUserImplementedAgentConfig,
WebSocketUserImplementedAgentConfig,
EchoAgentConfig,
ChatGPTAlphaAgentConfig,
LLMAgentConfig,
ChatGPTAgentConfig,
)
from vocode.models.synthesizer import AzureSynthesizerConfig
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
@ -14,25 +26,30 @@ logging.root.setLevel(logging.INFO)
if __name__ == "__main__":
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_default_devices=False)
microphone_input, speaker_output = create_microphone_input_and_speaker_output(
use_default_devices=False
)
conversation = Conversation(
input_device=microphone_input,
output_device=speaker_output,
transcriber_config=DeepgramTranscriberConfig.from_input_device(
microphone_input,
endpointing_config=PunctuationEndpointingConfig()
microphone_input, endpointing_config=PunctuationEndpointingConfig()
),
agent_config=WebSocketUserImplementedAgentConfig(
initial_message="Hello!",
respond=WebSocketUserImplementedAgentConfig.RouteConfig(
url="wss://8b7425d5b2ab.ngrok.io/respond",
)
# agent_config=WebSocketUserImplementedAgentConfig(
# initial_message="Hello!",
# respond=WebSocketUserImplementedAgentConfig.RouteConfig(
# url="wss://8b7425d5b2ab.ngrok.io/respond",
# )
# ),
# id="ajay",
agent_config=ChatGPTAgentConfig(
initial_message="goodbye",
prompt_preamble="you are an expert on the NBA",
generate_responses=True,
end_conversation_on_goodbye=True,
),
id="ajay",
# agent_config=ChatGPTAgentConfig(initial_message="hello", prompt_preamble="you are an expert on the NBA"),
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output),
)
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
asyncio.run(conversation.start())

View file

@ -18,20 +18,25 @@ class AgentConfig(TypedModel, type=AgentType.BASE):
initial_message: Optional[str] = None
generate_responses: bool = True
allowed_idle_time_seconds: Optional[float] = None
end_conversation_on_goodbye: bool = False
class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
prompt_preamble: str
expected_first_prompt: Optional[str] = None
class ChatGPTAlphaAgentConfig(AgentConfig, type=AgentType.CHAT_GPT_ALPHA):
prompt_preamble: str
expected_first_prompt: Optional[str] = None
class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT):
prompt_preamble: str
expected_first_prompt: Optional[str] = None
generate_responses: bool = False
class InformationRetrievalAgentConfig(
AgentConfig, type=AgentType.INFORMATION_RETRIEVAL
):
@ -45,7 +50,10 @@ class InformationRetrievalAgentConfig(
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
pass
class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED):
class RESTfulUserImplementedAgentConfig(
AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED
):
class EndpointConfig(BaseModel):
url: str
method: str = "POST"
@ -55,25 +63,33 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
# generate_response: Optional[EndpointConfig]
# update_last_bot_message_on_cut_off: Optional[EndpointConfig]
class RESTfulAgentInput(BaseModel):
conversation_id: str
human_input: str
class RESTfulAgentOutputType(str, Enum):
BASE = "restful_agent_base"
TEXT = "restful_agent_text"
END = "restful_agent_end"
class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE):
pass
class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT):
response: str
class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END):
pass
class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED):
class WebSocketUserImplementedAgentConfig(
AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED
):
class RouteConfig(BaseModel):
url: str
@ -82,17 +98,22 @@ class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_
# generate_response: Optional[RouteConfig]
# send_message_on_cut_off: bool = False
class WebSocketAgentMessageType(str, Enum):
BASE = 'websocket_agent_base'
START = 'websocket_agent_start'
TEXT = 'websocket_agent_text'
READY = 'websocket_agent_ready'
STOP = 'websocket_agent_stop'
BASE = "websocket_agent_base"
START = "websocket_agent_start"
TEXT = "websocket_agent_text"
READY = "websocket_agent_ready"
STOP = "websocket_agent_stop"
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
conversation_id: Optional[str] = None
class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT):
class WebSocketAgentTextMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT
):
class Payload(BaseModel):
text: str
@ -103,11 +124,19 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag
return cls(data=cls.Payload(text=text), conversation_id=conversation_id)
class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START):
class WebSocketAgentStartMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.START
):
pass
class WebSocketAgentReadyMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.READY):
class WebSocketAgentReadyMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.READY
):
pass
class WebSocketAgentStopMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP):
class WebSocketAgentStopMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
):
pass

View file

@ -1,10 +1,10 @@
from fastapi import FastAPI
import uvicorn
class BaseAgent():
class BaseAgent:
def __init__(self):
self.app = FastAPI()
def run(self, host="localhost", port=3000):
uvicorn.run(self.app, host=host, port=port)
uvicorn.run(self.app, host=host, port=port)