sets up websocket client streaming
This commit is contained in:
parent
9cbec6f416
commit
472f553ea0
4 changed files with 76 additions and 29 deletions
|
|
@ -46,12 +46,12 @@ if __name__ == "__main__":
|
||||||
transcriber_config=DeepgramTranscriberConfig.from_input_device(
|
transcriber_config=DeepgramTranscriberConfig.from_input_device(
|
||||||
microphone_input
|
microphone_input
|
||||||
),
|
),
|
||||||
agent_config=ChatGPTAgentConfig(
|
agent_config=WebSocketUserImplementedAgentConfig(
|
||||||
initial_message=BaseMessage(text="Hello!"),
|
initial_message=BaseMessage(text="Hello!"),
|
||||||
prompt_preamble="The AI is having a pleasant conversation about life.",
|
generate_responses=True,
|
||||||
generate_responses=False,
|
respond=WebSocketUserImplementedAgentConfig.RouteConfig(
|
||||||
end_conversation_on_goodbye=True,
|
url="wss://9b1ff0eee874.ngrok.app/respond",
|
||||||
send_filler_audio=FillerAudioConfig(use_typing_noise=True),
|
),
|
||||||
),
|
),
|
||||||
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output),
|
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,48 @@
|
||||||
|
from typing import AsyncGenerator
|
||||||
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||||
from vocode.models.agent import RESTfulAgentOutput, RESTfulAgentText, RESTfulAgentEnd, WebSocketAgentMessage, WebSocketAgentTextMessage, WebSocketAgentStopMessage
|
from vocode.models.agent import (
|
||||||
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
|
RESTfulAgentOutput,
|
||||||
|
RESTfulAgentText,
|
||||||
|
RESTfulAgentEnd,
|
||||||
|
WebSocketAgentMessage,
|
||||||
|
WebSocketAgentTextEndMessage,
|
||||||
|
WebSocketAgentTextMessage,
|
||||||
|
WebSocketAgentStopMessage,
|
||||||
|
)
|
||||||
|
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
|
||||||
|
|
||||||
|
|
||||||
class TestRESTfulAgent(RESTfulAgent):
|
class TestRESTfulAgent(RESTfulAgent):
|
||||||
|
|
||||||
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
|
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
|
||||||
print(input, conversation_id)
|
print(input, conversation_id)
|
||||||
if "bye" in input:
|
if "bye" in input:
|
||||||
return RESTfulAgentEnd()
|
return RESTfulAgentEnd()
|
||||||
else:
|
else:
|
||||||
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input)))
|
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
|
||||||
return RESTfulAgentText(response=spelt)
|
return RESTfulAgentText(response=spelt)
|
||||||
|
|
||||||
class TestWebSocketAgent(WebSocketAgent):
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSocketAgent(WebSocketAgent):
|
||||||
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
|
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
|
||||||
print(input, conversation_id)
|
print(input, conversation_id)
|
||||||
if "bye" in input:
|
if "bye" in input:
|
||||||
return WebSocketAgentStopMessage()
|
return WebSocketAgentStopMessage()
|
||||||
else:
|
else:
|
||||||
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input)))
|
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
|
||||||
return WebSocketAgentTextMessage.from_text(spelt)
|
return WebSocketAgentTextMessage.from_text(spelt)
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
self, input: str, conversation_id: str
|
||||||
|
) -> AsyncGenerator[WebSocketAgentMessage, None]:
|
||||||
|
print(input, conversation_id)
|
||||||
|
if "bye" in input:
|
||||||
|
yield WebSocketAgentTextEndMessage()
|
||||||
|
else:
|
||||||
|
for word in input.split():
|
||||||
|
yield WebSocketAgentTextMessage.from_text(word)
|
||||||
|
yield WebSocketAgentTextEndMessage()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
agent = TestWebSocketAgent()
|
agent = TestWebSocketAgent(generate_responses=True)
|
||||||
agent.run(port=3001)
|
agent.run(port=3001)
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,7 @@ class WebSocketAgentMessageType(str, Enum):
|
||||||
BASE = "websocket_agent_base"
|
BASE = "websocket_agent_base"
|
||||||
START = "websocket_agent_start"
|
START = "websocket_agent_start"
|
||||||
TEXT = "websocket_agent_text"
|
TEXT = "websocket_agent_text"
|
||||||
|
TEXT_END = "websocket_agent_text_end"
|
||||||
READY = "websocket_agent_ready"
|
READY = "websocket_agent_ready"
|
||||||
STOP = "websocket_agent_stop"
|
STOP = "websocket_agent_stop"
|
||||||
|
|
||||||
|
|
@ -171,3 +172,9 @@ class WebSocketAgentStopMessage(
|
||||||
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketAgentTextEndMessage(
|
||||||
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT_END
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,35 @@
|
||||||
from .base_agent import BaseAgent
|
from .base_agent import BaseAgent
|
||||||
import uuid
|
import uuid
|
||||||
import typing
|
import typing
|
||||||
from typing import Union, Optional
|
from typing import AsyncGenerator, Union, Optional
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from ..models.agent import (
|
from ..models.agent import (
|
||||||
WebSocketAgentStartMessage,
|
WebSocketAgentStartMessage,
|
||||||
WebSocketAgentReadyMessage,
|
WebSocketAgentReadyMessage,
|
||||||
WebSocketAgentTextMessage,
|
WebSocketAgentTextEndMessage,
|
||||||
WebSocketAgentStopMessage,
|
WebSocketAgentTextMessage,
|
||||||
WebSocketAgentMessage,
|
WebSocketAgentStopMessage,
|
||||||
WebSocketAgentMessageType
|
WebSocketAgentMessage,
|
||||||
|
WebSocketAgentMessageType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketAgent(BaseAgent):
|
class WebSocketAgent(BaseAgent):
|
||||||
|
def __init__(self, generate_responses: bool = False):
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.generate_responses = generate_responses
|
||||||
self.app.websocket("/respond")(self.respond_websocket)
|
self.app.websocket("/respond")(self.respond_websocket)
|
||||||
|
|
||||||
async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
|
async def respond(
|
||||||
|
self, human_input: str, conversation_id: Optional[str] = None
|
||||||
|
) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
self, human_input: str, conversation_id: Optional[str] = None
|
||||||
|
) -> AsyncGenerator[
|
||||||
|
Union[WebSocketAgentTextMessage, WebSocketAgentTextEndMessage], None
|
||||||
|
]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def respond_websocket(self, websocket: WebSocket):
|
async def respond_websocket(self, websocket: WebSocket):
|
||||||
|
|
@ -26,12 +37,20 @@ class WebSocketAgent(BaseAgent):
|
||||||
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
||||||
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
||||||
while True:
|
while True:
|
||||||
input_message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
|
input_message: WebSocketAgentMessage = WebSocketAgentMessage.parse_obj(
|
||||||
|
await websocket.receive_json()
|
||||||
|
)
|
||||||
if input_message.type == WebSocketAgentMessageType.STOP:
|
if input_message.type == WebSocketAgentMessageType.STOP:
|
||||||
break
|
break
|
||||||
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
||||||
print(text_message)
|
if self.generate_responses:
|
||||||
output_response = await self.respond(text_message.data.text, text_message.conversation_id)
|
async for output_response in self.generate_response(
|
||||||
await websocket.send_text(output_response.json())
|
text_message.data.text, text_message.conversation_id
|
||||||
|
):
|
||||||
|
await websocket.send_text(output_response.json())
|
||||||
|
else:
|
||||||
|
output_response = await self.respond(
|
||||||
|
text_message.data.text, text_message.conversation_id
|
||||||
|
)
|
||||||
|
await websocket.send_text(output_response.json())
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue