sets up websocket client streaming

This commit is contained in:
Ajay Raj 2023-03-14 00:32:20 -07:00
commit 472f553ea0
4 changed files with 76 additions and 29 deletions

View file

@ -46,12 +46,12 @@ if __name__ == "__main__":
transcriber_config=DeepgramTranscriberConfig.from_input_device(
microphone_input
),
agent_config=ChatGPTAgentConfig(
agent_config=WebSocketUserImplementedAgentConfig(
initial_message=BaseMessage(text="Hello!"),
prompt_preamble="The AI is having a pleasant conversation about life.",
generate_responses=False,
end_conversation_on_goodbye=True,
send_filler_audio=FillerAudioConfig(use_typing_noise=True),
generate_responses=True,
respond=WebSocketUserImplementedAgentConfig.RouteConfig(
url="wss://9b1ff0eee874.ngrok.app/respond",
),
),
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output),
)

View file

@ -1,27 +1,48 @@
from typing import AsyncGenerator
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
from vocode.models.agent import RESTfulAgentOutput, RESTfulAgentText, RESTfulAgentEnd, WebSocketAgentMessage, WebSocketAgentTextMessage, WebSocketAgentStopMessage
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
from vocode.models.agent import (
RESTfulAgentOutput,
RESTfulAgentText,
RESTfulAgentEnd,
WebSocketAgentMessage,
WebSocketAgentTextEndMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
)
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
class TestRESTfulAgent(RESTfulAgent):
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
print(input, conversation_id)
if "bye" in input:
return RESTfulAgentEnd()
else:
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input)))
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return RESTfulAgentText(response=spelt)
class TestWebSocketAgent(WebSocketAgent):
class TestWebSocketAgent(WebSocketAgent):
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
print(input, conversation_id)
if "bye" in input:
return WebSocketAgentStopMessage()
else:
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input)))
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return WebSocketAgentTextMessage.from_text(spelt)
async def generate_response(
self, input: str, conversation_id: str
) -> AsyncGenerator[WebSocketAgentMessage, None]:
print(input, conversation_id)
if "bye" in input:
yield WebSocketAgentTextEndMessage()
else:
for word in input.split():
yield WebSocketAgentTextMessage.from_text(word)
yield WebSocketAgentTextEndMessage()
if __name__ == "__main__":
agent = TestWebSocketAgent()
agent = TestWebSocketAgent(generate_responses=True)
agent.run(port=3001)

View file

@ -134,6 +134,7 @@ class WebSocketAgentMessageType(str, Enum):
BASE = "websocket_agent_base"
START = "websocket_agent_start"
TEXT = "websocket_agent_text"
TEXT_END = "websocket_agent_text_end"
READY = "websocket_agent_ready"
STOP = "websocket_agent_stop"
@ -171,3 +172,9 @@ class WebSocketAgentStopMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
):
pass
class WebSocketAgentTextEndMessage(
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT_END
):
pass

View file

@ -1,24 +1,35 @@
from .base_agent import BaseAgent
import uuid
import typing
from typing import Union, Optional
from typing import AsyncGenerator, Union, Optional
from fastapi import WebSocket
from ..models.agent import (
WebSocketAgentStartMessage,
WebSocketAgentReadyMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
WebSocketAgentMessage,
WebSocketAgentMessageType
WebSocketAgentStartMessage,
WebSocketAgentReadyMessage,
WebSocketAgentTextEndMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
WebSocketAgentMessage,
WebSocketAgentMessageType,
)
class WebSocketAgent(BaseAgent):
def __init__(self):
def __init__(self, generate_responses: bool = False):
super().__init__()
self.generate_responses = generate_responses
self.app.websocket("/respond")(self.respond_websocket)
async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
async def respond(
self, human_input: str, conversation_id: Optional[str] = None
) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
raise NotImplementedError
async def generate_response(
self, human_input: str, conversation_id: Optional[str] = None
) -> AsyncGenerator[
Union[WebSocketAgentTextMessage, WebSocketAgentTextEndMessage], None
]:
raise NotImplementedError
async def respond_websocket(self, websocket: WebSocket):
@ -26,12 +37,20 @@ class WebSocketAgent(BaseAgent):
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json())
while True:
input_message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
input_message: WebSocketAgentMessage = WebSocketAgentMessage.parse_obj(
await websocket.receive_json()
)
if input_message.type == WebSocketAgentMessageType.STOP:
break
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
print(text_message)
output_response = await self.respond(text_message.data.text, text_message.conversation_id)
await websocket.send_text(output_response.json())
if self.generate_responses:
async for output_response in self.generate_response(
text_message.data.text, text_message.conversation_id
):
await websocket.send_text(output_response.json())
else:
output_response = await self.respond(
text_message.data.text, text_message.conversation_id
)
await websocket.send_text(output_response.json())
await websocket.close()