conversation ID tracking
This commit is contained in:
parent
2b3dc08f99
commit
3ad8a109e7
9 changed files with 24 additions and 13 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "vocode"
|
name = "vocode"
|
||||||
version = "0.1.26"
|
version = "0.1.28"
|
||||||
description = "The all-in-one voice SDK"
|
description = "The all-in-one voice SDK"
|
||||||
authors = ["Ajay Raj <ajay@vocode.dev>"]
|
authors = ["Ajay Raj <ajay@vocode.dev>"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import signal
|
||||||
|
|
||||||
from vocode.conversation import Conversation
|
from vocode.conversation import Conversation
|
||||||
from vocode.helpers import create_microphone_input_and_speaker_output
|
from vocode.helpers import create_microphone_input_and_speaker_output
|
||||||
from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig
|
from vocode.models.transcriber import DeepgramTranscriberConfig, PunctuationEndpointingConfig, GoogleTranscriberConfig
|
||||||
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig
|
from vocode.models.agent import ChatGPTAgentConfig, RESTfulUserImplementedAgentConfig, WebSocketUserImplementedAgentConfig, EchoAgentConfig, ChatGPTAlphaAgentConfig, ChatGPTAgentConfig
|
||||||
from vocode.models.synthesizer import AzureSynthesizerConfig
|
from vocode.models.synthesizer import AzureSynthesizerConfig
|
||||||
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
from vocode.user_implemented_agent.restful_agent import RESTfulAgent
|
||||||
|
|
@ -26,9 +26,11 @@ if __name__ == "__main__":
|
||||||
agent_config=WebSocketUserImplementedAgentConfig(
|
agent_config=WebSocketUserImplementedAgentConfig(
|
||||||
initial_message="Hello!",
|
initial_message="Hello!",
|
||||||
respond=WebSocketUserImplementedAgentConfig.RouteConfig(
|
respond=WebSocketUserImplementedAgentConfig.RouteConfig(
|
||||||
url="wss://1d7e6ab7b588.ngrok.io/respond",
|
url="wss://8b7425d5b2ab.ngrok.io/respond",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
id="ajay",
|
||||||
|
# agent_config=ChatGPTAgentConfig(initial_message="hello", prompt_preamble="you are an expert on the NBA"),
|
||||||
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
|
synthesizer_config=AzureSynthesizerConfig.from_output_device(speaker_output)
|
||||||
)
|
)
|
||||||
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
signal.signal(signal.SIGINT, lambda _0, _1: conversation.deactivate())
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ from vocode.user_implemented_agent.websocket_agent import WebSocketAgent
|
||||||
|
|
||||||
class TestRESTfulAgent(RESTfulAgent):
|
class TestRESTfulAgent(RESTfulAgent):
|
||||||
|
|
||||||
async def respond(self, input: str) -> RESTfulAgentOutput:
|
async def respond(self, input: str, conversation_id: str) -> RESTfulAgentOutput:
|
||||||
print(input)
|
print(input, conversation_id)
|
||||||
if "bye" in input:
|
if "bye" in input:
|
||||||
return RESTfulAgentEnd()
|
return RESTfulAgentEnd()
|
||||||
else:
|
else:
|
||||||
|
|
@ -24,4 +24,4 @@ class TestWebSocketAgent(WebSocketAgent):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
agent = TestWebSocketAgent()
|
agent = TestWebSocketAgent()
|
||||||
agent.run(port=3000)
|
agent.run(port=3001)
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,9 @@ class Conversation:
|
||||||
transcriber_config: TranscriberConfig,
|
transcriber_config: TranscriberConfig,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
synthesizer_config: SynthesizerConfig,
|
synthesizer_config: SynthesizerConfig,
|
||||||
|
id: str = None,
|
||||||
):
|
):
|
||||||
|
self.id = id
|
||||||
self.input_device = input_device
|
self.input_device = input_device
|
||||||
self.output_device = output_device
|
self.output_device = output_device
|
||||||
self.transcriber_config = transcriber_config
|
self.transcriber_config = transcriber_config
|
||||||
|
|
@ -68,6 +70,7 @@ class Conversation:
|
||||||
transcriber_config=self.transcriber_config,
|
transcriber_config=self.transcriber_config,
|
||||||
agent_config=self.agent_config,
|
agent_config=self.agent_config,
|
||||||
synthesizer_config=self.synthesizer_config,
|
synthesizer_config=self.synthesizer_config,
|
||||||
|
conversation_id=self.id
|
||||||
)
|
)
|
||||||
await ws.send(start_message.json())
|
await ws.send(start_message.json())
|
||||||
await self.wait_for_ready()
|
await self.wait_for_ready()
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ class RESTfulUserImplementedAgentConfig(AgentConfig, type=AgentType.RESTFUL_USER
|
||||||
# update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
# update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
||||||
|
|
||||||
class RESTfulAgentInput(BaseModel):
|
class RESTfulAgentInput(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
human_input: str
|
human_input: str
|
||||||
|
|
||||||
class RESTfulAgentOutputType(str, Enum):
|
class RESTfulAgentOutputType(str, Enum):
|
||||||
|
|
@ -88,7 +89,8 @@ class WebSocketAgentMessageType(str, Enum):
|
||||||
READY = 'websocket_agent_ready'
|
READY = 'websocket_agent_ready'
|
||||||
STOP = 'websocket_agent_stop'
|
STOP = 'websocket_agent_stop'
|
||||||
|
|
||||||
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): pass
|
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
|
||||||
class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT):
|
class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT):
|
||||||
class Payload(BaseModel):
|
class Payload(BaseModel):
|
||||||
|
|
@ -97,8 +99,8 @@ class WebSocketAgentTextMessage(WebSocketAgentMessage, type=WebSocketAgentMessag
|
||||||
data: Payload
|
data: Payload
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_text(cls, text: str):
|
def from_text(cls, text: str, conversation_id: Optional[str] = None):
|
||||||
return cls(data=cls.Payload(text=text))
|
return cls(data=cls.Payload(text=text), conversation_id=conversation_id)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START):
|
class WebSocketAgentStartMessage(WebSocketAgentMessage, type=WebSocketAgentMessageType.START):
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ class CallEntity(BaseModel):
|
||||||
class CreateInboundCall(BaseModel):
|
class CreateInboundCall(BaseModel):
|
||||||
agent_config: AgentConfig
|
agent_config: AgentConfig
|
||||||
twilio_sid: str
|
twilio_sid: str
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class CreateOutboundCall(BaseModel):
|
class CreateOutboundCall(BaseModel):
|
||||||
|
|
@ -18,4 +19,5 @@ class CreateOutboundCall(BaseModel):
|
||||||
caller: CallEntity
|
caller: CallEntity
|
||||||
agent_config: AgentConfig
|
agent_config: AgentConfig
|
||||||
synthesizer_config: Optional[SynthesizerConfig] = None
|
synthesizer_config: Optional[SynthesizerConfig] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
# TODO add IVR/etc.
|
# TODO add IVR/etc.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
from .model import TypedModel
|
from .model import TypedModel
|
||||||
from .transcriber import TranscriberConfig
|
from .transcriber import TranscriberConfig
|
||||||
from .agent import AgentConfig
|
from .agent import AgentConfig
|
||||||
|
|
@ -28,6 +29,7 @@ class StartMessage(WebSocketMessage, type=WebSocketMessageType.START):
|
||||||
transcriber_config: TranscriberConfig
|
transcriber_config: TranscriberConfig
|
||||||
agent_config: AgentConfig
|
agent_config: AgentConfig
|
||||||
synthesizer_config: SynthesizerConfig
|
synthesizer_config: SynthesizerConfig
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
|
||||||
class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY):
|
class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ class RESTfulAgent(BaseAgent):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.app.post("/respond")(self.respond_rest)
|
self.app.post("/respond")(self.respond_rest)
|
||||||
|
|
||||||
async def respond(self, human_input) -> RESTfulAgentOutput:
|
async def respond(self, human_input, conversation_id) -> RESTfulAgentOutput:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def respond_rest(self, request: RESTfulAgentInput) -> Union[RESTfulAgentText, RESTfulAgentEnd]:
|
async def respond_rest(self, request: RESTfulAgentInput) -> Union[RESTfulAgentText, RESTfulAgentEnd]:
|
||||||
response = await self.respond(request.human_input)
|
response = await self.respond(request.human_input, request.conversation_id)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ class WebSocketAgent(BaseAgent):
|
||||||
|
|
||||||
async def respond_websocket(self, websocket: WebSocket):
|
async def respond_websocket(self, websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
conversation_id = str(uuid.uuid4())
|
|
||||||
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
WebSocketAgentStartMessage.parse_obj(await websocket.receive_json())
|
||||||
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
await websocket.send_text(WebSocketAgentReadyMessage().json())
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -31,7 +30,8 @@ class WebSocketAgent(BaseAgent):
|
||||||
if input_message.type == WebSocketAgentMessageType.STOP:
|
if input_message.type == WebSocketAgentMessageType.STOP:
|
||||||
break
|
break
|
||||||
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
text_message = typing.cast(WebSocketAgentTextMessage, input_message)
|
||||||
output_response = await self.respond(text_message.data.text, conversation_id=conversation_id)
|
print(text_message)
|
||||||
|
output_response = await self.respond(text_message.data.text, text_message.conversation_id)
|
||||||
await websocket.send_text(output_response.json())
|
await websocket.send_text(output_response.json())
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue