diff --git a/user_implemented_agent.py b/simple_user_implemented_agent.py similarity index 88% rename from user_implemented_agent.py rename to simple_user_implemented_agent.py index 71d2429..d035332 100644 --- a/user_implemented_agent.py +++ b/simple_user_implemented_agent.py @@ -14,8 +14,8 @@ class TestRESTfulAgent(RESTfulAgent): class TestWebSocketAgent(WebSocketAgent): - async def respond(self, input: str) -> WebSocketAgentMessage: - print(input) + async def respond(self, input: str, conversation_id: str) -> WebSocketAgentMessage: + print(input, conversation_id) if "bye" in input: return WebSocketAgentStopMessage() else: diff --git a/vocode/user_implemented_agent/websocket_agent.py b/vocode/user_implemented_agent/websocket_agent.py index d191d86..6468e08 100644 --- a/vocode/user_implemented_agent/websocket_agent.py +++ b/vocode/user_implemented_agent/websocket_agent.py @@ -1,8 +1,8 @@ from .base_agent import BaseAgent -from pydantic import BaseModel +import uuid import typing -from typing import Union -from fastapi import APIRouter, WebSocket +from typing import Union, Optional +from fastapi import WebSocket from ..models.agent import ( WebSocketAgentStartMessage, WebSocketAgentReadyMessage, @@ -18,11 +18,12 @@ class WebSocketAgent(BaseAgent): super().__init__() self.app.websocket("/respond")(self.respond_websocket) - async def respond(self, human_input) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]: + async def respond(self, human_input: str, conversation_id: Optional[str] = None) -> Union[WebSocketAgentTextMessage, WebSocketAgentStopMessage]: raise NotImplementedError async def respond_websocket(self, websocket: WebSocket): await websocket.accept() + conversation_id = str(uuid.uuid4()) WebSocketAgentStartMessage.parse_obj(await websocket.receive_json()) await websocket.send_text(WebSocketAgentReadyMessage().json()) while True: @@ -30,7 +31,7 @@ class WebSocketAgent(BaseAgent): if input_message.type == WebSocketAgentMessageType.STOP: break text_message = typing.cast(WebSocketAgentTextMessage, input_message) - output_response = await self.respond(text_message.data.text) + output_response = await self.respond(text_message.data.text, conversation_id=conversation_id) await websocket.send_text(output_response.json()) await websocket.close()