vocode-python/simple_user_implemented_agent.py
2023-03-14 00:32:20 -07:00

48 lines
1.7 KiB
Python

from typing import AsyncGenerator
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
from vocode.models.agent import (
RESTfulAgentOutput,
RESTfulAgentText,
RESTfulAgentEnd,
WebSocketAgentMessage,
WebSocketAgentTextEndMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
)
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
class TestRESTfulAgent(RESTfulAgent):
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
print(input, conversation_id)
if "bye" in input:
return RESTfulAgentEnd()
else:
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return RESTfulAgentText(response=spelt)
class TestWebSocketAgent(WebSocketAgent):
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
print(input, conversation_id)
if "bye" in input:
return WebSocketAgentStopMessage()
else:
spelt = "".join(i + j for i, j in zip(input, " " * len(input)))
return WebSocketAgentTextMessage.from_text(spelt)
async def generate_response(
self, input: str, conversation_id: str
) -> AsyncGenerator[WebSocketAgentMessage, None]:
print(input, conversation_id)
if "bye" in input:
yield WebSocketAgentTextEndMessage()
else:
for word in input.split():
yield WebSocketAgentTextMessage.from_text(word)
yield WebSocketAgentTextEndMessage()
if __name__ == "__main__":
agent = TestWebSocketAgent(generate_responses=True)
agent.run(port=3001)