checkpoint
This commit is contained in:
parent
94f94cbebf
commit
d9f12ec0de
5 changed files with 63 additions and 6 deletions
|
|
@ -5,25 +5,40 @@ import signal
|
||||||
from vocode.conversation import Conversation
|
from vocode.conversation import Conversation
|
||||||
from vocode.helpers import create_microphone_input_and_speaker_output
|
from vocode.helpers import create_microphone_input_and_speaker_output
|
||||||
from vocode.models.transcriber import DeepgramTranscriberConfig
|
from vocode.models.transcriber import DeepgramTranscriberConfig
|
||||||
from vocode.models.agent import ChatGPTAgentConfig
|
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig
|
||||||
from vocode.models.synthesizer import AzureSynthesizerConfig
|
from vocode.models.synthesizer import AzureSynthesizerConfig
|
||||||
|
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||||
|
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
logging.root.setLevel(logging.INFO)
|
logging.root.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
class EchoAgent(RESTfulAgent):
|
||||||
|
|
||||||
|
async def respond(self, input: str) -> str:
|
||||||
|
return input
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import threading
|
||||||
|
|
||||||
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False)
|
microphone_input, speaker_output = create_microphone_input_and_speaker_output(use_first_available_device=False)
|
||||||
|
user_agent_thread = threading.Thread(target=EchoAgent().run)
|
||||||
|
user_agent_thread.start()
|
||||||
|
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
input_device=microphone_input,
|
input_device=microphone_input,
|
||||||
output_device=speaker_output,
|
output_device=speaker_output,
|
||||||
transcriber_config=DeepgramTranscriberConfig.from_input_device(microphone_input),
|
transcriber_config=DeepgramTranscriberConfig.from_input_device(microphone_input),
|
||||||
agent_config=ChatGPTAgentConfig(
|
agent_config=RESTfulUserImplementedAgentConfig(
|
||||||
initial_message="Hello!",
|
initial_message="Hello!",
|
||||||
prompt_preamble="The AI is having a pleasant conversation about life."
|
generate_responses=False,
|
||||||
|
respond=RESTfulUserImplementedAgentConfig.EndpointConfig(
|
||||||
|
url="http://localhost:3001/respond",
|
||||||
|
method="POST"
|
||||||
|
)
|
||||||
),
|
),
|
||||||
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
|
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
|
||||||
)
|
)
|
||||||
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
||||||
asyncio.run(conversation.start())
|
asyncio.run(conversation.start())
|
||||||
|
user_agent_thread.join()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from .models.synthesizer import SynthesizerConfig
|
||||||
from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage
|
from .models.websocket import ReadyMessage, AudioMessage, StartMessage, StopMessage
|
||||||
from . import api_key
|
from . import api_key
|
||||||
|
|
||||||
VOCODE_WEBSOCKET_URL = f'wss://api.vocode.dev/conversation'
|
VOCODE_WEBSOCKET_URL = f'wss://a6eb64f4a9b7.ngrok.io/conversation'
|
||||||
|
|
||||||
class Conversation:
|
class Conversation:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .model import TypedModel
|
from .model import TypedModel, BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AgentType(str, Enum):
|
class AgentType(str, Enum):
|
||||||
|
|
@ -9,11 +9,12 @@ class AgentType(str, Enum):
|
||||||
CHAT_GPT = "chat_gpt"
|
CHAT_GPT = "chat_gpt"
|
||||||
ECHO = "echo"
|
ECHO = "echo"
|
||||||
INFORMATION_RETRIEVAL = "information_retrieval"
|
INFORMATION_RETRIEVAL = "information_retrieval"
|
||||||
|
RESTFUL_USER_IMPLEMENTED = "restful_user_implemented"
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(TypedModel, type=AgentType.BASE):
|
class AgentConfig(TypedModel, type=AgentType.BASE):
|
||||||
initial_message: Optional[str] = None
|
initial_message: Optional[str] = None
|
||||||
|
generate_responses: bool = True
|
||||||
|
|
||||||
class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
|
class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
|
||||||
prompt_preamble: str
|
prompt_preamble: str
|
||||||
|
|
@ -35,3 +36,14 @@ class InformationRetrievalAgentConfig(
|
||||||
|
|
||||||
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
|
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED):
|
||||||
|
class EndpointConfig(BaseModel):
|
||||||
|
url: str
|
||||||
|
method: str = "POST"
|
||||||
|
input_param_name: str = "human_input"
|
||||||
|
output_jsonpath: str = "response"
|
||||||
|
|
||||||
|
respond: EndpointConfig
|
||||||
|
generate_response: Optional[EndpointConfig]
|
||||||
|
update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
||||||
|
|
|
||||||
13
vocode/user_implemented_agent/base_agent.py
Normal file
13
vocode/user_implemented_agent/base_agent.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from fastapi import FastAPI, APIRouter
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
class BaseAgent():
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.app = FastAPI()
|
||||||
|
|
||||||
|
async def respond(self, human_input) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def run(self, host="localhost", port=3001):
|
||||||
|
uvicorn.run(self.app, host=host, port=port)
|
||||||
17
vocode/user_implemented_agent/restful_agent.py
Normal file
17
vocode/user_implemented_agent/restful_agent.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
from .base_agent import BaseAgent
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
class RESTfulAgent(BaseAgent):
|
||||||
|
|
||||||
|
class HumanInput(BaseModel):
|
||||||
|
human_input: str
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.app.post("/respond")(self.respond_rest)
|
||||||
|
|
||||||
|
async def respond_rest(self, request: HumanInput):
|
||||||
|
response = await self.respond(request.human_input)
|
||||||
|
return {"response": response}
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue