diff --git a/src/backend/base/langflow/api/v1/voice_mode.py b/src/backend/base/langflow/api/v1/voice_mode.py index f5d4ea38b..553a4abf7 100644 --- a/src/backend/base/langflow/api/v1/voice_mode.py +++ b/src/backend/base/langflow/api/v1/voice_mode.py @@ -18,7 +18,7 @@ import webrtcvad import websockets from cryptography.fernet import InvalidToken from elevenlabs import ElevenLabs -from fastapi import APIRouter, BackgroundTasks, Security +from fastapi import APIRouter, BackgroundTasks from openai import OpenAI from sqlalchemy import select from starlette.websockets import WebSocket, WebSocketDisconnect @@ -29,8 +29,8 @@ from langflow.api.v1.schemas import InputValueRequest from langflow.logging import logger from langflow.memory import aadd_messagetables from langflow.schema.properties import Properties -from langflow.services.auth.utils import api_key_header, api_key_query, api_key_security, get_current_user_by_jwt -from langflow.services.database.models import MessageTable +from langflow.services.auth.utils import get_current_user_for_websocket +from langflow.services.database.models import MessageTable, User from langflow.services.database.models.flow.model import Flow from langflow.services.deps import get_variable_service, session_scope from langflow.utils.voice_utils import ( @@ -75,38 +75,34 @@ Your instructions will be divided into three mutually exclusive sections: "Perma [ADDITIONAL] The following instructions are to be considered only "Additional" """ +DIRECTION_TO_OPENAI = "Client → OpenAI" +DIRECTION_TO_CLIENT = "OpenAI → Client" # --- Helper Functions --- -async def authenticate_and_get_openai_key(client_websocket: WebSocket, session: DbSession): +async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_send_json): """Authenticate the user using a token or API key and retrieve the OpenAI API key. Returns a tuple: (current_user, openai_key). If authentication fails, sends an error message to the client and returns (None, None). """ - token = client_websocket.cookies.get("access_token_lf") - current_user = None - if token: - current_user = await get_current_user_by_jwt(token, session) - if current_user is None: - current_user = await api_key_security(Security(api_key_query), Security(api_key_header)) - if current_user is None: - await client_websocket.send_json( - { - "type": "error", - "code": "langflow_auth", - "message": "You must pass a valid Langflow token or cookie.", - } - ) - return None, None + if user is None: + await safe_send_json( + { + "type": "error", + "code": "langflow_auth", + "message": "You must pass a valid Langflow token or cookie.", + } + ) + return None, None variable_service = get_variable_service() try: openai_key_value = await variable_service.get_variable( - user_id=current_user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session + user_id=user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session ) openai_key = openai_key_value if openai_key_value is not None else os.getenv("OPENAI_API_KEY", "") if not openai_key or openai_key == "dummy": - await client_websocket.send_json( + await safe_send_json( { "type": "error", "code": "api_key_missing", @@ -119,7 +115,7 @@ async def authenticate_and_get_openai_key(client_websocket: WebSocket, session: logger.error(f"Error with API key: {e}") logger.error(traceback.format_exc()) return None, None - return current_user, openai_key + return user, openai_key # --- Synchronous Text Chunker --- @@ -151,105 +147,16 @@ def sync_text_chunker(sync_queue_obj: queue.Queue, timeout: float = 0.3): yield buffer + " " -async def handle_function_call( - websocket: WebSocket, - openai_ws: websockets.WebSocketClientProtocol, - function_call: dict, - function_call_args: str, - flow_id: str, - background_tasks: BackgroundTasks, - current_user: CurrentActiveUser, - conversation_id: str, -): - """Handle function calls from the OpenAI API.""" - try: - args = json.loads(function_call_args) if function_call_args else {} - input_request = InputValueRequest( - input_value=args.get("input"), components=[], type="chat", session=conversation_id - ) - response = await build_flow_and_stream( - flow_id=UUID(flow_id), - inputs=input_request, - background_tasks=background_tasks, - current_user=current_user, - ) - result = "" - async for line in response.body_iterator: - if not line: - continue - event_data = json.loads(line) - await websocket.send_json({"type": "flow.build.progress", "data": event_data}) - if event_data.get("event") == "end_vertex": - text_part = ( - event_data.get("data", {}) - .get("build_data", "") - .get("data", {}) - .get("results", {}) - .get("message", {}) - .get("text", "") - ) - result += text_part - function_output = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": function_call.get("call_id"), - "output": str(result), - }, - } - await openai_ws.send(json.dumps(function_output)) - await openai_ws.send(json.dumps({"type": "response.create"})) - except json.JSONDecodeError as e: - trace = traceback.format_exc() - logger.error(f"JSON decode error: {e!s}\ntrace: {trace}") - function_output = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": function_call.get("call_id"), - "output": f"Error parsing arguments: {e!s}", - }, - } - await openai_ws.send(json.dumps(function_output)) - except ValueError as e: - trace = traceback.format_exc() - logger.error(f"Value error: {e!s}\ntrace: {trace}") - function_output = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": function_call.get("call_id"), - "output": f"Error with input values: {e!s}", - }, - } - await openai_ws.send(json.dumps(function_output)) - except (ConnectionError, websockets.exceptions.WebSocketException) as e: - trace = traceback.format_exc() - logger.error(f"Connection error: {e!s}\ntrace: {trace}") - function_output = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": function_call.get("call_id"), - "output": f"Connection error: {e!s}", - }, - } - await openai_ws.send(json.dumps(function_output)) - except (KeyError, AttributeError, TypeError) as e: - logger.error(f"Error executing flow: {e}") - logger.error(traceback.format_exc()) - function_output = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": function_call.get("call_id"), - "output": f"Error executing flow: {e}", - }, - } - await openai_ws.send(json.dumps(function_output)) - - -# --- Config Classes and Caches --- +def common_response_create(session_id, original: dict | None = None) -> dict: + msg = {} + if original is not None: + msg = original + msg["type"] = "response.create" + voice_config = get_voice_config(session_id) + if voice_config.use_elevenlabs: + response = msg.setdefault("response", {}) + response["modalities"] = ["text"] + return msg class VoiceConfig: @@ -261,6 +168,7 @@ class VoiceConfig: self.elevenlabs_client = None self.elevenlabs_key = None self.barge_in_enabled = False + self.progress_enabled = True self.default_openai_realtime_session = { "modalities": ["text", "audio"], @@ -318,9 +226,6 @@ class ElevenLabsClientManager: return cls._instance -voice_config_cache: dict[str, VoiceConfig] = {} - - def get_voice_config(session_id: str) -> VoiceConfig: if session_id is None: msg = "session_id cannot be None" @@ -369,9 +274,6 @@ class TTSConfig: return self.openai_voice -tts_config_cache: dict[str, TTSConfig] = {} - - def get_tts_config(session_id: str, openai_key: str) -> TTSConfig: if session_id is None: msg = "session_id cannot be None" @@ -455,6 +357,132 @@ async def process_message_queue(queue_key, session): logger.error(traceback.format_exc()) +async def handle_function_call( + function_call: dict, + function_call_args: str, + flow_id: str, + background_tasks: BackgroundTasks, + current_user: CurrentActiveUser, + conversation_id: str, + session_id: str, + voice_config: VoiceConfig, + client_safe_send_json, + openai_send, +): + """Handle function calls from the OpenAI API.""" + try: + # trigger response that tool was called + if voice_config.progress_enabled: + await openai_send( + { + "type": "conversation.item.create", + "item": { + "type": "message", + "role": "system", + "content": [ + { + "type": "input_text", + "text": "Tell the user you are now looking into or solving " + "a request that will be explained later. Do not repeat " + "the prompt exactly, summarize what's being requested." + "Keep it very short." + f"\n\nThe request: {function_call_args}", + } + ], + }, + } + ) + await openai_send(common_response_create(session_id)) + args = json.loads(function_call_args) if function_call_args else {} + input_request = InputValueRequest( + input_value=args.get("input"), components=[], type="chat", session=conversation_id + ) + response = await build_flow_and_stream( + flow_id=UUID(flow_id), + inputs=input_request, + background_tasks=background_tasks, + current_user=current_user, + ) + result = "" + async for line in response.body_iterator: + if not line: + continue + event_data = json.loads(line) + client_safe_send_json({"type": "flow.build.progress", "data": event_data}) + if event_data.get("event") == "end_vertex": + text_part = ( + event_data.get("data", {}) + .get("build_data", "") + .get("data", {}) + .get("results", {}) + .get("message", {}) + .get("text", "") + ) + result += text_part + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": str(result), + }, + } + await openai_send(function_output) + await openai_send(common_response_create(session_id)) + except json.JSONDecodeError as e: + trace = traceback.format_exc() + logger.error(f"JSON decode error: {e!s}\ntrace: {trace}") + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Error parsing arguments: {e!s}", + }, + } + await openai_send(function_output) + except ValueError as e: + trace = traceback.format_exc() + logger.error(f"Value error: {e!s}\ntrace: {trace}") + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Error with input values: {e!s}", + }, + } + await openai_send(function_output) + except (ConnectionError, websockets.exceptions.WebSocketException) as e: + trace = traceback.format_exc() + logger.error(f"Connection error: {e!s}\ntrace: {trace}") + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Connection error: {e!s}", + }, + } + await openai_send(function_output) + except (KeyError, AttributeError, TypeError) as e: + logger.error(f"Error executing flow: {e}") + logger.error(traceback.format_exc()) + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Error executing flow: {e}", + }, + } + await openai_send(function_output) + + +voice_config_cache: dict[str, VoiceConfig] = {} +tts_config_cache: dict[str, TTSConfig] = {} + + # --- Global Queues and Message Processing --- message_queues: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue) @@ -525,11 +553,17 @@ def create_event_logger(session_id: str): state = {"last_event_type": None, "event_count": 0} def log_event(event: dict, direction: str) -> None: - event_type = event["type"] + event_type = event.get("type", "None") + event_id = event.get("event_id", "None") if event_type != state["last_event_type"]: - logger.debug(f"Event (session - {session_id}): {direction} {event_type}") + logger.debug(f"Event (id - {event_id}) (session - {session_id}): {direction} {event_type}") state["last_event_type"] = event_type state["event_count"] = 0 + if event_type == "response.created": + response_id = event.get("response", {}).get("id") + logger.debug(f"response_id: {response_id}") + if "error" in event_type: + logger.debug(f"Error {event}") current_count = 0 if state["event_count"] is None else state["event_count"] state["event_count"] = current_count + 1 @@ -565,8 +599,60 @@ async def flow_as_tool_websocket( """WebSocket endpoint registering the flow as a tool for real-time interaction.""" try: await client_websocket.accept() + + openai_send_q: asyncio.Queue[str] = asyncio.Queue() + client_send_q: asyncio.Queue[str] = asyncio.Queue() + + async def openai_writer(): + while True: + msg = await openai_send_q.get() + if msg is None: + break + await openai_ws.send(msg) + + async def client_writer(): + while True: + msg = await client_send_q.get() + if msg is None: + break + await client_websocket.send_text(msg) + + log_event = create_event_logger(session_id) + + response_q: asyncio.Queue = asyncio.Queue() + await response_q.put(None) + + async def openai_send(payload): + log_event(payload, DIRECTION_TO_OPENAI) + logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}") + if payload.get("type") == "response.create": + await response_q.get() + openai_send_q.put_nowait(json.dumps(payload)) + logger.trace("JSON sent.") + + def client_send(payload): + log_event(payload, DIRECTION_TO_CLIENT) + logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload.get('type', 'None')}") + if payload.get("type") == "response.done": + try: + response_q.put_nowait(None) + except asyncio.QueueFull: + logger.warning("Queue is full, skipping put_nowait") + client_send_q.put_nowait(json.dumps(payload)) + logger.trace("JSON sent.") + + async def close(): + openai_send_q.put_nowait(None) + client_send_q.put_nowait(None) + await openai_writer_task + await client_writer_task + await client_websocket.close() + await openai_ws.close() + + vad_task = None voice_config = get_voice_config(session_id) - current_user, openai_key = await authenticate_and_get_openai_key(client_websocket, session) + current_user: User = await get_current_user_for_websocket(client_websocket, session) + current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send) if current_user is None or openai_key is None: return try: @@ -582,7 +668,7 @@ async def flow_as_tool_websocket( }, } except Exception as e: # noqa: BLE001 - await client_websocket.send_json({"error": f"Failed to load flow: {e!s}"}) + client_send({"error": f"Failed to load flow: {e!s}"}) logger.error(f"Failed to load flow: {e}") return @@ -599,8 +685,12 @@ async def flow_as_tool_websocket( async with websockets.connect(url, extra_headers=headers) as openai_ws: openai_realtime_session = init_session_dict() + + openai_writer_task = asyncio.create_task(openai_writer()) + client_writer_task = asyncio.create_task(client_writer()) + session_update = {"type": "session.update", "session": openai_realtime_session} - await openai_ws.send(json.dumps(session_update)) + await openai_send(session_update) # Setup for VAD processing. vad_queue: asyncio.Queue = asyncio.Queue() @@ -626,7 +716,7 @@ async def flow_as_tool_websocket( has_speech = True logger.trace("!", end="") if bot_speaking_flag[0]: - await openai_ws.send(json.dumps({"type": "response.cancel"})) + await openai_send({"type": "response.cancel"}) bot_speaking_flag[0] = False except Exception as e: # noqa: BLE001 logger.error(f"[ERROR] VAD processing failed (ValueError): {e}") @@ -639,33 +729,8 @@ async def flow_as_tool_websocket( if time_since_speech >= 1.0: logger.trace("_", end="") - shared_state = {"last_event_type": None, "event_count": 0} - - def log_event(event, _direction: str) -> None: - event_type = event["type"] - - # Ensure shared_state has necessary keys initialized - if "last_event_type" not in shared_state: - shared_state["last_event_type"] = None - if "event_count" not in shared_state: - shared_state["event_count"] = 0 - - if event_type != shared_state["last_event_type"]: - logger.debug(f"Event (session - {session_id}): {_direction} {event_type}") - shared_state["last_event_type"] = event_type - shared_state["event_count"] = 0 - - # Explicitly convert to integer if needed - current_count = int(shared_state["event_count"]) if shared_state["event_count"] is not None else 0 - - shared_state["event_count"] = current_count + 1 - - def send_event(websocket, event, loop, direction) -> None: - asyncio.run_coroutine_threadsafe( - websocket.send_json(event), - loop, - ).result() - log_event(event, direction) + def client_send_event_from_thread(event, loop) -> None: + return loop.call_soon_threadsafe(client_send, event) def pass_through(from_dict, to_dict, keys): for key in keys: @@ -709,11 +774,18 @@ async def flow_as_tool_websocket( ) return new_session - # --- Spawn a text delta queue and task for TTS --- - text_delta_queue: asyncio.Queue = asyncio.Queue() - text_delta_task: asyncio.Task | None = None # Will hold our background task. + class Response: + def __init__(self, response_id: str, use_elevenlabs: bool | None = None): + if use_elevenlabs is None: + use_elevenlabs = False + self.response_id = response_id + if use_elevenlabs: + self.text_delta_queue: asyncio.Queue = asyncio.Queue() + self.text_delta_task = asyncio.create_task(process_text_deltas(self)) - async def process_text_deltas(async_q: asyncio.Queue): + responses = {} + + async def process_text_deltas(rsp: Response): """Transfer text deltas from the async queue to a synchronous queue. then run the ElevenLabs TTS call (which expects a sync generator) in a separate thread. @@ -722,7 +794,7 @@ async def flow_as_tool_websocket( async def transfer_text_deltas(): while True: - item = await async_q.get() + item = await rsp.text_delta_queue.get() sync_q.put(item) if item is None: break @@ -757,11 +829,15 @@ async def flow_as_tool_websocket( for chunk in audio_stream: base64_audio = base64.b64encode(chunk).decode("utf-8") # Schedule sending the audio chunk in the main event loop. - event = {"type": "response.audio.delta", "delta": base64_audio} - send_event(client_websocket, event, main_loop, "↓") + event = { + "type": "response.audio.delta", + "delta": base64_audio, + "response_id": rsp.response_id, + } + client_send_event_from_thread(event, main_loop) - event = {"type": "response.done"} - send_event(client_websocket, event, main_loop, "↓") + event = {"type": "response.done", "response": {"id": rsp.response_id}} + client_send_event_from_thread(event, main_loop) except Exception as e: # noqa: BLE001 logger.error(f"Error in TTS processing (ValueError): {e}") @@ -785,20 +861,18 @@ async def flow_as_tool_websocket( # Ensure we're adding to an integer num_audio_samples += len(base64_data) event = {"type": "input_audio_buffer.append", "audio": base64_data} - await openai_ws.send(json.dumps(event)) - log_event(event, "↑") + await openai_send(event) if voice_config.barge_in_enabled: await vad_queue.put(base64_data) elif msg.get("type") == "response.create": - if voice_config.use_elevenlabs: - response = msg.setdefault("response", {}) - response["modalities"] = ["text"] - await openai_ws.send(json.dumps(msg)) + await openai_send(common_response_create(session_id, msg)) elif msg.get("type") == "input_audio_buffer.commit": if num_audio_samples > AUDIO_SAMPLE_THRESHOLD: - await openai_ws.send(message_text) - log_event(msg, "↑") + await openai_send(msg) num_audio_samples = 0 + elif msg.get("type") == "langflow.voice_mode.config": + logger.info(f"langflow.voice_mode.config {msg}") + voice_config.progress_enabled = msg.get("progress_enabled", True) elif msg.get("type") == "langflow.elevenlabs.config": logger.info(f"langflow.elevenlabs.config {msg}") voice_config.use_elevenlabs = msg["enabled"] @@ -808,21 +882,18 @@ async def flow_as_tool_websocket( modalities = ["text"] if voice_config.use_elevenlabs else ["audio", "text"] openai_realtime_session["modalities"] = modalities session_update = {"type": "session.update", "session": openai_realtime_session} - await openai_ws.send(json.dumps(session_update)) - log_event(session_update, "↑") + await openai_send(session_update) elif msg.get("type") == "session.update": openai_realtime_session = update_global_session(msg["session"]) session_update = {"type": "session.update", "session": openai_realtime_session} - await openai_ws.send(json.dumps(session_update)) - log_event(session_update, "↑") + await openai_send(session_update) else: - await openai_ws.send(message_text) - log_event(msg, "↑") + await openai_send(msg) except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError): pass async def forward_to_client() -> None: - nonlocal bot_speaking_flag, text_delta_queue, text_delta_task + nonlocal bot_speaking_flag, responses function_call = None function_call_args = "" conversation_id = str(uuid4()) @@ -838,22 +909,27 @@ async def flow_as_tool_websocket( do_forward = True do_forward = do_forward and not (event_type == "response.done" and voice_config.use_elevenlabs) do_forward = do_forward and event_type.find("flow.") != 0 - if do_forward: - await client_websocket.send_text(data) - if event_type == "response.text.delta": + response_id = None + if do_forward: + client_send(event) + if event_type == "response.created": + response_id = event["response"]["id"] + responses[response_id] = Response(response_id, voice_config.use_elevenlabs) + elif event_type == "response.text.delta": if voice_config.use_elevenlabs: + response_id = event["response_id"] delta = event.get("delta", "") - await text_delta_queue.put(delta) - if text_delta_task is None: - # if text_delta_task is None or text_delta_task.done(): - text_delta_task = asyncio.create_task(process_text_deltas(text_delta_queue)) + rsp: Response = responses[response_id] + await rsp.text_delta_queue.put(delta) elif event_type == "response.text.done": if voice_config.use_elevenlabs: - await text_delta_queue.put(None) - if text_delta_task and not text_delta_task.done(): - await text_delta_task - text_delta_task = None + response_id = event["response_id"] + rsp = responses[response_id] + await rsp.text_delta_queue.put(None) + if rsp.text_delta_task and not rsp.text_delta_task.done(): + await rsp.text_delta_task + responses.pop(response_id) try: message_text = event.get("text", "") @@ -892,14 +968,16 @@ async def flow_as_tool_websocket( # Create and store reference to the task function_call_task = asyncio.create_task( handle_function_call( - client_websocket, - openai_ws, function_call, function_call_args, flow_id, background_tasks, current_user, conversation_id, + session_id, + voice_config, + client_send, + openai_send, ) ) # Store the task reference to prevent garbage collection @@ -925,34 +1003,30 @@ async def flow_as_tool_websocket( logger.error(traceback.format_exc()) elif event_type == "error": pass - else: - await client_websocket.send_text(data) - log_event(event, "↓") except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError): pass - # Fix for storing references to asyncio tasks - vad_task = None if voice_config.barge_in_enabled: # Store the task reference to prevent it from being garbage collected vad_task = asyncio.create_task(process_vad_audio()) - await asyncio.gather( - forward_to_openai(), - forward_to_client(), - ) - + try: + # Create tasks and gather them for concurrent execution + task1 = asyncio.create_task(forward_to_openai()) + task2 = asyncio.create_task(forward_to_client()) + await asyncio.gather(task1, task2) + except Exception as exc: # noqa: BLE001 + # handle any exceptions from any task + logger.error("WS loop failed:", exc_info=exc) + logger.error(traceback.format_exc()) + finally: + # shared cleanup for writers & sockets + await close() except Exception as e: # noqa: BLE001 - logger.error(f"Value error: {e}") + logger.error(f"Unexpected error: {e}") logger.error(traceback.format_exc()) finally: - # Ensure that the client websocket is closed. - try: - await client_websocket.close() - except Exception as e: # noqa: BLE001 - logger.debug(f"{e} ") - logger.info("Client websocket cleanup complete.") # Make sure to clean up the task if vad_task and not vad_task.done(): vad_task.cancel() @@ -986,11 +1060,49 @@ async def flow_tts_websocket( """WebSocket endpoint for direct flow text-to-speech interaction.""" try: await client_websocket.accept() - log_event = create_event_logger(session_id) - current_user, openai_key = await authenticate_and_get_openai_key(client_websocket, session) - if current_user is None or openai_key is None: - return + openai_send_q: asyncio.Queue[str] = asyncio.Queue() + client_send_q: asyncio.Queue[str] = asyncio.Queue() + + async def openai_writer(): + while True: + msg = await openai_send_q.get() + if msg is None: + break + await openai_ws.send(msg) + + async def client_writer(): + while True: + msg = await client_send_q.get() + if msg is None: + break + await client_websocket.send_text(msg) + + log_event = create_event_logger(session_id) + + def openai_send(payload): + log_event(payload, DIRECTION_TO_OPENAI) + logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}") + openai_send_q.put_nowait(json.dumps(payload)) + logger.trace("JSON sent.") + + def client_send(payload): + log_event(payload, DIRECTION_TO_CLIENT) + logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload['type']}") + client_send_q.put_nowait(json.dumps(payload)) + logger.trace("JSON sent.") + + async def close(): + openai_send_q.put_nowait(None) + client_send_q.put_nowait(None) + await openai_writer_task + await client_writer_task + await client_websocket.close() + await openai_ws.close() + + log_event = create_event_logger(session_id) + current_user: User = await get_current_user_for_websocket(client_websocket, session) + current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send) url = "wss://api.openai.com/v1/realtime?intent=transcription" headers = { "Authorization": f"Bearer {openai_key}", @@ -999,23 +1111,26 @@ async def flow_tts_websocket( tts_config = get_tts_config(session_id, openai_key) async with websockets.connect(url, extra_headers=headers) as openai_ws: + openai_writer_task = asyncio.create_task(openai_writer()) + client_writer_task = asyncio.create_task(client_writer()) + tts_realtime_session = tts_config.get_session_dict() - await openai_ws.send(json.dumps(tts_realtime_session)) + + openai_send(tts_realtime_session) async def forward_to_openai() -> None: try: while True: message_text = await client_websocket.receive_text() event = json.loads(message_text) - log_event(event, "Client → OpenAI") if event.get("type") == "input_audio_buffer.append": base64_data = event.get("audio", "") if not base64_data: continue out_event = {"type": "input_audio_buffer.append", "audio": base64_data} - await openai_ws.send(json.dumps(out_event)) + openai_send(out_event) elif event.get("type") == "input_audio_buffer.commit": - await openai_ws.send(message_text) + openai_send(event) elif event.get("type") == "langflow.elevenlabs.config": logger.info(f"langflow.elevenlabs.config {event}") tts_config.use_elevenlabs = event["enabled"] @@ -1033,8 +1148,7 @@ async def flow_tts_websocket( while True: data = await openai_ws.recv() event = json.loads(data) - log_event(event, "OpenAI → Client") - await client_websocket.send_text(data) + client_send(event) if event.get("type") == "conversation.item.input_audio_transcription.completed": transcript = event.get("transcript") if transcript is not None and transcript != "": @@ -1052,9 +1166,7 @@ async def flow_tts_websocket( if not line: continue event_data = json.loads(line) - await client_websocket.send_json( - {"type": "flow.build.progress", "data": event_data} - ) + client_send({"type": "flow.build.progress", "data": event_data}) if event_data.get("event") == "end_vertex": text = ( event_data.get("data", {}) @@ -1084,7 +1196,7 @@ async def flow_tts_websocket( for chunk in audio_stream: base64_audio = base64.b64encode(chunk).decode("utf-8") audio_event = {"type": "response.audio.delta", "delta": base64_audio} - await client_websocket.send_json(audio_event) + client_send(audio_event) else: oai_client = tts_config.get_openai_client() voice = tts_config.get_openai_voice() @@ -1098,26 +1210,24 @@ async def flow_tts_websocket( base64_audio = base64.b64encode(response.content).decode("utf-8") audio_event = {"type": "response.audio.delta", "delta": base64_audio} - await client_websocket.send_json(audio_event) + client_send(audio_event) except Exception as e: # noqa: BLE001 logger.error(f"Error in WebSocket communication: {e}") - forward_to_openai_task = asyncio.create_task(forward_to_openai()) - forward_to_client_task = asyncio.create_task(forward_to_client()) try: - await asyncio.gather( - forward_to_openai_task, - forward_to_client_task, - ) - except Exception as e: # noqa: BLE001 - logger.error(f"Error in WebSocket communication: {e}") - logger.error(traceback.format_exc()) + # Create tasks and gather them for concurrent execution + task1 = asyncio.create_task(forward_to_openai()) + task2 = asyncio.create_task(forward_to_client()) + await asyncio.gather(task1, task2) + except Exception as exc: # noqa: BLE001 + # handle any exceptions from any task + logger.error("WS loop failed:", exc_info=exc) finally: - forward_to_openai_task.cancel() + # shared cleanup for writers & sockets + await close() except Exception as e: # noqa: BLE001 - logger.error(f"WebSocket error: {e}") + logger.error(f"Unexpected error: {e}") logger.error(traceback.format_exc()) - await client_websocket.close() def extract_transcript(json_data): diff --git a/src/backend/base/langflow/services/auth/utils.py b/src/backend/base/langflow/services/auth/utils.py index 030f6d5b4..c045384c8 100644 --- a/src/backend/base/langflow/services/auth/utils.py +++ b/src/backend/base/langflow/services/auth/utils.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Annotated from uuid import UUID from cryptography.fernet import Fernet -from fastapi import Depends, HTTPException, Security, status +from fastapi import Depends, HTTPException, Security, WebSocketException, status from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer from jose import JWTError, jwt from loguru import logger @@ -86,6 +86,55 @@ async def api_key_security( raise ValueError(msg) +async def ws_api_key_security( + api_key: str | None, +) -> UserRead: + settings = get_settings_service() + async with get_db_service().with_session() as db: + if settings.auth_settings.AUTO_LOGIN: + if not settings.auth_settings.SUPERUSER: + # internal server misconfiguration + raise WebSocketException( + code=status.WS_1011_INTERNAL_ERROR, + reason="Missing first superuser credentials", + ) + warnings.warn( + ("In v1.5, AUTO_LOGIN will *require* a valid API key or JWT. Please update your clients accordingly."), + DeprecationWarning, + stacklevel=2, + ) + if api_key: + result = await check_key(db, api_key) + else: + result = await get_user_by_username(db, settings.auth_settings.SUPERUSER) + + # normal path: must provide an API key + else: + if not api_key: + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="An API key must be passed as query or header", + ) + result = await check_key(db, api_key) + + # key was invalid or missing + if not result: + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="Invalid or missing API key", + ) + + # convert SQL-model User → pydantic UserRead + if isinstance(result, User): + return UserRead.model_validate(result, from_attributes=True) + + # fallback: something unexpected happened + raise WebSocketException( + code=status.WS_1011_INTERNAL_ERROR, + reason="Authentication subsystem error", + ) + + async def get_current_user( token: Annotated[str, Security(oauth2_login)], query_param: Annotated[str, Security(api_key_query)], @@ -167,16 +216,28 @@ async def get_current_user_by_jwt( async def get_current_user_for_websocket( websocket: WebSocket, - db: Annotated[AsyncSession, Depends(get_session)], - query_param: Annotated[str, Security(api_key_query)], -) -> User | None: - token = websocket.query_params.get("token") - api_key = websocket.query_params.get("x-api-key") + db: AsyncSession, +) -> User | UserRead: + token = websocket.cookies.get("access_token_lf") or websocket.query_params.get("token") if token: - return await get_current_user_by_jwt(token, db) + user = await get_current_user_by_jwt(token, db) + if user: + return user + + api_key = ( + websocket.query_params.get("x-api-key") + or websocket.query_params.get("api_key") + or websocket.headers.get("x-api-key") + or websocket.headers.get("api_key") + ) if api_key: - return await api_key_security(api_key, query_param) - return None + user_read = await ws_api_key_security(api_key) + if user_read: + return user_read + + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, reason="Missing or invalid credentials (cookie, token or API key)." + ) async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):