add goodbye option
This commit is contained in:
parent
385d1386af
commit
9a0677e88d
3 changed files with 74 additions and 28 deletions
|
|
@ -4,8 +4,20 @@ import signal
|
|||
|
||||
from vocode.conversation import Conversation
|
||||
from vocode.helpers import create_microphone_input_and_speaker_output
|
||||
from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig, GoogleTranscriberConfig
|
||||
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig
|
||||
from vocode.models.transcriber import (
|
||||
DeepgramTranscriberConfig,
|
||||
PunctuationEndpointingConfig,
|
||||
GoogleTranscriberConfig,
|
||||
)
|
||||
from vocode.models.agent import (
|
||||
ChatGPTAgentConfig,
|
||||
RESTfulUserImplementedAgentConfig,
|
||||
WebSocketUserImplementedAgentConfig,
|
||||
EchoAgentConfig,
|
||||
ChatGPTAlphaAgentConfig,
|
||||
LLMAgentConfig,
|
||||
ChatGPTAgentConfig,
|
||||
)
|
||||
from vocode.models.synthesizer import AzureSynthesizerConfig
|
||||
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||
|
||||
|
|
@ -14,25 +26,30 @@ logging.root.setLevel(logging.INFO)
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_default_devices=False)
|
||||
microphone_input, speaker_output = create_microphone_input_and_speaker_output(
|
||||
use_default_devices=False
|
||||
)
|
||||
|
||||
conversation = Conversation(
|
||||
input_device=microphone_input,
|
||||
output_device=speaker_output,
|
||||
transcriber_config=DeepgramTranscriberConfig.from_input_device(
|
||||
microphone_input,
|
||||
endpointing_config=PunctuationEndpointingConfig()
|
||||
microphone_input, endpointing_config=PunctuationEndpointingConfig()
|
||||
),
|
||||
agent_config=WebSocketUserImplementedAgentConfig(
|
||||
initial_message="Hello!",
|
||||
respond=WebSocketUserImplementedAgentConfig.RouteConfig(
|
||||
url="wss://8b7425d5b2ab.ngrok.io/respond",
|
||||
)
|
||||
# agent_config=WebSocketUserImplementedAgentConfig(
|
||||
# initial_message="Hello!",
|
||||
# respond=WebSocketUserImplementedAgentConfig.RouteConfig(
|
||||
# url="wss://8b7425d5b2ab.ngrok.io/respond",
|
||||
# )
|
||||
# ),
|
||||
# id="ajay",
|
||||
agent_config=ChatGPTAgentConfig(
|
||||
initial_message="goodbye",
|
||||
prompt_preamble="you are an expert on the NBA",
|
||||
generate_responses=True,
|
||||
end_conversation_on_goodbye=True,
|
||||
),
|
||||
id="ajay",
|
||||
# agent_config=ChatGPTAgentConfig(initial_message="hello", prompt_preamble="you are an expert on the NBA"),
|
||||
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
|
||||
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output),
|
||||
)
|
||||
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
||||
asyncio.run(conversation.start())
|
||||
|
||||
|
|
|
|||
|
|
@ -18,20 +18,25 @@ class AgentConfig(TypedModel, type=AgentType.BASE):
|
|||
initial_message: Optional[str] = None
|
||||
generate_responses: bool = True
|
||||
allowed_idle_time_seconds: Optional[float] = None
|
||||
end_conversation_on_goodbye: bool = False
|
||||
|
||||
|
||||
class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
|
||||
prompt_preamble: str
|
||||
expected_first_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class ChatGPTAlphaAgentConfig(AgentConfig, type=AgentType.CHAT_GPT_ALPHA):
|
||||
prompt_preamble: str
|
||||
expected_first_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT):
|
||||
prompt_preamble: str
|
||||
expected_first_prompt: Optional[str] = None
|
||||
generate_responses: bool = False
|
||||
|
||||
|
||||
class InformationRetrievalAgentConfig(
|
||||
AgentConfig, type=AgentType.INFORMATION_RETRIEVAL
|
||||
):
|
||||
|
|
@ -45,7 +50,10 @@ class InformationRetrievalAgentConfig(
|
|||
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
|
||||
pass
|
||||
|
||||
class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED):
|
||||
|
||||
class RESTfulUserImplementedAgentConfig(
|
||||
AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED
|
||||
):
|
||||
class EndpointConfig(BaseModel):
|
||||
url: str
|
||||
method: str = "POST"
|
||||
|
|
@ -55,25 +63,33 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
|
|||
# generate_response: Optional[EndpointConfig]
|
||||
# update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
||||
|
||||
|
||||
class RESTfulAgentInput(BaseModel):
|
||||
conversation_id: str
|
||||
human_input: str
|
||||
|
||||
|
||||
class RESTfulAgentOutputType(str, Enum):
|
||||
BASE = "restful_agent_base"
|
||||
TEXT = "restful_agent_text"
|
||||
END = "restful_agent_end"
|
||||
|
||||
|
||||
class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE):
|
||||
pass
|
||||
|
||||
|
||||
class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT):
|
||||
response: str
|
||||
|
||||
|
||||
class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END):
|
||||
pass
|
||||
|
||||
class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED):
|
||||
|
||||
class WebSocketUserImplementedAgentConfig(
|
||||
AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED
|
||||
):
|
||||
class RouteConfig(BaseModel):
|
||||
url: str
|
||||
|
||||
|
|
@ -82,17 +98,22 @@ class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_
|
|||
# generate_response: Optional[RouteConfig]
|
||||
# send_message_on_cut_off: bool = False
|
||||
|
||||
|
||||
class WebSocketAgentMessageType(str, Enum):
|
||||
BASE = 'websocket_agent_base'
|
||||
START = 'websocket_agent_start'
|
||||
TEXT = 'websocket_agent_text'
|
||||
READY = 'websocket_agent_ready'
|
||||
STOP = 'websocket_agent_stop'
|
||||
BASE = "websocket_agent_base"
|
||||
START = "websocket_agent_start"
|
||||
TEXT = "websocket_agent_text"
|
||||
READY = "websocket_agent_ready"
|
||||
STOP = "websocket_agent_stop"
|
||||
|
||||
|
||||
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT):
|
||||
|
||||
class WebSocketAgentTextMessage(
|
||||
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT
|
||||
):
|
||||
class Payload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
|
@ -103,11 +124,19 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag
|
|||
return cls(data=cls.Payload(text=text), conversation_id=conversation_id)
|
||||
|
||||
|
||||
class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START):
|
||||
class WebSocketAgentStartMessage(
|
||||
WebSocketAgentMessage, type=WebSocketAgentMessageType.START
|
||||
):
|
||||
pass
|
||||
|
||||
class WebSocketAgentReadyMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.READY):
|
||||
|
||||
class WebSocketAgentReadyMessage(
|
||||
WebSocketAgentMessage, type=WebSocketAgentMessageType.READY
|
||||
):
|
||||
pass
|
||||
|
||||
class WebSocketAgentStopMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP):
|
||||
|
||||
class WebSocketAgentStopMessage(
|
||||
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
|
||||
):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
class BaseAgent():
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self):
|
||||
self.app = FastAPI()
|
||||
|
||||
|
||||
def run(self, host="localhost", port=3000):
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue