conversation ID
This commit is contained in:
parent
425d2cf21c
commit
65ec1f9770
2 changed files with 8 additions and 7 deletions
|
|
@ -14,8 +14,8 @@ class TestRESTfulAgent(RESTfulAgent):
|
||||||
|
|
||||||
class TestWebSocketAgent(WebSocketAgent):
|
class TestWebSocketAgent(WebSocketAgent):
|
||||||
|
|
||||||
async def respond(self, input: str) -> WebSocketAgentMessage:
|
async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage:
|
||||||
print(input)
|
print(input, conversation_id)
|
||||||
if "bye" in input:
|
if "bye" in input:
|
||||||
return WebSocketAgentStopMessage()
|
return WebSocketAgentStopMessage()
|
||||||
else:
|
else:
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from .base_agent import BaseAgent
|
from .base_agent import BaseAgent
|
||||||
from pydantic import BaseModel
|
import uuid
|
||||||
import typing
|
import typing
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import WebSocket
|
||||||
from ..models.agent import (
|
from ..models.agent import (
|
||||||
WebSocketAgentStartMessage,
|
WebSocketAgentStartMessage,
|
||||||
WebSocketAgentReadyMessage,
|
WebSocketAgentReadyMessage,
|
||||||
|
|
@ -18,11 +18,12 @@ class WebSocketAgent(BaseAgent):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.app.websocket("/respond")(self.respond_websocket)
|
self.app.websocket("/respond")(self.respond_websocket)
|
||||||
|
|
||||||
async def respond(self, human_input) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
|
async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def respond_websocket(self, websocket: WebSocket):
|
async def respond_websocket(self, websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
||||||
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -30,7 +31,7 @@ class WebSocketAgent(BaseAgent):
|
||||||
if input_message.type == WebSocketAgentMessageType.STOP:
|
if input_message.type == WebSocketAgentMessageType.STOP:
|
||||||
break
|
break
|
||||||
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
||||||
output_response = await self.respond(text_message.data.text)
|
output_response = await self.respond(text_message.data.text, conversation_id=conversation_id)
|
||||||
await websocket.send_text(output_response.json())
|
await websocket.send_text(output_response.json())
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue