checkpoint

This commit is contained in:
Ajay Raj 2023-03-01 14:50:30 -08:00
commit d9f12ec0de
5 changed files with 63 additions and 6 deletions

View file

@ -5,25 +5,40 @@ import signal
from vocode.conversation import Conversation from vocode.conversation import Conversation
from vocode.helpers import create_microphone_input_and_speaker_output from vocode.helpers import create_microphone_input_and_speaker_output
from vocode.models.transcriber import DeepgramTranscriberConfig from vocode.models.transcriber import DeepgramTranscriberConfig
from vocode.models.agent import ChatGPTAgentConfig from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig
from vocode.models.synthesizer import AzureSynthesizerConfig from vocode.models.synthesizer import AzureSynthesizerConfig
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
logging.basicConfig() logging.basicConfig()
logging.root.setLevel(logging.INFO) logging.root.setLevel(logging.INFO)
class EchoAgent(RESTfulAgent):
async def respond(self, input: str) -> str:
return input
if __name__ == "__main__": if __name__ == "__main__":
import threading
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False) microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False)
user_agent_thread = threading.Thread(target=EchoAgent().run)
user_agent_thread.start()
conversation = Conversation( conversation = Conversation(
input_device=microphone_input, input_device=microphone_input,
output_device=speaker_output, output_device=speaker_output,
transcriber_config=DeepgramTranscriberConfig.from_input_device(microphone_input), transcriber_config=DeepgramTranscriberConfig.from_input_device(microphone_input),
agent_config=ChatGPTAgentConfig( agent_config=RESTfulUserImplementedAgentConfig(
initial_message="Hello!", initial_message="Hello!",
prompt_preamble="The AI is having a pleasant conversation about life." generate_responses=False,
respond=RESTfulUserImplementedAgentConfig.EndpointConfig(
url="http://localhost:3001/respond",
method="POST"
)
), ),
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output) synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
) )
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate()) signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
asyncio.run(conversation.start()) asyncio.run(conversation.start())
user_agent_thread.join()

View file

@ -16,7 +16,7 @@ from .models.synthesizer import SynthesizerConfig
from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage
from . import api_key from . import api_key
VOCODE_WEBSOCKET_URL = f'wss://api.vocode.dev/conversation' VOCODE_WEBSOCKET_URL = f'wss://a6eb64f4a9b7.ngrok.io/conversation'
class Conversation: class Conversation:

View file

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from enum import Enum from enum import Enum
from .model import TypedModel from .model import TypedModel, BaseModel
class AgentType(str, Enum): class AgentType(str, Enum):
@ -9,11 +9,12 @@ class AgentType(str, Enum):
CHAT_GPT = "chat_gpt" CHAT_GPT = "chat_gpt"
ECHO = "echo" ECHO = "echo"
INFORMATION_RETRIEVAL = "information_retrieval" INFORMATION_RETRIEVAL = "information_retrieval"
RESTFUL_USER_IMPLEMENTED = "restful_user_implemented"
class AgentConfig(TypedModel, type=AgentType.BASE): class AgentConfig(TypedModel, type=AgentType.BASE):
initial_message: Optional[str] = None initial_message: Optional[str] = None
generate_responses: bool = True
class LLMAgentConfig(AgentConfig, type=AgentType.LLM): class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
prompt_preamble: str prompt_preamble: str
@ -35,3 +36,14 @@ class InformationRetrievalAgentConfig(
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO): class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
pass pass
class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED):
class EndpointConfig(BaseModel):
url: str
method: str = "POST"
input_param_name: str = "human_input"
output_jsonpath: str = "response"
respond: EndpointConfig
generate_response: Optional[EndpointConfig]
update_last_bot_message_on_cut_off: Optional[EndpointConfig]

View file

@ -0,0 +1,13 @@
from fastapi import FastAPI, APIRouter
import uvicorn
class BaseAgent():
def __init__(self):
self.app = FastAPI()
async def respond(self, human_input) -> str:
raise NotImplementedError
def run(self, host="localhost", port=3001):
uvicorn.run(self.app, host=host, port=port)

View file

@ -0,0 +1,17 @@
from .base_agent import BaseAgent
from pydantic import BaseModel
from fastapi import APIRouter
class RESTfulAgent(BaseAgent):
class HumanInput(BaseModel):
human_input: str
def __init__(self):
super().__init__()
self.app.post("/respond")(self.respond_rest)
async def respond_rest(self, request: HumanInput):
response = await self.respond(request.human_input)
return {"response": response}