checkpoint
This commit is contained in:
parent
94f94cbebf
commit
d9f12ec0de
5 changed files with 63 additions and 6 deletions
|
|
@ -5,25 +5,40 @@ import signal
|
|||
from vocode.conversation import Conversation
|
||||
from vocode.helpers import create_microphone_input_and_speaker_output
|
||||
from vocode.models.transcriber import DeepgramTranscriberConfig
|
||||
from vocode.models.agent import ChatGPTAgentConfig
|
||||
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig
|
||||
from vocode.models.synthesizer import AzureSynthesizerConfig
|
||||
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||
|
||||
logging.basicConfig()
|
||||
logging.root.setLevel(logging.INFO)
|
||||
|
||||
class EchoAgent(RESTfulAgent):
|
||||
|
||||
async def respond(self, input: str) -> str:
|
||||
return input
|
||||
|
||||
if __name__ == "__main__":
|
||||
import threading
|
||||
|
||||
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False)
|
||||
user_agent_thread = threading.Thread(target=EchoAgent().run)
|
||||
user_agent_thread.start()
|
||||
|
||||
conversation = Conversation(
|
||||
input_device=microphone_input,
|
||||
output_device=speaker_output,
|
||||
transcriber_config=DeepgramTranscriberConfig.from_input_device(microphone_input),
|
||||
agent_config=ChatGPTAgentConfig(
|
||||
agent_config=RESTfulUserImplementedAgentConfig(
|
||||
initial_message="Hello!",
|
||||
prompt_preamble="The AI is having a pleasant conversation about life."
|
||||
generate_responses=False,
|
||||
respond=RESTfulUserImplementedAgentConfig.EndpointConfig(
|
||||
url="http://localhost:3001/respond",
|
||||
method="POST"
|
||||
)
|
||||
),
|
||||
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
|
||||
)
|
||||
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
||||
asyncio.run(conversation.start())
|
||||
user_agent_thread.join()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .models.synthesizer import SynthesizerConfig
|
|||
from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage
|
||||
from . import api_key
|
||||
|
||||
VOCODE_WEBSOCKET_URL = f'wss://api.vocode.dev/conversation'
|
||||
VOCODE_WEBSOCKET_URL = f'wss://a6eb64f4a9b7.ngrok.io/conversation'
|
||||
|
||||
class Conversation:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
from enum import Enum
|
||||
from .model import TypedModel
|
||||
from .model import TypedModel, BaseModel
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
|
|
@ -9,11 +9,12 @@ class AgentType(str, Enum):
|
|||
CHAT_GPT = "chat_gpt"
|
||||
ECHO = "echo"
|
||||
INFORMATION_RETRIEVAL = "information_retrieval"
|
||||
RESTFUL_USER_IMPLEMENTED = "restful_user_implemented"
|
||||
|
||||
|
||||
class AgentConfig(TypedModel, type=AgentType.BASE):
|
||||
initial_message: Optional[str] = None
|
||||
|
||||
generate_responses: bool = True
|
||||
|
||||
class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
|
||||
prompt_preamble: str
|
||||
|
|
@ -35,3 +36,14 @@ class InformationRetrievalAgentConfig(
|
|||
|
||||
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
|
||||
pass
|
||||
|
||||
class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED):
|
||||
class EndpointConfig(BaseModel):
|
||||
url: str
|
||||
method: str = "POST"
|
||||
input_param_name: str = "human_input"
|
||||
output_jsonpath: str = "response"
|
||||
|
||||
respond: EndpointConfig
|
||||
generate_response: Optional[EndpointConfig]
|
||||
update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
||||
|
|
|
|||
13
vocode/user_implemented_agent/base_agent.py
Normal file
13
vocode/user_implemented_agent/base_agent.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from fastapi import FastAPI, APIRouter
|
||||
import uvicorn
|
||||
|
||||
class BaseAgent():
|
||||
|
||||
def __init__(self):
|
||||
self.app = FastAPI()
|
||||
|
||||
async def respond(self, human_input) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, host="localhost", port=3001):
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
17
vocode/user_implemented_agent/restful_agent.py
Normal file
17
vocode/user_implemented_agent/restful_agent.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from .base_agent import BaseAgent
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter
|
||||
|
||||
class RESTfulAgent(BaseAgent):
|
||||
|
||||
class HumanInput(BaseModel):
|
||||
human_input: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app.post("/respond")(self.respond_rest)
|
||||
|
||||
async def respond_rest(self, request: HumanInput):
|
||||
response = await self.respond(request.human_input)
|
||||
return {"response": response}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue