sets up websocket client streaming

This commit is contained in:
Ajay Raj 2023-03-14 00:32:20 -07:00
commit 472f553ea0
4 changed files with 76 additions and 29 deletions

View file

@ -1,27 +1,48 @@
from typing import AsyncGenerator
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
from vocode.models.agent import RESTfulAgentOutput, RESTfulAgentText, RESTfulAgentEnd, WebSocketAgentMessage, WebSocketAgentTextMessage, WebSocketAgentStopMessage
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
from vocode.models.agent import (
RESTfulAgentOutput,
RESTfulAgentText,
RESTfulAgentEnd,
WebSocketAgentMessage,
WebSocketAgentTextEndMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
)
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
class TestRESTfulAgent(RESTfulAgent):
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
print(input, conversation_id)
if "bye" in input:
return RESTfulAgentEnd()
else:
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input)))
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return RESTfulAgentText(response=spelt)
class TestWebSocketAgent(WebSocketAgent):
class TestWebSocketAgent(WebSocketAgent):
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
print(input, conversation_id)
if "bye" in input:
return WebSocketAgentStopMessage()
else:
spelt = ''.join(i + j for i, j in zip(input, ' ' * len(input)))
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return WebSocketAgentTextMessage.from_text(spelt)
async def generate_response(
self, input: str, conversation_id: str
) -> AsyncGenerator[WebSocketAgentMessage, None]:
print(input, conversation_id)
if "bye" in input:
yield WebSocketAgentTextEndMessage()
else:
for word in input.split():
yield WebSocketAgentTextMessage.from_text(word)
yield WebSocketAgentTextEndMessage()
if __name__ == "__main__":
agent = TestWebSocketAgent()
agent = TestWebSocketAgent(generate_responses=True)
agent.run(port=3001)