sets up websocket client streaming
This commit is contained in:
parent
9cbec6f416
commit
472f553ea0
4 changed files with 76 additions and 29 deletions
|
|
@ -1,24 +1,35 @@
|
|||
from .base_agent import BaseAgent
|
||||
import uuid
|
||||
import typing
|
||||
from typing import Union, Optional
|
||||
from typing import AsyncGenerator, Union, Optional
|
||||
from fastapi import WebSocket
|
||||
from ..models.agent import (
|
||||
WebSocketAgentStartMessage,
|
||||
WebSocketAgentReadyMessage,
|
||||
WebSocketAgentTextMessage,
|
||||
WebSocketAgentStopMessage,
|
||||
WebSocketAgentMessage,
|
||||
WebSocketAgentMessageType
|
||||
WebSocketAgentStartMessage,
|
||||
WebSocketAgentReadyMessage,
|
||||
WebSocketAgentTextEndMessage,
|
||||
WebSocketAgentTextMessage,
|
||||
WebSocketAgentStopMessage,
|
||||
WebSocketAgentMessage,
|
||||
WebSocketAgentMessageType,
|
||||
)
|
||||
|
||||
|
||||
class WebSocketAgent(BaseAgent):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, generate_responses: bool = False):
|
||||
super().__init__()
|
||||
self.generate_responses = generate_responses
|
||||
self.app.websocket("/respond")(self.respond_websocket)
|
||||
|
||||
async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
|
||||
async def respond(
|
||||
self, human_input: str, conversation_id: Optional[str] = None
|
||||
) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def generate_response(
|
||||
self, human_input: str, conversation_id: Optional[str] = None
|
||||
) -> AsyncGenerator[
|
||||
Union[WebSocketAgentTextMessage, WebSocketAgentTextEndMessage], None
|
||||
]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def respond_websocket(self, websocket: WebSocket):
|
||||
|
|
@ -26,12 +37,20 @@ class WebSocketAgent(BaseAgent):
|
|||
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
||||
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
||||
while True:
|
||||
input_message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
|
||||
input_message: WebSocketAgentMessage = WebSocketAgentMessage.parse_obj(
|
||||
await websocket.receive_json()
|
||||
)
|
||||
if input_message.type == WebSocketAgentMessageType.STOP:
|
||||
break
|
||||
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
||||
print(text_message)
|
||||
output_response = await self.respond(text_message.data.text, text_message.conversation_id)
|
||||
await websocket.send_text(output_response.json())
|
||||
if self.generate_responses:
|
||||
async for output_response in self.generate_response(
|
||||
text_message.data.text, text_message.conversation_id
|
||||
):
|
||||
await websocket.send_text(output_response.json())
|
||||
else:
|
||||
output_response = await self.respond(
|
||||
text_message.data.text, text_message.conversation_id
|
||||
)
|
||||
await websocket.send_text(output_response.json())
|
||||
await websocket.close()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue