sets up websocket client streaming

This commit is contained in:
Ajay Raj 2023-03-14 00:32:20 -07:00
commit 472f553ea0
4 changed files with 76 additions and 29 deletions

View file

@ -46,12 +46,12 @@ if __name__ == "__main__":
transcriber_config=DeepgramTranscriberConfig.from_input_device( transcriber_config=DeepgramTranscriberConfig.from_input_device(
microphone_input microphone_input
), ),
agent_config=ChatGPTAgentConfig( agent_config=WebSocketUserImplementedAgentConfig(
initial_message=BaseMessage(text="Hello!"), initial_message=BaseMessage(text="Hello!"),
prompt_preamble="The AI is having a pleasant conversation about life.", generate_responses=True,
generate_responses=False, respond=WebSocketUserImplementedAgentConfig.RouteConfig(
end_conversation_on_goodbye=True, url="wss://9b1ff0eee874.ngrok.app/respond",
send_filler_audio=FillerAudioConfig(use_typing_noise=True), ),
), ),
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output), synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output),
) )

View file

@ -1,27 +1,48 @@
from typing import AsyncGenerator
from vocode.user_implemented_agent.restful_agent import RESTfulAgent from vocode.user_implemented_agent.restful_agent import RESTfulAgent
from vocode.models.agent import RESTfulAgentOutput, RESTfulAgentText, RESTfulAgentEnd, WebSocketAgentMessage, WebSocketAgentTextMessage, WebSocketAgentStopMessage from vocode.models.agent import (
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent RESTfulAgentOutput,
RESTfulAgentText,
RESTfulAgentEnd,
WebSocketAgentMessage,
WebSocketAgentTextEndMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
)
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
class TestRESTfulAgent(RESTfulAgent): class TestRESTfulAgent(RESTfulAgent):
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput: async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
print(input, conversation_id) print(input, conversation_id)
if "bye" in input: if "bye" in input:
return RESTfulAgentEnd() return RESTfulAgentEnd()
else: else:
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input))) spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return RESTfulAgentText(response=spelt) return RESTfulAgentText(response=spelt)
class TestWebSocketAgent(WebSocketAgent):
class TestWebSocketAgent(WebSocketAgent):
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage: async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
print(input, conversation_id) print(input, conversation_id)
if "bye" in input: if "bye" in input:
return WebSocketAgentStopMessage() return WebSocketAgentStopMessage()
else: else:
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input))) spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return WebSocketAgentTextMessage.from_text(spelt) return WebSocketAgentTextMessage.from_text(spelt)
async def generate_response(
self, input: str, conversation_id: str
) -> AsyncGenerator[WebSocketAgentMessage, None]:
print(input, conversation_id)
if "bye" in input:
yield WebSocketAgentTextEndMessage()
else:
for word in input.split():
yield WebSocketAgentTextMessage.from_text(word)
yield WebSocketAgentTextEndMessage()
if __name__ == "__main__": if __name__ == "__main__":
agent = TestWebSocketAgent() agent = TestWebSocketAgent(generate_responses=True)
agent.run(port=3001) agent.run(port=3001)

View file

@ -134,6 +134,7 @@ class WebSocketAgentMessageType(str, Enum):
BASE = "websocket_agent_base" BASE = "websocket_agent_base"
START = "websocket_agent_start" START = "websocket_agent_start"
TEXT = "websocket_agent_text" TEXT = "websocket_agent_text"
TEXT_END = "websocket_agent_text_end"
READY = "websocket_agent_ready" READY = "websocket_agent_ready"
STOP = "websocket_agent_stop" STOP = "websocket_agent_stop"
@ -171,3 +172,9 @@ class WebSocketAgentStopMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
): ):
pass pass
class WebSocketAgentTextEndMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT_END
):
pass

View file

@ -1,24 +1,35 @@
from .base_agent import BaseAgent from .base_agent import BaseAgent
import uuid import uuid
import typing import typing
from typing import Union, Optional from typing import AsyncGenerator, Union, Optional
from fastapi import WebSocket from fastapi import WebSocket
from ..models.agent import ( from ..models.agent import (
WebSocketAgentStartMessage, WebSocketAgentStartMessage,
WebSocketAgentReadyMessage, WebSocketAgentReadyMessage,
WebSocketAgentTextMessage, WebSocketAgentTextEndMessage,
WebSocketAgentStopMessage, WebSocketAgentTextMessage,
WebSocketAgentMessage, WebSocketAgentStopMessage,
WebSocketAgentMessageType WebSocketAgentMessage,
WebSocketAgentMessageType,
) )
class WebSocketAgent(BaseAgent): class WebSocketAgent(BaseAgent):
def __init__(self, generate_responses: bool = False):
def __init__(self):
super().__init__() super().__init__()
self.generate_responses = generate_responses
self.app.websocket("/respond")(self.respond_websocket) self.app.websocket("/respond")(self.respond_websocket)
async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]: async def respond(
self, human_input: str, conversation_id: Optional[str] = None
) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
raise NotImplementedError
async def generate_response(
self, human_input: str, conversation_id: Optional[str] = None
) -> AsyncGenerator[
Union[WebSocketAgentTextMessage, WebSocketAgentTextEndMessage], None
]:
raise NotImplementedError raise NotImplementedError
async def respond_websocket(self, websocket: WebSocket): async def respond_websocket(self, websocket: WebSocket):
@ -26,12 +37,20 @@ class WebSocketAgent(BaseAgent):
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json()) WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json()) await websocket.send_text(WebSocketAgentReadyMessage().json())
while True: while True:
input_message = WebSocketAgentMessage.parse_obj(await websocket.receive_json()) input_message: WebSocketAgentMessage = WebSocketAgentMessage.parse_obj(
await websocket.receive_json()
)
if input_message.type == WebSocketAgentMessageType.STOP: if input_message.type == WebSocketAgentMessageType.STOP:
break break
text_message = typing.cast(WebSocketAgentTextMessage, input_message) text_message = typing.cast(WebSocketAgentTextMessage, input_message)
print(text_message) if self.generate_responses:
output_response = await self.respond(text_message.data.text, text_message.conversation_id) async for output_response in self.generate_response(
await websocket.send_text(output_response.json()) text_message.data.text, text_message.conversation_id
):
await websocket.send_text(output_response.json())
else:
output_response = await self.respond(
text_message.data.text, text_message.conversation_id
)
await websocket.send_text(output_response.json())
await websocket.close() await websocket.close()