open source
This commit is contained in:
parent
70b6e17c69
commit
a93bfc1ec9
61 changed files with 4013 additions and 126 deletions
63
vocode/streaming/utils/__init__.py
Normal file
63
vocode/streaming/utils/__init__.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import asyncio
|
||||
import audioop
|
||||
import secrets
|
||||
from typing import Any
|
||||
import wave
|
||||
|
||||
from ..models.audio_encoding import AudioEncoding
|
||||
|
||||
|
||||
def create_loop_in_thread(loop: asyncio.AbstractEventLoop, long_running_task=None):
|
||||
asyncio.set_event_loop(loop)
|
||||
if long_running_task:
|
||||
loop.run_until_complete(long_running_task)
|
||||
else:
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
def convert_linear_audio(
|
||||
raw_wav: bytes,
|
||||
input_sample_rate=24000,
|
||||
output_sample_rate=8000,
|
||||
output_encoding=AudioEncoding.LINEAR16,
|
||||
output_sample_width=2,
|
||||
):
|
||||
# downsample
|
||||
if input_sample_rate != output_sample_rate:
|
||||
raw_wav, _ = audioop.ratecv(
|
||||
raw_wav, 2, 1, input_sample_rate, output_sample_rate, None
|
||||
)
|
||||
|
||||
if output_encoding == AudioEncoding.LINEAR16:
|
||||
return raw_wav
|
||||
elif output_encoding == AudioEncoding.MULAW:
|
||||
return audioop.lin2ulaw(raw_wav, output_sample_width)
|
||||
|
||||
|
||||
def convert_wav(
|
||||
file: Any,
|
||||
output_sample_rate=8000,
|
||||
output_encoding=AudioEncoding.LINEAR16,
|
||||
):
|
||||
with wave.open(file, "rb") as wav:
|
||||
raw_wav = wav.readframes(wav.getnframes())
|
||||
return convert_linear_audio(
|
||||
raw_wav,
|
||||
input_sample_rate=wav.getframerate(),
|
||||
output_sample_rate=output_sample_rate,
|
||||
output_encoding=output_encoding,
|
||||
output_sample_width=wav.getsampwidth(),
|
||||
)
|
||||
|
||||
|
||||
def get_chunk_size_per_second(audio_encoding: AudioEncoding, sampling_rate: int) -> int:
|
||||
if audio_encoding == AudioEncoding.LINEAR16:
|
||||
return sampling_rate * 2
|
||||
elif audio_encoding == AudioEncoding.MULAW:
|
||||
return sampling_rate
|
||||
else:
|
||||
raise Exception("Unsupported audio encoding")
|
||||
|
||||
|
||||
def create_conversation_id() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
0
vocode/streaming/utils/goodbye_embeddings/.gitkeep
Normal file
0
vocode/streaming/utils/goodbye_embeddings/.gitkeep
Normal file
102
vocode/streaming/utils/goodbye_model.py
Normal file
102
vocode/streaming/utils/goodbye_model.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
import asyncio
|
||||
import openai
|
||||
from dotenv import load_dotenv
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
load_dotenv()
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
PLATFORM = "pyq" if os.getenv("USE_PYQ_EMBEDDINGS", "false") == "true" else "openai"
|
||||
SIMILARITY_THRESHOLD = 0.9
|
||||
SIMILARITY_THRESHOLD_PYQ = 0.7
|
||||
EMBEDDING_SIZE = 1536
|
||||
PYQ_EMBEDDING_SIZE = 768
|
||||
GOODBYE_PHRASES = [
|
||||
"bye",
|
||||
"goodbye",
|
||||
"see you",
|
||||
"see you later",
|
||||
"talk to you later",
|
||||
"talk to you soon",
|
||||
"have a good day",
|
||||
"have a good night",
|
||||
]
|
||||
PYQ_API_URL = "https://embeddings.pyqai.com"
|
||||
|
||||
|
||||
class GoodbyeModel:
|
||||
def __init__(
|
||||
self,
|
||||
embeddings_cache_path=os.path.join(
|
||||
os.path.dirname(__file__), "goodbye_embeddings"
|
||||
),
|
||||
):
|
||||
self.goodbye_embeddings = self.load_or_create_embeddings(
|
||||
f"{embeddings_cache_path}/goodbye_embeddings.npy"
|
||||
)
|
||||
self.goodbye_embeddings_pyq = self.load_or_create_embeddings(
|
||||
f"{embeddings_cache_path}/goodbye_embeddings_pyq.npy"
|
||||
)
|
||||
|
||||
def load_or_create_embeddings(self, path):
|
||||
if os.path.exists(path):
|
||||
return np.load(path)
|
||||
else:
|
||||
embeddings = self.create_embeddings()
|
||||
np.save(path, embeddings)
|
||||
return embeddings
|
||||
|
||||
def create_embeddings(self, platform=PLATFORM):
|
||||
print("Creating embeddings...")
|
||||
size = EMBEDDING_SIZE if platform == "openai" else PYQ_EMBEDDING_SIZE
|
||||
embeddings = np.empty((size, len(GOODBYE_PHRASES)))
|
||||
for i, goodbye_phrase in enumerate(GOODBYE_PHRASES):
|
||||
embeddings[:, i] = self.create_embedding(goodbye_phrase, platform=platform)
|
||||
return embeddings
|
||||
|
||||
async def is_goodbye(self, text: str, platform=PLATFORM) -> bool:
|
||||
if "bye" in text.lower():
|
||||
return True
|
||||
embedding = self.create_embedding(text.strip().lower(), platform=platform)
|
||||
goodbye_embeddings = (
|
||||
self.goodbye_embeddings
|
||||
if platform == "openai"
|
||||
else self.goodbye_embeddings_pyq
|
||||
)
|
||||
threshold = (
|
||||
SIMILARITY_THRESHOLD if platform == "openai" else SIMILARITY_THRESHOLD_PYQ
|
||||
)
|
||||
similarity_results = embedding @ goodbye_embeddings
|
||||
return np.max(similarity_results) > threshold
|
||||
|
||||
def create_embedding(self, text, platform=PLATFORM) -> np.array:
|
||||
if platform == "openai":
|
||||
return np.array(
|
||||
openai.Embedding.create(input=text, model="text-embedding-ada-002")[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
)
|
||||
elif platform == "pyq":
|
||||
return np.array(
|
||||
requests.post(
|
||||
PYQ_API_URL,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": os.getenv("PYQ_API_KEY"),
|
||||
},
|
||||
json={"input_sequence": [text], "account_id": "400"},
|
||||
).json()["response"][0]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
model = GoodbyeModel()
|
||||
while True:
|
||||
print(await model.is_goodbye(input("Text: ")))
|
||||
|
||||
asyncio.run(main())
|
||||
236
vocode/streaming/utils/sse_client.py
Normal file
236
vocode/streaming/utils/sse_client.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""
|
||||
A port of sseclient (https://pypi.org/project/sseclient/) that allows you to get server-side events with a POST request
|
||||
|
||||
Copyright (c) 2015 Brent Tubbs
|
||||
|
||||
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE."""
|
||||
#
|
||||
# Distributed under the terms of the MIT license.
|
||||
#
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import codecs
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import six
|
||||
|
||||
import requests
|
||||
|
||||
__version__ = "0.0.27"
|
||||
|
||||
# Technically, we should support streams that mix line endings. This regex,
|
||||
# however, assumes that a system will provide consistent line endings.
|
||||
end_of_field = re.compile(r"\r\n\r\n|\r\r|\n\n")
|
||||
|
||||
|
||||
class SSEClient(object):
|
||||
def __init__(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
last_id=None,
|
||||
retry=3000,
|
||||
session=None,
|
||||
chunk_size=1024,
|
||||
**kwargs
|
||||
):
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.last_id = last_id
|
||||
self.retry = retry
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# Optional support for passing in a requests.Session()
|
||||
self.session = session
|
||||
|
||||
# Any extra kwargs will be fed into the requests.get call later.
|
||||
self.requests_kwargs = kwargs
|
||||
|
||||
# The SSE spec requires making requests with Cache-Control: nocache
|
||||
if "headers" not in self.requests_kwargs:
|
||||
self.requests_kwargs["headers"] = {}
|
||||
self.requests_kwargs["headers"]["Cache-Control"] = "no-cache"
|
||||
|
||||
# The 'Accept' header is not required, but explicit > implicit
|
||||
self.requests_kwargs["headers"]["Accept"] = "text/event-stream"
|
||||
|
||||
# Keep data here as it streams in
|
||||
self.buf = ""
|
||||
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
if self.last_id:
|
||||
self.requests_kwargs["headers"]["Last-Event-ID"] = self.last_id
|
||||
|
||||
# Use session if set. Otherwise fall back to requests module.
|
||||
requester = self.session or requests
|
||||
self.resp = requester.request(
|
||||
self.method, self.url, stream=True, **self.requests_kwargs
|
||||
)
|
||||
self.resp_iterator = self.iter_content()
|
||||
encoding = self.resp.encoding or self.resp.apparent_encoding
|
||||
self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace")
|
||||
|
||||
# TODO: Ensure we're handling redirects. Might also stick the 'origin'
|
||||
# attribute on Events like the Javascript spec requires.
|
||||
self.resp.raise_for_status()
|
||||
|
||||
def iter_content(self):
|
||||
def generate():
|
||||
while True:
|
||||
if (
|
||||
hasattr(self.resp.raw, "_fp")
|
||||
and hasattr(self.resp.raw._fp, "fp")
|
||||
and hasattr(self.resp.raw._fp.fp, "read1")
|
||||
):
|
||||
chunk = self.resp.raw._fp.fp.read1(self.chunk_size)
|
||||
else:
|
||||
# _fp is not available, this means that we cannot use short
|
||||
# reads and this will block until the full chunk size is
|
||||
# actually read
|
||||
chunk = self.resp.raw.read(self.chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
def _event_complete(self):
|
||||
return re.search(end_of_field, self.buf) is not None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while not self._event_complete():
|
||||
try:
|
||||
next_chunk = next(self.resp_iterator)
|
||||
if not next_chunk:
|
||||
raise EOFError()
|
||||
self.buf += self.decoder.decode(next_chunk)
|
||||
|
||||
except (
|
||||
StopIteration,
|
||||
requests.RequestException,
|
||||
EOFError,
|
||||
six.moves.http_client.IncompleteRead,
|
||||
) as e:
|
||||
print(e)
|
||||
time.sleep(self.retry / 1000.0)
|
||||
self._connect()
|
||||
|
||||
# The SSE spec only supports resuming from a whole message, so
|
||||
# if we have half a message we should throw it out.
|
||||
head, sep, tail = self.buf.rpartition("\n")
|
||||
self.buf = head + sep
|
||||
continue
|
||||
|
||||
# Split the complete event (up to the end_of_field) into event_string,
|
||||
# and retain anything after the current complete event in self.buf
|
||||
# for next time.
|
||||
(event_string, self.buf) = re.split(end_of_field, self.buf, maxsplit=1)
|
||||
msg = Event.parse(event_string)
|
||||
|
||||
# If the server requests a specific retry delay, we need to honor it.
|
||||
if msg.retry:
|
||||
self.retry = msg.retry
|
||||
|
||||
# last_id should only be set if included in the message. It's not
|
||||
# forgotten if a message omits it.
|
||||
if msg.id:
|
||||
self.last_id = msg.id
|
||||
|
||||
return msg
|
||||
|
||||
if six.PY2:
|
||||
next = __next__
|
||||
|
||||
|
||||
class Event(object):
|
||||
sse_line_pattern = re.compile("(?P<name>[^:]*):?( ?(?P<value>.*))?")
|
||||
|
||||
def __init__(self, data="", event="message", id=None, retry=None):
|
||||
assert isinstance(data, six.string_types), "Data must be text"
|
||||
self.data = data
|
||||
self.event = event
|
||||
self.id = id
|
||||
self.retry = retry
|
||||
|
||||
def dump(self):
|
||||
lines = []
|
||||
if self.id:
|
||||
lines.append("id: %s" % self.id)
|
||||
|
||||
# Only include an event line if it's not the default already.
|
||||
if self.event != "message":
|
||||
lines.append("event: %s" % self.event)
|
||||
|
||||
if self.retry:
|
||||
lines.append("retry: %s" % self.retry)
|
||||
|
||||
lines.extend("data: %s" % d for d in self.data.split("\n"))
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
@classmethod
|
||||
def parse(cls, raw):
|
||||
"""
|
||||
Given a possibly-multiline string representing an SSE message, parse it
|
||||
and return a Event object.
|
||||
"""
|
||||
msg = cls()
|
||||
for line in raw.splitlines():
|
||||
m = cls.sse_line_pattern.match(line)
|
||||
if m is None:
|
||||
# Malformed line. Discard but warn.
|
||||
warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning)
|
||||
continue
|
||||
|
||||
name = m.group("name")
|
||||
if name == "":
|
||||
# line began with a ":", so is a comment. Ignore
|
||||
continue
|
||||
value = m.group("value")
|
||||
|
||||
if name == "data":
|
||||
# If we already have some data, then join to it with a newline.
|
||||
# Else this is it.
|
||||
if msg.data:
|
||||
msg.data = "%s\n%s" % (msg.data, value)
|
||||
else:
|
||||
msg.data = value
|
||||
elif name == "event":
|
||||
msg.event = value
|
||||
elif name == "id":
|
||||
msg.id = value
|
||||
elif name == "retry":
|
||||
msg.retry = int(value)
|
||||
|
||||
return msg
|
||||
|
||||
def __str__(self):
|
||||
return self.data
|
||||
40
vocode/streaming/utils/transcript.py
Normal file
40
vocode/streaming/utils/transcript.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import time
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Sender(str, Enum):
|
||||
HUMAN = "human"
|
||||
BOT = "bot"
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
text: str
|
||||
sender: Sender
|
||||
timestamp: float
|
||||
|
||||
def to_string(self, include_timestamp: bool = False) -> str:
|
||||
if include_timestamp:
|
||||
return f"{self.sender.name}: {self.text} ({self.timestamp})"
|
||||
return f"{self.sender.name}: {self.text}"
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
messages: list[Message] = []
|
||||
start_time: float = Field(default_factory=time.time)
|
||||
|
||||
def to_string(self, include_timestamps: bool = False) -> str:
|
||||
return "\n".join(
|
||||
message.to_string(include_timestamp=include_timestamps)
|
||||
for message in self.messages
|
||||
)
|
||||
|
||||
def add_human_message(self, text: str):
|
||||
self.messages.append(
|
||||
Message(text=text, sender=Sender.HUMAN, timestamp=time.time())
|
||||
)
|
||||
|
||||
def add_bot_message(self, text: str):
|
||||
self.messages.append(
|
||||
Message(text=text, sender=Sender.BOT, timestamp=time.time())
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue