conversation ID tracking

This commit is contained in:
Ajay Raj 2023-03-08 15:33:59 -08:00
commit 3ad8a109e7
9 changed files with 24 additions and 13 deletions

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "vocode" name = "vocode"
version = "0.1.26" version = "0.1.28"
description = "The all-in-one voice SDK" description = "The all-in-one voice SDK"
authors = ["Ajay Raj <ajay@vocode.dev>"] authors = ["Ajay Raj <ajay@vocode.dev>"]
license = "MIT License" license = "MIT License"

View file

@ -4,7 +4,7 @@ import signal
from vocode.conversation import Conversation from vocode.conversation import Conversation
from vocode.helpers import create_microphone_input_and_speaker_output from vocode.helpers import create_microphone_input_and_speaker_output
from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig, GoogleTranscriberConfig
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig
from vocode.models.synthesizer import AzureSynthesizerConfig from vocode.models.synthesizer import AzureSynthesizerConfig
from vocode.user_implemented_agent.restful_agent import RESTfulAgent from vocode.user_implemented_agent.restful_agent import RESTfulAgent
@ -26,9 +26,11 @@ if __name__ == "__main__":
agent_config=WebSocketUserImplementedAgentConfig( agent_config=WebSocketUserImplementedAgentConfig(
initial_message="Hello!", initial_message="Hello!",
respond=WebSocketUserImplementedAgentConfig.RouteConfig( respond=WebSocketUserImplementedAgentConfig.RouteConfig(
url="wss://1d7e6ab7b588.ngrok.io/respond", url="wss://8b7425d5b2ab.ngrok.io/respond",
) )
), ),
id="ajay",
# agent_config=ChatGPTAgentConfig(initial_message="hello", prompt_preamble="you are an expert on the NBA"),
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output) synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
) )
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate()) signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())

View file

@ -4,8 +4,8 @@ from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
class TestRESTfulAgent(RESTfulAgent): class TestRESTfulAgent(RESTfulAgent):
async def respond(self, input: str) -> RESTfulAgentOutput: async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
print(input) print(input, conversation_id)
if "bye" in input: if "bye" in input:
return RESTfulAgentEnd() return RESTfulAgentEnd()
else: else:
@ -24,4 +24,4 @@ class TestWebSocketAgent(WebSocketAgent):
if __name__ == "__main__": if __name__ == "__main__":
agent = TestWebSocketAgent() agent = TestWebSocketAgent()
agent.run(port=3000) agent.run(port=3001)

View file

@ -28,7 +28,9 @@ class Conversation:
transcriber_config: TranscriberConfig, transcriber_config: TranscriberConfig,
agent_config: AgentConfig, agent_config: AgentConfig,
synthesizer_config: SynthesizerConfig, synthesizer_config: SynthesizerConfig,
id: str = None,
): ):
self.id = id
self.input_device = input_device self.input_device = input_device
self.output_device = output_device self.output_device = output_device
self.transcriber_config = transcriber_config self.transcriber_config = transcriber_config
@ -68,6 +70,7 @@ class Conversation:
transcriber_config=self.transcriber_config, transcriber_config=self.transcriber_config,
agent_config=self.agent_config, agent_config=self.agent_config,
synthesizer_config=self.synthesizer_config, synthesizer_config=self.synthesizer_config,
conversation_id=self.id
) )
await ws.send(start_message.json()) await ws.send(start_message.json())
await self.wait_for_ready() await self.wait_for_ready()

View file

