From 37b5ffd6748bc5574aaf4a13e57fb088ebb72793 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 2 Mar 2023 13:15:03 -0800 Subject: [PATCH] v0 of user rolled AI --- simple_conversation.py | 13 +----- user_implemented_agent.py | 11 +++++ vocode/conversation.py | 2 +- vocode/models/agent.py | 46 ++++++++++++++++++- vocode/models/model.py | 3 +- vocode/user_implemented_agent/base_agent.py | 2 +- .../user_implemented_agent/restful_agent.py | 8 ++-- .../user_implemented_agent/websocket_agent.py | 26 +++++++++++ 8 files changed, 90 insertions(+), 21 deletions(-) create mode 100644 user_implemented_agent.py create mode 100644 vocode/user_implemented_agent/websocket_agent.py diff --git a/simple_conversation.py b/simple_conversation.py index fff4bfd..bb9611d 100644 --- a/simple_conversation.py +++ b/simple_conversation.py @@ -5,24 +5,16 @@ import signal from vocode.conversation import Conversation from vocode.helpers import create_microphone_input_and_speaker_output from vocode.models.transcriber import DeepgramTranscriberConfig -from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig +from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig from vocode.models.synthesizer import AzureSynthesizerConfig from vocode.user_implemented_agent.restful_agent import RESTfulAgent logging.basicConfig() logging.root.setLevel(logging.INFO) -class EchoAgent(RESTfulAgent): - - async def respond(self, input: str) -> str: - return input if __name__ == "__main__": - import threading - microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False) - user_agent_thread = threading.Thread(target=EchoAgent().run) - user_agent_thread.start() conversation = Conversation( input_device=microphone_input, @@ -32,7 +24,7 @@ if __name__ == "__main__": initial_message="Hello!", generate_responses=False, respond=RESTfulUserImplementedAgentConfig.EndpointConfig( - url="http://localhost:3001/respond", + url="http://a6eb64f4a9b7.ngrok.io/respond", method="POST" ) ), @@ -40,5 +32,4 @@ if __name__ == "__main__": ) signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate()) asyncio.run(conversation.start()) - user_agent_thread.join() diff --git a/user_implemented_agent.py b/user_implemented_agent.py new file mode 100644 index 0000000..65a35c6 --- /dev/null +++ b/user_implemented_agent.py @@ -0,0 +1,11 @@ +from vocode.user_implemented_agent.restful_agent import RESTfulAgent +from vocode.user_implemented_agent.websocket_agent import WebSocketAgent + +class EchoAgent(RESTfulAgent): + + async def respond(self, input: str) -> str: + return input + +if __name__ == "__main__": + agent = EchoAgent() + agent.run() \ No newline at end of file diff --git a/vocode/conversation.py b/vocode/conversation.py index 0073fbb..fa22d77 100644 --- a/vocode/conversation.py +++ b/vocode/conversation.py @@ -16,7 +16,7 @@ from .models.synthesizer import SynthesizerConfig from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage from . import api_key -VOCODE_WEBSOCKET_URL = f'wss://a6eb64f4a9b7.ngrok.io/conversation' +VOCODE_WEBSOCKET_URL = f'wss://api.vocode.dev/conversation' class Conversation: diff --git a/vocode/models/agent.py b/vocode/models/agent.py index 9ac8923..0b9267a 100644 --- a/vocode/models/agent.py +++ b/vocode/models/agent.py @@ -10,6 +10,7 @@ class AgentType(str, Enum): ECHO = "echo" INFORMATION_RETRIEVAL = "information_retrieval" RESTFUL_USER_IMPLEMENTED = "restful_user_implemented" + WEBSOCKET_USER_IMPLEMENTED = "websocket_user_implemented" class AgentConfig(TypedModel, type=AgentType.BASE): @@ -41,9 +42,50 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER class EndpointConfig(BaseModel): url: str method: str = "POST" - input_param_name: str = "human_input" - output_jsonpath: str = "response" respond: EndpointConfig generate_response: Optional[EndpointConfig] update_last_bot_message_on_cut_off: Optional[EndpointConfig] + +class RESTfulAgentInput(BaseModel): + human_input: str + +class RESTfulAgentOutput(BaseModel): + response: str + +class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED): + class RouteConfig(BaseModel): + url: str + + respond: RouteConfig + generate_response: Optional[RouteConfig] + send_message_on_cut_off: bool = False + +class WebSocketAgentMessageType(str, Enum): + AGENT_BASE = 'agent_base' + AGENT_START = 'agent_start' + AGENT_TEXT = 'agent_text' + AGENT_READY = 'agent_ready' + AGENT_STOP = 'agent_stop' + +class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.AGENT_BASE): pass + +class AgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_TEXT): + class Payload(BaseModel): + text: str + + data: Payload + + @classmethod + def from_text(cls, text: str): + return cls(data=cls.Payload(text=text)) + + +class AgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_START): + pass + +class AgentReadyMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_READY): + pass + +class AgentStopMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_STOP): + pass \ No newline at end of file diff --git a/vocode/models/model.py b/vocode/models/model.py index 5b7fc92..01777d5 100644 --- a/vocode/models/model.py +++ b/vocode/models/model.py @@ -5,7 +5,8 @@ class BaseModel(pydantic.BaseModel): def __init__(self, **data): for key, value in data.items(): if isinstance(value, dict): - data[key] = self.parse_obj(value) + if 'type' in value: + data[key] = TypedModel.parse_obj(value) super().__init__(**data) # Adapted from https://github.com/pydantic/pydantic/discussions/3091 diff --git a/vocode/user_implemented_agent/base_agent.py b/vocode/user_implemented_agent/base_agent.py index 01f09b5..904687e 100644 --- a/vocode/user_implemented_agent/base_agent.py +++ b/vocode/user_implemented_agent/base_agent.py @@ -9,5 +9,5 @@ class BaseAgent(): async def respond(self, human_input) -> str: raise NotImplementedError - def run(self, host="localhost", port=3001): + def run(self, host="localhost", port=3000): uvicorn.run(self.app, host=host, port=port) \ No newline at end of file diff --git a/vocode/user_implemented_agent/restful_agent.py b/vocode/user_implemented_agent/restful_agent.py index d06044a..385f28d 100644 --- a/vocode/user_implemented_agent/restful_agent.py +++ b/vocode/user_implemented_agent/restful_agent.py @@ -1,17 +1,15 @@ from .base_agent import BaseAgent +from ..models.agent import RESTfulAgentInput, RESTfulAgentOutput from pydantic import BaseModel from fastapi import APIRouter class RESTfulAgent(BaseAgent): - - class HumanInput(BaseModel): - human_input: str def __init__(self): super().__init__() self.app.post("/respond")(self.respond_rest) - async def respond_rest(self, request: HumanInput): + async def respond_rest(self, request: RESTfulAgentInput) -> RESTfulAgentOutput: response = await self.respond(request.human_input) - return {"response": response} + return RESTfulAgentOutput(response=response) diff --git a/vocode/user_implemented_agent/websocket_agent.py b/vocode/user_implemented_agent/websocket_agent.py new file mode 100644 index 0000000..9b4e1e1 --- /dev/null +++ b/vocode/user_implemented_agent/websocket_agent.py @@ -0,0 +1,26 @@ +from .base_agent import BaseAgent +from pydantic import BaseModel +import typing +from fastapi import APIRouter, WebSocket +from ..models.agent import AgentStartMessage, AgentReadyMessage, AgentTextMessage, WebSocketAgentMessage, WebSocketAgentMessageType +from jsonpath_ng import parse + +class WebSocketAgent(BaseAgent): + + def __init__(self): + super().__init__() + self.app.websocket("/respond")(self.respond_websocket) + + async def respond_websocket(self, websocket: WebSocket): + await websocket.accept() + AgentStartMessage.parse_obj(await websocket.receive_json()) + await websocket.send_text(AgentReadyMessage().json()) + while True: + message = WebSocketAgentMessage.parse_obj(await websocket.receive_json()) + if message.type == WebSocketAgentMessageType.AGENT_STOP: + break + text_message = typing.cast(AgentTextMessage, message) + response = await self.respond(text_message.data.text) + await websocket.send_text(AgentTextMessage.from_text(response).json()) + await websocket.close() +