v0 of user rolled AI

This commit is contained in:
Ajay Raj 2023-03-02 13:15:03 -08:00
commit 37b5ffd674
8 changed files with 90 additions and 21 deletions

View file

@ -5,24 +5,16 @@ import signal
from vocode.conversation import Conversation
from vocode.helpers import create_microphone_input_and_speaker_output
from vocode.models.transcriber import DeepgramTranscriberConfig
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig
from vocode.models.synthesizer import AzureSynthesizerConfig
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
logging.basicConfig()
logging.root.setLevel(logging.INFO)
class EchoAgent(RESTfulAgent):
async def respond(self, input: str) -> str:
return input
if __name__ == "__main__":
import threading
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False)
user_agent_thread = threading.Thread(target=EchoAgent().run)
user_agent_thread.start()
conversation = Conversation(
input_device=microphone_input,
@ -32,7 +24,7 @@ if __name__ == "__main__":
initial_message="Hello!",
generate_responses=False,
respond=RESTfulUserImplementedAgentConfig.EndpointConfig(
url="http://localhost:3001/respond",
url="http://a6eb64f4a9b7.ngrok.io/respond",
method="POST"
)
),
@ -40,5 +32,4 @@ if __name__ == "__main__":
)
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
asyncio.run(conversation.start())
user_agent_thread.join()

11
user_implemented_agent.py Normal file
View file

@ -0,0 +1,11 @@
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
class EchoAgent(RESTfulAgent):
async def respond(self, input: str) -> str:
return input
if __name__ == "__main__":
agent = EchoAgent()
agent.run()

View file

@ -16,7 +16,7 @@ from .models.synthesizer import SynthesizerConfig
from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage
from . import api_key
VOCODE_WEBSOCKET_URL = f'wss://a6eb64f4a9b7.ngrok.io/conversation'
VOCODE_WEBSOCKET_URL = f'wss://api.vocode.dev/conversation'
class Conversation:

View file

@ -10,6 +10,7 @@ class AgentType(str, Enum):
ECHO = "echo"
INFORMATION_RETRIEVAL = "information_retrieval"
RESTFUL_USER_IMPLEMENTED = "restful_user_implemented"
WEBSOCKET_USER_IMPLEMENTED = "websocket_user_implemented"
class AgentConfig(TypedModel, type=AgentType.BASE):
@ -41,9 +42,50 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
class EndpointConfig(BaseModel):
url: str
method: str = "POST"
input_param_name: str = "human_input"
output_jsonpath: str = "response"
respond: EndpointConfig
generate_response: Optional[EndpointConfig]
update_last_bot_message_on_cut_off: Optional[EndpointConfig]
class RESTfulAgentInput(BaseModel):
human_input: str
class RESTfulAgentOutput(BaseModel):
response: str
class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED):
class RouteConfig(BaseModel):
url: str
respond: RouteConfig
generate_response: Optional[RouteConfig]
send_message_on_cut_off: bool = False
class WebSocketAgentMessageType(str, Enum):
AGENT_BASE = 'agent_base'
AGENT_START = 'agent_start'
AGENT_TEXT = 'agent_text'
AGENT_READY = 'agent_ready'
AGENT_STOP = 'agent_stop'
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.AGENT_BASE): pass
class AgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_TEXT):
class Payload(BaseModel):
text: str
data: Payload
@classmethod
def from_text(cls, text: str):
return cls(data=cls.Payload(text=text))
class AgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_START):
pass
class AgentReadyMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_READY):
pass
class AgentStopMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_STOP):
pass

View file

@ -5,7 +5,8 @@ class BaseModel(pydantic.BaseModel):
def __init__(self, **data):
for key, value in data.items():
if isinstance(value, dict):
data[key] = self.parse_obj(value)
if 'type' in value:
data[key] = TypedModel.parse_obj(value)
super().__init__(**data)
# Adapted from https://github.com/pydantic/pydantic/discussions/3091

View file

@ -9,5 +9,5 @@ class BaseAgent():
async def respond(self, human_input) -> str:
raise NotImplementedError
def run(self, host="localhost", port=3001):
def run(self, host="localhost", port=3000):
uvicorn.run(self.app, host=host, port=port)

View file

@ -1,17 +1,15 @@
from .base_agent import BaseAgent
from ..models.agent import RESTfulAgentInput, RESTfulAgentOutput
from pydantic import BaseModel
from fastapi import APIRouter
class RESTfulAgent(BaseAgent):
class HumanInput(BaseModel):
human_input: str
def __init__(self):
super().__init__()
self.app.post("/respond")(self.respond_rest)
async def respond_rest(self, request: HumanInput):
async def respond_rest(self, request: RESTfulAgentInput) -> RESTfulAgentOutput:
response = await self.respond(request.human_input)
return {"response": response}
return RESTfulAgentOutput(response=response)

View file

@ -0,0 +1,26 @@
from .base_agent import BaseAgent
from pydantic import BaseModel
import typing
from fastapi import APIRouter, WebSocket
from ..models.agent import AgentStartMessage, AgentReadyMessage, AgentTextMessage, WebSocketAgentMessage, WebSocketAgentMessageType
from jsonpath_ng import parse
class WebSocketAgent(BaseAgent):
def __init__(self):
super().__init__()
self.app.websocket("/respond")(self.respond_websocket)
async def respond_websocket(self, websocket: WebSocket):
await websocket.accept()
AgentStartMessage.parse_obj(await websocket.receive_json())
await websocket.send_text(AgentReadyMessage().json())
while True:
message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
if message.type == WebSocketAgentMessageType.AGENT_STOP:
break
text_message = typing.cast(AgentTextMessage, message)
response = await self.respond(text_message.data.text)
await websocket.send_text(AgentTextMessage.from_text(response).json())
await websocket.close()