fix: voice mode progress reliability (#7830)
* progress update nudge
* progress update nudge
* nick's fix
* progress toggle
* typo
* fix auth and ws thread safety
* [autofix.ci] apply automated fixes
* fix deadlock
* delete commented function
* fix duplicate events
* [autofix.ci] apply automated fixes
* merge
* clean up log_event
* queues not locks
* response ids for 11L flow
* async bug
* Fix awaits
* ElevenLabsResponse -> Response
* queues not locks
* response_q
* comment
* small fix
* ♻️ (voice_mode.py): refactor code to use individual tasks and gather them for concurrent execution instead of TaskGroup for better readability and error handling.
* mypy
---------
Co-authored-by: phact <estevezsebastian@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Cristhian Zanforlin Lousa <cristhian.lousa@gmail.com>
This commit is contained in:
parent
e18b55042e
commit
7c0beedcab
2 changed files with 421 additions and 250 deletions
|
|
@ -18,7 +18,7 @@ import webrtcvad
|
|||
import websockets
|
||||
from cryptography.fernet import InvalidToken
|
||||
from elevenlabs import ElevenLabs
|
||||
from fastapi import APIRouter, BackgroundTasks, Security
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from openai import OpenAI
|
||||
from sqlalchemy import select
|
||||
from starlette.websockets import WebSocket, WebSocketDisconnect
|
||||
|
|
@ -29,8 +29,8 @@ from langflow.api.v1.schemas import InputValueRequest
|
|||
from langflow.logging import logger
|
||||
from langflow.memory import aadd_messagetables
|
||||
from langflow.schema.properties import Properties
|
||||
from langflow.services.auth.utils import api_key_header, api_key_query, api_key_security, get_current_user_by_jwt
|
||||
from langflow.services.database.models import MessageTable
|
||||
from langflow.services.auth.utils import get_current_user_for_websocket
|
||||
from langflow.services.database.models import MessageTable, User
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
from langflow.services.deps import get_variable_service, session_scope
|
||||
from langflow.utils.voice_utils import (
|
||||
|
|
@ -75,38 +75,34 @@ Your instructions will be divided into three mutually exclusive sections: "Perma
|
|||
[ADDITIONAL] The following instructions are to be considered only "Additional"
|
||||
"""
|
||||
|
||||
DIRECTION_TO_OPENAI = "Client → OpenAI"
|
||||
DIRECTION_TO_CLIENT = "OpenAI → Client"
|
||||
# --- Helper Functions ---
|
||||
|
||||
|
||||
async def authenticate_and_get_openai_key(client_websocket: WebSocket, session: DbSession):
|
||||
async def authenticate_and_get_openai_key(session: DbSession, user: User, safe_send_json):
|
||||
"""Authenticate the user using a token or API key and retrieve the OpenAI API key.
|
||||
|
||||
Returns a tuple: (current_user, openai_key). If authentication fails, sends an error
|
||||
message to the client and returns (None, None).
|
||||
"""
|
||||
token = client_websocket.cookies.get("access_token_lf")
|
||||
current_user = None
|
||||
if token:
|
||||
current_user = await get_current_user_by_jwt(token, session)
|
||||
if current_user is None:
|
||||
current_user = await api_key_security(Security(api_key_query), Security(api_key_header))
|
||||
if current_user is None:
|
||||
await client_websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": "langflow_auth",
|
||||
"message": "You must pass a valid Langflow token or cookie.",
|
||||
}
|
||||
)
|
||||
return None, None
|
||||
if user is None:
|
||||
await safe_send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": "langflow_auth",
|
||||
"message": "You must pass a valid Langflow token or cookie.",
|
||||
}
|
||||
)
|
||||
return None, None
|
||||
variable_service = get_variable_service()
|
||||
try:
|
||||
openai_key_value = await variable_service.get_variable(
|
||||
user_id=current_user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session
|
||||
user_id=user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session
|
||||
)
|
||||
openai_key = openai_key_value if openai_key_value is not None else os.getenv("OPENAI_API_KEY", "")
|
||||
if not openai_key or openai_key == "dummy":
|
||||
await client_websocket.send_json(
|
||||
await safe_send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": "api_key_missing",
|
||||
|
|
@ -119,7 +115,7 @@ async def authenticate_and_get_openai_key(client_websocket: WebSocket, session:
|
|||
logger.error(f"Error with API key: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None, None
|
||||
return current_user, openai_key
|
||||
return user, openai_key
|
||||
|
||||
|
||||
# --- Synchronous Text Chunker ---
|
||||
|
|
@ -151,105 +147,16 @@ def sync_text_chunker(sync_queue_obj: queue.Queue, timeout: float = 0.3):
|
|||
yield buffer + " "
|
||||
|
||||
|
||||
async def handle_function_call(
|
||||
websocket: WebSocket,
|
||||
openai_ws: websockets.WebSocketClientProtocol,
|
||||
function_call: dict,
|
||||
function_call_args: str,
|
||||
flow_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: CurrentActiveUser,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""Handle function calls from the OpenAI API."""
|
||||
try:
|
||||
args = json.loads(function_call_args) if function_call_args else {}
|
||||
input_request = InputValueRequest(
|
||||
input_value=args.get("input"), components=[], type="chat", session=conversation_id
|
||||
)
|
||||
response = await build_flow_and_stream(
|
||||
flow_id=UUID(flow_id),
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
)
|
||||
result = ""
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
event_data = json.loads(line)
|
||||
await websocket.send_json({"type": "flow.build.progress", "data": event_data})
|
||||
if event_data.get("event") == "end_vertex":
|
||||
text_part = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", "")
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
result += text_part
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": str(result),
|
||||
},
|
||||
}
|
||||
await openai_ws.send(json.dumps(function_output))
|
||||
await openai_ws.send(json.dumps({"type": "response.create"}))
|
||||
except json.JSONDecodeError as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"JSON decode error: {e!s}\ntrace: {trace}")
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Error parsing arguments: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_ws.send(json.dumps(function_output))
|
||||
except ValueError as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"Value error: {e!s}\ntrace: {trace}")
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Error with input values: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_ws.send(json.dumps(function_output))
|
||||
except (ConnectionError, websockets.exceptions.WebSocketException) as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"Connection error: {e!s}\ntrace: {trace}")
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Connection error: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_ws.send(json.dumps(function_output))
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
logger.error(f"Error executing flow: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Error executing flow: {e}",
|
||||
},
|
||||
}
|
||||
await openai_ws.send(json.dumps(function_output))
|
||||
|
||||
|
||||
# --- Config Classes and Caches ---
|
||||
def common_response_create(session_id, original: dict | None = None) -> dict:
|
||||
msg = {}
|
||||
if original is not None:
|
||||
msg = original
|
||||
msg["type"] = "response.create"
|
||||
voice_config = get_voice_config(session_id)
|
||||
if voice_config.use_elevenlabs:
|
||||
response = msg.setdefault("response", {})
|
||||
response["modalities"] = ["text"]
|
||||
return msg
|
||||
|
||||
|
||||
class VoiceConfig:
|
||||
|
|
@ -261,6 +168,7 @@ class VoiceConfig:
|
|||
self.elevenlabs_client = None
|
||||
self.elevenlabs_key = None
|
||||
self.barge_in_enabled = False
|
||||
self.progress_enabled = True
|
||||
|
||||
self.default_openai_realtime_session = {
|
||||
"modalities": ["text", "audio"],
|
||||
|
|
@ -318,9 +226,6 @@ class ElevenLabsClientManager:
|
|||
return cls._instance
|
||||
|
||||
|
||||
voice_config_cache: dict[str, VoiceConfig] = {}
|
||||
|
||||
|
||||
def get_voice_config(session_id: str) -> VoiceConfig:
|
||||
if session_id is None:
|
||||
msg = "session_id cannot be None"
|
||||
|
|
@ -369,9 +274,6 @@ class TTSConfig:
|
|||
return self.openai_voice
|
||||
|
||||
|
||||
tts_config_cache: dict[str, TTSConfig] = {}
|
||||
|
||||
|
||||
def get_tts_config(session_id: str, openai_key: str) -> TTSConfig:
|
||||
if session_id is None:
|
||||
msg = "session_id cannot be None"
|
||||
|
|
@ -455,6 +357,132 @@ async def process_message_queue(queue_key, session):
|
|||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def handle_function_call(
|
||||
function_call: dict,
|
||||
function_call_args: str,
|
||||
flow_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: CurrentActiveUser,
|
||||
conversation_id: str,
|
||||
session_id: str,
|
||||
voice_config: VoiceConfig,
|
||||
client_safe_send_json,
|
||||
openai_send,
|
||||
):
|
||||
"""Handle function calls from the OpenAI API."""
|
||||
try:
|
||||
# trigger response that tool was called
|
||||
if voice_config.progress_enabled:
|
||||
await openai_send(
|
||||
{
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Tell the user you are now looking into or solving "
|
||||
"a request that will be explained later. Do not repeat "
|
||||
"the prompt exactly, summarize what's being requested."
|
||||
"Keep it very short."
|
||||
f"\n\nThe request: {function_call_args}",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
await openai_send(common_response_create(session_id))
|
||||
args = json.loads(function_call_args) if function_call_args else {}
|
||||
input_request = InputValueRequest(
|
||||
input_value=args.get("input"), components=[], type="chat", session=conversation_id
|
||||
)
|
||||
response = await build_flow_and_stream(
|
||||
flow_id=UUID(flow_id),
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
)
|
||||
result = ""
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
event_data = json.loads(line)
|
||||
client_safe_send_json({"type": "flow.build.progress", "data": event_data})
|
||||
if event_data.get("event") == "end_vertex":
|
||||
text_part = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", "")
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
result += text_part
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": str(result),
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
await openai_send(common_response_create(session_id))
|
||||
except json.JSONDecodeError as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"JSON decode error: {e!s}\ntrace: {trace}")
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Error parsing arguments: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
except ValueError as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"Value error: {e!s}\ntrace: {trace}")
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Error with input values: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
except (ConnectionError, websockets.exceptions.WebSocketException) as e:
|
||||
trace = traceback.format_exc()
|
||||
logger.error(f"Connection error: {e!s}\ntrace: {trace}")
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Connection error: {e!s}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
logger.error(f"Error executing flow: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
function_output = {
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.get("call_id"),
|
||||
"output": f"Error executing flow: {e}",
|
||||
},
|
||||
}
|
||||
await openai_send(function_output)
|
||||
|
||||
|
||||
voice_config_cache: dict[str, VoiceConfig] = {}
|
||||
tts_config_cache: dict[str, TTSConfig] = {}
|
||||
|
||||
|
||||
# --- Global Queues and Message Processing ---
|
||||
|
||||
message_queues: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)
|
||||
|
|
@ -525,11 +553,17 @@ def create_event_logger(session_id: str):
|
|||
state = {"last_event_type": None, "event_count": 0}
|
||||
|
||||
def log_event(event: dict, direction: str) -> None:
|
||||
event_type = event["type"]
|
||||
event_type = event.get("type", "None")
|
||||
event_id = event.get("event_id", "None")
|
||||
if event_type != state["last_event_type"]:
|
||||
logger.debug(f"Event (session - {session_id}): {direction} {event_type}")
|
||||
logger.debug(f"Event (id - {event_id}) (session - {session_id}): {direction} {event_type}")
|
||||
state["last_event_type"] = event_type
|
||||
state["event_count"] = 0
|
||||
if event_type == "response.created":
|
||||
response_id = event.get("response", {}).get("id")
|
||||
logger.debug(f"response_id: {response_id}")
|
||||
if "error" in event_type:
|
||||
logger.debug(f"Error {event}")
|
||||
current_count = 0 if state["event_count"] is None else state["event_count"]
|
||||
state["event_count"] = current_count + 1
|
||||
|
||||
|
|
@ -565,8 +599,60 @@ async def flow_as_tool_websocket(
|
|||
"""WebSocket endpoint registering the flow as a tool for real-time interaction."""
|
||||
try:
|
||||
await client_websocket.accept()
|
||||
|
||||
openai_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
client_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def openai_writer():
|
||||
while True:
|
||||
msg = await openai_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await openai_ws.send(msg)
|
||||
|
||||
async def client_writer():
|
||||
while True:
|
||||
msg = await client_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await client_websocket.send_text(msg)
|
||||
|
||||
log_event = create_event_logger(session_id)
|
||||
|
||||
response_q: asyncio.Queue = asyncio.Queue()
|
||||
await response_q.put(None)
|
||||
|
||||
async def openai_send(payload):
|
||||
log_event(payload, DIRECTION_TO_OPENAI)
|
||||
logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}")
|
||||
if payload.get("type") == "response.create":
|
||||
await response_q.get()
|
||||
openai_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
def client_send(payload):
|
||||
log_event(payload, DIRECTION_TO_CLIENT)
|
||||
logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload.get('type', 'None')}")
|
||||
if payload.get("type") == "response.done":
|
||||
try:
|
||||
response_q.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Queue is full, skipping put_nowait")
|
||||
client_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
async def close():
|
||||
openai_send_q.put_nowait(None)
|
||||
client_send_q.put_nowait(None)
|
||||
await openai_writer_task
|
||||
await client_writer_task
|
||||
await client_websocket.close()
|
||||
await openai_ws.close()
|
||||
|
||||
vad_task = None
|
||||
voice_config = get_voice_config(session_id)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(client_websocket, session)
|
||||
current_user: User = await get_current_user_for_websocket(client_websocket, session)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send)
|
||||
if current_user is None or openai_key is None:
|
||||
return
|
||||
try:
|
||||
|
|
@ -582,7 +668,7 @@ async def flow_as_tool_websocket(
|
|||
},
|
||||
}
|
||||
except Exception as e: # noqa: BLE001
|
||||
await client_websocket.send_json({"error": f"Failed to load flow: {e!s}"})
|
||||
client_send({"error": f"Failed to load flow: {e!s}"})
|
||||
logger.error(f"Failed to load flow: {e}")
|
||||
return
|
||||
|
||||
|
|
@ -599,8 +685,12 @@ async def flow_as_tool_websocket(
|
|||
|
||||
async with websockets.connect(url, extra_headers=headers) as openai_ws:
|
||||
openai_realtime_session = init_session_dict()
|
||||
|
||||
openai_writer_task = asyncio.create_task(openai_writer())
|
||||
client_writer_task = asyncio.create_task(client_writer())
|
||||
|
||||
session_update = {"type": "session.update", "session": openai_realtime_session}
|
||||
await openai_ws.send(json.dumps(session_update))
|
||||
await openai_send(session_update)
|
||||
|
||||
# Setup for VAD processing.
|
||||
vad_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
|
@ -626,7 +716,7 @@ async def flow_as_tool_websocket(
|
|||
has_speech = True
|
||||
logger.trace("!", end="")
|
||||
if bot_speaking_flag[0]:
|
||||
await openai_ws.send(json.dumps({"type": "response.cancel"}))
|
||||
await openai_send({"type": "response.cancel"})
|
||||
bot_speaking_flag[0] = False
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"[ERROR] VAD processing failed (ValueError): {e}")
|
||||
|
|
@ -639,33 +729,8 @@ async def flow_as_tool_websocket(
|
|||
if time_since_speech >= 1.0:
|
||||
logger.trace("_", end="")
|
||||
|
||||
shared_state = {"last_event_type": None, "event_count": 0}
|
||||
|
||||
def log_event(event, _direction: str) -> None:
|
||||
event_type = event["type"]
|
||||
|
||||
# Ensure shared_state has necessary keys initialized
|
||||
if "last_event_type" not in shared_state:
|
||||
shared_state["last_event_type"] = None
|
||||
if "event_count" not in shared_state:
|
||||
shared_state["event_count"] = 0
|
||||
|
||||
if event_type != shared_state["last_event_type"]:
|
||||
logger.debug(f"Event (session - {session_id}): {_direction} {event_type}")
|
||||
shared_state["last_event_type"] = event_type
|
||||
shared_state["event_count"] = 0
|
||||
|
||||
# Explicitly convert to integer if needed
|
||||
current_count = int(shared_state["event_count"]) if shared_state["event_count"] is not None else 0
|
||||
|
||||
shared_state["event_count"] = current_count + 1
|
||||
|
||||
def send_event(websocket, event, loop, direction) -> None:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_json(event),
|
||||
loop,
|
||||
).result()
|
||||
log_event(event, direction)
|
||||
def client_send_event_from_thread(event, loop) -> None:
|
||||
return loop.call_soon_threadsafe(client_send, event)
|
||||
|
||||
def pass_through(from_dict, to_dict, keys):
|
||||
for key in keys:
|
||||
|
|
@ -709,11 +774,18 @@ async def flow_as_tool_websocket(
|
|||
)
|
||||
return new_session
|
||||
|
||||
# --- Spawn a text delta queue and task for TTS ---
|
||||
text_delta_queue: asyncio.Queue = asyncio.Queue()
|
||||
text_delta_task: asyncio.Task | None = None # Will hold our background task.
|
||||
class Response:
|
||||
def __init__(self, response_id: str, use_elevenlabs: bool | None = None):
|
||||
if use_elevenlabs is None:
|
||||
use_elevenlabs = False
|
||||
self.response_id = response_id
|
||||
if use_elevenlabs:
|
||||
self.text_delta_queue: asyncio.Queue = asyncio.Queue()
|
||||
self.text_delta_task = asyncio.create_task(process_text_deltas(self))
|
||||
|
||||
async def process_text_deltas(async_q: asyncio.Queue):
|
||||
responses = {}
|
||||
|
||||
async def process_text_deltas(rsp: Response):
|
||||
"""Transfer text deltas from the async queue to a synchronous queue.
|
||||
|
||||
then run the ElevenLabs TTS call (which expects a sync generator) in a separate thread.
|
||||
|
|
@ -722,7 +794,7 @@ async def flow_as_tool_websocket(
|
|||
|
||||
async def transfer_text_deltas():
|
||||
while True:
|
||||
item = await async_q.get()
|
||||
item = await rsp.text_delta_queue.get()
|
||||
sync_q.put(item)
|
||||
if item is None:
|
||||
break
|
||||
|
|
@ -757,11 +829,15 @@ async def flow_as_tool_websocket(
|
|||
for chunk in audio_stream:
|
||||
base64_audio = base64.b64encode(chunk).decode("utf-8")
|
||||
# Schedule sending the audio chunk in the main event loop.
|
||||
event = {"type": "response.audio.delta", "delta": base64_audio}
|
||||
send_event(client_websocket, event, main_loop, "↓")
|
||||
event = {
|
||||
"type": "response.audio.delta",
|
||||
"delta": base64_audio,
|
||||
"response_id": rsp.response_id,
|
||||
}
|
||||
client_send_event_from_thread(event, main_loop)
|
||||
|
||||
event = {"type": "response.done"}
|
||||
send_event(client_websocket, event, main_loop, "↓")
|
||||
event = {"type": "response.done", "response": {"id": rsp.response_id}}
|
||||
client_send_event_from_thread(event, main_loop)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error in TTS processing (ValueError): {e}")
|
||||
|
||||
|
|
@ -785,20 +861,18 @@ async def flow_as_tool_websocket(
|
|||
# Ensure we're adding to an integer
|
||||
num_audio_samples += len(base64_data)
|
||||
event = {"type": "input_audio_buffer.append", "audio": base64_data}
|
||||
await openai_ws.send(json.dumps(event))
|
||||
log_event(event, "↑")
|
||||
await openai_send(event)
|
||||
if voice_config.barge_in_enabled:
|
||||
await vad_queue.put(base64_data)
|
||||
elif msg.get("type") == "response.create":
|
||||
if voice_config.use_elevenlabs:
|
||||
response = msg.setdefault("response", {})
|
||||
response["modalities"] = ["text"]
|
||||
await openai_ws.send(json.dumps(msg))
|
||||
await openai_send(common_response_create(session_id, msg))
|
||||
elif msg.get("type") == "input_audio_buffer.commit":
|
||||
if num_audio_samples > AUDIO_SAMPLE_THRESHOLD:
|
||||
await openai_ws.send(message_text)
|
||||
log_event(msg, "↑")
|
||||
await openai_send(msg)
|
||||
num_audio_samples = 0
|
||||
elif msg.get("type") == "langflow.voice_mode.config":
|
||||
logger.info(f"langflow.voice_mode.config {msg}")
|
||||
voice_config.progress_enabled = msg.get("progress_enabled", True)
|
||||
elif msg.get("type") == "langflow.elevenlabs.config":
|
||||
logger.info(f"langflow.elevenlabs.config {msg}")
|
||||
voice_config.use_elevenlabs = msg["enabled"]
|
||||
|
|
@ -808,21 +882,18 @@ async def flow_as_tool_websocket(
|
|||
modalities = ["text"] if voice_config.use_elevenlabs else ["audio", "text"]
|
||||
openai_realtime_session["modalities"] = modalities
|
||||
session_update = {"type": "session.update", "session": openai_realtime_session}
|
||||
await openai_ws.send(json.dumps(session_update))
|
||||
log_event(session_update, "↑")
|
||||
await openai_send(session_update)
|
||||
elif msg.get("type") == "session.update":
|
||||
openai_realtime_session = update_global_session(msg["session"])
|
||||
session_update = {"type": "session.update", "session": openai_realtime_session}
|
||||
await openai_ws.send(json.dumps(session_update))
|
||||
log_event(session_update, "↑")
|
||||
await openai_send(session_update)
|
||||
else:
|
||||
await openai_ws.send(message_text)
|
||||
log_event(msg, "↑")
|
||||
await openai_send(msg)
|
||||
except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError):
|
||||
pass
|
||||
|
||||
async def forward_to_client() -> None:
|
||||
nonlocal bot_speaking_flag, text_delta_queue, text_delta_task
|
||||
nonlocal bot_speaking_flag, responses
|
||||
function_call = None
|
||||
function_call_args = ""
|
||||
conversation_id = str(uuid4())
|
||||
|
|
@ -838,22 +909,27 @@ async def flow_as_tool_websocket(
|
|||
do_forward = True
|
||||
do_forward = do_forward and not (event_type == "response.done" and voice_config.use_elevenlabs)
|
||||
do_forward = do_forward and event_type.find("flow.") != 0
|
||||
if do_forward:
|
||||
await client_websocket.send_text(data)
|
||||
|
||||
if event_type == "response.text.delta":
|
||||
response_id = None
|
||||
if do_forward:
|
||||
client_send(event)
|
||||
if event_type == "response.created":
|
||||
response_id = event["response"]["id"]
|
||||
responses[response_id] = Response(response_id, voice_config.use_elevenlabs)
|
||||
elif event_type == "response.text.delta":
|
||||
if voice_config.use_elevenlabs:
|
||||
response_id = event["response_id"]
|
||||
delta = event.get("delta", "")
|
||||
await text_delta_queue.put(delta)
|
||||
if text_delta_task is None:
|
||||
# if text_delta_task is None or text_delta_task.done():
|
||||
text_delta_task = asyncio.create_task(process_text_deltas(text_delta_queue))
|
||||
rsp: Response = responses[response_id]
|
||||
await rsp.text_delta_queue.put(delta)
|
||||
elif event_type == "response.text.done":
|
||||
if voice_config.use_elevenlabs:
|
||||
await text_delta_queue.put(None)
|
||||
if text_delta_task and not text_delta_task.done():
|
||||
await text_delta_task
|
||||
text_delta_task = None
|
||||
response_id = event["response_id"]
|
||||
rsp = responses[response_id]
|
||||
await rsp.text_delta_queue.put(None)
|
||||
if rsp.text_delta_task and not rsp.text_delta_task.done():
|
||||
await rsp.text_delta_task
|
||||
responses.pop(response_id)
|
||||
|
||||
try:
|
||||
message_text = event.get("text", "")
|
||||
|
|
@ -892,14 +968,16 @@ async def flow_as_tool_websocket(
|
|||
# Create and store reference to the task
|
||||
function_call_task = asyncio.create_task(
|
||||
handle_function_call(
|
||||
client_websocket,
|
||||
openai_ws,
|
||||
function_call,
|
||||
function_call_args,
|
||||
flow_id,
|
||||
background_tasks,
|
||||
current_user,
|
||||
conversation_id,
|
||||
session_id,
|
||||
voice_config,
|
||||
client_send,
|
||||
openai_send,
|
||||
)
|
||||
)
|
||||
# Store the task reference to prevent garbage collection
|
||||
|
|
@ -925,34 +1003,30 @@ async def flow_as_tool_websocket(
|
|||
logger.error(traceback.format_exc())
|
||||
elif event_type == "error":
|
||||
pass
|
||||
else:
|
||||
await client_websocket.send_text(data)
|
||||
log_event(event, "↓")
|
||||
|
||||
except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError):
|
||||
pass
|
||||
|
||||
# Fix for storing references to asyncio tasks
|
||||
vad_task = None
|
||||
if voice_config.barge_in_enabled:
|
||||
# Store the task reference to prevent it from being garbage collected
|
||||
vad_task = asyncio.create_task(process_vad_audio())
|
||||
|
||||
await asyncio.gather(
|
||||
forward_to_openai(),
|
||||
forward_to_client(),
|
||||
)
|
||||
|
||||
try:
|
||||
# Create tasks and gather them for concurrent execution
|
||||
task1 = asyncio.create_task(forward_to_openai())
|
||||
task2 = asyncio.create_task(forward_to_client())
|
||||
await asyncio.gather(task1, task2)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# handle any exceptions from any task
|
||||
logger.error("WS loop failed:", exc_info=exc)
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# shared cleanup for writers & sockets
|
||||
await close()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Value error: {e}")
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# Ensure that the client websocket is closed.
|
||||
try:
|
||||
await client_websocket.close()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"{e} ")
|
||||
logger.info("Client websocket cleanup complete.")
|
||||
# Make sure to clean up the task
|
||||
if vad_task and not vad_task.done():
|
||||
vad_task.cancel()
|
||||
|
|
@ -986,11 +1060,49 @@ async def flow_tts_websocket(
|
|||
"""WebSocket endpoint for direct flow text-to-speech interaction."""
|
||||
try:
|
||||
await client_websocket.accept()
|
||||
log_event = create_event_logger(session_id)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(client_websocket, session)
|
||||
if current_user is None or openai_key is None:
|
||||
return
|
||||
|
||||
openai_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
client_send_q: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def openai_writer():
|
||||
while True:
|
||||
msg = await openai_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await openai_ws.send(msg)
|
||||
|
||||
async def client_writer():
|
||||
while True:
|
||||
msg = await client_send_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
await client_websocket.send_text(msg)
|
||||
|
||||
log_event = create_event_logger(session_id)
|
||||
|
||||
def openai_send(payload):
|
||||
log_event(payload, DIRECTION_TO_OPENAI)
|
||||
logger.trace(f"Sending text {DIRECTION_TO_OPENAI}: {payload['type']}")
|
||||
openai_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
def client_send(payload):
|
||||
log_event(payload, DIRECTION_TO_CLIENT)
|
||||
logger.trace(f"Sending JSON {DIRECTION_TO_CLIENT}: {payload['type']}")
|
||||
client_send_q.put_nowait(json.dumps(payload))
|
||||
logger.trace("JSON sent.")
|
||||
|
||||
async def close():
|
||||
openai_send_q.put_nowait(None)
|
||||
client_send_q.put_nowait(None)
|
||||
await openai_writer_task
|
||||
await client_writer_task
|
||||
await client_websocket.close()
|
||||
await openai_ws.close()
|
||||
|
||||
log_event = create_event_logger(session_id)
|
||||
current_user: User = await get_current_user_for_websocket(client_websocket, session)
|
||||
current_user, openai_key = await authenticate_and_get_openai_key(session, current_user, client_send)
|
||||
url = "wss://api.openai.com/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {openai_key}",
|
||||
|
|
@ -999,23 +1111,26 @@ async def flow_tts_websocket(
|
|||
|
||||
tts_config = get_tts_config(session_id, openai_key)
|
||||
async with websockets.connect(url, extra_headers=headers) as openai_ws:
|
||||
openai_writer_task = asyncio.create_task(openai_writer())
|
||||
client_writer_task = asyncio.create_task(client_writer())
|
||||
|
||||
tts_realtime_session = tts_config.get_session_dict()
|
||||
await openai_ws.send(json.dumps(tts_realtime_session))
|
||||
|
||||
openai_send(tts_realtime_session)
|
||||
|
||||
async def forward_to_openai() -> None:
|
||||
try:
|
||||
while True:
|
||||
message_text = await client_websocket.receive_text()
|
||||
event = json.loads(message_text)
|
||||
log_event(event, "Client → OpenAI")
|
||||
if event.get("type") == "input_audio_buffer.append":
|
||||
base64_data = event.get("audio", "")
|
||||
if not base64_data:
|
||||
continue
|
||||
out_event = {"type": "input_audio_buffer.append", "audio": base64_data}
|
||||
await openai_ws.send(json.dumps(out_event))
|
||||
openai_send(out_event)
|
||||
elif event.get("type") == "input_audio_buffer.commit":
|
||||
await openai_ws.send(message_text)
|
||||
openai_send(event)
|
||||
elif event.get("type") == "langflow.elevenlabs.config":
|
||||
logger.info(f"langflow.elevenlabs.config {event}")
|
||||
tts_config.use_elevenlabs = event["enabled"]
|
||||
|
|
@ -1033,8 +1148,7 @@ async def flow_tts_websocket(
|
|||
while True:
|
||||
data = await openai_ws.recv()
|
||||
event = json.loads(data)
|
||||
log_event(event, "OpenAI → Client")
|
||||
await client_websocket.send_text(data)
|
||||
client_send(event)
|
||||
if event.get("type") == "conversation.item.input_audio_transcription.completed":
|
||||
transcript = event.get("transcript")
|
||||
if transcript is not None and transcript != "":
|
||||
|
|
@ -1052,9 +1166,7 @@ async def flow_tts_websocket(
|
|||
if not line:
|
||||
continue
|
||||
event_data = json.loads(line)
|
||||
await client_websocket.send_json(
|
||||
{"type": "flow.build.progress", "data": event_data}
|
||||
)
|
||||
client_send({"type": "flow.build.progress", "data": event_data})
|
||||
if event_data.get("event") == "end_vertex":
|
||||
text = (
|
||||
event_data.get("data", {})
|
||||
|
|
@ -1084,7 +1196,7 @@ async def flow_tts_websocket(
|
|||
for chunk in audio_stream:
|
||||
base64_audio = base64.b64encode(chunk).decode("utf-8")
|
||||
audio_event = {"type": "response.audio.delta", "delta": base64_audio}
|
||||
await client_websocket.send_json(audio_event)
|
||||
client_send(audio_event)
|
||||
else:
|
||||
oai_client = tts_config.get_openai_client()
|
||||
voice = tts_config.get_openai_voice()
|
||||
|
|
@ -1098,26 +1210,24 @@ async def flow_tts_websocket(
|
|||
|
||||
base64_audio = base64.b64encode(response.content).decode("utf-8")
|
||||
audio_event = {"type": "response.audio.delta", "delta": base64_audio}
|
||||
await client_websocket.send_json(audio_event)
|
||||
client_send(audio_event)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error in WebSocket communication: {e}")
|
||||
|
||||
forward_to_openai_task = asyncio.create_task(forward_to_openai())
|
||||
forward_to_client_task = asyncio.create_task(forward_to_client())
|
||||
try:
|
||||
await asyncio.gather(
|
||||
forward_to_openai_task,
|
||||
forward_to_client_task,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error in WebSocket communication: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# Create tasks and gather them for concurrent execution
|
||||
task1 = asyncio.create_task(forward_to_openai())
|
||||
task2 = asyncio.create_task(forward_to_client())
|
||||
await asyncio.gather(task1, task2)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# handle any exceptions from any task
|
||||
logger.error("WS loop failed:", exc_info=exc)
|
||||
finally:
|
||||
forward_to_openai_task.cancel()
|
||||
# shared cleanup for writers & sockets
|
||||
await close()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
await client_websocket.close()
|
||||
|
||||
|
||||
def extract_transcript(json_data):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Annotated
|
|||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import Depends, HTTPException, Security, status
|
||||
from fastapi import Depends, HTTPException, Security, WebSocketException, status
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from loguru import logger
|
||||
|
|
@ -86,6 +86,55 @@ async def api_key_security(
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
async def ws_api_key_security(
|
||||
api_key: str | None,
|
||||
) -> UserRead:
|
||||
settings = get_settings_service()
|
||||
async with get_db_service().with_session() as db:
|
||||
if settings.auth_settings.AUTO_LOGIN:
|
||||
if not settings.auth_settings.SUPERUSER:
|
||||
# internal server misconfiguration
|
||||
raise WebSocketException(
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason="Missing first superuser credentials",
|
||||
)
|
||||
warnings.warn(
|
||||
("In v1.5, AUTO_LOGIN will *require* a valid API key or JWT. Please update your clients accordingly."),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if api_key:
|
||||
result = await check_key(db, api_key)
|
||||
else:
|
||||
result = await get_user_by_username(db, settings.auth_settings.SUPERUSER)
|
||||
|
||||
# normal path: must provide an API key
|
||||
else:
|
||||
if not api_key:
|
||||
raise WebSocketException(
|
||||
code=status.WS_1008_POLICY_VIOLATION,
|
||||
reason="An API key must be passed as query or header",
|
||||
)
|
||||
result = await check_key(db, api_key)
|
||||
|
||||
# key was invalid or missing
|
||||
if not result:
|
||||
raise WebSocketException(
|
||||
code=status.WS_1008_POLICY_VIOLATION,
|
||||
reason="Invalid or missing API key",
|
||||
)
|
||||
|
||||
# convert SQL-model User → pydantic UserRead
|
||||
if isinstance(result, User):
|
||||
return UserRead.model_validate(result, from_attributes=True)
|
||||
|
||||
# fallback: something unexpected happened
|
||||
raise WebSocketException(
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason="Authentication subsystem error",
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Security(oauth2_login)],
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
|
|
@ -167,16 +216,28 @@ async def get_current_user_by_jwt(
|
|||
|
||||
async def get_current_user_for_websocket(
|
||||
websocket: WebSocket,
|
||||
db: Annotated[AsyncSession, Depends(get_session)],
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
) -> User | None:
|
||||
token = websocket.query_params.get("token")
|
||||
api_key = websocket.query_params.get("x-api-key")
|
||||
db: AsyncSession,
|
||||
) -> User | UserRead:
|
||||
token = websocket.cookies.get("access_token_lf") or websocket.query_params.get("token")
|
||||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
user = await get_current_user_by_jwt(token, db)
|
||||
if user:
|
||||
return user
|
||||
|
||||
api_key = (
|
||||
websocket.query_params.get("x-api-key")
|
||||
or websocket.query_params.get("api_key")
|
||||
or websocket.headers.get("x-api-key")
|
||||
or websocket.headers.get("api_key")
|
||||
)
|
||||
if api_key:
|
||||
return await api_key_security(api_key, query_param)
|
||||
return None
|
||||
user_read = await ws_api_key_security(api_key)
|
||||
if user_read:
|
||||
return user_read
|
||||
|
||||
raise WebSocketException(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Missing or invalid credentials (cookie, token or API key)."
|
||||
)
|
||||
|
||||
|
||||
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue