fix: voice mode progress response queue (#7938)
* progress update nudge * progress update nudge * nick's fix * progress toggle * typo * fix auth and ws thread safety * [autofix.ci] apply automated fixes * fix deadlock * delete commented function * fix duplicate events * [autofix.ci] apply automated fixes * merge * clean up log_event * queues not locks * response ids for 11L flow * async bug * Fix awaits * ElevenLabsResponse -> Response * queues not locks * response_q * comment * small fix * new response queue mechanism * ignore duplicate events * event logging * process_text_deltas refactor to be a single thread/task and use AsyncElevenLabs. No longer need a text chunker. Significantly more stable. * return the chunking * reliable function calling and de-duplicating * fix shadowed var names * fix shadowed var names * cleanup * fix response id in response.audio.done * FunctionCall * [autofix.ci] apply automated fixes * mypy * 3.10 exceptions --------- Co-authored-by: phact <estevezsebastian@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ad82222ecc
commit
66fb4d471a
1 changed files with 345 additions and 258 deletions
|
|
@ -2,12 +2,12 @@ import asyncio
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
|
@ -75,19 +75,21 @@ Your instructions will be divided into three mutually exclusive sections: "Perma
|
|||
[ADDITIONAL] The following instructions are to be considered only "Additional"
|
||||
"""
|
||||
|
||||
DIRECTION_TO_OPENAI = "Client → OpenAI"
|
||||
DIRECTION_TO_CLIENT = "OpenAI → Client"
|
||||
LF_TO_OPENAI = "LF → OpenAI"
|
||||
LF_TO_CLIENT = "LF → Client"
|
||||
OPENAI_TO_LF = "OpenAI → LF"
|
||||
CLIENT_TO_LF = "Client → LF"
|
||||
# --- Helper Functions ---
|
||||
|
||||
|
||||
async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_send_json):
|
||||
async def authenticate_and_get_openai_key(session: DbSession, user: User, websocket: WebSocket):
|
||||
"""Authenticate the user using a token or API key and retrieve the OpenAI API key.
|
||||
|
||||
Returns a tuple: (current_user, openai_key). If authentication fails, sends an error
|
||||
message to the client and returns (None, None).
|
||||
"""
|
||||
if user is None:
|
||||
await safe_send_json(
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": "langflow_auth",
|
||||
|
|
@ -102,7 +104,7 @@ async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_s
|
|||
)
|
||||
openai_key = openai_key_value if openai_key_value is not None else os.getenv("OPENAI_API_KEY", "")
|
||||
if not openai_key or openai_key == "dummy":
|
||||
await safe_send_json(
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": "api_key_missing",
|
||||
|
|
@ -118,47 +120,6 @@ async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_s
|
|||
return user, openai_key
|
||||
|
||||
|
||||
# --- Synchronous Text Chunker ---
|
||||
def sync_text_chunker(sync_queue_obj: queue.Queue, timeout: float = 0.3):
|
||||
"""Synchronous generator that reads text pieces from a sync queue and yields complete chunks."""
|
||||
splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ")
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
text = sync_queue_obj.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
if buffer:
|
||||
yield buffer + " "
|
||||
buffer = ""
|
||||
continue
|
||||
if text is None:
|
||||
if buffer:
|
||||
yield buffer + " "
|
||||
break
|
||||
if buffer and buffer[-1] in splitters:
|
||||
yield buffer + " "
|
||||
buffer = text
|
||||
elif text and text[0] in splitters:
|
||||
yield buffer + text[0] + " "
|
||||
buffer = text[1:]
|
||||
else:
|
||||
buffer += text
|
||||
if buffer:
|
||||
yield buffer + " "
|
||||
|
||||
|
||||
def common_response_create(session_id, original: dict | None = None) -> dict:
|
||||
msg = {}
|
||||
if original is not None:
|
||||
msg = original
|
||||
msg["type"] = "response.create"
|
||||
voice_config = get_voice_config(session_id)
|
||||
if voice_config.use_elevenlabs:
|
||||
response = msg.setdefault("response", {})
|
||||
response["modalities"] = ["text"]
|
||||
return msg
|
||||
|
||||
|
||||
class VoiceConfig:
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
|
|
@ -357,6 +318,86 @@ async def process_message_queue(queue_key, session):
|
|||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
class SendQueues:
|
||||
def __init__(self, openai_ws: websockets.WebSocketClientProtocol, client_ws: WebSocket, log_event):
|
||||
self.openai_ws: websockets.WebSocketClientProtocol = openai_ws
|
||||
self.openai_send_q: asyncio.Queue[tuple] = asyncio.Queue()
|
||||
self.openai_writer_task: asyncio.Task = asyncio.create_task(self.__openai_writer())
|
||||
|
||||
self.block: asyncio.Event = asyncio.Event()
|
||||
self.block.set()
|
||||
|
||||
self.client_ws: WebSocket = client_ws
|
||||
self.client_send_q: asyncio.Queue[dict] = asyncio.Queue()
|
||||
self.client_writer_task: asyncio.Task = asyncio.create_task(self.__client_writer())
|
||||
self.log_event = log_event
|
||||
|
||||
def openai_send(self, payload, *, is_blocking=False):
|
||||
try:
|
||||
self.openai_send_q.put_nowait([payload, is_blocking])
|
||||
except Exception: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def openai_unblock(self):
|
||||
logger.trace("OPENAI UNBLOCKING")
|
||||
self.block.set()
|
||||
|
||||
async def __openai_writer(self):
|
||||
try:
|
||||
while True:
|
||||
msg, is_blocking = await self.openai_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await self.block.wait()
|
||||
await self.openai_ws.send(json.dumps(msg))
|
||||
self.log_event(msg, LF_TO_OPENAI)
|
||||
if is_blocking:
|
||||
self.block.clear()
|
||||
logger.trace("OPENAI BLOCKING")
|
||||
# log_event(msg, DIRECTION_TO_OPENAI)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def client_send(self, payload):
|
||||
try:
|
||||
self.client_send_q.put_nowait(payload)
|
||||
self.log_event(payload, LF_TO_OPENAI)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def __client_writer(self):
|
||||
try:
|
||||
while True:
|
||||
msg = await self.client_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
self.log_event(msg, LF_TO_CLIENT)
|
||||
await self.client_ws.send_text(json.dumps(msg))
|
||||
except Exception: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def close(self):
|
||||
self.openai_send_q.put_nowait(None)
|
||||
self.client_send_q.put_nowait(None)
|
||||
await self.openai_writer_task
|
||||
await self.client_writer_task
|
||||
|
||||
|
||||
def get_create_response(send_handler: SendQueues, session_id):
|
||||
def create_response(original: dict | None = None):
|
||||
msg = {}
|
||||
if original is not None:
|
||||
msg = original
|
||||
msg["type"] = "response.create"
|
||||
voice_config = get_voice_config(session_id)
|
||||
if voice_config.use_elevenlabs:
|
||||
response = msg.setdefault("response", {})
|
||||
response["modalities"] = ["text"]
|
||||
send_handler.openai_send(payload=msg, is_blocking=True)
|
||||
|
||||
return create_response
|
||||
|
||||
|
||||
async def handle_function_call(
|
||||
function_call: dict,
|
||||
function_call_args: str,
|
||||
|
|
@ -365,34 +406,11 @@ async def handle_function_call(
|
|||
current_user: CurrentActiveUser,
|
||||
conversation_id: str,
|
||||
session_id: str,
|
||||
voice_config: VoiceConfig,
|
||||
client_safe_send_json,
|
||||
openai_send,
|
||||
msg_handler: SendQueues,
|
||||
):
|
||||
create_response = get_create_response(msg_handler, session_id)
|
||||
"""Handle function calls from the OpenAI API."""
|
||||
try:
|
||||
# trigger response that tool was called
|
||||
if voice_config.progress_enabled:
|
||||
await openai_send(
|
||||
{
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Tell the user you are now looking into or solving "
|
||||
"a request that will be explained later. Do not repeat "
|
||||
"the prompt exactly, summarize what's being requested."
|
||||
"Keep it very short."
|
||||
f"\n\nThe request: {function_call_args}",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
await openai_send(common_response_create(session_id))
|
||||
args = json.loads(function_call_args) if function_call_args else {}
|
||||
input_request = InputValueRequest(
|
||||
input_value=args.get("input"), components=[], type="chat", session=conversation_id
|
||||
|
|
@ -408,7 +426,7 @@ async def handle_function_call(
|
|||
if not line:
|
||||
continue
|
||||
event_data = json.loads(line)
|
||||
client_safe_send_json({"type": "flow.build.progress", "data": event_data})
|
||||
msg_handler.client_send({"type": "flow.build.progress", "data": event_data})
|
||||
if event_data.get("event") == "end_vertex":
|
||||
text_part = (
|
||||
event_data.get("data", {})
|
||||
|
|
@ -427,8 +445,8 @@ async def handle_function_call(
|
|||
"output": str(result),
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
await openai_send(common_response_create(session_id))
|
||||
msg_handler.openai_send(function_output)
|
||||
create_response()
|
||||
except json.JSONDecodeError as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"JSON decode error: {e!s}\ntrace: {trace}")
|
||||
|
|
@ -440,7 +458,7 @@ async def handle_function_call(
|
|||
"output": f"Error parsing arguments: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
msg_handler.openai_send(function_output)
|
||||
except ValueError as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"Value error: {e!s}\ntrace: {trace}")
|
||||
|
|
@ -452,7 +470,7 @@ async def handle_function_call(
|
|||
"output": f"Error with input values: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
msg_handler.openai_send(function_output)
|
||||
except (ConnectionError, websockets.exceptions.WebSocketException) as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"Connection error: {e!s}\ntrace: {trace}")
|
||||
|
|
@ -464,7 +482,7 @@ async def handle_function_call(
|
|||
"output": f"Connection error: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
msg_handler.openai_send(function_output)
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
logger.error(f"Error executing flow: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
@ -476,7 +494,7 @@ async def handle_function_call(
|
|||
"output": f"Error executing flow: {e}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
msg_handler.openai_send(function_output)
|
||||
|
||||
|
||||
voice_config_cache: dict[str, VoiceConfig] = {}
|
||||
|
|
@ -549,14 +567,14 @@ async def queue_generator(queue: asyncio.Queue):
|
|||
yield item
|
||||
|
||||
|
||||
def create_event_logger(session_id: str):
|
||||
def create_event_logger():
|
||||
state = {"last_event_type": None, "event_count": 0}
|
||||
|
||||
def log_event(event: dict, direction: str) -> None:
|
||||
def log_event(event: dict, provenance: str) -> None:
|
||||
event_type = event.get("type", "None")
|
||||
event_id = event.get("event_id", "None")
|
||||
response_id = event.get("response_id") or event.get("response", {}).get("id", None)
|
||||
if event_type != state["last_event_type"]:
|
||||
logger.debug(f"Event (id - {event_id}) (session - {session_id}): {direction} {event_type}")
|
||||
logger.debug(f"Event (response_id - {response_id}): {provenance} {event_type}")
|
||||
state["last_event_type"] = event_type
|
||||
state["event_count"] = 0
|
||||
if event_type == "response.created":
|
||||
|
|
@ -570,6 +588,103 @@ def create_event_logger(session_id: str):
|
|||
return log_event
|
||||
|
||||
|
||||
TTL_SECONDS = 60
|
||||
_completed: dict[str, float] = {}
|
||||
|
||||
|
||||
def mark_response_done(response_id: str):
|
||||
logger.debug(f"Marking response {response_id} as done")
|
||||
_completed[response_id] = time.time()
|
||||
|
||||
|
||||
# Don't let this grow unbounded
|
||||
def is_response_done(response_id: str) -> bool:
|
||||
now = time.time()
|
||||
# prune old entries
|
||||
for k, ts in list(_completed.items()):
|
||||
if now - ts > TTL_SECONDS:
|
||||
del _completed[k]
|
||||
completed = response_id in _completed
|
||||
if completed:
|
||||
logger.debug(f"Response {response_id} is done: {completed}")
|
||||
return completed
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
def __init__(
|
||||
self,
|
||||
item: dict,
|
||||
msg_handler,
|
||||
flow_id: str,
|
||||
background_tasks,
|
||||
current_user,
|
||||
conversation_id: str,
|
||||
session_id: str,
|
||||
*,
|
||||
is_prog_enabled: bool,
|
||||
):
|
||||
self.item = item
|
||||
self.args = ""
|
||||
self.done = False
|
||||
self.prog_rsp_id: str | None = None
|
||||
self.func_rsp_id: str | None = None
|
||||
self.func_task: asyncio.Task | None = None
|
||||
self.is_prog_enabled = is_prog_enabled
|
||||
self.msg_handler = msg_handler
|
||||
self.flow_id = flow_id
|
||||
self.background_tasks = background_tasks
|
||||
self.current_user = current_user
|
||||
self.conversation_id = conversation_id
|
||||
self.session_id = session_id
|
||||
|
||||
def append_args(self, args: str):
|
||||
self.args += args
|
||||
|
||||
def execute(self):
|
||||
if self.is_prog_enabled:
|
||||
self._send_progress_message()
|
||||
self._send_function_call()
|
||||
|
||||
def _send_progress_message(self):
|
||||
# Summarize and notify user of in-progress function call
|
||||
self.msg_handler.openai_send(
|
||||
{
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Tell the user you are now looking into or solving a request."
|
||||
"and summarize what is being requested."
|
||||
"Keep it very short."
|
||||
f"\n\nThe request: {self.args}",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
create_response = partial(get_create_response(self.msg_handler, self.session_id))
|
||||
create_response()
|
||||
|
||||
def _send_function_call(self):
|
||||
async def _call():
|
||||
await handle_function_call(
|
||||
function_call=self.item,
|
||||
function_call_args=self.args,
|
||||
flow_id=self.flow_id,
|
||||
background_tasks=self.background_tasks,
|
||||
current_user=self.current_user,
|
||||
conversation_id=self.conversation_id,
|
||||
session_id=self.session_id,
|
||||
msg_handler=self.msg_handler,
|
||||
)
|
||||
self.done = True
|
||||
|
||||
self.func_task = asyncio.create_task(_call())
|
||||
|
||||
|
||||
# --- WebSocket Endpoints for Flow-as-Tool ---
|
||||
@router.websocket("/ws/flow_as_tool/{flow_id}")
|
||||
async def flow_as_tool_websocket_no_session(
|
||||
|
|
@ -600,59 +715,12 @@ async def flow_as_tool_websocket(
|
|||
try:
|
||||
await client_websocket.accept()
|
||||
|
||||
openai_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
client_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
log_event = create_event_logger()
|
||||
|
||||
async def openai_writer():
|
||||
while True:
|
||||
msg = await openai_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await openai_ws.send(msg)
|
||||
|
||||
async def client_writer():
|
||||
while True:
|
||||
msg = await client_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await client_websocket.send_text(msg)
|
||||
|
||||
log_event = create_event_logger(session_id)
|
||||
|
||||
response_q: asyncio.Queue = asyncio.Queue()
|
||||
await response_q.put(None)
|
||||
|
||||
async def openai_send(payload):
|
||||
log_event(payload, DIRECTION_TO_OPENAI)
|
||||
logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}")
|
||||
if payload.get("type") == "response.create":
|
||||
await response_q.get()
|
||||
openai_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
def client_send(payload):
|
||||
log_event(payload, DIRECTION_TO_CLIENT)
|
||||
logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload.get('type', 'None')}")
|
||||
if payload.get("type") == "response.done":
|
||||
try:
|
||||
response_q.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Queue is full, skipping put_nowait")
|
||||
client_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
async def close():
|
||||
openai_send_q.put_nowait(None)
|
||||
client_send_q.put_nowait(None)
|
||||
await openai_writer_task
|
||||
await client_writer_task
|
||||
await client_websocket.close()
|
||||
await openai_ws.close()
|
||||
|
||||
vad_task = None
|
||||
vad_task: asyncio.Task | None = None
|
||||
voice_config = get_voice_config(session_id)
|
||||
current_user: User = await get_current_user_for_websocket(client_websocket, session)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_websocket)
|
||||
if current_user is None or openai_key is None:
|
||||
return
|
||||
try:
|
||||
|
|
@ -668,7 +736,8 @@ async def flow_as_tool_websocket(
|
|||
},
|
||||
}
|
||||
except Exception as e: # noqa: BLE001
|
||||
client_send({"error": f"Failed to load flow: {e!s}"})
|
||||
err_msg = {"error": f"Failed to load flow: {e!s}"}
|
||||
await client_websocket.send_json(err_msg)
|
||||
logger.error(f"Failed to load flow: {e}")
|
||||
return
|
||||
|
||||
|
|
@ -684,13 +753,12 @@ async def flow_as_tool_websocket(
|
|||
return session_dict
|
||||
|
||||
async with websockets.connect(url, extra_headers=headers) as openai_ws:
|
||||
msg_handler = SendQueues(openai_ws, client_websocket, log_event)
|
||||
|
||||
openai_realtime_session = init_session_dict()
|
||||
|
||||
openai_writer_task = asyncio.create_task(openai_writer())
|
||||
client_writer_task = asyncio.create_task(client_writer())
|
||||
|
||||
session_update = {"type": "session.update", "session": openai_realtime_session}
|
||||
await openai_send(session_update)
|
||||
msg_handler.openai_send(session_update)
|
||||
|
||||
# Setup for VAD processing.
|
||||
vad_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
|
@ -716,7 +784,7 @@ async def flow_as_tool_websocket(
|
|||
has_speech = True
|
||||
logger.trace("!", end="")
|
||||
if bot_speaking_flag[0]:
|
||||
await openai_send({"type": "response.cancel"})
|
||||
msg_handler.openai_send({"type": "response.cancel"})
|
||||
bot_speaking_flag[0] = False
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"[ERROR] VAD processing failed (ValueError): {e}")
|
||||
|
|
@ -730,7 +798,7 @@ async def flow_as_tool_websocket(
|
|||
logger.trace("_", end="")
|
||||
|
||||
def client_send_event_from_thread(event, loop) -> None:
|
||||
return loop.call_soon_threadsafe(client_send, event)
|
||||
return loop.call_soon_threadsafe(msg_handler.client_send, event)
|
||||
|
||||
def pass_through(from_dict, to_dict, keys):
|
||||
for key in keys:
|
||||
|
|
@ -790,69 +858,71 @@ async def flow_as_tool_websocket(
|
|||
|
||||
then run the ElevenLabs TTS call (which expects a sync generator) in a separate thread.
|
||||
"""
|
||||
sync_q: queue.Queue = queue.Queue()
|
||||
try:
|
||||
elevenlabs_client = await get_or_create_elevenlabs_client(current_user.id, session)
|
||||
if elevenlabs_client is None:
|
||||
return
|
||||
|
||||
async def transfer_text_deltas():
|
||||
while True:
|
||||
item = await rsp.text_delta_queue.get()
|
||||
sync_q.put(item)
|
||||
if item is None:
|
||||
break
|
||||
async def get_chunks(q: asyncio.Queue):
|
||||
delims = [".", "?", ";", "!"]
|
||||
buf: str = ""
|
||||
while True:
|
||||
text = await q.get()
|
||||
if text is None:
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
return
|
||||
buf += text
|
||||
delim_locs = []
|
||||
for delim in delims:
|
||||
i = buf.find(delim)
|
||||
while i != -1:
|
||||
delim_locs.append(i)
|
||||
i = buf.find(delim, i + 1)
|
||||
substr_begin = 0
|
||||
for delim_loc in delim_locs:
|
||||
chunk = buf[substr_begin : delim_loc + 1]
|
||||
substr_begin = delim_loc + 1
|
||||
yield chunk
|
||||
buf = buf[substr_begin:]
|
||||
|
||||
# Schedule the transfer task in the main event loop.
|
||||
transfer_task = asyncio.create_task(transfer_text_deltas())
|
||||
chunk_gen = get_chunks(rsp.text_delta_queue)
|
||||
|
||||
# Create the synchronous generator from the sync queue.
|
||||
sync_gen = sync_text_chunker(sync_q, timeout=0.3)
|
||||
elevenlabs_client = await get_or_create_elevenlabs_client(current_user.id, session)
|
||||
if elevenlabs_client is None:
|
||||
transfer_task.cancel()
|
||||
return
|
||||
# Capture the current event loop to schedule send operations.
|
||||
main_loop = asyncio.get_running_loop()
|
||||
async for text_chunk in chunk_gen:
|
||||
audio_chunks = elevenlabs_client.generate(
|
||||
voice=voice_config.elevenlabs_voice,
|
||||
output_format="pcm_24000",
|
||||
text=text_chunk, # synchronous generator expected by ElevenLabs
|
||||
model=voice_config.elevenlabs_model,
|
||||
voice_settings=None,
|
||||
stream=True,
|
||||
)
|
||||
for audio_chunk in audio_chunks:
|
||||
base64_audio = base64.b64encode(audio_chunk).decode("utf-8")
|
||||
# Schedule sending the audio chunk in the main event loop.
|
||||
event = {
|
||||
"type": "response.audio.delta",
|
||||
"delta": base64_audio,
|
||||
"response_id": rsp.response_id,
|
||||
}
|
||||
# client_send_event_from_thread(event, main_loop)
|
||||
msg_handler.client_send(event)
|
||||
|
||||
def tts_thread():
|
||||
# Create a new event loop for this thread.
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
|
||||
async def run_tts():
|
||||
try:
|
||||
audio_stream = elevenlabs_client.generate(
|
||||
voice=voice_config.elevenlabs_voice,
|
||||
output_format="pcm_24000",
|
||||
text=sync_gen, # synchronous generator expected by ElevenLabs
|
||||
model=voice_config.elevenlabs_model,
|
||||
voice_settings=None,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in audio_stream:
|
||||
base64_audio = base64.b64encode(chunk).decode("utf-8")
|
||||
# Schedule sending the audio chunk in the main event loop.
|
||||
event = {
|
||||
"type": "response.audio.delta",
|
||||
"delta": base64_audio,
|
||||
"response_id": rsp.response_id,
|
||||
}
|
||||
client_send_event_from_thread(event, main_loop)
|
||||
|
||||
event = {"type": "response.done", "response": {"id": rsp.response_id}}
|
||||
client_send_event_from_thread(event, main_loop)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error in TTS processing (ValueError): {e}")
|
||||
|
||||
new_loop.run_until_complete(run_tts())
|
||||
new_loop.close()
|
||||
|
||||
threading.Thread(target=tts_thread, daemon=True).start()
|
||||
event = {"type": "response.audio.done", "response_id": rsp.response_id}
|
||||
# client_send_event_from_thread(event, main_loop)
|
||||
msg_handler.client_send(event)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def forward_to_openai() -> None:
|
||||
nonlocal openai_realtime_session
|
||||
create_response = get_create_response(msg_handler, session_id)
|
||||
try:
|
||||
num_audio_samples = 0 # Initialize as an integer instead of None
|
||||
while True:
|
||||
message_text = await client_websocket.receive_text()
|
||||
msg = json.loads(message_text)
|
||||
log_event(msg, CLIENT_TO_LF)
|
||||
if msg.get("type") == "input_audio_buffer.append":
|
||||
logger.trace(f"buffer_id {msg.get('buffer_id', '')}")
|
||||
base64_data = msg.get("audio", "")
|
||||
|
|
@ -861,14 +931,14 @@ async def flow_as_tool_websocket(
|
|||
# Ensure we're adding to an integer
|
||||
num_audio_samples += len(base64_data)
|
||||
event = {"type": "input_audio_buffer.append", "audio": base64_data}
|
||||
await openai_send(event)
|
||||
msg_handler.openai_send(event)
|
||||
if voice_config.barge_in_enabled:
|
||||
await vad_queue.put(base64_data)
|
||||
elif msg.get("type") == "response.create":
|
||||
await openai_send(common_response_create(session_id, msg))
|
||||
create_response(msg)
|
||||
elif msg.get("type") == "input_audio_buffer.commit":
|
||||
if num_audio_samples > AUDIO_SAMPLE_THRESHOLD:
|
||||
await openai_send(msg)
|
||||
msg_handler.openai_send(msg)
|
||||
num_audio_samples = 0
|
||||
elif msg.get("type") == "langflow.voice_mode.config":
|
||||
logger.info(f"langflow.voice_mode.config {msg}")
|
||||
|
|
@ -882,110 +952,112 @@ async def flow_as_tool_websocket(
|
|||
modalities = ["text"] if voice_config.use_elevenlabs else ["audio", "text"]
|
||||
openai_realtime_session["modalities"] = modalities
|
||||
session_update = {"type": "session.update", "session": openai_realtime_session}
|
||||
await openai_send(session_update)
|
||||
msg_handler.openai_send(session_update)
|
||||
elif msg.get("type") == "session.update":
|
||||
openai_realtime_session = update_global_session(msg["session"])
|
||||
session_update = {"type": "session.update", "session": openai_realtime_session}
|
||||
await openai_send(session_update)
|
||||
msg_handler.openai_send(session_update)
|
||||
else:
|
||||
await openai_send(msg)
|
||||
msg_handler.openai_send(msg)
|
||||
except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError):
|
||||
pass
|
||||
|
||||
async def forward_to_client() -> None:
|
||||
nonlocal bot_speaking_flag, responses
|
||||
function_call = None
|
||||
function_call_args = ""
|
||||
conversation_id = str(uuid4())
|
||||
function_call = None
|
||||
rsp: Response | None = None
|
||||
# Store function call tasks to prevent garbage collection
|
||||
function_call_tasks = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await openai_ws.recv()
|
||||
event = json.loads(data)
|
||||
log_event(event, OPENAI_TO_LF)
|
||||
event_type = event.get("type")
|
||||
response_id = event.get("response_id", None) or event.get("response", {}).get("id", None)
|
||||
|
||||
do_forward = True
|
||||
do_forward = do_forward and not (event_type == "response.done" and voice_config.use_elevenlabs)
|
||||
do_forward = do_forward and event_type.find("flow.") != 0
|
||||
|
||||
response_id = None
|
||||
if do_forward:
|
||||
client_send(event)
|
||||
msg_handler.client_send(event)
|
||||
if event_type == "response.created":
|
||||
response_id = event["response"]["id"]
|
||||
responses[response_id] = Response(response_id, voice_config.use_elevenlabs)
|
||||
if function_call:
|
||||
if function_call.is_prog_enabled and not function_call.prog_rsp_id:
|
||||
function_call.prog_rsp_id = response_id
|
||||
elif not function_call.func_rsp_id:
|
||||
function_call.func_rsp_id = response_id
|
||||
elif event_type == "response.text.delta":
|
||||
rsp = responses[response_id]
|
||||
if voice_config.use_elevenlabs:
|
||||
response_id = event["response_id"]
|
||||
delta = event.get("delta", "")
|
||||
rsp: Response = responses[response_id]
|
||||
await rsp.text_delta_queue.put(delta)
|
||||
elif event_type == "response.text.done":
|
||||
rsp = responses[response_id]
|
||||
if voice_config.use_elevenlabs:
|
||||
response_id = event["response_id"]
|
||||
rsp = responses[response_id]
|
||||
await rsp.text_delta_queue.put(None)
|
||||
if rsp.text_delta_task and not rsp.text_delta_task.done():
|
||||
await rsp.text_delta_task
|
||||
responses.pop(response_id)
|
||||
msg_handler.client_send({"type": "response.done", "response": {"id": response_id}})
|
||||
|
||||
try:
|
||||
message_text = event.get("text", "")
|
||||
await add_message_to_db(message_text, session, flow_id, session_id, "Machine", "AI")
|
||||
except ValueError as e:
|
||||
logger.error(f"Error saving message to database (ValueError): {e}")
|
||||
except ValueError as err:
|
||||
logger.error(f"Error saving message to database (ValueError): {err}")
|
||||
logger.error(traceback.format_exc())
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
except (KeyError, AttributeError, TypeError) as err:
|
||||
# Replace blind Exception with specific exceptions
|
||||
logger.error(f"Error saving message to database: {e}")
|
||||
logger.error(f"Error saving message to database: {err}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
elif event_type == "response.output_item.added":
|
||||
bot_speaking_flag[0] = True
|
||||
item = event.get("item", {})
|
||||
if item.get("type") == "function_call":
|
||||
function_call = item
|
||||
function_call_args = ""
|
||||
if item.get("type") == "function_call" and (
|
||||
not function_call or (function_call and function_call.done)
|
||||
):
|
||||
function_call = FunctionCall(
|
||||
item=item,
|
||||
msg_handler=msg_handler,
|
||||
flow_id=flow_id,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
conversation_id=conversation_id,
|
||||
session_id=session_id,
|
||||
is_prog_enabled=voice_config.progress_enabled,
|
||||
)
|
||||
elif event_type == "response.output_item.done":
|
||||
try:
|
||||
transcript = extract_transcript(event)
|
||||
if transcript and transcript.strip():
|
||||
await add_message_to_db(transcript, session, flow_id, session_id, "Machine", "AI")
|
||||
except ValueError as e:
|
||||
logger.error(f"Error saving message to database (ValueError): {e}")
|
||||
except ValueError as err:
|
||||
logger.error(f"Error saving message to database (ValueError): {err}")
|
||||
logger.error(traceback.format_exc())
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
except (KeyError, AttributeError, TypeError) as err:
|
||||
# Replace blind Exception with specific exceptions
|
||||
logger.error(f"Error saving message to database: {e}")
|
||||
logger.error(f"Error saving message to database: {err}")
|
||||
logger.error(traceback.format_exc())
|
||||
bot_speaking_flag[0] = False
|
||||
elif event_type == "response.done":
|
||||
msg_handler.openai_unblock()
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
function_call_args += event.get("delta", "")
|
||||
if function_call and response_id not in (
|
||||
function_call.prog_rsp_id,
|
||||
function_call.func_rsp_id,
|
||||
):
|
||||
function_call.append_args(event.get("delta", ""))
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
if function_call:
|
||||
# Create and store reference to the task
|
||||
function_call_task = asyncio.create_task(
|
||||
handle_function_call(
|
||||
function_call,
|
||||
function_call_args,
|
||||
flow_id,
|
||||
background_tasks,
|
||||
current_user,
|
||||
conversation_id,
|
||||
session_id,
|
||||
voice_config,
|
||||
client_send,
|
||||
openai_send,
|
||||
)
|
||||
)
|
||||
# Store the task reference to prevent garbage collection
|
||||
function_call_tasks.append(function_call_task)
|
||||
# Clean up completed tasks periodically
|
||||
function_call_tasks = [t for t in function_call_tasks if not t.done()]
|
||||
function_call = None
|
||||
function_call_args = ""
|
||||
if function_call and response_id not in (
|
||||
function_call.prog_rsp_id,
|
||||
function_call.func_rsp_id,
|
||||
):
|
||||
function_call.execute()
|
||||
elif event_type == "response.audio.delta":
|
||||
# there are no audio deltas from OpenAI if ElevenLabs is used (because modality = ["text"]).
|
||||
event.get("delta", "")
|
||||
|
|
@ -1012,16 +1084,26 @@ async def flow_as_tool_websocket(
|
|||
vad_task = asyncio.create_task(process_vad_audio())
|
||||
|
||||
try:
|
||||
# Create tasks and gather them for concurrent execution
|
||||
task1 = asyncio.create_task(forward_to_openai())
|
||||
task2 = asyncio.create_task(forward_to_client())
|
||||
await asyncio.gather(task1, task2)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# handle any exceptions from any task
|
||||
logger.error("WS loop failed:", exc_info=exc)
|
||||
# Use gather with return_exceptions to collect any exceptions
|
||||
tasks = [asyncio.create_task(forward_to_openai()), asyncio.create_task(forward_to_client())]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check for exceptions in results
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error("WS loop failed:", exc_info=result)
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Handle any other exceptions
|
||||
logger.error(f"WS loop failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# shared cleanup for writers & sockets
|
||||
async def close():
|
||||
await msg_handler.close()
|
||||
await client_websocket.close()
|
||||
await openai_ws.close()
|
||||
|
||||
await close()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
|
|
@ -1064,31 +1146,37 @@ async def flow_tts_websocket(
|
|||
openai_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
client_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
log_event = create_event_logger()
|
||||
|
||||
async def openai_writer():
|
||||
while True:
|
||||
msg = await openai_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await openai_ws.send(msg)
|
||||
logger.trace(f"Sending text {LF_TO_OPENAI}: {msg['type']}")
|
||||
await openai_ws.send(json.dumps(msg))
|
||||
logger.trace("JSON sent.")
|
||||
log_event(msg, LF_TO_OPENAI)
|
||||
|
||||
async def client_writer():
|
||||
while True:
|
||||
msg = await client_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await client_websocket.send_text(msg)
|
||||
|
||||
log_event = create_event_logger(session_id)
|
||||
logger.trace(f"Sending JSON {LF_TO_CLIENT}: {msg['type']}")
|
||||
await client_websocket.send_text(json.dumps(msg))
|
||||
logger.trace("JSON sent.")
|
||||
log_event(msg, LF_TO_CLIENT)
|
||||
|
||||
def openai_send(payload):
|
||||
log_event(payload, DIRECTION_TO_OPENAI)
|
||||
logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}")
|
||||
log_event(payload, LF_TO_OPENAI)
|
||||
logger.trace(f"Sending text {LF_TO_OPENAI}: {payload['type']}")
|
||||
openai_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
def client_send(payload):
|
||||
log_event(payload, DIRECTION_TO_CLIENT)
|
||||
logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload['type']}")
|
||||
log_event(payload, LF_TO_CLIENT)
|
||||
logger.trace(f"Sending JSON {LF_TO_CLIENT}: {payload['type']}")
|
||||
client_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
|
|
@ -1100,7 +1188,6 @@ async def flow_tts_websocket(
|
|||
await client_websocket.close()
|
||||
await openai_ws.close()
|
||||
|
||||
log_event = create_event_logger(session_id)
|
||||
current_user: User = await get_current_user_for_websocket(client_websocket, session)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send)
|
||||
url = "wss://api.openai.com/v1/realtime?intent=transcription"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue