diff --git a/src/backend/base/langflow/api/v1/voice_mode.py b/src/backend/base/langflow/api/v1/voice_mode.py index 553a4abf7..af6752c1e 100644 --- a/src/backend/base/langflow/api/v1/voice_mode.py +++ b/src/backend/base/langflow/api/v1/voice_mode.py @@ -2,12 +2,12 @@ import asyncio import base64 import json import os -import queue -import threading +import time import traceback import uuid from collections import defaultdict from datetime import datetime, timezone +from functools import partial from typing import Any from uuid import UUID, uuid4 @@ -75,19 +75,21 @@ Your instructions will be divided into three mutually exclusive sections: "Perma [ADDITIONAL] The following instructions are to be considered only "Additional" """ -DIRECTION_TO_OPENAI = "Client → OpenAI" -DIRECTION_TO_CLIENT = "OpenAI → Client" +LF_TO_OPENAI = "LF → OpenAI" +LF_TO_CLIENT = "LF → Client" +OPENAI_TO_LF = "OpenAI → LF" +CLIENT_TO_LF = "Client → LF" # --- Helper Functions --- -async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_send_json): +async def authenticate_and_get_openai_key(session: DbSession, user: User, websocket: WebSocket): """Authenticate the user using a token or API key and retrieve the OpenAI API key. Returns a tuple: (current_user, openai_key). If authentication fails, sends an error message to the client and returns (None, None). """ if user is None: - await safe_send_json( + await websocket.send_json( { "type": "error", "code": "langflow_auth", @@ -102,7 +104,7 @@ async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_s ) openai_key = openai_key_value if openai_key_value is not None else os.getenv("OPENAI_API_KEY", "") if not openai_key or openai_key == "dummy": - await safe_send_json( + await websocket.send_json( { "type": "error", "code": "api_key_missing", @@ -118,47 +120,6 @@ async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_s return user, openai_key -# --- Synchronous Text Chunker --- -def sync_text_chunker(sync_queue_obj: queue.Queue, timeout: float = 0.3): - """Synchronous generator that reads text pieces from a sync queue and yields complete chunks.""" - splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ") - buffer = "" - while True: - try: - text = sync_queue_obj.get(timeout=timeout) - except queue.Empty: - if buffer: - yield buffer + " " - buffer = "" - continue - if text is None: - if buffer: - yield buffer + " " - break - if buffer and buffer[-1] in splitters: - yield buffer + " " - buffer = text - elif text and text[0] in splitters: - yield buffer + text[0] + " " - buffer = text[1:] - else: - buffer += text - if buffer: - yield buffer + " " - - -def common_response_create(session_id, original: dict | None = None) -> dict: - msg = {} - if original is not None: - msg = original - msg["type"] = "response.create" - voice_config = get_voice_config(session_id) - if voice_config.use_elevenlabs: - response = msg.setdefault("response", {}) - response["modalities"] = ["text"] - return msg - - class VoiceConfig: def __init__(self, session_id: str): self.session_id = session_id @@ -357,6 +318,86 @@ async def process_message_queue(queue_key, session): logger.error(traceback.format_exc()) +class SendQueues: + def __init__(self, openai_ws: websockets.WebSocketClientProtocol, client_ws: WebSocket, log_event): + self.openai_ws: websockets.WebSocketClientProtocol = openai_ws + self.openai_send_q: asyncio.Queue[tuple] = asyncio.Queue() + self.openai_writer_task: asyncio.Task = asyncio.create_task(self.__openai_writer()) + + self.block: asyncio.Event = asyncio.Event() + self.block.set() + + self.client_ws: WebSocket = client_ws + self.client_send_q: asyncio.Queue[dict] = asyncio.Queue() + self.client_writer_task: asyncio.Task = asyncio.create_task(self.__client_writer()) + self.log_event = log_event + + def openai_send(self, payload, *, is_blocking=False): + try: + self.openai_send_q.put_nowait([payload, is_blocking]) + except Exception: # noqa: BLE001 + logger.error(traceback.format_exc()) + + def openai_unblock(self): + logger.trace("OPENAI UNBLOCKING") + self.block.set() + + async def __openai_writer(self): + try: + while True: + msg, is_blocking = await self.openai_send_q.get() + if msg is None: + break + await self.block.wait() + await self.openai_ws.send(json.dumps(msg)) + self.log_event(msg, LF_TO_OPENAI) + if is_blocking: + self.block.clear() + logger.trace("OPENAI BLOCKING") + # log_event(msg, DIRECTION_TO_OPENAI) + except Exception: # noqa: BLE001 + logger.error(traceback.format_exc()) + + def client_send(self, payload): + try: + self.client_send_q.put_nowait(payload) + self.log_event(payload, LF_TO_OPENAI) + except Exception: # noqa: BLE001 + logger.error(traceback.format_exc()) + + async def __client_writer(self): + try: + while True: + msg = await self.client_send_q.get() + if msg is None: + break + self.log_event(msg, LF_TO_CLIENT) + await self.client_ws.send_text(json.dumps(msg)) + except Exception: # noqa: BLE001 + logger.error(traceback.format_exc()) + + async def close(self): + self.openai_send_q.put_nowait(None) + self.client_send_q.put_nowait(None) + await self.openai_writer_task + await self.client_writer_task + + +def get_create_response(send_handler: SendQueues, session_id): + def create_response(original: dict | None = None): + msg = {} + if original is not None: + msg = original + msg["type"] = "response.create" + voice_config = get_voice_config(session_id) + if voice_config.use_elevenlabs: + response = msg.setdefault("response", {}) + response["modalities"] = ["text"] + send_handler.openai_send(payload=msg, is_blocking=True) + + return create_response + + async def handle_function_call( function_call: dict, function_call_args: str, @@ -365,34 +406,11 @@ async def handle_function_call( current_user: CurrentActiveUser, conversation_id: str, session_id: str, - voice_config: VoiceConfig, - client_safe_send_json, - openai_send, + msg_handler: SendQueues, ): + create_response = get_create_response(msg_handler, session_id) """Handle function calls from the OpenAI API.""" try: - # trigger response that tool was called - if voice_config.progress_enabled: - await openai_send( - { - "type": "conversation.item.create", - "item": { - "type": "message", - "role": "system", - "content": [ - { - "type": "input_text", - "text": "Tell the user you are now looking into or solving " - "a request that will be explained later. Do not repeat " - "the prompt exactly, summarize what's being requested." - "Keep it very short." - f"\n\nThe request: {function_call_args}", - } - ], - }, - } - ) - await openai_send(common_response_create(session_id)) args = json.loads(function_call_args) if function_call_args else {} input_request = InputValueRequest( input_value=args.get("input"), components=[], type="chat", session=conversation_id @@ -408,7 +426,7 @@ async def handle_function_call( if not line: continue event_data = json.loads(line) - client_safe_send_json({"type": "flow.build.progress", "data": event_data}) + msg_handler.client_send({"type": "flow.build.progress", "data": event_data}) if event_data.get("event") == "end_vertex": text_part = ( event_data.get("data", {}) @@ -427,8 +445,8 @@ async def handle_function_call( "output": str(result), }, } - await openai_send(function_output) - await openai_send(common_response_create(session_id)) + msg_handler.openai_send(function_output) + create_response() except json.JSONDecodeError as e: trace = traceback.format_exc() logger.error(f"JSON decode error: {e!s}\ntrace: {trace}") @@ -440,7 +458,7 @@ async def handle_function_call( "output": f"Error parsing arguments: {e!s}", }, } - await openai_send(function_output) + msg_handler.openai_send(function_output) except ValueError as e: trace = traceback.format_exc() logger.error(f"Value error: {e!s}\ntrace: {trace}") @@ -452,7 +470,7 @@ async def handle_function_call( "output": f"Error with input values: {e!s}", }, } - await openai_send(function_output) + msg_handler.openai_send(function_output) except (ConnectionError, websockets.exceptions.WebSocketException) as e: trace = traceback.format_exc() logger.error(f"Connection error: {e!s}\ntrace: {trace}") @@ -464,7 +482,7 @@ async def handle_function_call( "output": f"Connection error: {e!s}", }, } - await openai_send(function_output) + msg_handler.openai_send(function_output) except (KeyError, AttributeError, TypeError) as e: logger.error(f"Error executing flow: {e}") logger.error(traceback.format_exc()) @@ -476,7 +494,7 @@ async def handle_function_call( "output": f"Error executing flow: {e}", }, } - await openai_send(function_output) + msg_handler.openai_send(function_output) voice_config_cache: dict[str, VoiceConfig] = {} @@ -549,14 +567,14 @@ async def queue_generator(queue: asyncio.Queue): yield item -def create_event_logger(session_id: str): +def create_event_logger(): state = {"last_event_type": None, "event_count": 0} - def log_event(event: dict, direction: str) -> None: + def log_event(event: dict, provenance: str) -> None: event_type = event.get("type", "None") - event_id = event.get("event_id", "None") + response_id = event.get("response_id") or event.get("response", {}).get("id", None) if event_type != state["last_event_type"]: - logger.debug(f"Event (id - {event_id}) (session - {session_id}): {direction} {event_type}") + logger.debug(f"Event (response_id - {response_id}): {provenance} {event_type}") state["last_event_type"] = event_type state["event_count"] = 0 if event_type == "response.created": @@ -570,6 +588,103 @@ def create_event_logger(session_id: str): return log_event +TTL_SECONDS = 60 +_completed: dict[str, float] = {} + + +def mark_response_done(response_id: str): + logger.debug(f"Marking response {response_id} as done") + _completed[response_id] = time.time() + + +# Don't let this grow unbounded +def is_response_done(response_id: str) -> bool: + now = time.time() + # prune old entries + for k, ts in list(_completed.items()): + if now - ts > TTL_SECONDS: + del _completed[k] + completed = response_id in _completed + if completed: + logger.debug(f"Response {response_id} is done: {completed}") + return completed + + +class FunctionCall: + def __init__( + self, + item: dict, + msg_handler, + flow_id: str, + background_tasks, + current_user, + conversation_id: str, + session_id: str, + *, + is_prog_enabled: bool, + ): + self.item = item + self.args = "" + self.done = False + self.prog_rsp_id: str | None = None + self.func_rsp_id: str | None = None + self.func_task: asyncio.Task | None = None + self.is_prog_enabled = is_prog_enabled + self.msg_handler = msg_handler + self.flow_id = flow_id + self.background_tasks = background_tasks + self.current_user = current_user + self.conversation_id = conversation_id + self.session_id = session_id + + def append_args(self, args: str): + self.args += args + + def execute(self): + if self.is_prog_enabled: + self._send_progress_message() + self._send_function_call() + + def _send_progress_message(self): + # Summarize and notify user of in-progress function call + self.msg_handler.openai_send( + { + "type": "conversation.item.create", + "item": { + "type": "message", + "role": "system", + "content": [ + { + "type": "input_text", + "text": "Tell the user you are now looking into or solving a request." + "and summarize what is being requested." + "Keep it very short." + f"\n\nThe request: {self.args}", + } + ], + }, + } + ) + create_response = partial(get_create_response(self.msg_handler, self.session_id)) + create_response() + + def _send_function_call(self): + async def _call(): + await handle_function_call( + function_call=self.item, + function_call_args=self.args, + flow_id=self.flow_id, + background_tasks=self.background_tasks, + current_user=self.current_user, + conversation_id=self.conversation_id, + session_id=self.session_id, + msg_handler=self.msg_handler, + ) + self.done = True + + self.func_task = asyncio.create_task(_call()) + + # --- WebSocket Endpoints for Flow-as-Tool --- @router.websocket("/ws/flow_as_tool/{flow_id}") async def flow_as_tool_websocket_no_session( @@ -600,59 +715,12 @@ async def flow_as_tool_websocket( try: await client_websocket.accept() - openai_send_q: asyncio.Queue[str] = asyncio.Queue() - client_send_q: asyncio.Queue[str] = asyncio.Queue() + log_event = create_event_logger() - async def openai_writer(): - while True: - msg = await openai_send_q.get() - if msg is None: - break - await openai_ws.send(msg) - - async def client_writer(): - while True: - msg = await client_send_q.get() - if msg is None: - break - await client_websocket.send_text(msg) - - log_event = create_event_logger(session_id) - - response_q: asyncio.Queue = asyncio.Queue() - await response_q.put(None) - - async def openai_send(payload): - log_event(payload, DIRECTION_TO_OPENAI) - logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}") - if payload.get("type") == "response.create": - await response_q.get() - openai_send_q.put_nowait(json.dumps(payload)) - logger.trace("JSON sent.") - - def client_send(payload): - log_event(payload, DIRECTION_TO_CLIENT) - logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload.get('type', 'None')}") - if payload.get("type") == "response.done": - try: - response_q.put_nowait(None) - except asyncio.QueueFull: - logger.warning("Queue is full, skipping put_nowait") - client_send_q.put_nowait(json.dumps(payload)) - logger.trace("JSON sent.") - - async def close(): - openai_send_q.put_nowait(None) - client_send_q.put_nowait(None) - await openai_writer_task - await client_writer_task - await client_websocket.close() - await openai_ws.close() - - vad_task = None + vad_task: asyncio.Task | None = None voice_config = get_voice_config(session_id) current_user: User = await get_current_user_for_websocket(client_websocket, session) - current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send) + current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_websocket) if current_user is None or openai_key is None: return try: @@ -668,7 +736,8 @@ async def flow_as_tool_websocket( }, } except Exception as e: # noqa: BLE001 - client_send({"error": f"Failed to load flow: {e!s}"}) + err_msg = {"error": f"Failed to load flow: {e!s}"} + await client_websocket.send_json(err_msg) logger.error(f"Failed to load flow: {e}") return @@ -684,13 +753,12 @@ async def flow_as_tool_websocket( return session_dict async with websockets.connect(url, extra_headers=headers) as openai_ws: + msg_handler = SendQueues(openai_ws, client_websocket, log_event) + openai_realtime_session = init_session_dict() - openai_writer_task = asyncio.create_task(openai_writer()) - client_writer_task = asyncio.create_task(client_writer()) - session_update = {"type": "session.update", "session": openai_realtime_session} - await openai_send(session_update) + msg_handler.openai_send(session_update) # Setup for VAD processing. vad_queue: asyncio.Queue = asyncio.Queue() @@ -716,7 +784,7 @@ async def flow_as_tool_websocket( has_speech = True logger.trace("!", end="") if bot_speaking_flag[0]: - await openai_send({"type": "response.cancel"}) + msg_handler.openai_send({"type": "response.cancel"}) bot_speaking_flag[0] = False except Exception as e: # noqa: BLE001 logger.error(f"[ERROR] VAD processing failed (ValueError): {e}") @@ -730,7 +798,7 @@ async def flow_as_tool_websocket( logger.trace("_", end="") def client_send_event_from_thread(event, loop) -> None: - return loop.call_soon_threadsafe(client_send, event) + return loop.call_soon_threadsafe(msg_handler.client_send, event) def pass_through(from_dict, to_dict, keys): for key in keys: @@ -790,69 +858,71 @@ async def flow_as_tool_websocket( then run the ElevenLabs TTS call (which expects a sync generator) in a separate thread. """ - sync_q: queue.Queue = queue.Queue() + try: + elevenlabs_client = await get_or_create_elevenlabs_client(current_user.id, session) + if elevenlabs_client is None: + return - async def transfer_text_deltas(): - while True: - item = await rsp.text_delta_queue.get() - sync_q.put(item) - if item is None: - break + async def get_chunks(q: asyncio.Queue): + delims = [".", "?", ";", "!"] + buf: str = "" + while True: + text = await q.get() + if text is None: + if len(buf) > 0: + yield buf + return + buf += text + delim_locs = [] + for delim in delims: + i = buf.find(delim) + while i != -1: + delim_locs.append(i) + i = buf.find(delim, i + 1) + substr_begin = 0 + for delim_loc in delim_locs: + chunk = buf[substr_begin : delim_loc + 1] + substr_begin = delim_loc + 1 + yield chunk + buf = buf[substr_begin:] - # Schedule the transfer task in the main event loop. - transfer_task = asyncio.create_task(transfer_text_deltas()) + chunk_gen = get_chunks(rsp.text_delta_queue) - # Create the synchronous generator from the sync queue. - sync_gen = sync_text_chunker(sync_q, timeout=0.3) - elevenlabs_client = await get_or_create_elevenlabs_client(current_user.id, session) - if elevenlabs_client is None: - transfer_task.cancel() - return - # Capture the current event loop to schedule send operations. - main_loop = asyncio.get_running_loop() + async for text_chunk in chunk_gen: + audio_chunks = elevenlabs_client.generate( + voice=voice_config.elevenlabs_voice, + output_format="pcm_24000", + text=text_chunk, # synchronous generator expected by ElevenLabs + model=voice_config.elevenlabs_model, + voice_settings=None, + stream=True, + ) + for audio_chunk in audio_chunks: + base64_audio = base64.b64encode(audio_chunk).decode("utf-8") + # Schedule sending the audio chunk in the main event loop. + event = { + "type": "response.audio.delta", + "delta": base64_audio, + "response_id": rsp.response_id, + } + # client_send_event_from_thread(event, main_loop) + msg_handler.client_send(event) - def tts_thread(): - # Create a new event loop for this thread. - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - - async def run_tts(): - try: - audio_stream = elevenlabs_client.generate( - voice=voice_config.elevenlabs_voice, - output_format="pcm_24000", - text=sync_gen, # synchronous generator expected by ElevenLabs - model=voice_config.elevenlabs_model, - voice_settings=None, - stream=True, - ) - for chunk in audio_stream: - base64_audio = base64.b64encode(chunk).decode("utf-8") - # Schedule sending the audio chunk in the main event loop. - event = { - "type": "response.audio.delta", - "delta": base64_audio, - "response_id": rsp.response_id, - } - client_send_event_from_thread(event, main_loop) - - event = {"type": "response.done", "response": {"id": rsp.response_id}} - client_send_event_from_thread(event, main_loop) - except Exception as e: # noqa: BLE001 - logger.error(f"Error in TTS processing (ValueError): {e}") - - new_loop.run_until_complete(run_tts()) - new_loop.close() - - threading.Thread(target=tts_thread, daemon=True).start() + event = {"type": "response.audio.done", "response_id": rsp.response_id} + # client_send_event_from_thread(event, main_loop) + msg_handler.client_send(event) + except Exception: # noqa: BLE001 + logger.error(traceback.format_exc()) async def forward_to_openai() -> None: nonlocal openai_realtime_session + create_response = get_create_response(msg_handler, session_id) try: num_audio_samples = 0 # Initialize as an integer instead of None while True: message_text = await client_websocket.receive_text() msg = json.loads(message_text) + log_event(msg, CLIENT_TO_LF) if msg.get("type") == "input_audio_buffer.append": logger.trace(f"buffer_id {msg.get('buffer_id', '')}") base64_data = msg.get("audio", "") @@ -861,14 +931,14 @@ async def flow_as_tool_websocket( # Ensure we're adding to an integer num_audio_samples += len(base64_data) event = {"type": "input_audio_buffer.append", "audio": base64_data} - await openai_send(event) + msg_handler.openai_send(event) if voice_config.barge_in_enabled: await vad_queue.put(base64_data) elif msg.get("type") == "response.create": - await openai_send(common_response_create(session_id, msg)) + create_response(msg) elif msg.get("type") == "input_audio_buffer.commit": if num_audio_samples > AUDIO_SAMPLE_THRESHOLD: - await openai_send(msg) + msg_handler.openai_send(msg) num_audio_samples = 0 elif msg.get("type") == "langflow.voice_mode.config": logger.info(f"langflow.voice_mode.config {msg}") @@ -882,110 +952,112 @@ async def flow_as_tool_websocket( modalities = ["text"] if voice_config.use_elevenlabs else ["audio", "text"] openai_realtime_session["modalities"] = modalities session_update = {"type": "session.update", "session": openai_realtime_session} - await openai_send(session_update) + msg_handler.openai_send(session_update) elif msg.get("type") == "session.update": openai_realtime_session = update_global_session(msg["session"]) session_update = {"type": "session.update", "session": openai_realtime_session} - await openai_send(session_update) + msg_handler.openai_send(session_update) else: - await openai_send(msg) + msg_handler.openai_send(msg) except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError): pass async def forward_to_client() -> None: nonlocal bot_speaking_flag, responses - function_call = None - function_call_args = "" conversation_id = str(uuid4()) + function_call = None + rsp: Response | None = None # Store function call tasks to prevent garbage collection - function_call_tasks = [] try: while True: data = await openai_ws.recv() event = json.loads(data) + log_event(event, OPENAI_TO_LF) event_type = event.get("type") + response_id = event.get("response_id", None) or event.get("response", {}).get("id", None) do_forward = True do_forward = do_forward and not (event_type == "response.done" and voice_config.use_elevenlabs) do_forward = do_forward and event_type.find("flow.") != 0 - response_id = None if do_forward: - client_send(event) + msg_handler.client_send(event) if event_type == "response.created": - response_id = event["response"]["id"] responses[response_id] = Response(response_id, voice_config.use_elevenlabs) + if function_call: + if function_call.is_prog_enabled and not function_call.prog_rsp_id: + function_call.prog_rsp_id = response_id + elif not function_call.func_rsp_id: + function_call.func_rsp_id = response_id elif event_type == "response.text.delta": + rsp = responses[response_id] if voice_config.use_elevenlabs: - response_id = event["response_id"] delta = event.get("delta", "") - rsp: Response = responses[response_id] await rsp.text_delta_queue.put(delta) elif event_type == "response.text.done": + rsp = responses[response_id] if voice_config.use_elevenlabs: - response_id = event["response_id"] - rsp = responses[response_id] await rsp.text_delta_queue.put(None) if rsp.text_delta_task and not rsp.text_delta_task.done(): await rsp.text_delta_task responses.pop(response_id) + msg_handler.client_send({"type": "response.done", "response": {"id": response_id}}) try: message_text = event.get("text", "") await add_message_to_db(message_text, session, flow_id, session_id, "Machine", "AI") - except ValueError as e: - logger.error(f"Error saving message to database (ValueError): {e}") + except ValueError as err: + logger.error(f"Error saving message to database (ValueError): {err}") logger.error(traceback.format_exc()) - except (KeyError, AttributeError, TypeError) as e: + except (KeyError, AttributeError, TypeError) as err: # Replace blind Exception with specific exceptions - logger.error(f"Error saving message to database: {e}") + logger.error(f"Error saving message to database: {err}") logger.error(traceback.format_exc()) elif event_type == "response.output_item.added": bot_speaking_flag[0] = True item = event.get("item", {}) - if item.get("type") == "function_call": - function_call = item - function_call_args = "" + if item.get("type") == "function_call" and ( + not function_call or (function_call and function_call.done) + ): + function_call = FunctionCall( + item=item, + msg_handler=msg_handler, + flow_id=flow_id, + background_tasks=background_tasks, + current_user=current_user, + conversation_id=conversation_id, + session_id=session_id, + is_prog_enabled=voice_config.progress_enabled, + ) elif event_type == "response.output_item.done": try: transcript = extract_transcript(event) if transcript and transcript.strip(): await add_message_to_db(transcript, session, flow_id, session_id, "Machine", "AI") - except ValueError as e: - logger.error(f"Error saving message to database (ValueError): {e}") + except ValueError as err: + logger.error(f"Error saving message to database (ValueError): {err}") logger.error(traceback.format_exc()) - except (KeyError, AttributeError, TypeError) as e: + except (KeyError, AttributeError, TypeError) as err: # Replace blind Exception with specific exceptions - logger.error(f"Error saving message to database: {e}") + logger.error(f"Error saving message to database: {err}") logger.error(traceback.format_exc()) bot_speaking_flag[0] = False + elif event_type == "response.done": + msg_handler.openai_unblock() elif event_type == "response.function_call_arguments.delta": - function_call_args += event.get("delta", "") + if function_call and response_id not in ( + function_call.prog_rsp_id, + function_call.func_rsp_id, + ): + function_call.append_args(event.get("delta", "")) elif event_type == "response.function_call_arguments.done": - if function_call: - # Create and store reference to the task - function_call_task = asyncio.create_task( - handle_function_call( - function_call, - function_call_args, - flow_id, - background_tasks, - current_user, - conversation_id, - session_id, - voice_config, - client_send, - openai_send, - ) - ) - # Store the task reference to prevent garbage collection - function_call_tasks.append(function_call_task) - # Clean up completed tasks periodically - function_call_tasks = [t for t in function_call_tasks if not t.done()] - function_call = None - function_call_args = "" + if function_call and response_id not in ( + function_call.prog_rsp_id, + function_call.func_rsp_id, + ): + function_call.execute() elif event_type == "response.audio.delta": # there are no audio deltas from OpenAI if ElevenLabs is used (because modality = ["text"]). event.get("delta", "") @@ -1012,16 +1084,26 @@ async def flow_as_tool_websocket( vad_task = asyncio.create_task(process_vad_audio()) try: - # Create tasks and gather them for concurrent execution - task1 = asyncio.create_task(forward_to_openai()) - task2 = asyncio.create_task(forward_to_client()) - await asyncio.gather(task1, task2) - except Exception as exc: # noqa: BLE001 - # handle any exceptions from any task - logger.error("WS loop failed:", exc_info=exc) + # Use gather with return_exceptions to collect any exceptions + tasks = [asyncio.create_task(forward_to_openai()), asyncio.create_task(forward_to_client())] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Check for exceptions in results + for result in results: + if isinstance(result, Exception): + logger.error("WS loop failed:", exc_info=result) + logger.error(traceback.format_exc()) + except Exception as e: # noqa: BLE001 + # Handle any other exceptions + logger.error(f"WS loop failed: {e}") logger.error(traceback.format_exc()) finally: # shared cleanup for writers & sockets + async def close(): + await msg_handler.close() + await client_websocket.close() + await openai_ws.close() + await close() except Exception as e: # noqa: BLE001 logger.error(f"Unexpected error: {e}") @@ -1064,31 +1146,37 @@ async def flow_tts_websocket( openai_send_q: asyncio.Queue[str] = asyncio.Queue() client_send_q: asyncio.Queue[str] = asyncio.Queue() + log_event = create_event_logger() + async def openai_writer(): while True: msg = await openai_send_q.get() if msg is None: break - await openai_ws.send(msg) + logger.trace(f"Sending text {LF_TO_OPENAI}: {msg['type']}") + await openai_ws.send(json.dumps(msg)) + logger.trace("JSON sent.") + log_event(msg, LF_TO_OPENAI) async def client_writer(): while True: msg = await client_send_q.get() if msg is None: break - await client_websocket.send_text(msg) - - log_event = create_event_logger(session_id) + logger.trace(f"Sending JSON {LF_TO_CLIENT}: {msg['type']}") + await client_websocket.send_text(json.dumps(msg)) + logger.trace("JSON sent.") + log_event(msg, LF_TO_CLIENT) def openai_send(payload): - log_event(payload, DIRECTION_TO_OPENAI) - logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}") + log_event(payload, LF_TO_OPENAI) + logger.trace(f"Sending text {LF_TO_OPENAI}: {payload['type']}") openai_send_q.put_nowait(json.dumps(payload)) logger.trace("JSON sent.") def client_send(payload): - log_event(payload, DIRECTION_TO_CLIENT) - logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload['type']}") + log_event(payload, LF_TO_CLIENT) + logger.trace(f"Sending JSON {LF_TO_CLIENT}: {payload['type']}") client_send_q.put_nowait(json.dumps(payload)) logger.trace("JSON sent.") @@ -1100,7 +1188,6 @@ async def flow_tts_websocket( await client_websocket.close() await openai_ws.close() - log_event = create_event_logger(session_id) current_user: User = await get_current_user_for_websocket(client_websocket, session) current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send) url = "wss://api.openai.com/v1/realtime?intent=transcription"