From 3ad8a109e73feb468eb72eee2886d3625e1443b5 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Wed, 8 Mar 2023 15:33:59 -0800 Subject: [PATCH] conversation ID tracking --- pyproject.toml | 2 +- simple_conversation.py | 6 ++++-- simple_user_implemented_agent.py | 6 +++--- vocode/conversation.py | 3 +++ vocode/models/agent.py | 8 +++++--- vocode/models/telephony.py | 2 ++ vocode/models/websocket.py | 2 ++ vocode/user_implemented_agent/restful_agent.py | 4 ++-- vocode/user_implemented_agent/websocket_agent.py | 4 ++-- 9 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0eaa8e9..da4b6e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "vocode" -version = "0.1.26" +version = "0.1.28" description = "The all-in-one voice SDK" authors = ["Ajay Raj "] license = "MIT License" diff --git a/simple_conversation.py b/simple_conversation.py index f152e90..8cf4f3b 100644 --- a/simple_conversation.py +++ b/simple_conversation.py @@ -4,7 +4,7 @@ import signal from vocode.conversation import Conversation from vocode.helpers import create_microphone_input_and_speaker_output -from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig +from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig, GoogleTranscriberConfig from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig from vocode.models.synthesizer import AzureSynthesizerConfig from vocode.user_implemented_agent.restful_agent import RESTfulAgent @@ -26,9 +26,11 @@ if __name__ == "__main__": agent_config=WebSocketUserImplementedAgentConfig( initial_message="Hello!", respond=WebSocketUserImplementedAgentConfig.RouteConfig( - url="wss://1d7e6ab7b588.ngrok.io/respond", + url="wss://8b7425d5b2ab.ngrok.io/respond", ) ), + id="ajay", + # agent_config=ChatGPTAgentConfig(initial_message="hello", prompt_preamble="you are an expert on the NBA"), synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output) ) signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate()) diff --git a/simple_user_implemented_agent.py b/simple_user_implemented_agent.py index d035332..c4ca065 100644 --- a/simple_user_implemented_agent.py +++ b/simple_user_implemented_agent.py @@ -4,8 +4,8 @@ from vocode.user_implemented_agent.websocket_agent import WebSocketAgent class TestRESTfulAgent(RESTfulAgent): - async def respond(self, input: str) -> RESTfulAgentOutput: - print(input) + async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput: + print(input, conversation_id) if "bye" in input: return RESTfulAgentEnd() else: @@ -24,4 +24,4 @@ class TestWebSocketAgent(WebSocketAgent): if __name__ == "__main__": agent = TestWebSocketAgent() - agent.run(port=3000) + agent.run(port=3001) diff --git a/vocode/conversation.py b/vocode/conversation.py index 02bf99a..89ace0a 100644 --- a/vocode/conversation.py +++ b/vocode/conversation.py @@ -28,7 +28,9 @@ class Conversation: transcriber_config: TranscriberConfig, agent_config: AgentConfig, synthesizer_config: SynthesizerConfig, + id: str = None, ): + self.id = id self.input_device = input_device self.output_device = output_device self.transcriber_config = transcriber_config @@ -68,6 +70,7 @@ class Conversation: transcriber_config=self.transcriber_config, agent_config=self.agent_config, synthesizer_config=self.synthesizer_config, + conversation_id=self.id ) await ws.send(start_message.json()) await self.wait_for_ready() diff --git a/vocode/models/agent.py b/vocode/models/agent.py index 9bae43e..f7832bb 100644 --- a/vocode/models/agent.py +++ b/vocode/models/agent.py @@ -56,6 +56,7 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER # update_last_bot_message_on_cut_off: Optional[EndpointConfig] class RESTfulAgentInput(BaseModel): + conversation_id: str human_input: str class RESTfulAgentOutputType(str, Enum): @@ -88,7 +89,8 @@ class WebSocketAgentMessageType(str, Enum): READY = 'websocket_agent_ready' STOP = 'websocket_agent_stop' -class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): pass +class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): + conversation_id: Optional[str] = None class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT): class Payload(BaseModel): @@ -97,8 +99,8 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag data: Payload @classmethod - def from_text(cls, text: str): - return cls(data=cls.Payload(text=text)) + def from_text(cls, text: str, conversation_id: Optional[str] = None): + return cls(data=cls.Payload(text=text), conversation_id=conversation_id) class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START): diff --git a/vocode/models/telephony.py b/vocode/models/telephony.py index 1bf57ec..8e24e6a 100644 --- a/vocode/models/telephony.py +++ b/vocode/models/telephony.py @@ -11,6 +11,7 @@ class CallEntity(BaseModel): class CreateInboundCall(BaseModel): agent_config: AgentConfig twilio_sid: str + conversation_id: Optional[str] = None class CreateOutboundCall(BaseModel): @@ -18,4 +19,5 @@ class CreateOutboundCall(BaseModel): caller: CallEntity agent_config: AgentConfig synthesizer_config: Optional[SynthesizerConfig] = None + conversation_id: Optional[str] = None # TODO add IVR/etc. diff --git a/vocode/models/websocket.py b/vocode/models/websocket.py index b88d116..b398ff2 100644 --- a/vocode/models/websocket.py +++ b/vocode/models/websocket.py @@ -1,5 +1,6 @@ import base64 from enum import Enum +from typing import Optional from .model import TypedModel from .transcriber import TranscriberConfig from .agent import AgentConfig @@ -28,6 +29,7 @@ class StartMessage(WebSocketMessage, type=WebSocketMessageType.START): transcriber_config: TranscriberConfig agent_config: AgentConfig synthesizer_config: SynthesizerConfig + conversation_id: Optional[str] = None class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY): pass diff --git a/vocode/user_implemented_agent/restful_agent.py b/vocode/user_implemented_agent/restful_agent.py index 9d346a0..ecae491 100644 --- a/vocode/user_implemented_agent/restful_agent.py +++ b/vocode/user_implemented_agent/restful_agent.py @@ -10,10 +10,10 @@ class RESTfulAgent(BaseAgent): super().__init__() self.app.post("/respond")(self.respond_rest) - async def respond(self, human_input) -> RESTfulAgentOutput: + async def respond(self, human_input, conversation_id) -> RESTfulAgentOutput: raise NotImplementedError async def respond_rest(self, request: RESTfulAgentInput) -> Union[RESTfulAgentText, RESTfulAgentEnd]: - response = await self.respond(request.human_input) + response = await self.respond(request.human_input, request.conversation_id) return response diff --git a/vocode/user_implemented_agent/websocket_agent.py b/vocode/user_implemented_agent/websocket_agent.py index 6468e08..67f5581 100644 --- a/vocode/user_implemented_agent/websocket_agent.py +++ b/vocode/user_implemented_agent/websocket_agent.py @@ -23,7 +23,6 @@ class WebSocketAgent(BaseAgent): async def respond_websocket(self, websocket: WebSocket): await websocket.accept() - conversation_id = str(uuid.uuid4()) WebSocketAgentStartMessage.parse_obj(await websocket.receive_json()) await websocket.send_text(WebSocketAgentReadyMessage().json()) while True: @@ -31,7 +30,8 @@ class WebSocketAgent(BaseAgent): if input_message.type == WebSocketAgentMessageType.STOP: break text_message = typing.cast(WebSocketAgentTextMessage, input_message) - output_response = await self.respond(text_message.data.text, conversation_id=conversation_id) + print(text_message) + output_response = await self.respond(text_message.data.text, text_message.conversation_id) await websocket.send_text(output_response.json()) await websocket.close()