@ -56,6 +56,7 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
# update_last_bot_message_on_cut_off: Optional[EndpointConfig] # update_last_bot_message_on_cut_off: Optional[EndpointConfig]
class RESTfulAgentInput(BaseModel): class RESTfulAgentInput(BaseModel):
conversation_id: str
human_input: str human_input: str
class RESTfulAgentOutputType(str, Enum): class RESTfulAgentOutputType(str, Enum):
@ -88,7 +89,8 @@ class WebSocketAgentMessageType(str, Enum):
READY = 'websocket_agent_ready' READY = 'websocket_agent_ready'
STOP = 'websocket_agent_stop' STOP = 'websocket_agent_stop'
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): pass class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
conversation_id: Optional[str] = None
class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT): class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT):
class Payload(BaseModel): class Payload(BaseModel):
@ -97,8 +99,8 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag
data: Payload data: Payload
@classmethod @classmethod
def from_text(cls, text: str): def from_text(cls, text: str, conversation_id: Optional[str] = None):
return cls(data=cls.Payload(text=text)) return cls(data=cls.Payload(text=text), conversation_id=conversation_id)
class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START): class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START):

View file

@ -11,6 +11,7 @@ class CallEntity(BaseModel):
class CreateInboundCall(BaseModel): class CreateInboundCall(BaseModel):
agent_config: AgentConfig agent_config: AgentConfig
twilio_sid: str twilio_sid: str
conversation_id: Optional[str] = None
class CreateOutboundCall(BaseModel): class CreateOutboundCall(BaseModel):
@ -18,4 +19,5 @@ class CreateOutboundCall(BaseModel):
caller: CallEntity caller: CallEntity
agent_config: AgentConfig agent_config: AgentConfig
synthesizer_config: Optional[SynthesizerConfig] = None synthesizer_config: Optional[SynthesizerConfig] = None
conversation_id: Optional[str] = None
# TODO add IVR/etc. # TODO add IVR/etc.

View file

@ -1,5 +1,6 @@
import base64 import base64
from enum import Enum from enum import Enum
from typing import Optional
from .model import TypedModel from .model import TypedModel
from .transcriber import TranscriberConfig from .transcriber import TranscriberConfig
from .agent import AgentConfig from .agent import AgentConfig
@ -28,6 +29,7 @@ class StartMessage(WebSocketMessage, type=WebSocketMessageType.START):
transcriber_config: TranscriberConfig transcriber_config: TranscriberConfig
agent_config: AgentConfig agent_config: AgentConfig
synthesizer_config: SynthesizerConfig synthesizer_config: SynthesizerConfig
conversation_id: Optional[str] = None
class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY): class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY):
pass pass

View file

@ -10,10 +10,10 @@ class RESTfulAgent(BaseAgent):
super().__init__() super().__init__()
self.app.post("/respond")(self.respond_rest) self.app.post("/respond")(self.respond_rest)
async def respond(self, human_input) -> RESTfulAgentOutput: async def respond(self, human_input, conversation_id) -> RESTfulAgentOutput:
raise NotImplementedError raise NotImplementedError
async def respond_rest(self, request: RESTfulAgentInput) -> Union[RESTfulAgentText, RESTfulAgentEnd]: async def respond_rest(self, request: RESTfulAgentInput) -> Union[RESTfulAgentText, RESTfulAgentEnd]:
response = await self.respond(request.human_input) response = await self.respond(request.human_input, request.conversation_id)
return response return response

View file

@ -23,7 +23,6 @@ class WebSocketAgent(BaseAgent):
async def respond_websocket(self, websocket: WebSocket): async def respond_websocket(self, websocket: WebSocket):
await websocket.accept() await websocket.accept()
conversation_id = str(uuid.uuid4())
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json()) WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json()) await websocket.send_text(WebSocketAgentReadyMessage().json())
while True: while True:
@ -31,7 +30,8 @@ class WebSocketAgent(BaseAgent):
if input_message.type == WebSocketAgentMessageType.STOP: if input_message.type == WebSocketAgentMessageType.STOP:
break break
text_message = typing.cast(WebSocketAgentTextMessage, input_message) text_message = typing.cast(WebSocketAgentTextMessage, input_message)
output_response = await self.respond(text_message.data.text, conversation_id=conversation_id) print(text_message)
output_response = await self.respond(text_message.data.text, text_message.conversation_id)
await websocket.send_text(output_response.json()) await websocket.send_text(output_response.json())
await websocket.close() await websocket.close()