monster commit

This commit is contained in:
Ajay Raj 2023-03-03 18:24:56 -08:00
commit de6d76c955
14 changed files with 155 additions and 66 deletions

View file

@ -1,8 +1,16 @@
from .base_agent import BaseAgent
from pydantic import BaseModel
import typing
from typing import Union
from fastapi import APIRouter, WebSocket
from ..models.agent import AgentStartMessage, AgentReadyMessage, AgentTextMessage, WebSocketAgentMessage, WebSocketAgentMessageType
from ..models.agent import (
WebSocketAgentStartMessage,
WebSocketAgentReadyMessage,
WebSocketAgentTextMessage,
WebSocketAgentStopMessage,
WebSocketAgentMessage,
WebSocketAgentMessageType
)
class WebSocketAgent(BaseAgent):
@ -10,16 +18,19 @@ class WebSocketAgent(BaseAgent):
super().__init__()
self.app.websocket("/respond")(self.respond_websocket)
async def respond(self, human_input) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
raise NotImplementedError
async def respond_websocket(self, websocket: WebSocket):
await websocket.accept()
AgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(AgentReadyMessage().json())
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json())
while True:
message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
if message.type == WebSocketAgentMessageType.AGENT_STOP:
input_message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
if input_message.type == WebSocketAgentMessageType.STOP:
break
text_message = typing.cast(AgentTextMessage, message)
response = await self.respond(text_message.data.text)
await websocket.send_text(AgentTextMessage.from_text(response).json())
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
output_response = await self.respond(text_message.data.text)
await websocket.send_text(output_response.json())
await websocket.close()