conversation ID

This commit is contained in:
Ajay Raj 2023-03-04 19:37:52 -08:00
commit 65ec1f9770
2 changed files with 8 additions and 7 deletions

View file

@ -14,8 +14,8 @@ class TestRESTfulAgent(RESTfulAgent):
class TestWebSocketAgent(WebSocketAgent): class TestWebSocketAgent(WebSocketAgent):
async def respond(self, input: str) -> WebSocketAgentMessage: async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
print(input) print(input, conversation_id)
if "bye" in input: if "bye" in input:
return WebSocketAgentStopMessage() return WebSocketAgentStopMessage()
else: else:

View file

@ -1,8 +1,8 @@
from .base_agent import BaseAgent from .base_agent import BaseAgent
from pydantic import BaseModel import uuid
import typing import typing
from typing import Union from typing import Union, Optional
from fastapi import APIRouter, WebSocket from fastapi import WebSocket
from ..models.agent import ( from ..models.agent import (
WebSocketAgentStartMessage, WebSocketAgentStartMessage,
WebSocketAgentReadyMessage, WebSocketAgentReadyMessage,
@ -18,11 +18,12 @@ class WebSocketAgent(BaseAgent):
super().__init__() super().__init__()
self.app.websocket("/respond")(self.respond_websocket) self.app.websocket("/respond")(self.respond_websocket)
async def respond(self, human_input) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]: async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
raise NotImplementedError raise NotImplementedError
async def respond_websocket(self, websocket: WebSocket): async def respond_websocket(self, websocket: WebSocket):
await websocket.accept() await websocket.accept()
conversation_id = str(uuid.uuid4())
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json()) WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json()) await websocket.send_text(WebSocketAgentReadyMessage().json())
while True: while True:
@ -30,7 +31,7 @@ class WebSocketAgent(BaseAgent):
if input_message.type == WebSocketAgentMessageType.STOP: if input_message.type == WebSocketAgentMessageType.STOP:
break break
text_message = typing.cast(WebSocketAgentTextMessage, input_message) text_message = typing.cast(WebSocketAgentTextMessage, input_message)
output_response = await self.respond(text_message.data.text) output_response = await self.respond(text_message.data.text, conversation_id=conversation_id)
await websocket.send_text(output_response.json()) await websocket.send_text(output_response.json())
await websocket.close() await websocket.close()