vocode-python/vocode/user_implemented_agent/websocket_agent.py
2023-03-14 00:32:20 -07:00

56 lines
2.1 KiB
Python

from .base_agent import BaseAgent
import uuid
import typing
from typing import AsyncGenerator, Union, Optional
from fastapi import WebSocket
from ..models.agent import (
WebSocketAgentStartMessage,
WebSocketAgentReadyMessage,
WebSocketAgentTextEndMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
WebSocketAgentMessage,
WebSocketAgentMessageType,
)
class WebSocketAgent(BaseAgent):
def __init__(self, generate_responses: bool = False):
super().__init__()
self.generate_responses = generate_responses
self.app.websocket("/respond")(self.respond_websocket)
async def respond(
self, human_input: str, conversation_id: Optional[str] = None
) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
raise NotImplementedError
async def generate_response(
self, human_input: str, conversation_id: Optional[str] = None
) -> AsyncGenerator[
Union[WebSocketAgentTextMessage, WebSocketAgentTextEndMessage], None
]:
raise NotImplementedError
async def respond_websocket(self, websocket: WebSocket):
await websocket.accept()
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json())
while True:
input_message: WebSocketAgentMessage = WebSocketAgentMessage.parse_obj(
await websocket.receive_json()
)
if input_message.type == WebSocketAgentMessageType.STOP:
break
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
if self.generate_responses:
async for output_response in self.generate_response(
text_message.data.text, text_message.conversation_id
):
await websocket.send_text(output_response.json())
else:
output_response = await self.respond(
text_message.data.text, text_message.conversation_id
)
await websocket.send_text(output_response.json())
await websocket.close()