conversation ID tracking

This commit is contained in:
Ajay Raj 2023-03-08 15:33:59 -08:00
commit 3ad8a109e7
9 changed files with 24 additions and 13 deletions

View file

@ -28,7 +28,9 @@ class Conversation:
transcriber_config: TranscriberConfig,
agent_config: AgentConfig,
synthesizer_config: SynthesizerConfig,
id: str = None,
):
self.id = id
self.input_device = input_device
self.output_device = output_device
self.transcriber_config = transcriber_config
@ -68,6 +70,7 @@ class Conversation:
transcriber_config=self.transcriber_config,
agent_config=self.agent_config,
synthesizer_config=self.synthesizer_config,
conversation_id=self.id
)
await ws.send(start_message.json())
await self.wait_for_ready()

View file

@ -56,6 +56,7 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
# update_last_bot_message_on_cut_off: Optional[EndpointConfig]
class RESTfulAgentInput(BaseModel):
conversation_id: str
human_input: str
class RESTfulAgentOutputType(str, Enum):
@ -88,7 +89,8 @@ class WebSocketAgentMessageType(str, Enum):
READY = 'websocket_agent_ready'
STOP = 'websocket_agent_stop'
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): pass
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
conversation_id: Optional[str] = None
class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT):
class Payload(BaseModel):
@ -97,8 +99,8 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag
data: Payload
@classmethod
def from_text(cls, text: str):
return cls(data=cls.Payload(text=text))
def from_text(cls, text: str, conversation_id: Optional[str] = None):
return cls(data=cls.Payload(text=text), conversation_id=conversation_id)
class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START):

View file

@ -11,6 +11,7 @@ class CallEntity(BaseModel):
class CreateInboundCall(BaseModel):
agent_config: AgentConfig
twilio_sid: str
conversation_id: Optional[str] = None
class CreateOutboundCall(BaseModel):
@ -18,4 +19,5 @@ class CreateOutboundCall(BaseModel):
caller: CallEntity
agent_config: AgentConfig
synthesizer_config: Optional[SynthesizerConfig] = None
conversation_id: Optional[str] = None
# TODO add IVR/etc.

View file

@ -1,5 +1,6 @@
import base64
from enum import Enum
from typing import Optional
from .model import TypedModel
from .transcriber import TranscriberConfig
from .agent import AgentConfig
@ -28,6 +29,7 @@ class StartMessage(WebSocketMessage, type=WebSocketMessageType.START):
transcriber_config: TranscriberConfig
agent_config: AgentConfig
synthesizer_config: SynthesizerConfig
conversation_id: Optional[str] = None
class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY):
pass

View file

@ -10,10 +10,10 @@ class RESTfulAgent(BaseAgent):
super().__init__()
self.app.post("/respond")(self.respond_rest)
async def respond(self, human_input) -> RESTfulAgentOutput:
async def respond(self, human_input, conversation_id) -> RESTfulAgentOutput:
raise NotImplementedError
async def respond_rest(self, request: RESTfulAgentInput) -> Union[RESTfulAgentText, RESTfulAgentEnd]:
response = await self.respond(request.human_input)
response = await self.respond(request.human_input, request.conversation_id)
return response

View file

@ -23,7 +23,6 @@ class WebSocketAgent(BaseAgent):
async def respond_websocket(self, websocket: WebSocket):
await websocket.accept()
conversation_id = str(uuid.uuid4())
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json())
while True:
@ -31,7 +30,8 @@ class WebSocketAgent(BaseAgent):
if input_message.type == WebSocketAgentMessageType.STOP:
break
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
output_response = await self.respond(text_message.data.text, conversation_id=conversation_id)
print(text_message)
output_response = await self.respond(text_message.data.text, text_message.conversation_id)
await websocket.send_text(output_response.json())
await websocket.close()