v0 of user rolled AI
This commit is contained in:
parent
d9f12ec0de
commit
37b5ffd674
8 changed files with 90 additions and 21 deletions
|
|
@ -5,24 +5,16 @@ import signal
|
|||
from vocode.conversation import Conversation
|
||||
from vocode.helpers import create_microphone_input_and_speaker_output
|
||||
from vocode.models.transcriber import DeepgramTranscriberConfig
|
||||
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig
|
||||
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig
|
||||
from vocode.models.synthesizer import AzureSynthesizerConfig
|
||||
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||
|
||||
logging.basicConfig()
|
||||
logging.root.setLevel(logging.INFO)
|
||||
|
||||
class EchoAgent(RESTfulAgent):
|
||||
|
||||
async def respond(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
if __name__ == "__main__":
|
||||
import threading
|
||||
|
||||
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False)
|
||||
user_agent_thread = threading.Thread(target=EchoAgent().run)
|
||||
user_agent_thread.start()
|
||||
|
||||
conversation = Conversation(
|
||||
input_device=microphone_input,
|
||||
|
|
@ -32,7 +24,7 @@ if __name__ == "__main__":
|
|||
initial_message="Hello!",
|
||||
generate_responses=False,
|
||||
respond=RESTfulUserImplementedAgentConfig.EndpointConfig(
|
||||
url="http://localhost:3001/respond",
|
||||
url="http://a6eb64f4a9b7.ngrok.io/respond",
|
||||
method="POST"
|
||||
)
|
||||
),
|
||||
|
|
@ -40,5 +32,4 @@ if __name__ == "__main__":
|
|||
)
|
||||
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
||||
asyncio.run(conversation.start())
|
||||
user_agent_thread.join()
|
||||
|
||||
|
|
|
|||
11
user_implemented_agent.py
Normal file
11
user_implemented_agent.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||
from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
|
||||
|
||||
class EchoAgent(RESTfulAgent):
|
||||
|
||||
async def respond(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = EchoAgent()
|
||||
agent.run()
|
||||
|
|
@ -16,7 +16,7 @@ from .models.synthesizer import SynthesizerConfig
|
|||
from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage
|
||||
from . import api_key
|
||||
|
||||
VOCODE_WEBSOCKET_URL = f'wss://a6eb64f4a9b7.ngrok.io/conversation'
|
||||
VOCODE_WEBSOCKET_URL = f'wss://api.vocode.dev/conversation'
|
||||
|
||||
class Conversation:
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class AgentType(str, Enum):
|
|||
ECHO = "echo"
|
||||
INFORMATION_RETRIEVAL = "information_retrieval"
|
||||
RESTFUL_USER_IMPLEMENTED = "restful_user_implemented"
|
||||
WEBSOCKET_USER_IMPLEMENTED = "websocket_user_implemented"
|
||||
|
||||
|
||||
class AgentConfig(TypedModel, type=AgentType.BASE):
|
||||
|
|
@ -41,9 +42,50 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
|
|||
class EndpointConfig(BaseModel):
|
||||
url: str
|
||||
method: str = "POST"
|
||||
input_param_name: str = "human_input"
|
||||
output_jsonpath: str = "response"
|
||||
|
||||
respond: EndpointConfig
|
||||
generate_response: Optional[EndpointConfig]
|
||||
update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
||||
|
||||
class RESTfulAgentInput(BaseModel):
|
||||
human_input: str
|
||||
|
||||
class RESTfulAgentOutput(BaseModel):
|
||||
response: str
|
||||
|
||||
class WebSocketUserImplementedAgentConfig(AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED):
|
||||
class RouteConfig(BaseModel):
|
||||
url: str
|
||||
|
||||
respond: RouteConfig
|
||||
generate_response: Optional[RouteConfig]
|
||||
send_message_on_cut_off: bool = False
|
||||
|
||||
class WebSocketAgentMessageType(str, Enum):
|
||||
AGENT_BASE = 'agent_base'
|
||||
AGENT_START = 'agent_start'
|
||||
AGENT_TEXT = 'agent_text'
|
||||
AGENT_READY = 'agent_ready'
|
||||
AGENT_STOP = 'agent_stop'
|
||||
|
||||
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.AGENT_BASE): pass
|
||||
|
||||
class AgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_TEXT):
|
||||
class Payload(BaseModel):
|
||||
text: str
|
||||
|
||||
data: Payload
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
return cls(data=cls.Payload(text=text))
|
||||
|
||||
|
||||
class AgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_START):
|
||||
pass
|
||||
|
||||
class AgentReadyMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_READY):
|
||||
pass
|
||||
|
||||
class AgentStopMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.AGENT_STOP):
|
||||
pass
|
||||
|
|
@ -5,7 +5,8 @@ class BaseModel(pydantic.BaseModel):
|
|||
def __init__(self, **data):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
data[key] = self.parse_obj(value)
|
||||
if 'type' in value:
|
||||
data[key] = TypedModel.parse_obj(value)
|
||||
super().__init__(**data)
|
||||
|
||||
# Adapted from https://github.com/pydantic/pydantic/discussions/3091
|
||||
|
|
|
|||
|
|
@ -9,5 +9,5 @@ class BaseAgent():
|
|||
async def respond(self, human_input) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, host="localhost", port=3001):
|
||||
def run(self, host="localhost", port=3000):
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
|
|
@ -1,17 +1,15 @@
|
|||
from .base_agent import BaseAgent
|
||||
from ..models.agent import RESTfulAgentInput, RESTfulAgentOutput
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter
|
||||
|
||||
class RESTfulAgent(BaseAgent):
|
||||
|
||||
class HumanInput(BaseModel):
|
||||
human_input: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app.post("/respond")(self.respond_rest)
|
||||
|
||||
async def respond_rest(self, request: HumanInput):
|
||||
async def respond_rest(self, request: RESTfulAgentInput) -> RESTfulAgentOutput:
|
||||
response = await self.respond(request.human_input)
|
||||
return {"response": response}
|
||||
return RESTfulAgentOutput(response=response)
|
||||
|
||||
|
|
|
|||
26
vocode/user_implemented_agent/websocket_agent.py
Normal file
26
vocode/user_implemented_agent/websocket_agent.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from .base_agent import BaseAgent
|
||||
from pydantic import BaseModel
|
||||
import typing
|
||||
from fastapi import APIRouter, WebSocket
|
||||
from ..models.agent import AgentStartMessage, AgentReadyMessage, AgentTextMessage, WebSocketAgentMessage, WebSocketAgentMessageType
|
||||
from jsonpath_ng import parse
|
||||
|
||||
class WebSocketAgent(BaseAgent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app.websocket("/respond")(self.respond_websocket)
|
||||
|
||||
async def respond_websocket(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
AgentStartMessage.parse_obj(await websocket.receive_json())
|
||||
await websocket.send_text(AgentReadyMessage().json())
|
||||
while True:
|
||||
message = WebSocketAgentMessage.parse_obj(await websocket.receive_json())
|
||||
if message.type == WebSocketAgentMessageType.AGENT_STOP:
|
||||
break
|
||||
text_message = typing.cast(AgentTextMessage, message)
|
||||
response = await self.respond(text_message.data.text)
|
||||
await websocket.send_text(AgentTextMessage.from_text(response).json())
|
||||
await websocket.close()
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue