fix: voice mode progress reliability (#7830)

* progress update nudge

* progress update nudge

* nick's fix

* progress toggle

* typo

* fix auth and ws thread safety

* [autofix.ci] apply automated fixes

* fix deadlock

* delete commented function

* fix duplicate events

* [autofix.ci] apply automated fixes

* merge

* clean up log_event

* queues not locks

* response ids for 11L flow

* async bug

* Fix awaits

* ElevenLabsResponse -> Response

* queues not locks

* response_q

* comment

* small fix

* ♻️ (voice_mode.py): refactor code to use individual tasks and gather them for concurrent execution instead of TaskGroup for better readability and error handling.

* mypy

---------

Co-authored-by: phact <estevezsebastian@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Cristhian Zanforlin Lousa <cristhian.lousa@gmail.com>
This commit is contained in:
Nicholas Freybler 2025-04-30 20:37:49 -04:00 committed by GitHub
commit 7c0beedcab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 421 additions and 250 deletions

View file

@ -18,7 +18,7 @@ import webrtcvad
import websockets
from cryptography.fernet import InvalidToken
from elevenlabs import ElevenLabs
from fastapi import APIRouter, BackgroundTasks, Security
from fastapi import APIRouter, BackgroundTasks
from openai import OpenAI
from sqlalchemy import select
from starlette.websockets import WebSocket, WebSocketDisconnect
@ -29,8 +29,8 @@ from langflow.api.v1.schemas import InputValueRequest
from langflow.logging import logger
from langflow.memory import aadd_messagetables
from langflow.schema.properties import Properties
from langflow.services.auth.utils import api_key_header, api_key_query, api_key_security, get_current_user_by_jwt
from langflow.services.database.models import MessageTable
from langflow.services.auth.utils import get_current_user_for_websocket
from langflow.services.database.models import MessageTable, User
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import get_variable_service, session_scope
from langflow.utils.voice_utils import (
@ -75,38 +75,34 @@ Your instructions will be divided into three mutually exclusive sections: "Perma
[ADDITIONAL] The following instructions are to be considered only "Additional"
"""
DIRECTION_TO_OPENAI = "Client → OpenAI"
DIRECTION_TO_CLIENT = "OpenAI → Client"
# --- Helper Functions ---
async def authenticate_and_get_openai_key(client_websocket: WebSocket, session: DbSession):
async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_send_json):
"""Authenticate the user using a token or API key and retrieve the OpenAI API key.
Returns a tuple: (current_user, openai_key). If authentication fails, sends an error
message to the client and returns (None, None).
"""
token = client_websocket.cookies.get("access_token_lf")
current_user = None
if token:
current_user = await get_current_user_by_jwt(token, session)
if current_user is None:
current_user = await api_key_security(Security(api_key_query), Security(api_key_header))
if current_user is None:
await client_websocket.send_json(
{
"type": "error",
"code": "langflow_auth",
"message": "You must pass a valid Langflow token or cookie.",
}
)
return None, None
if user is None:
await safe_send_json(
{
"type": "error",
"code": "langflow_auth",
"message": "You must pass a valid Langflow token or cookie.",
}
)
return None, None
variable_service = get_variable_service()
try:
openai_key_value = await variable_service.get_variable(
user_id=current_user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session
user_id=user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session
)
openai_key = openai_key_value if openai_key_value is not None else os.getenv("OPENAI_API_KEY", "")
if not openai_key or openai_key == "dummy":
await client_websocket.send_json(
await safe_send_json(
{
"type": "error",
"code": "api_key_missing",
@ -119,7 +115,7 @@ async def authenticate_and_get_openai_key(client_websocket: WebSocket, session:
logger.error(f"Error with API key: {e}")
logger.error(traceback.format_exc())
return None, None
return current_user, openai_key
return user, openai_key
# --- Synchronous Text Chunker ---
@ -151,105 +147,16 @@ def sync_text_chunker(sync_queue_obj: queue.Queue, timeout: float = 0.3):
yield buffer + " "
async def handle_function_call(
websocket: WebSocket,
openai_ws: websockets.WebSocketClientProtocol,
function_call: dict,
function_call_args: str,
flow_id: str,
background_tasks: BackgroundTasks,
current_user: CurrentActiveUser,
conversation_id: str,
):
"""Handle function calls from the OpenAI API."""
try:
args = json.loads(function_call_args) if function_call_args else {}
input_request = InputValueRequest(
input_value=args.get("input"), components=[], type="chat", session=conversation_id
)
response = await build_flow_and_stream(
flow_id=UUID(flow_id),
inputs=input_request,
background_tasks=background_tasks,
current_user=current_user,
)
result = ""
async for line in response.body_iterator:
if not line:
continue
event_data = json.loads(line)
await websocket.send_json({"type": "flow.build.progress", "data": event_data})
if event_data.get("event") == "end_vertex":
text_part = (
event_data.get("data", {})
.get("build_data", "")
.get("data", {})
.get("results", {})
.get("message", {})
.get("text", "")
)
result += text_part
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": str(result),
},
}
await openai_ws.send(json.dumps(function_output))
await openai_ws.send(json.dumps({"type": "response.create"}))
except json.JSONDecodeError as e:
trace = traceback.format_exc()
logger.error(f"JSON decode error: {e!s}\ntrace: {trace}")
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Error parsing arguments: {e!s}",
},
}
await openai_ws.send(json.dumps(function_output))
except ValueError as e:
trace = traceback.format_exc()
logger.error(f"Value error: {e!s}\ntrace: {trace}")
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Error with input values: {e!s}",
},
}
await openai_ws.send(json.dumps(function_output))
except (ConnectionError, websockets.exceptions.WebSocketException) as e:
trace = traceback.format_exc()
logger.error(f"Connection error: {e!s}\ntrace: {trace}")
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Connection error: {e!s}",
},
}
await openai_ws.send(json.dumps(function_output))
except (KeyError, AttributeError, TypeError) as e:
logger.error(f"Error executing flow: {e}")
logger.error(traceback.format_exc())
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Error executing flow: {e}",
},
}
await openai_ws.send(json.dumps(function_output))
# --- Config Classes and Caches ---
def common_response_create(session_id, original: dict | None = None) -> dict:
msg = {}
if original is not None:
msg = original
msg["type"] = "response.create"
voice_config = get_voice_config(session_id)
if voice_config.use_elevenlabs:
response = msg.setdefault("response", {})
response["modalities"] = ["text"]
return msg
class VoiceConfig:
@ -261,6 +168,7 @@ class VoiceConfig:
self.elevenlabs_client = None
self.elevenlabs_key = None
self.barge_in_enabled = False
self.progress_enabled = True
self.default_openai_realtime_session = {
"modalities": ["text", "audio"],
@ -318,9 +226,6 @@ class ElevenLabsClientManager:
return cls._instance
voice_config_cache: dict[str, VoiceConfig] = {}
def get_voice_config(session_id: str) -> VoiceConfig:
if session_id is None:
msg = "session_id cannot be None"
@ -369,9 +274,6 @@ class TTSConfig:
return self.openai_voice
tts_config_cache: dict[str, TTSConfig] = {}
def get_tts_config(session_id: str, openai_key: str) -> TTSConfig:
if session_id is None:
msg = "session_id cannot be None"
@ -455,6 +357,132 @@ async def process_message_queue(queue_key, session):
logger.error(traceback.format_exc())
async def handle_function_call(
function_call: dict,
function_call_args: str,
flow_id: str,
background_tasks: BackgroundTasks,
current_user: CurrentActiveUser,
conversation_id: str,
session_id: str,
voice_config: VoiceConfig,
client_safe_send_json,
openai_send,
):
"""Handle function calls from the OpenAI API."""
try:
# trigger response that tool was called
if voice_config.progress_enabled:
await openai_send(
{
"type": "conversation.item.create",
"item": {
"type": "message",
"role": "system",
"content": [
{
"type": "input_text",
"text": "Tell the user you are now looking into or solving "
"a request that will be explained later. Do not repeat "
"the prompt exactly, summarize what's being requested."
"Keep it very short."
f"\n\nThe request: {function_call_args}",
}
],
},
}
)
await openai_send(common_response_create(session_id))
args = json.loads(function_call_args) if function_call_args else {}
input_request = InputValueRequest(
input_value=args.get("input"), components=[], type="chat", session=conversation_id
)
response = await build_flow_and_stream(
flow_id=UUID(flow_id),
inputs=input_request,
background_tasks=background_tasks,
current_user=current_user,
)
result = ""
async for line in response.body_iterator:
if not line:
continue
event_data = json.loads(line)
client_safe_send_json({"type": "flow.build.progress", "data": event_data})
if event_data.get("event") == "end_vertex":
text_part = (
event_data.get("data", {})
.get("build_data", "")
.get("data", {})
.get("results", {})
.get("message", {})
.get("text", "")
)
result += text_part
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": str(result),
},
}
await openai_send(function_output)
await openai_send(common_response_create(session_id))
except json.JSONDecodeError as e:
trace = traceback.format_exc()
logger.error(f"JSON decode error: {e!s}\ntrace: {trace}")
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Error parsing arguments: {e!s}",
},
}
await openai_send(function_output)
except ValueError as e:
trace = traceback.format_exc()
logger.error(f"Value error: {e!s}\ntrace: {trace}")
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Error with input values: {e!s}",
},
}
await openai_send(function_output)
except (ConnectionError, websockets.exceptions.WebSocketException) as e:
trace = traceback.format_exc()
logger.error(f"Connection error: {e!s}\ntrace: {trace}")
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Connection error: {e!s}",
},
}
await openai_send(function_output)
except (KeyError, AttributeError, TypeError) as e:
logger.error(f"Error executing flow: {e}")
logger.error(traceback.format_exc())
function_output = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": function_call.get("call_id"),
"output": f"Error executing flow: {e}",
},
}
await openai_send(function_output)
voice_config_cache: dict[str, VoiceConfig] = {}
tts_config_cache: dict[str, TTSConfig] = {}
# --- Global Queues and Message Processing ---
message_queues: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)
@ -525,11 +553,17 @@ def create_event_logger(session_id: str):
state = {"last_event_type": None, "event_count": 0}
def log_event(event: dict, direction: str) -> None:
event_type = event["type"]
event_type = event.get("type", "None")
event_id = event.get("event_id", "None")
if event_type != state["last_event_type"]:
logger.debug(f"Event (session - {session_id}): {direction} {event_type}")
logger.debug(f"Event (id - {event_id}) (session - {session_id}): {direction} {event_type}")
state["last_event_type"] = event_type
state["event_count"] = 0
if event_type == "response.created":
response_id = event.get("response", {}).get("id")
logger.debug(f"response_id: {response_id}")
if "error" in event_type:
logger.debug(f"Error {event}")
current_count = 0 if state["event_count"] is None else state["event_count"]
state["event_count"] = current_count + 1
@ -565,8 +599,60 @@ async def flow_as_tool_websocket(
"""WebSocket endpoint registering the flow as a tool for real-time interaction."""
try:
await client_websocket.accept()
openai_send_q: asyncio.Queue[str] = asyncio.Queue()
client_send_q: asyncio.Queue[str] = asyncio.Queue()
async def openai_writer():
while True:
msg = await openai_send_q.get()
if msg is None:
break
await openai_ws.send(msg)
async def client_writer():
while True:
msg = await client_send_q.get()
if msg is None:
break
await client_websocket.send_text(msg)
log_event = create_event_logger(session_id)
response_q: asyncio.Queue = asyncio.Queue()
await response_q.put(None)
async def openai_send(payload):
log_event(payload, DIRECTION_TO_OPENAI)
logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}")
if payload.get("type") == "response.create":
await response_q.get()
openai_send_q.put_nowait(json.dumps(payload))
logger.trace("JSON sent.")
def client_send(payload):
log_event(payload, DIRECTION_TO_CLIENT)
logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload.get('type', 'None')}")
if payload.get("type") == "response.done":
try:
response_q.put_nowait(None)
except asyncio.QueueFull:
logger.warning("Queue is full, skipping put_nowait")
client_send_q.put_nowait(json.dumps(payload))
logger.trace("JSON sent.")
async def close():
openai_send_q.put_nowait(None)
client_send_q.put_nowait(None)
await openai_writer_task
await client_writer_task
await client_websocket.close()
await openai_ws.close()
vad_task = None
voice_config = get_voice_config(session_id)
current_user, openai_key = await authenticate_and_get_openai_key(client_websocket, session)
current_user: User = await get_current_user_for_websocket(client_websocket, session)
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send)
if current_user is None or openai_key is None:
return
try:
@ -582,7 +668,7 @@ async def flow_as_tool_websocket(
},
}
except Exception as e: # noqa: BLE001
await client_websocket.send_json({"error": f"Failed to load flow: {e!s}"})
client_send({"error": f"Failed to load flow: {e!s}"})
logger.error(f"Failed to load flow: {e}")
return
@ -599,8 +685,12 @@ async def flow_as_tool_websocket(
async with websockets.connect(url, extra_headers=headers) as openai_ws:
openai_realtime_session = init_session_dict()
openai_writer_task = asyncio.create_task(openai_writer())
client_writer_task = asyncio.create_task(client_writer())
session_update = {"type": "session.update", "session": openai_realtime_session}
await openai_ws.send(json.dumps(session_update))
await openai_send(session_update)
# Setup for VAD processing.
vad_queue: asyncio.Queue = asyncio.Queue()
@ -626,7 +716,7 @@ async def flow_as_tool_websocket(
has_speech = True
logger.trace("!", end="")
if bot_speaking_flag[0]:
await openai_ws.send(json.dumps({"type": "response.cancel"}))
await openai_send({"type": "response.cancel"})
bot_speaking_flag[0] = False
except Exception as e: # noqa: BLE001
logger.error(f"[ERROR] VAD processing failed (ValueError): {e}")
@ -639,33 +729,8 @@ async def flow_as_tool_websocket(
if time_since_speech >= 1.0:
logger.trace("_", end="")
shared_state = {"last_event_type": None, "event_count": 0}
def log_event(event, _direction: str) -> None:
event_type = event["type"]
# Ensure shared_state has necessary keys initialized
if "last_event_type" not in shared_state:
shared_state["last_event_type"] = None
if "event_count" not in shared_state:
shared_state["event_count"] = 0
if event_type != shared_state["last_event_type"]:
logger.debug(f"Event (session - {session_id}): {_direction} {event_type}")
shared_state["last_event_type"] = event_type
shared_state["event_count"] = 0
# Explicitly convert to integer if needed
current_count = int(shared_state["event_count"]) if shared_state["event_count"] is not None else 0
shared_state["event_count"] = current_count + 1
def send_event(websocket, event, loop, direction) -> None:
asyncio.run_coroutine_threadsafe(
websocket.send_json(event),
loop,
).result()
log_event(event, direction)
def client_send_event_from_thread(event, loop) -> None:
return loop.call_soon_threadsafe(client_send, event)
def pass_through(from_dict, to_dict, keys):
for key in keys:
@ -709,11 +774,18 @@ async def flow_as_tool_websocket(
)
return new_session
# --- Spawn a text delta queue and task for TTS ---
text_delta_queue: asyncio.Queue = asyncio.Queue()
text_delta_task: asyncio.Task | None = None # Will hold our background task.
class Response:
def __init__(self, response_id: str, use_elevenlabs: bool | None = None):
if use_elevenlabs is None:
use_elevenlabs = False
self.response_id = response_id
if use_elevenlabs:
self.text_delta_queue: asyncio.Queue = asyncio.Queue()
self.text_delta_task = asyncio.create_task(process_text_deltas(self))
async def process_text_deltas(async_q: asyncio.Queue):
responses = {}
async def process_text_deltas(rsp: Response):
"""Transfer text deltas from the async queue to a synchronous queue.
then run the ElevenLabs TTS call (which expects a sync generator) in a separate thread.
@ -722,7 +794,7 @@ async def flow_as_tool_websocket(
async def transfer_text_deltas():
while True:
item = await async_q.get()
item = await rsp.text_delta_queue.get()
sync_q.put(item)
if item is None:
break
@ -757,11 +829,15 @@ async def flow_as_tool_websocket(
for chunk in audio_stream:
base64_audio = base64.b64encode(chunk).decode("utf-8")
# Schedule sending the audio chunk in the main event loop.
event = {"type": "response.audio.delta", "delta": base64_audio}
send_event(client_websocket, event, main_loop, "")
event = {
"type": "response.audio.delta",
"delta": base64_audio,
"response_id": rsp.response_id,
}
client_send_event_from_thread(event, main_loop)
event = {"type": "response.done"}
send_event(client_websocket, event, main_loop, "")
event = {"type": "response.done", "response": {"id": rsp.response_id}}
client_send_event_from_thread(event, main_loop)
except Exception as e: # noqa: BLE001
logger.error(f"Error in TTS processing (ValueError): {e}")
@ -785,20 +861,18 @@ async def flow_as_tool_websocket(
# Ensure we're adding to an integer
num_audio_samples += len(base64_data)
event = {"type": "input_audio_buffer.append", "audio": base64_data}
await openai_ws.send(json.dumps(event))
log_event(event, "")
await openai_send(event)
if voice_config.barge_in_enabled:
await vad_queue.put(base64_data)
elif msg.get("type") == "response.create":
if voice_config.use_elevenlabs:
response = msg.setdefault("response", {})
response["modalities"] = ["text"]
await openai_ws.send(json.dumps(msg))
await openai_send(common_response_create(session_id, msg))
elif msg.get("type") == "input_audio_buffer.commit":
if num_audio_samples > AUDIO_SAMPLE_THRESHOLD:
await openai_ws.send(message_text)
log_event(msg, "")
await openai_send(msg)
num_audio_samples = 0
elif msg.get("type") == "langflow.voice_mode.config":
logger.info(f"langflow.voice_mode.config {msg}")
voice_config.progress_enabled = msg.get("progress_enabled", True)
elif msg.get("type") == "langflow.elevenlabs.config":
logger.info(f"langflow.elevenlabs.config {msg}")
voice_config.use_elevenlabs = msg["enabled"]
@ -808,21 +882,18 @@ async def flow_as_tool_websocket(
modalities = ["text"] if voice_config.use_elevenlabs else ["audio", "text"]
openai_realtime_session["modalities"] = modalities
session_update = {"type": "session.update", "session": openai_realtime_session}
await openai_ws.send(json.dumps(session_update))
log_event(session_update, "")
await openai_send(session_update)
elif msg.get("type") == "session.update":
openai_realtime_session = update_global_session(msg["session"])
session_update = {"type": "session.update", "session": openai_realtime_session}
await openai_ws.send(json.dumps(session_update))
log_event(session_update, "")
await openai_send(session_update)
else:
await openai_ws.send(message_text)
log_event(msg, "")
await openai_send(msg)
except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError):
pass
async def forward_to_client() -> None:
nonlocal bot_speaking_flag, text_delta_queue, text_delta_task
nonlocal bot_speaking_flag, responses
function_call = None
function_call_args = ""
conversation_id = str(uuid4())
@ -838,22 +909,27 @@ async def flow_as_tool_websocket(
do_forward = True
do_forward = do_forward and not (event_type == "response.done" and voice_config.use_elevenlabs)
do_forward = do_forward and event_type.find("flow.") != 0
if do_forward:
await client_websocket.send_text(data)
if event_type == "response.text.delta":
response_id = None
if do_forward:
client_send(event)
if event_type == "response.created":
response_id = event["response"]["id"]
responses[response_id] = Response(response_id, voice_config.use_elevenlabs)
elif event_type == "response.text.delta":
if voice_config.use_elevenlabs:
response_id = event["response_id"]
delta = event.get("delta", "")
await text_delta_queue.put(delta)
if text_delta_task is None:
# if text_delta_task is None or text_delta_task.done():
text_delta_task = asyncio.create_task(process_text_deltas(text_delta_queue))
rsp: Response = responses[response_id]
await rsp.text_delta_queue.put(delta)
elif event_type == "response.text.done":
if voice_config.use_elevenlabs:
await text_delta_queue.put(None)
if text_delta_task and not text_delta_task.done():
await text_delta_task
text_delta_task = None
response_id = event["response_id"]
rsp = responses[response_id]
await rsp.text_delta_queue.put(None)
if rsp.text_delta_task and not rsp.text_delta_task.done():
await rsp.text_delta_task
responses.pop(response_id)
try:
message_text = event.get("text", "")
@ -892,14 +968,16 @@ async def flow_as_tool_websocket(
# Create and store reference to the task
function_call_task = asyncio.create_task(
handle_function_call(
client_websocket,
openai_ws,
function_call,
function_call_args,
flow_id,
background_tasks,
current_user,
conversation_id,
session_id,
voice_config,
client_send,
openai_send,
)
)
# Store the task reference to prevent garbage collection
@ -925,34 +1003,30 @@ async def flow_as_tool_websocket(
logger.error(traceback.format_exc())
elif event_type == "error":
pass
else:
await client_websocket.send_text(data)
log_event(event, "")
except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError):
pass
# Fix for storing references to asyncio tasks
vad_task = None
if voice_config.barge_in_enabled:
# Store the task reference to prevent it from being garbage collected
vad_task = asyncio.create_task(process_vad_audio())
await asyncio.gather(
forward_to_openai(),
forward_to_client(),
)
try:
# Create tasks and gather them for concurrent execution
task1 = asyncio.create_task(forward_to_openai())
task2 = asyncio.create_task(forward_to_client())
await asyncio.gather(task1, task2)
except Exception as exc: # noqa: BLE001
# handle any exceptions from any task
logger.error("WS loop failed:", exc_info=exc)
logger.error(traceback.format_exc())
finally:
# shared cleanup for writers & sockets
await close()
except Exception as e: # noqa: BLE001
logger.error(f"Value error: {e}")
logger.error(f"Unexpected error: {e}")
logger.error(traceback.format_exc())
finally:
# Ensure that the client websocket is closed.
try:
await client_websocket.close()
except Exception as e: # noqa: BLE001
logger.debug(f"{e} ")
logger.info("Client websocket cleanup complete.")
# Make sure to clean up the task
if vad_task and not vad_task.done():
vad_task.cancel()
@ -986,11 +1060,49 @@ async def flow_tts_websocket(
"""WebSocket endpoint for direct flow text-to-speech interaction."""
try:
await client_websocket.accept()
log_event = create_event_logger(session_id)
current_user, openai_key = await authenticate_and_get_openai_key(client_websocket, session)
if current_user is None or openai_key is None:
return
openai_send_q: asyncio.Queue[str] = asyncio.Queue()
client_send_q: asyncio.Queue[str] = asyncio.Queue()
async def openai_writer():
while True:
msg = await openai_send_q.get()
if msg is None:
break
await openai_ws.send(msg)
async def client_writer():
while True:
msg = await client_send_q.get()
if msg is None:
break
await client_websocket.send_text(msg)
log_event = create_event_logger(session_id)
def openai_send(payload):
log_event(payload, DIRECTION_TO_OPENAI)
logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}")
openai_send_q.put_nowait(json.dumps(payload))
logger.trace("JSON sent.")
def client_send(payload):
log_event(payload, DIRECTION_TO_CLIENT)
logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload['type']}")
client_send_q.put_nowait(json.dumps(payload))
logger.trace("JSON sent.")
async def close():
openai_send_q.put_nowait(None)
client_send_q.put_nowait(None)
await openai_writer_task
await client_writer_task
await client_websocket.close()
await openai_ws.close()
log_event = create_event_logger(session_id)
current_user: User = await get_current_user_for_websocket(client_websocket, session)
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send)
url = "wss://api.openai.com/v1/realtime?intent=transcription"
headers = {
"Authorization": f"Bearer {openai_key}",
@ -999,23 +1111,26 @@ async def flow_tts_websocket(
tts_config = get_tts_config(session_id, openai_key)
async with websockets.connect(url, extra_headers=headers) as openai_ws:
openai_writer_task = asyncio.create_task(openai_writer())
client_writer_task = asyncio.create_task(client_writer())
tts_realtime_session = tts_config.get_session_dict()
await openai_ws.send(json.dumps(tts_realtime_session))
openai_send(tts_realtime_session)
async def forward_to_openai() -> None:
try:
while True:
message_text = await client_websocket.receive_text()
event = json.loads(message_text)
log_event(event, "Client → OpenAI")
if event.get("type") == "input_audio_buffer.append":
base64_data = event.get("audio", "")
if not base64_data:
continue
out_event = {"type": "input_audio_buffer.append", "audio": base64_data}
await openai_ws.send(json.dumps(out_event))
openai_send(out_event)
elif event.get("type") == "input_audio_buffer.commit":
await openai_ws.send(message_text)
openai_send(event)
elif event.get("type") == "langflow.elevenlabs.config":
logger.info(f"langflow.elevenlabs.config {event}")
tts_config.use_elevenlabs = event["enabled"]
@ -1033,8 +1148,7 @@ async def flow_tts_websocket(
while True:
data = await openai_ws.recv()
event = json.loads(data)
log_event(event, "OpenAI → Client")
await client_websocket.send_text(data)
client_send(event)
if event.get("type") == "conversation.item.input_audio_transcription.completed":
transcript = event.get("transcript")
if transcript is not None and transcript != "":
@ -1052,9 +1166,7 @@ async def flow_tts_websocket(
if not line:
continue
event_data = json.loads(line)
await client_websocket.send_json(
{"type": "flow.build.progress", "data": event_data}
)
client_send({"type": "flow.build.progress", "data": event_data})
if event_data.get("event") == "end_vertex":
text = (
event_data.get("data", {})
@ -1084,7 +1196,7 @@ async def flow_tts_websocket(
for chunk in audio_stream:
base64_audio = base64.b64encode(chunk).decode("utf-8")
audio_event = {"type": "response.audio.delta", "delta": base64_audio}
await client_websocket.send_json(audio_event)
client_send(audio_event)
else:
oai_client = tts_config.get_openai_client()
voice = tts_config.get_openai_voice()
@ -1098,26 +1210,24 @@ async def flow_tts_websocket(
base64_audio = base64.b64encode(response.content).decode("utf-8")
audio_event = {"type": "response.audio.delta", "delta": base64_audio}
await client_websocket.send_json(audio_event)
client_send(audio_event)
except Exception as e: # noqa: BLE001
logger.error(f"Error in WebSocket communication: {e}")
forward_to_openai_task = asyncio.create_task(forward_to_openai())
forward_to_client_task = asyncio.create_task(forward_to_client())
try:
await asyncio.gather(
forward_to_openai_task,
forward_to_client_task,
)
except Exception as e: # noqa: BLE001
logger.error(f"Error in WebSocket communication: {e}")
logger.error(traceback.format_exc())
# Create tasks and gather them for concurrent execution
task1 = asyncio.create_task(forward_to_openai())
task2 = asyncio.create_task(forward_to_client())
await asyncio.gather(task1, task2)
except Exception as exc: # noqa: BLE001
# handle any exceptions from any task
logger.error("WS loop failed:", exc_info=exc)
finally:
forward_to_openai_task.cancel()
# shared cleanup for writers & sockets
await close()
except Exception as e: # noqa: BLE001
logger.error(f"WebSocket error: {e}")
logger.error(f"Unexpected error: {e}")
logger.error(traceback.format_exc())
await client_websocket.close()
def extract_transcript(json_data):

View file

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Annotated
from uuid import UUID
from cryptography.fernet import Fernet
from fastapi import Depends, HTTPException, Security, status
from fastapi import Depends, HTTPException, Security, WebSocketException, status
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
from jose import JWTError, jwt
from loguru import logger
@ -86,6 +86,55 @@ async def api_key_security(
raise ValueError(msg)
async def ws_api_key_security(
api_key: str | None,
) -> UserRead:
settings = get_settings_service()
async with get_db_service().with_session() as db:
if settings.auth_settings.AUTO_LOGIN:
if not settings.auth_settings.SUPERUSER:
# internal server misconfiguration
raise WebSocketException(
code=status.WS_1011_INTERNAL_ERROR,
reason="Missing first superuser credentials",
)
warnings.warn(
("In v1.5, AUTO_LOGIN will *require* a valid API key or JWT. Please update your clients accordingly."),
DeprecationWarning,
stacklevel=2,
)
if api_key:
result = await check_key(db, api_key)
else:
result = await get_user_by_username(db, settings.auth_settings.SUPERUSER)
# normal path: must provide an API key
else:
if not api_key:
raise WebSocketException(
code=status.WS_1008_POLICY_VIOLATION,
reason="An API key must be passed as query or header",
)
result = await check_key(db, api_key)
# key was invalid or missing
if not result:
raise WebSocketException(
code=status.WS_1008_POLICY_VIOLATION,
reason="Invalid or missing API key",
)
# convert SQL-model User → pydantic UserRead
if isinstance(result, User):
return UserRead.model_validate(result, from_attributes=True)
# fallback: something unexpected happened
raise WebSocketException(
code=status.WS_1011_INTERNAL_ERROR,
reason="Authentication subsystem error",
)
async def get_current_user(
token: Annotated[str, Security(oauth2_login)],
query_param: Annotated[str, Security(api_key_query)],
@ -167,16 +216,28 @@ async def get_current_user_by_jwt(
async def get_current_user_for_websocket(
websocket: WebSocket,
db: Annotated[AsyncSession, Depends(get_session)],
query_param: Annotated[str, Security(api_key_query)],
) -> User | None:
token = websocket.query_params.get("token")
api_key = websocket.query_params.get("x-api-key")
db: AsyncSession,
) -> User | UserRead:
token = websocket.cookies.get("access_token_lf") or websocket.query_params.get("token")
if token:
return await get_current_user_by_jwt(token, db)
user = await get_current_user_by_jwt(token, db)
if user:
return user
api_key = (
websocket.query_params.get("x-api-key")
or websocket.query_params.get("api_key")
or websocket.headers.get("x-api-key")
or websocket.headers.get("api_key")
)
if api_key:
return await api_key_security(api_key, query_param)
return None
user_read = await ws_api_key_security(api_key)
if user_read:
return user_read
raise WebSocketException(
code=status.WS_1008_POLICY_VIOLATION, reason="Missing or invalid credentials (cookie, token or API key)."
)
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):