conversation ID

This commit is contained in:
Ajay Raj 2023-03-04 19:37:52 -08:00
commit 65ec1f9770
2 changed files with 8 additions and 7 deletions

View file

@ -14,8 +14,8 @@ class TestRESTfulAgent(RESTfulAgent):
class TestWebSocketAgent(WebSocketAgent):
async def respond(self, input: str) -> WebSocketAgentMessage:
print(input)
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
print(input, conversation_id)
if "bye" in input:
return WebSocketAgentStopMessage()
else:

View file

@ -1,8 +1,8 @@
from .base_agent import BaseAgent
from pydantic import BaseModel
import uuid
import typing
from typing import Union
from fastapi import APIRouter, WebSocket
from typing import Union, Optional
from fastapi import WebSocket
from ..models.agent import (
WebSocketAgentStartMessage,
WebSocketAgentReadyMessage,
@ -18,11 +18,12 @@ class WebSocketAgent(BaseAgent):
super().__init__()
self.app.websocket("/respond")(self.respond_websocket)
async def respond(self, human_input) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
raise NotImplementedError
async def respond_websocket(self, websocket: WebSocket):
await websocket.accept()
conversation_id = str(uuid.uuid4())
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(WebSocketAgentReadyMessage().json())
while True:
@ -30,7 +31,7 @@ class WebSocketAgent(BaseAgent):
if input_message.type == WebSocketAgentMessageType.STOP:
break
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
output_response = await self.respond(text_message.data.text)
output_response = await self.respond(text_message.data.text, conversation_id=conversation_id)
await websocket.send_text(output_response.json())
await websocket.close()