diff --git a/.gitattributes b/.gitattributes index c79a33b70..6d062c470 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ Dockerfile text *.svg binary *.csv binary *.wav binary +*.raw binary \ No newline at end of file diff --git a/.gitignore b/.gitignore index 9be9589dc..4bb12b759 100644 --- a/.gitignore +++ b/.gitignore @@ -276,4 +276,4 @@ src/frontend/temp .history .dspy_cache/ -*.db +*.db \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d59bfb706..56ddecfc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,8 @@ dependencies = [ "crewai==0.102.0", "mcp>=0.9.1", "uv>=0.5.7", + "webrtcvad>=2.0.10", + "scipy>=1.14.1", "ag2>=0.1.0", "scrapegraph-py>=1.12.0", "pydantic-ai>=0.0.19", @@ -160,6 +162,9 @@ dev = [ "hypothesis>=6.123.17", "locust>=2.32.9", "pytest-rerunfailures>=15.0", + "scrapegraph-py>=1.10.2", + "pydantic-ai>=0.0.19", + "elevenlabs>=1.52.0", "faker>=37.0.0", ] diff --git a/src/backend/base/langflow/api/router.py b/src/backend/base/langflow/api/router.py index 17cf33204..571914ea6 100644 --- a/src/backend/base/langflow/api/router.py +++ b/src/backend/base/langflow/api/router.py @@ -9,12 +9,14 @@ from langflow.api.v1 import ( flows_router, folders_router, login_router, + mcp_router, monitor_router, starter_projects_router, store_router, users_router, validate_router, variables_router, + voice_mode_router, ) from langflow.api.v2 import files_router as files_router_v2 @@ -43,6 +45,8 @@ router_v1.include_router(files_router) router_v1.include_router(monitor_router) router_v1.include_router(folders_router) router_v1.include_router(starter_projects_router) +router_v1.include_router(voice_mode_router) +router_v1.include_router(mcp_router) router_v2.include_router(files_router_v2) diff --git a/src/backend/base/langflow/api/v1/__init__.py b/src/backend/base/langflow/api/v1/__init__.py index 567415af1..9f9bbe166 100644 --- a/src/backend/base/langflow/api/v1/__init__.py +++ b/src/backend/base/langflow/api/v1/__init__.py @@ -12,6 +12,7 @@ from langflow.api.v1.store import router as store_router from langflow.api.v1.users import router as users_router from langflow.api.v1.validate import router as validate_router from langflow.api.v1.variable import router as variables_router +from langflow.api.v1.voice_mode import router as voice_mode_router __all__ = [ "api_key_router", @@ -28,4 +29,5 @@ __all__ = [ "users_router", "validate_router", "variables_router", + "voice_mode_router", ] diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 0af57def0..42e0ebce2 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -526,6 +526,19 @@ async def build_vertex_stream( raise HTTPException(status_code=500, detail="Error building Component") from exc +async def build_flow_and_stream(flow_id, inputs, background_tasks, current_user): + queue_service = get_queue_service() + build_response = await build_flow( + flow_id=flow_id, + inputs=inputs, + background_tasks=background_tasks, + current_user=current_user, + queue_service=queue_service, + ) + job_id = build_response["job_id"] + return await get_build_events(job_id, queue_service) + + @router.post("/build_public_tmp/{flow_id}/flow") async def build_public_tmp( *, diff --git a/src/backend/base/langflow/api/v1/mcp.py b/src/backend/base/langflow/api/v1/mcp.py index 5db1f8354..0c358037d 100644 --- a/src/backend/base/langflow/api/v1/mcp.py +++ b/src/backend/base/langflow/api/v1/mcp.py @@ -10,7 +10,7 @@ from uuid import UUID, uuid4 import pydantic from anyio import BrokenResourceError -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from mcp import types from mcp.server import NotificationOptions, Server @@ -18,12 +18,17 @@ from mcp.server.sse import SseServerTransport from sqlmodel import select from starlette.background import BackgroundTasks -from langflow.api.v1.chat import build_flow +from langflow.api.v1.chat import build_flow_and_stream from langflow.api.v1.schemas import InputValueRequest from langflow.helpers.flow import json_schema_from_flow from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models import Flow, User -from langflow.services.deps import get_db_service, get_session, get_settings_service, get_storage_service +from langflow.services.deps import ( + get_db_service, + get_session, + get_settings_service, + get_storage_service, +) from langflow.services.storage.utils import build_content_type_from_extension logger = logging.getLogger(__name__) @@ -45,6 +50,20 @@ if False: logger.debug("MCP module loaded - debug logging enabled") +class MCPConfig: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.enable_progress_notifications = None + return cls._instance + + +def get_mcp_config(): + return MCPConfig() + + router = APIRouter(prefix="/mcp", tags=["mcp"]) server = Server("langflow-mcp-server") @@ -177,10 +196,12 @@ async def handle_list_tools(): @server.call_tool() -async def handle_call_tool( - name: str, arguments: dict, *, enable_progress_notifications: bool = Depends(get_enable_progress_notifications) -) -> list[types.TextContent]: +async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]: """Handle tool execution requests.""" + mcp_config = get_mcp_config() + if mcp_config.enable_progress_notifications is None: + settings_service = get_settings_service() + mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications try: session = await anext(get_session()) background_tasks = BackgroundTasks() @@ -196,7 +217,7 @@ async def handle_call_tool( processed_inputs = dict(arguments) # Initial progress notification - if enable_progress_notifications and (progress_token := server.request_context.meta.progressToken): + if mcp_config.enable_progress_notifications and (progress_token := server.request_context.meta.progressToken): await server.request_context.session.send_progress_notification( progress_token=progress_token, progress=0.0, total=1.0 ) @@ -207,7 +228,7 @@ async def handle_call_tool( ) async def send_progress_updates(): - if not (enable_progress_notifications and server.request_context.meta.progressToken): + if not (mcp_config.enable_progress_notifications and server.request_context.meta.progressToken): return try: @@ -220,7 +241,7 @@ async def handle_call_tool( await asyncio.sleep(1.0) except asyncio.CancelledError: # Send final 100% progress - if enable_progress_notifications: + if mcp_config.enable_progress_notifications: await server.request_context.session.send_progress_notification( progress_token=progress_token, progress=1.0, total=1.0 ) @@ -228,17 +249,16 @@ async def handle_call_tool( db_service = get_db_service() collected_results = [] - async with db_service.with_session() as async_session: + async with db_service.with_session(): try: progress_task = asyncio.create_task(send_progress_updates()) try: - response = await build_flow( + response = await build_flow_and_stream( flow_id=UUID(name), inputs=input_request, background_tasks=background_tasks, current_user=current_user, - session=async_session, ) async for line in response.body_iterator: @@ -276,7 +296,7 @@ async def handle_call_tool( except Exception as e: context = server.request_context # Send error progress if there's an exception - if enable_progress_notifications and (progress_token := context.meta.progressToken): + if mcp_config.enable_progress_notifications and (progress_token := context.meta.progressToken): await server.request_context.session.send_progress_notification( progress_token=progress_token, progress=1.0, total=1.0 ) @@ -346,4 +366,8 @@ async def handle_sse(request: Request, current_user: Annotated[User, Depends(get @router.post("/") async def handle_messages(request: Request): - await sse.handle_post_message(request.scope, request.receive, request._send) + try: + await sse.handle_post_message(request.scope, request.receive, request._send) + except BrokenResourceError as e: + logger.info("MCP Server disconnected") + raise HTTPException(status_code=404, detail=f"MCP Server disconnected, error: {e}") from e diff --git a/src/backend/base/langflow/api/v1/voice_mode.py b/src/backend/base/langflow/api/v1/voice_mode.py new file mode 100644 index 000000000..05a1a2eee --- /dev/null +++ b/src/backend/base/langflow/api/v1/voice_mode.py @@ -0,0 +1,947 @@ +import asyncio +import base64 +import json +import os + +# For sync queue and thread +import queue +import threading +import traceback +import uuid +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any +from uuid import UUID, uuid4 + +import numpy as np +import requests +import sqlalchemy.exc +import webrtcvad +import websockets +from cryptography.fernet import InvalidToken +from elevenlabs.client import ElevenLabs +from fastapi import APIRouter, BackgroundTasks, Security +from sqlalchemy import select +from starlette.websockets import WebSocket, WebSocketDisconnect + +from langflow.api.utils import CurrentActiveUser, DbSession +from langflow.api.v1.chat import build_flow_and_stream +from langflow.api.v1.schemas import InputValueRequest +from langflow.logging import logger +from langflow.memory import aadd_messagetables +from langflow.schema.properties import Properties +from langflow.services.auth.utils import api_key_header, api_key_query, api_key_security, get_current_user_by_jwt +from langflow.services.database.models.flow.model import Flow +from langflow.services.database.models.message.model import MessageTable +from langflow.services.deps import get_variable_service, session_scope +from langflow.utils.voice_utils import ( + BYTES_PER_24K_FRAME, + VAD_SAMPLE_RATE_16K, + resample_24k_to_16k, +) + +router = APIRouter(prefix="/voice", tags=["Voice"]) + +SILENCE_THRESHOLD = 0.1 +PREFIX_PADDING_MS = 100 +SILENCE_DURATION_MS = 100 +AUDIO_SAMPLE_THRESHOLD = 100 +SESSION_INSTRUCTIONS = """ +Your instructions will be divided into three mutually exclusive sections: "Permanent", "Default", and "Additional". +"Permanent" instructions are to never be overrided, superceded or otherwise ignored. +"Default" instructions are provided by default. They may never override "Permanent" + or "Additional" instructions, and they may likewise be superceded by those same other rules. +"Additional" instructions may be empty. When relevant, they override "Default" instructions, + but never "Permanent" instructions. + +[PERMANENT] The following instructions are to be considered "Permanent" +* When the user's query necessitates use of one of the enumerated tools, call the execute_flow + function to assist, and pass in the user's entire query as the input parameter, and use that + to craft your responses. +* No other function is allowed to be registered besides the execute_flow function + +[DEFAULT] The following instructions are to be considered only "Default" +* Converse with the user to assist with their question. +* Never provide URLs in repsonses, but you may use URLs in tool calls or when processing those + URLs' content. +* Always (and I mean *always*) let the user know before you call a function that you will be + doing so. +* Always update the user with the required information, when the function returns. +* Unless otherwise requested, only summarize the return results. Do not repeat everything. +* Always call the function again when requested, regardless of whether execute_flow previously + succeeded or failed. +* Never provide URLs in repsonses, but you may use URLs in tool calls or when processing those + URLs' content. + +[ADDITIONAL] The following instructions are to be considered only "Additional" +""" + + +class VoiceConfig: + def __init__(self, session_id: str): + self.session_id = session_id + self.use_elevenlabs = False + self.elevenlabs_voice = "JBFqnCBsd6RMkjVDRZzb" + self.elevenlabs_model = "eleven_multilingual_v2" + self.elevenlabs_client = None + self.elevenlabs_key = None + self.barge_in_enabled = False + + self.default_openai_realtime_session = { + "modalities": ["text", "audio"], + "instructions": SESSION_INSTRUCTIONS, + "voice": "echo", + "temperature": 0.8, + "input_audio_format": "pcm16", + "output_audio_format": "pcm16", + "turn_detection": { + "type": "server_vad", + "threshold": SILENCE_THRESHOLD, + "prefix_padding_ms": PREFIX_PADDING_MS, + "silence_duration_ms": SILENCE_DURATION_MS, + }, + "input_audio_transcription": {"model": "whisper-1"}, + "tools": [], + "tool_choice": "auto", + } + self.openai_realtime_session: dict[str, Any] = {} + + def get_session_dict(self): + """Return a copy of the default session dictionary with current settings.""" + return dict(self.default_openai_realtime_session) + + +# Create a cache for voice configs +voice_config_cache: dict[str, VoiceConfig] = {} + + +def get_voice_config(session_id: str) -> VoiceConfig: + """Get or create a VoiceConfig instance for the given session_id.""" + if session_id is None: + msg = "session_id cannot be None" + raise ValueError(msg) + + if session_id not in voice_config_cache: + voice_config_cache[session_id] = VoiceConfig(session_id) + return voice_config_cache[session_id] + + +# Create a global dictionary to store queues for each session +message_queues: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue) +# Track active message processing tasks +message_tasks: dict[str, asyncio.Task] = {} + + +async def get_flow_desc_from_db(flow_id: str) -> Flow: + """Get flow from database.""" + async with session_scope() as session: + stmt = select(Flow).where(Flow.id == UUID(flow_id)) + result = await session.exec(stmt) + flow = result.scalar_one_or_none() + if not flow: + error_message = f"Flow with id {flow_id} not found" + raise ValueError(error_message) + return flow.description + + +def pcm16_to_float_array(pcm_data): + values = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) + return values / 32768.0 # Normalize to -1.0 to 1.0 + + +async def text_chunker_with_timeout(chunks, timeout=0.3): + """Async generator that takes an async iterable (of text pieces),. + + accumulates them and yields chunks without breaking sentences. + If no new text is received within 'timeout' seconds and there is + buffered text, it flushes that text. + """ + splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ") + buffer = "" + ait = chunks.__aiter__() + while True: + try: + text = await asyncio.wait_for(ait.__anext__(), timeout=timeout) + except asyncio.TimeoutError: + if buffer: + yield buffer + " " + buffer = "" + continue + except StopAsyncIteration: + break + if text is None: + if buffer: + yield buffer + " " + break + if buffer and buffer[-1] in splitters: + yield buffer + " " + buffer = text + elif text and text[0] in splitters: + yield buffer + text[0] + " " + buffer = text[1:] + else: + buffer += text + if buffer: + yield buffer + " " + + +async def queue_generator(queue: asyncio.Queue): + """Async generator that yields items from a queue.""" + while True: + item = await queue.get() + if item is None: + break + yield item + + +async def handle_function_call( + websocket: WebSocket, + openai_ws: websockets.WebSocketClientProtocol, + function_call: dict, + function_call_args: str, + flow_id: str, + background_tasks: BackgroundTasks, + current_user: CurrentActiveUser, + conversation_id: str, +): + """Handle function calls from the OpenAI API.""" + try: + args = json.loads(function_call_args) if function_call_args else {} + input_request = InputValueRequest( + input_value=args.get("input"), components=[], type="chat", session=conversation_id + ) + response = await build_flow_and_stream( + flow_id=UUID(flow_id), + inputs=input_request, + background_tasks=background_tasks, + current_user=current_user, + ) + + result = "" + async for line in response.body_iterator: + if not line: + continue + event_data = json.loads(line) + await websocket.send_json({"type": "flow.build.progress", "data": event_data}) + if event_data.get("event") == "end_vertex": + text_part = ( + event_data.get("data", {}) + .get("build_data", "") + .get("data", {}) + .get("results", {}) + .get("message", {}) + .get("text", "") + ) + result += text_part + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": str(result), + }, + } + await openai_ws.send(json.dumps(function_output)) + await openai_ws.send(json.dumps({"type": "response.create"})) + except json.JSONDecodeError as e: + trace = traceback.format_exc() + logger.error(f"JSON decode error: {e!s}\ntrace: {trace}") + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Error parsing arguments: {e!s}", + }, + } + await openai_ws.send(json.dumps(function_output)) + except ValueError as e: + trace = traceback.format_exc() + logger.error(f"Value error: {e!s}\ntrace: {trace}") + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Error with input values: {e!s}", + }, + } + await openai_ws.send(json.dumps(function_output)) + except (ConnectionError, websockets.exceptions.WebSocketException) as e: + trace = traceback.format_exc() + logger.error(f"Connection error: {e!s}\ntrace: {trace}") + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Connection error: {e!s}", + }, + } + await openai_ws.send(json.dumps(function_output)) + except (KeyError, AttributeError, TypeError) as e: + logger.error(f"Error executing flow: {e}") + logger.error(traceback.format_exc()) + function_output = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": function_call.get("call_id"), + "output": f"Error executing flow: {e}", + }, + } + await openai_ws.send(json.dumps(function_output)) + + +# --- Synchronous text chunker using a standard queue --- +def sync_text_chunker(sync_queue_obj: queue.Queue, timeout: float = 0.3): + """Synchronous generator that reads text pieces from a sync queue. + + accumulates them and yields complete chunks. + """ + splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ") + buffer = "" + while True: + try: + text = sync_queue_obj.get(timeout=timeout) + except queue.Empty: + if buffer: + yield buffer + " " + buffer = "" + continue + if text is None: + if buffer: + yield buffer + " " + break + if buffer and buffer[-1] in splitters: + yield buffer + " " + buffer = text + elif text and text[0] in splitters: + yield buffer + text[0] + " " + buffer = text[1:] + else: + buffer += text + if buffer: + yield buffer + " " + + +@router.websocket("/ws/flow_as_tool/{flow_id}") +async def flow_as_tool_websocket_no_session( + client_websocket: WebSocket, + flow_id: str, + background_tasks: BackgroundTasks, + session: DbSession, +): + session_id = str(uuid4()) + await flow_as_tool_websocket( + client_websocket=client_websocket, + flow_id=flow_id, + background_tasks=background_tasks, + session=session, + session_id=session_id, + ) + + +@router.websocket("/ws/flow_as_tool/{flow_id}/{session_id}") +async def flow_as_tool_websocket( + client_websocket: WebSocket, + flow_id: str, + background_tasks: BackgroundTasks, + session: DbSession, + session_id: str, +): + """WebSocket endpoint registering the flow as a tool for real-time interaction.""" + try: + await client_websocket.accept() + voice_config = get_voice_config(session_id) + token = client_websocket.cookies.get("access_token_lf") + current_user = None + if token: + current_user = await get_current_user_by_jwt(token, session) + + if current_user is None: + current_user = await api_key_security(Security(api_key_query), Security(api_key_header)) + if current_user is None: + await client_websocket.send_json( + { + "type": "error", + "code": "langflow_auth", + "message": "You must pass a valid Langflow token or cookie.", + } + ) + return + + variable_service = get_variable_service() + try: + openai_key_value = await variable_service.get_variable( + user_id=current_user.id, name="OPENAI_API_KEY", field="openai_api_key", session=session + ) + openai_key = openai_key_value if openai_key_value is not None else os.getenv("OPENAI_API_KEY", "") + if not openai_key or openai_key == "dummy": + await client_websocket.send_json( + { + "type": "error", + "code": "api_key_missing", + "key_name": "OPENAI_API_KEY", + "message": "OpenAI API key not found. Please set your API key as an env var or a " + "global variable.", + } + ) + return + except Exception as e: # noqa: BLE001 + logger.error(f"Error with API key: {e}") + logger.error(traceback.format_exc()) + return + + try: + flow_description = await get_flow_desc_from_db(flow_id) + flow_tool = { + "name": "execute_flow", + "type": "function", + "description": flow_description or "Execute the flow with the given input", + "parameters": { + "type": "object", + "properties": {"input": {"type": "string", "description": "The input to send to the flow"}}, + "required": ["input"], + }, + } + except Exception as e: # noqa: BLE001 + await client_websocket.send_json({"error": f"Failed to load flow: {e!s}"}) + logger.error(f"Failed to load flow: {e}") + return + + url = "wss://api.openai.com/v1/realtime?model=gpt-4o-mini-realtime-preview" + headers = { + "Authorization": f"Bearer {openai_key}", + "OpenAI-Beta": "realtime=v1", + } + + def init_session_dict(): + session_dict = voice_config.get_session_dict() + session_dict["tools"] = [flow_tool] + return session_dict + + async with websockets.connect(url, extra_headers=headers) as openai_ws: + openai_realtime_session = init_session_dict() + session_update = {"type": "session.update", "session": openai_realtime_session} + await openai_ws.send(json.dumps(session_update)) + + # Setup for VAD processing. + vad_queue: asyncio.Queue = asyncio.Queue() + vad_audio_buffer = bytearray() + bot_speaking_flag = [False] + vad = webrtcvad.Vad(mode=3) + + async def process_vad_audio() -> None: + nonlocal vad_audio_buffer + last_speech_time = datetime.now(tz=timezone.utc) + while True: + base64_data = await vad_queue.get() + raw_chunk_24k = base64.b64decode(base64_data) + vad_audio_buffer.extend(raw_chunk_24k) + has_speech = False + while len(vad_audio_buffer) >= BYTES_PER_24K_FRAME: + frame_24k = vad_audio_buffer[:BYTES_PER_24K_FRAME] + del vad_audio_buffer[:BYTES_PER_24K_FRAME] + try: + frame_16k = resample_24k_to_16k(frame_24k) + is_speech = vad.is_speech(frame_16k, VAD_SAMPLE_RATE_16K) + if is_speech: + has_speech = True + logger.trace("!", end="") + if bot_speaking_flag[0]: + await openai_ws.send(json.dumps({"type": "response.cancel"})) + bot_speaking_flag[0] = False + except Exception as e: # noqa: BLE001 + logger.error(f"[ERROR] VAD processing failed (ValueError): {e}") + continue + if has_speech: + last_speech_time = datetime.now(tz=timezone.utc) + logger.trace(".", end="") + else: + time_since_speech = (datetime.now(tz=timezone.utc) - last_speech_time).total_seconds() + if time_since_speech >= 1.0: + logger.trace("_", end="") + + shared_state = {"last_event_type": None, "event_count": 0} + + def log_event(event, _direction: str) -> None: + event_type = event["type"] + + # Ensure shared_state has necessary keys initialized + if "last_event_type" not in shared_state: + shared_state["last_event_type"] = None + if "event_count" not in shared_state: + shared_state["event_count"] = 0 + + if event_type != shared_state["last_event_type"]: + logger.debug(f"Event (session - {session_id}): {_direction} {event_type}") + shared_state["last_event_type"] = event_type + shared_state["event_count"] = 0 + + # Explicitly convert to integer if needed + current_count = int(shared_state["event_count"]) if shared_state["event_count"] is not None else 0 + + shared_state["event_count"] = current_count + 1 + + def send_event(websocket, event, loop, direction) -> None: + asyncio.run_coroutine_threadsafe( + websocket.send_json(event), + loop, + ).result() + log_event(event, direction) + + def pass_through(from_dict, to_dict, keys): + for key in keys: + if key in from_dict: + to_dict[key] = from_dict[key] + + def merge(from_dict, to_dict, keys): + for key in keys: + if key in from_dict: + if not isinstance(from_dict[key], str): + msg = f"Only string values are supported for merge. Issue with key: {key}" + raise ValueError(msg) + new_value = from_dict[key] + + if key not in to_dict: + to_dict[key] = new_value + else: + if not isinstance(to_dict[key], str): + msg = f"Only string values are supported for merge. Issue with key: {key}" + raise ValueError(msg) + old_value = to_dict[key] + + to_dict[key] = f"{old_value}\n{new_value}" + + def warn_if_present(config_dict, keys): + for key in keys: + if key in config_dict: + logger.warning(f"Removing key {key} from session.update.") + + def update_global_session(from_session): + # Create a new session dict instead of modifying global + new_session = init_session_dict() + pass_through( + from_session, + new_session, + ["voice", "temperature", "turn_detection", "input_audio_transcription"], + ) + merge(from_session, new_session, ["instructions"]) + warn_if_present( + from_session, ["modalities", "tools", "tool_choice", "input_audio_format", "output_audio_format"] + ) + return new_session + + # --- Spawn a text delta queue and task for TTS --- + text_delta_queue: asyncio.Queue = asyncio.Queue() + text_delta_task: asyncio.Task | None = None # Will hold our background task. + + async def process_text_deltas(async_q: asyncio.Queue): + """Transfer text deltas from the async queue to a synchronous queue. + + then run the ElevenLabs TTS call (which expects a sync generator) in a separate thread. + """ + sync_q: queue.Queue = queue.Queue() + + async def transfer_text_deltas(): + while True: + item = await async_q.get() + sync_q.put(item) + if item is None: + break + + # Schedule the transfer task in the main event loop. + transfer_task = asyncio.create_task(transfer_text_deltas()) + + # Create the synchronous generator from the sync queue. + sync_gen = sync_text_chunker(sync_q, timeout=0.3) + elevenlabs_client = await get_or_create_elevenlabs_client(current_user.id, session) + if elevenlabs_client is None: + transfer_task.cancel() + return + # Capture the current event loop to schedule send operations. + main_loop = asyncio.get_running_loop() + + def tts_thread(): + # Create a new event loop for this thread. + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + + async def run_tts(): + try: + audio_stream = elevenlabs_client.generate( + voice=voice_config.elevenlabs_voice, + output_format="pcm_24000", + text=sync_gen, # synchronous generator expected by ElevenLabs + model=voice_config.elevenlabs_model, + voice_settings=None, + stream=True, + ) + for chunk in audio_stream: + base64_audio = base64.b64encode(chunk).decode("utf-8") + # Schedule sending the audio chunk in the main event loop. + event = {"type": "response.audio.delta", "delta": base64_audio} + send_event(client_websocket, event, main_loop, "↓") + + event = {"type": "response.done"} + send_event(client_websocket, event, main_loop, "↓") + except Exception as e: # noqa: BLE001 + logger.error(f"Error in TTS processing (ValueError): {e}") + + new_loop.run_until_complete(run_tts()) + new_loop.close() + + threading.Thread(target=tts_thread, daemon=True).start() + + async def forward_to_openai() -> None: + nonlocal openai_realtime_session + try: + num_audio_samples = 0 # Initialize as an integer instead of None + while True: + message_text = await client_websocket.receive_text() + msg = json.loads(message_text) + if msg.get("type") == "input_audio_buffer.append": + logger.trace(f"buffer_id {msg.get('buffer_id', '')}") + base64_data = msg.get("audio", "") + if not base64_data: + continue + # Ensure we're adding to an integer + num_audio_samples += len(base64_data) + event = {"type": "input_audio_buffer.append", "audio": base64_data} + await openai_ws.send(json.dumps(event)) + log_event(event, "↑") + if voice_config.barge_in_enabled: + await vad_queue.put(base64_data) + elif msg.get("type") == "input_audio_buffer.commit": + if num_audio_samples > AUDIO_SAMPLE_THRESHOLD: + await openai_ws.send(message_text) + log_event(msg, "↑") + num_audio_samples = 0 + elif msg.get("type") == "langflow.elevenlabs.config": + logger.info(f"langflow.elevenlabs.config {msg}") + voice_config.use_elevenlabs = msg["enabled"] + voice_config.elevenlabs_voice = msg.get("voice_id", voice_config.elevenlabs_voice) + + # Update modalities based on TTS choice + modalities = ["text"] if voice_config.use_elevenlabs else ["audio", "text"] + openai_realtime_session["modalities"] = modalities + session_update = {"type": "session.update", "session": openai_realtime_session} + await openai_ws.send(json.dumps(session_update)) + log_event(session_update, "↑") + elif msg.get("type") == "session.update": + openai_realtime_session = update_global_session(msg["session"]) + session_update = {"type": "session.update", "session": openai_realtime_session} + await openai_ws.send(json.dumps(session_update)) + log_event(session_update, "↑") + else: + await openai_ws.send(message_text) + log_event(msg, "↑") + except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError): + pass + + async def forward_to_client() -> None: + nonlocal bot_speaking_flag, text_delta_queue, text_delta_task + function_call = None + function_call_args = "" + conversation_id = str(uuid4()) + # Store function call tasks to prevent garbage collection + function_call_tasks = [] + + try: + while True: + data = await openai_ws.recv() + event = json.loads(data) + event_type = event.get("type") + + do_forward = True + do_forward = do_forward and not (event_type == "response.done" and voice_config.use_elevenlabs) + do_forward = do_forward and event_type.find("flow.") != 0 + if do_forward: + await client_websocket.send_text(data) + + if event_type == "response.text.delta": + if voice_config.use_elevenlabs: + delta = event.get("delta", "") + await text_delta_queue.put(delta) + if text_delta_task is None: + # if text_delta_task is None or text_delta_task.done(): + text_delta_task = asyncio.create_task(process_text_deltas(text_delta_queue)) + elif event_type == "response.text.done": + if voice_config.use_elevenlabs: + await text_delta_queue.put(None) + if text_delta_task and not text_delta_task.done(): + await text_delta_task + text_delta_task = None + + try: + message_text = event.get("text", "") + await add_message_to_db(message_text, session, flow_id, session_id, "Machine", "AI") + except ValueError as e: + logger.error(f"Error saving message to database (ValueError): {e}") + logger.error(traceback.format_exc()) + except (KeyError, AttributeError, TypeError) as e: + # Replace blind Exception with specific exceptions + logger.error(f"Error saving message to database: {e}") + logger.error(traceback.format_exc()) + + elif event_type == "response.output_item.added": + bot_speaking_flag[0] = True + item = event.get("item", {}) + if item.get("type") == "function_call": + function_call = item + function_call_args = "" + elif event_type == "response.output_item.done": + try: + transcript = extract_transcript(event) + if transcript and transcript.strip(): + await add_message_to_db(transcript, session, flow_id, session_id, "Machine", "AI") + except ValueError as e: + logger.error(f"Error saving message to database (ValueError): {e}") + logger.error(traceback.format_exc()) + except (KeyError, AttributeError, TypeError) as e: + # Replace blind Exception with specific exceptions + logger.error(f"Error saving message to database: {e}") + logger.error(traceback.format_exc()) + bot_speaking_flag[0] = False + elif event_type == "response.function_call_arguments.delta": + function_call_args += event.get("delta", "") + elif event_type == "response.function_call_arguments.done": + if function_call: + # Create and store reference to the task + function_call_task = asyncio.create_task( + handle_function_call( + client_websocket, + openai_ws, + function_call, + function_call_args, + flow_id, + background_tasks, + current_user, + conversation_id, + ) + ) + # Store the task reference to prevent garbage collection + function_call_tasks.append(function_call_task) + # Clean up completed tasks periodically + function_call_tasks = [t for t in function_call_tasks if not t.done()] + function_call = None + function_call_args = "" + elif event_type == "response.audio.delta": + # there are no audio deltas from OpenAI if ElevenLabs is used (because modality = ["text"]). + event.get("delta", "") + elif event_type == "conversation.item.input_audio_transcription.completed": + try: + message_text = event.get("transcript", "") + if message_text and message_text.strip(): + await add_message_to_db(message_text, session, flow_id, session_id, "User", "User") + except ValueError as e: + logger.error(f"Error saving message to database (ValueError): {e}") + logger.error(traceback.format_exc()) + except (KeyError, AttributeError, TypeError) as e: + # Replace blind Exception with specific exceptions + logger.error(f"Error saving message to database: {e}") + logger.error(traceback.format_exc()) + elif event_type == "error": + pass + else: + await client_websocket.send_text(data) + log_event(event, "↓") + + except (WebSocketDisconnect, websockets.ConnectionClosedOK, websockets.ConnectionClosedError): + pass + + # Fix for storing references to asyncio tasks + vad_task = None + if voice_config.barge_in_enabled: + # Store the task reference to prevent it from being garbage collected + vad_task = asyncio.create_task(process_vad_audio()) + + await asyncio.gather( + forward_to_openai(), + forward_to_client(), + ) + + except Exception as e: # noqa: BLE001 + logger.error(f"Value error: {e}") + logger.error(traceback.format_exc()) + finally: + # Ensure that the client websocket is closed. + try: + await client_websocket.close() + except Exception as e: # noqa: BLE001 + logger.debug(f"{e} ") + logger.info("Client websocket cleanup complete.") + # Make sure to clean up the task + if vad_task and not vad_task.done(): + vad_task.cancel() + + +@router.get("/elevenlabs/voice_ids") +async def get_elevenlabs_voice_ids( + current_user: CurrentActiveUser, + session: DbSession, +): + """Get available voice IDs from ElevenLabs API.""" + try: + # Get or create the ElevenLabs client + elevenlabs_client = await get_or_create_elevenlabs_client(current_user.id, session) + if elevenlabs_client is None: + return {"error": "ElevenLabs API key not found or invalid"} + + voices_response = elevenlabs_client.voices.get_all() + voices = voices_response.voices + + # Fix for PERF401: Use list comprehension + return [ + { + "voice_id": voice.voice_id, + "name": voice.name, + } + for voice in voices + ] + except ValueError as e: + logger.error(f"Error fetching ElevenLabs voices (ValueError): {e}") + return {"error": str(e)} + except requests.RequestException as e: + logger.error(f"Error fetching ElevenLabs voices (RequestException): {e}") + return {"error": str(e)} + except (KeyError, AttributeError, TypeError) as e: + # More specific exceptions instead of blind Exception + logger.error(f"Error fetching ElevenLabs voices: {e}") + logger.error(traceback.format_exc()) + return {"error": str(e)} + + +# Replace ElevenLabsClient class with a better implementation +class ElevenLabsClientManager: + _instance = None + _api_key = None + + @classmethod + async def get_client(cls, user_id=None, session=None): + """Get or create an ElevenLabs client with the API key.""" + if cls._instance is None: + if cls._api_key is None and user_id and session: + variable_service = get_variable_service() + try: + cls._api_key = await variable_service.get_variable( + user_id=user_id, + name="ELEVENLABS_API_KEY", + field="elevenlabs_api_key", + session=session, + ) + except (InvalidToken, ValueError) as e: + logger.error(f"Error with ElevenLabs API key: {e}") + cls._api_key = os.getenv("ELEVENLABS_API_KEY", "") + if not cls._api_key: + logger.error("ElevenLabs API key not found") + return None + except (KeyError, AttributeError, sqlalchemy.exc.SQLAlchemyError) as e: + logger.error(f"Exception getting ElevenLabs API key: {e}") + return None + + if cls._api_key: + cls._instance = ElevenLabs(api_key=cls._api_key) + + return cls._instance + + +# Update the get_or_create_elevenlabs_client function to use the new manager +async def get_or_create_elevenlabs_client(user_id=None, session=None): + """Get or create an ElevenLabs client with the API key.""" + return await ElevenLabsClientManager.get_client(user_id, session) + + +# Global dictionary to track the last sender for each session (identified by queue_key) +last_sender_by_session: defaultdict[str, str | None] = defaultdict(lambda: None) + + +async def wait_for_sender_change(queue_key, current_sender, timeout=5): + """Wait until the last sender for this session is not the same as current_sender. + + or until the timeout expires. + """ + waited = 0 + interval = 0.05 + while last_sender_by_session[queue_key] == current_sender and waited < timeout: + await asyncio.sleep(interval) + waited += interval + + +async def add_message_to_db(message, session, flow_id, session_id, sender, sender_name): + """Enforce alternating sequence by checking the last sender. + + If two consecutive messages come from the same party (e.g. AI/AI), wait briefly. + """ + queue_key = f"{flow_id}:{session_id}" + + # If the incoming sender is the same as the last recorded sender, + # wait for a change (with a timeout as a fallback). + if last_sender_by_session[queue_key] == sender: + await wait_for_sender_change(queue_key, sender, timeout=5) + last_sender_by_session[queue_key] = sender + + # Now proceed to create the message + message_obj = MessageTable( + text=message, + sender=sender, + sender_name=sender_name, + session_id=session_id, + files=[], + flow_id=uuid.UUID(flow_id) if isinstance(flow_id, str) else flow_id, + properties=Properties().model_dump(), + content_blocks=[], + category="audio", + ) + + await message_queues[queue_key].put(message_obj) + # Update last sender for this session + + if queue_key not in message_tasks or message_tasks[queue_key].done(): + message_tasks[queue_key] = asyncio.create_task(process_message_queue(queue_key, session)) + + +async def process_message_queue(queue_key, session): + """Process messages from the queue one by one.""" + try: + while True: + message = await message_queues[queue_key].get() + + try: + await aadd_messagetables([message], session) + logger.debug(f"Added message to DB: {message.text[:30]}...") + except ValueError as e: + logger.error(f"Error saving message to database (ValueError): {e}") + logger.error(traceback.format_exc()) + except sqlalchemy.exc.SQLAlchemyError as e: + logger.error(f"Error saving message to database (SQLAlchemyError): {e}") + logger.error(traceback.format_exc()) + except (KeyError, AttributeError, TypeError) as e: + # More specific exceptions instead of blind Exception + logger.error(f"Error saving message to database: {e}") + logger.error(traceback.format_exc()) + finally: + message_queues[queue_key].task_done() + + if message_queues[queue_key].empty(): + break + except Exception as e: # noqa: BLE001 + logger.debug(f"Message queue processor for {queue_key} was cancelled: {e}") + logger.error(traceback.format_exc()) + + +def extract_transcript(json_data): + try: + content_list = json_data.get("item", {}).get("content", []) + + for content_item in content_list: + if content_item.get("type") == "audio": + return content_item.get("transcript", "") + # Move this to the else block + except (KeyError, TypeError, AttributeError) as e: + logger.debug(f"Error extracting transcript: {e}") + return "" + else: + # This is now properly in the else block + return "" diff --git a/src/backend/base/langflow/base/mcp/util.py b/src/backend/base/langflow/base/mcp/util.py index e4b867b0d..5bcfb06cd 100644 --- a/src/backend/base/langflow/base/mcp/util.py +++ b/src/backend/base/langflow/base/mcp/util.py @@ -1,15 +1,21 @@ import asyncio from collections.abc import Awaitable, Callable +from typing import Any + +from pydantic import Field, create_model from langflow.helpers.base_model import BaseModel def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[[dict], Awaitable]: - async def tool_coroutine(*args): - if len(args) == 0: - msg = f"at least one positional argument is required {args}" + async def tool_coroutine(*args, **kwargs): + fields = arg_schema.model_fields.keys() + expected_field_count = len(fields) + if len(args) + len(kwargs) != expected_field_count: + msg = f"{expected_field_count} arguments are required. Received: {args} {kwargs}" raise ValueError(msg) - arg_dict = dict(zip(arg_schema.model_fields.keys(), args, strict=False)) + arg_dict = dict(zip(fields, args, strict=False)) + arg_dict.update(kwargs) return await session.call_tool(tool_name, arguments=arg_dict) return tool_coroutine @@ -24,3 +30,43 @@ def create_tool_func(tool_name: str, session) -> Callable[..., str]: return loop.run_until_complete(session.call_tool(tool_name, arguments=kwargs)) return tool_func + + +def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]: + """Converts a JSON schema into a Pydantic model dynamically. + + :param schema: The JSON schema as a dictionary. + :return: A Pydantic model class. + """ + if schema.get("type") != "object": + msg = "JSON schema must be of type 'object' at the root level." + raise ValueError(msg) + + fields = {} + properties = schema.get("properties", {}) + required_fields = set(schema.get("required", [])) + + for field_name, field_def in properties.items(): + # Extract type + field_type_str = field_def.get("type", "str") # Default to string type if not specified + field_type = { + "string": str, + "str": str, + "integer": int, + "int": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + }.get(field_type_str, Any) + + # Extract description and default if present + field_metadata = {"description": field_def.get("description", "")} + if field_name not in required_fields: + field_metadata["default"] = field_def.get("default", None) + + # Create Pydantic field + fields[field_name] = (field_type, Field(**field_metadata)) + + # Dynamically create the model + return create_model("InputSchema", **fields) diff --git a/src/backend/base/langflow/base/memory/model.py b/src/backend/base/langflow/base/memory/model.py index c48227ca7..ae696f226 100644 --- a/src/backend/base/langflow/base/memory/model.py +++ b/src/backend/base/langflow/base/memory/model.py @@ -30,6 +30,7 @@ class LCChatMemoryComponent(Component): raise ValueError(msg) def build_base_memory(self) -> BaseChatMemory: + """Builds the base memory.""" return ConversationBufferMemory(chat_memory=self.build_message_history()) @abstractmethod diff --git a/src/backend/base/langflow/components/tools/mcp_sse.py b/src/backend/base/langflow/components/tools/mcp_sse.py index af9d857e2..5a184fe7f 100644 --- a/src/backend/base/langflow/components/tools/mcp_sse.py +++ b/src/backend/base/langflow/components/tools/mcp_sse.py @@ -3,11 +3,11 @@ import asyncio from contextlib import AsyncExitStack import httpx +from langchain_core.tools import StructuredTool from mcp import ClientSession, types from mcp.client.sse import sse_client -from langflow.base.mcp.util import create_tool_coroutine, create_tool_func -from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema +from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func from langflow.custom import Component from langflow.field_typing import Tool from langflow.io import MessageTextInput, Output @@ -32,6 +32,17 @@ class MCPSseClient: return response.headers.get("Location") # Return the redirect URL return url # Return the original URL if no redirect + async def _connect_with_timeout( + self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int + ): + """Connect to the SSE server with timeout.""" + sse_transport = await self.exit_stack.enter_async_context( + sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds) + ) + self.sse, self.write = sse_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write)) + await self.session.initialize() + async def connect_to_server( self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500 ): @@ -51,18 +62,7 @@ class MCPSseClient: except asyncio.TimeoutError as err: error_message = f"Connection to {url} timed out after {timeout_seconds} seconds" raise TimeoutError(error_message) from err - else: # Only executed if no TimeoutError - return response.tools - - async def _connect_with_timeout( - self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int - ): - sse_transport = await self.exit_stack.enter_async_context( - sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds) - ) - self.sse, self.write = sse_transport - self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write)) - await self.session.initialize() + return response.tools class MCPSse(Component): @@ -98,12 +98,12 @@ class MCPSse(Component): for tool in self.tools: args_schema = create_input_schema_from_json_schema(tool.inputSchema) tool_list.append( - Tool( + StructuredTool( name=tool.name, # maybe format this description=tool.description, args_schema=args_schema, - coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), func=create_tool_func(tool.name, self.client.session), + coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), ) ) diff --git a/src/backend/base/langflow/components/tools/mcp_stdio.py b/src/backend/base/langflow/components/tools/mcp_stdio.py index f1bfe8b6d..48535a00d 100644 --- a/src/backend/base/langflow/components/tools/mcp_stdio.py +++ b/src/backend/base/langflow/components/tools/mcp_stdio.py @@ -1,13 +1,12 @@ # from langflow.field_typing import Data import os from contextlib import AsyncExitStack -from typing import Any +from langchain_core.tools import StructuredTool from mcp import ClientSession, StdioServerParameters, types from mcp.client.stdio import stdio_client -from pydantic import BaseModel, Field, create_model -from langflow.base.mcp.util import create_tool_coroutine, create_tool_func +from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func from langflow.custom import Component from langflow.field_typing import Tool from langflow.io import MessageTextInput, Output @@ -36,46 +35,6 @@ class MCPStdioClient: return response.tools -def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]: - """Converts a JSON schema into a Pydantic model dynamically. - - :param schema: The JSON schema as a dictionary. - :return: A Pydantic model class. - """ - if schema.get("type") != "object": - msg = "JSON schema must be of type 'object' at the root level." - raise ValueError(msg) - - fields = {} - properties = schema.get("properties", {}) - required_fields = set(schema.get("required", [])) - - for field_name, field_def in properties.items(): - # Extract type - field_type_str = field_def.get("type", "str") # Default to string type if not specified - field_type = { - "string": str, - "str": str, - "integer": int, - "int": int, - "number": float, - "boolean": bool, - "array": list, - "object": dict, - }.get(field_type_str, Any) - - # Extract description and default if present - field_metadata = {"description": field_def.get("description", "")} - if field_name not in required_fields: - field_metadata["default"] = field_def.get("default", None) - - # Create Pydantic field - fields[field_name] = (field_type, Field(**field_metadata)) - - # Dynamically create the model - return create_model("InputSchema", **fields) - - class MCPStdio(Component): client = MCPStdioClient() tools = types.ListToolsResult @@ -111,11 +70,12 @@ class MCPStdio(Component): for tool in self.tools: args_schema = create_input_schema_from_json_schema(tool.inputSchema) tool_list.append( - Tool( + StructuredTool( name=tool.name, description=tool.description, - coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), + args_schema=args_schema, func=create_tool_func(tool.name, args_schema), + coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), ) ) self.tool_names = [tool.name for tool in self.tools] diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index 13afaaf84..13a0ac0ec 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -1,9 +1,11 @@ +# mypy: ignore-errors import ast import asyncio import contextlib import inspect import re import traceback +from pathlib import Path from typing import Any from uuid import UUID @@ -560,3 +562,134 @@ async def update_component_build_config( if inspect.iscoroutinefunction(component.update_build_config): return await component.update_build_config(build_config, field_value, field_name) return await asyncio.to_thread(component.update_build_config, build_config, field_value, field_name) + + +async def get_all_types_dict(components_paths: list[str]): + """Get all types dictionary with full component loading.""" + # This is the async version of the existing function + return await abuild_custom_components(components_paths=components_paths) + + +async def get_single_component_dict(component_type: str, component_name: str, components_paths: list[str]): + """Get a single component dictionary.""" + # For example, if components are loaded by importing Python modules: + for base_path in components_paths: + module_path = Path(base_path) / component_type / f"{component_name}.py" + if module_path.exists(): + # Try to import the module + module_name = f"langflow.components.{component_type}.{component_name}" + try: + # This is a simplified example - actual implementation may vary + import importlib.util + + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "template"): + return module.template + except ImportError as e: + logger.error(f"Import error loading component {module_path}: {e!s}") + except AttributeError as e: + logger.error(f"Attribute error loading component {module_path}: {e!s}") + except ValueError as e: + logger.error(f"Value error loading component {module_path}: {e!s}") + except (KeyError, IndexError) as e: + logger.error(f"Data structure error loading component {module_path}: {e!s}") + except RuntimeError as e: + logger.error(f"Runtime error loading component {module_path}: {e!s}") + logger.debug("Full traceback for runtime error", exc_info=True) + except OSError as e: + logger.error(f"OS error loading component {module_path}: {e!s}") + + # If we get here, the component wasn't found or couldn't be loaded + return None + + +async def load_custom_component(component_name: str, components_paths: list[str]): + """Load a custom component by name. + + Args: + component_name: Name of the component to load + components_paths: List of paths to search for components + """ + from langflow.interface.custom_component import get_custom_component_from_name + + try: + # First try to get the component from the registered components + component_class = get_custom_component_from_name(component_name) + if component_class: + # Define the function locally if it's not imported + def get_custom_component_template(component_cls): + """Get template for a custom component class.""" + # This is a simplified implementation - adjust as needed + if hasattr(component_cls, "get_template"): + return component_cls.get_template() + if hasattr(component_cls, "template"): + return component_cls.template + return None + + return get_custom_component_template(component_class) + + # If not found in registered components, search in the provided paths + for path in components_paths: + # Try to find the component in different category directories + base_path = Path(path) + if base_path.exists() and base_path.is_dir(): + # Search for the component in all subdirectories + for category_dir in base_path.iterdir(): + if category_dir.is_dir(): + component_file = category_dir / f"{component_name}.py" + if component_file.exists(): + # Try to import the module + module_name = f"langflow.components.{category_dir.name}.{component_name}" + try: + import importlib.util + + spec = importlib.util.spec_from_file_location(module_name, component_file) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "template"): + return module.template + if hasattr(module, "get_template"): + return module.get_template() + except ImportError as e: + logger.error(f"Import error loading component {component_file}: {e!s}") + logger.debug("Import error traceback", exc_info=True) + except AttributeError as e: + logger.error(f"Attribute error loading component {component_file}: {e!s}") + logger.debug("Attribute error traceback", exc_info=True) + except (ValueError, TypeError) as e: + logger.error(f"Value/Type error loading component {component_file}: {e!s}") + logger.debug("Value/Type error traceback", exc_info=True) + except (KeyError, IndexError) as e: + logger.error(f"Data structure error loading component {component_file}: {e!s}") + logger.debug("Data structure error traceback", exc_info=True) + except RuntimeError as e: + logger.error(f"Runtime error loading component {component_file}: {e!s}") + logger.debug("Runtime error traceback", exc_info=True) + except OSError as e: + logger.error(f"OS error loading component {component_file}: {e!s}") + logger.debug("OS error traceback", exc_info=True) + + except ImportError as e: + logger.error(f"Import error loading custom component {component_name}: {e!s}") + return None + except AttributeError as e: + logger.error(f"Attribute error loading custom component {component_name}: {e!s}") + return None + except ValueError as e: + logger.error(f"Value error loading custom component {component_name}: {e!s}") + return None + except (KeyError, IndexError) as e: + logger.error(f"Data structure error loading custom component {component_name}: {e!s}") + return None + except RuntimeError as e: + logger.error(f"Runtime error loading custom component {component_name}: {e!s}") + logger.debug("Full traceback for runtime error", exc_info=True) + return None + + # If we get here, the component wasn't found in any of the paths + logger.warning(f"Component {component_name} not found in any of the provided paths") + return None diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index 63473cdf5..f448b1307 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -702,6 +702,16 @@ class Vertex: event_manager: EventManager | None = None, **kwargs, ) -> Any: + # Add lazy loading check at the beginning + # Check if we need to fully load this component first + from langflow.interface.components import ensure_component_loaded + from langflow.services.deps import get_settings_service + + if get_settings_service().settings.lazy_load_components: + component_name = self.id.split("-")[0] + await ensure_component_loaded(self.vertex_type, component_name, get_settings_service()) + + # Continue with the original implementation async with self._lock: if self.state == VertexStates.INACTIVE: # If the vertex is inactive, return None diff --git a/src/backend/base/langflow/interface/components.py b/src/backend/base/langflow/interface/components.py index 904025a30..592415696 100644 --- a/src/backend/base/langflow/interface/components.py +++ b/src/backend/base/langflow/interface/components.py @@ -1,24 +1,280 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any from loguru import logger -from langflow.custom.utils import abuild_custom_components, build_custom_components +from langflow.custom.utils import abuild_custom_components if TYPE_CHECKING: from langflow.services.settings.service import SettingsService -async def aget_all_types_dict(components_paths): - """Get all types dictionary combining native and custom components.""" +# Create a class to manage component cache instead of using globals +class ComponentCache: + def __init__(self): + self.all_types_dict: dict[str, Any] | None = None + self.fully_loaded_components: dict[str, bool] = {} + + +# Singleton instance +component_cache = ComponentCache() + + +async def get_and_cache_all_types_dict( + settings_service: SettingsService, +): + """Get and cache the types dictionary, with partial loading support.""" + if component_cache.all_types_dict is None: + logger.debug("Building langchain types dict") + + if settings_service.settings.lazy_load_components: + # Partial loading mode - just load component metadata + logger.debug("Using partial component loading") + component_cache.all_types_dict = await aget_component_metadata(settings_service.settings.components_path) + else: + # Traditional full loading + component_cache.all_types_dict = await aget_all_types_dict(settings_service.settings.components_path) + + # Log loading stats + component_count = sum(len(comps) for comps in component_cache.all_types_dict.get("components", {}).values()) + logger.debug(f"Loaded {component_count} components") + + return component_cache.all_types_dict + + +async def aget_all_types_dict(components_paths: list[str]): + """Get all types dictionary with full component loading.""" return await abuild_custom_components(components_paths=components_paths) -def get_all_types_dict(components_paths): - """Get all types dictionary combining native and custom components.""" - return build_custom_components(components_paths=components_paths) +async def aget_component_metadata(components_paths: list[str]): + """Get just the metadata for all components without loading full templates.""" + # This builds a skeleton of the all_types_dict with just basic component info + + components_dict: dict = {"components": {}} + + # Get all component types + component_types = await discover_component_types(components_paths) + logger.debug(f"Discovered {len(component_types)} component types: {', '.join(component_types)}") + + # For each component type directory + for component_type in component_types: + components_dict["components"][component_type] = {} + + # Get list of components in this type + component_names = await discover_component_names(component_type, components_paths) + logger.debug(f"Found {len(component_names)} components for type {component_type}") + + # Create stub entries with just basic metadata + for name in component_names: + # Get minimal metadata for component + metadata = await get_component_minimal_metadata(component_type, name, components_paths) + + if metadata: + components_dict["components"][component_type][name] = metadata + # Mark as needing full loading + components_dict["components"][component_type][name]["lazy_loaded"] = True + + return components_dict + + +async def discover_component_types(components_paths: list[str]) -> list[str]: + """Discover available component types by scanning directories.""" + component_types: set[str] = set() + + for path in components_paths: + path_obj = Path(path) + if not path_obj.exists(): + continue + + for item in path_obj.iterdir(): + # Only include directories that don't start with _ or . + if item.is_dir() and not item.name.startswith(("_", ".")): + component_types.add(item.name) + + # Add known types that might not be in directories + standard_types = { + "agents", + "chains", + "embeddings", + "llms", + "memories", + "prompts", + "tools", + "retrievers", + "textsplitters", + "toolkits", + "utilities", + "vectorstores", + "custom_components", + "documentloaders", + "outputparsers", + "wrappers", + } + + component_types.update(standard_types) + + return sorted(component_types) + + +async def discover_component_names(component_type: str, components_paths: list[str]) -> list[str]: + """Discover component names for a specific type by scanning directories.""" + component_names: set[str] = set() + + for path in components_paths: + type_dir = Path(path) / component_type + + if type_dir.exists(): + for filename in type_dir.iterdir(): + # Get Python files that don't start with __ + if filename.name.endswith(".py") and not filename.name.startswith("__"): + component_name = filename.name[:-3] # Remove .py extension + component_names.add(component_name) + + return sorted(component_names) + + +async def get_component_minimal_metadata(component_type: str, component_name: str, components_paths: list[str]): + """Extract minimal metadata for a component without loading its full implementation.""" + # Create a more complete metadata structure that the UI needs + metadata = { + "display_name": component_name.replace("_", " ").title(), + "name": component_name, + "type": component_type, + "description": f"A {component_type} component (not fully loaded)", + "template": { + "_type": component_type, + "inputs": {}, + "outputs": {}, + "output_types": [], + "documentation": f"A {component_type} component", + "display_name": component_name.replace("_", " ").title(), + "base_classes": [component_type], + }, + } + + # Try to find the file to verify it exists + component_path = None + for path in components_paths: + candidate_path = Path(path) / component_type / f"{component_name}.py" + if candidate_path.exists(): + component_path = candidate_path + break + + if not component_path: + return None + + return metadata + + +async def ensure_component_loaded(component_type: str, component_name: str, settings_service: SettingsService): + """Ensure a component is fully loaded if it was only partially loaded.""" + # If already fully loaded, return immediately + component_key = f"{component_type}:{component_name}" + if component_key in component_cache.fully_loaded_components: + return + + # If we don't have a cache or the component doesn't exist in the cache, nothing to do + if ( + not component_cache.all_types_dict + or "components" not in component_cache.all_types_dict + or component_type not in component_cache.all_types_dict["components"] + or component_name not in component_cache.all_types_dict["components"][component_type] + ): + return + + # Check if component is marked for lazy loading + if component_cache.all_types_dict["components"][component_type][component_name].get("lazy_loaded", False): + logger.debug(f"Fully loading component {component_type}:{component_name}") + + # Load just this specific component + full_component = await load_single_component( + component_type, component_name, settings_service.settings.components_path + ) + + if full_component: + # Replace the stub with the fully loaded component + component_cache.all_types_dict["components"][component_type][component_name] = full_component + # Remove lazy_loaded flag if it exists + if "lazy_loaded" in component_cache.all_types_dict["components"][component_type][component_name]: + del component_cache.all_types_dict["components"][component_type][component_name]["lazy_loaded"] + + # Mark as fully loaded + component_cache.fully_loaded_components[component_key] = True + logger.debug(f"Component {component_type}:{component_name} fully loaded") + else: + logger.warning(f"Failed to fully load component {component_type}:{component_name}") + + +async def load_single_component(component_type: str, component_name: str, components_paths: list[str]): + """Load a single component fully.""" + from langflow.custom.utils import get_single_component_dict + + try: + # Delegate to a more specific function that knows how to load + # a single component of a specific type + return await get_single_component_dict(component_type, component_name, components_paths) + except (ImportError, ModuleNotFoundError) as e: + # Handle issues with importing the component or its dependencies + logger.error(f"Import error loading component {component_type}:{component_name}: {e!s}") + return None + except (AttributeError, TypeError) as e: + # Handle issues with component structure or type errors + logger.error(f"Component structure error for {component_type}:{component_name}: {e!s}") + return None + except FileNotFoundError as e: + # Handle missing files + logger.error(f"File not found for component {component_type}:{component_name}: {e!s}") + return None + except ValueError as e: + # Handle invalid values or configurations + logger.error(f"Invalid configuration for component {component_type}:{component_name}: {e!s}") + return None + except (KeyError, IndexError) as e: + # Handle data structure access errors + logger.error(f"Data structure error for component {component_type}:{component_name}: {e!s}") + return None + except RuntimeError as e: + # Handle runtime errors + logger.error(f"Runtime error loading component {component_type}:{component_name}: {e!s}") + logger.debug("Full traceback for runtime error", exc_info=True) + return None + except OSError as e: + # Handle OS-related errors (file system, permissions, etc.) + logger.error(f"OS error loading component {component_type}:{component_name}: {e!s}") + return None + + +# Also add a utility function to load specific component types +async def get_type_dict(component_type: str, settings_service: SettingsService | None = None): + """Get a specific component type dictionary, loading if needed.""" + if settings_service is None: + # Import here to avoid circular imports + from langflow.services.deps import get_settings_service + + settings_service = get_settings_service() + + # Make sure all_types_dict is loaded (at least partially) + if component_cache.all_types_dict is None: + await get_and_cache_all_types_dict(settings_service) + + # Check if component type exists in the cache + if ( + component_cache.all_types_dict + and "components" in component_cache.all_types_dict + and component_type in component_cache.all_types_dict["components"] + ): + # If in lazy mode, ensure all components of this type are fully loaded + if settings_service.settings.lazy_load_components: + for component_name in list(component_cache.all_types_dict["components"][component_type].keys()): + await ensure_component_loaded(component_type, component_name, settings_service) + + return component_cache.all_types_dict["components"][component_type] + + return {} # TypeError: unhashable type: 'list' @@ -43,7 +299,10 @@ async def aget_all_components(components_paths, *, as_dict=False): def get_all_components(components_paths, *, as_dict=False): """Get all components names combining native and custom components.""" - all_types_dict = get_all_types_dict(components_paths) + # Import here to avoid circular imports + from langflow.custom.utils import build_custom_components + + all_types_dict = build_custom_components(components_paths=components_paths) components = [] if not as_dict else {} for category in all_types_dict.values(): for component in category.values(): @@ -53,17 +312,3 @@ def get_all_components(components_paths, *, as_dict=False): else: components.append(component) return components - - -all_types_dict_cache = None - - -async def get_and_cache_all_types_dict( - settings_service: SettingsService, -): - global all_types_dict_cache # noqa: PLW0603 - if all_types_dict_cache is None: - logger.debug("Building langchain types dict") - all_types_dict_cache = await aget_all_types_dict(settings_service.settings.components_path) - - return all_types_dict_cache diff --git a/src/backend/base/langflow/logging/logger.py b/src/backend/base/langflow/logging/logger.py index 51a78a9d7..c3124844b 100644 --- a/src/backend/base/langflow/logging/logger.py +++ b/src/backend/base/langflow/logging/logger.py @@ -19,7 +19,7 @@ from typing_extensions import NotRequired, override from langflow.settings import DEV -VALID_LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] +VALID_LOG_LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] # Human-readable DEFAULT_LOG_FORMAT = ( "{time:YYYY-MM-DD HH:mm:ss} - {level: <8} - {module} - {message}" diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index 9371805a8..1b9152e77 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -125,19 +125,51 @@ def get_lifespan(*, fix_migration=False, version=None): temp_dirs: list[TemporaryDirectory] = [] sync_flows_from_fs_task = None try: + start_time = asyncio.get_event_loop().time() + + rprint("[bold blue]Initializing services[/bold blue]") await initialize_services(fix_migration=fix_migration) + rprint(f"✓ Services initialized in {asyncio.get_event_loop().time() - start_time:.2f}s") + + current_time = asyncio.get_event_loop().time() + rprint("[bold blue]Setting up LLM caching[/bold blue]") setup_llm_caching() + rprint(f"✓ LLM caching setup in {asyncio.get_event_loop().time() - current_time:.2f}s") + + current_time = asyncio.get_event_loop().time() + rprint("[bold blue]Initializing super user[/bold blue]") await initialize_super_user_if_needed() + rprint(f"✓ Super user initialized in {asyncio.get_event_loop().time() - current_time:.2f}s") + + current_time = asyncio.get_event_loop().time() + rprint("[bold blue]Loading bundles[/bold blue]") temp_dirs, bundles_components_paths = await load_bundles_with_error_handling() get_settings_service().settings.components_path.extend(bundles_components_paths) + rprint(f"✓ Bundles loaded in {asyncio.get_event_loop().time() - current_time:.2f}s") + + current_time = asyncio.get_event_loop().time() + rprint("[bold blue]Caching types[/bold blue]") all_types_dict = await get_and_cache_all_types_dict(get_settings_service()) + rprint(f"✓ Types cached in {asyncio.get_event_loop().time() - current_time:.2f}s") + + current_time = asyncio.get_event_loop().time() + rprint("[bold blue]Creating/updating starter projects[/bold blue]") await create_or_update_starter_projects(all_types_dict) + rprint(f"✓ Starter projects updated in {asyncio.get_event_loop().time() - current_time:.2f}s") + telemetry_service.start() + + current_time = asyncio.get_event_loop().time() + rprint("[bold blue]Loading flows[/bold blue]") await load_flows_from_directory() sync_flows_from_fs_task = asyncio.create_task(sync_flows_from_fs()) queue_service = get_queue_service() if not queue_service.is_started(): # Start if not already started queue_service.start() + rprint(f"✓ Flows loaded in {asyncio.get_event_loop().time() - current_time:.2f}s") + + total_time = asyncio.get_event_loop().time() - start_time + rprint(f"[bold green]✓ Total initialization time: {total_time:.2f}s[/bold green]") yield except Exception as exc: @@ -166,6 +198,7 @@ def create_app(): __version__ = get_version_info()["version"] + rprint("configuring") configure() lifespan = get_lifespan(version=__version__) app = FastAPI(lifespan=lifespan, title="Langflow", version=__version__) diff --git a/src/backend/base/langflow/services/settings/base.py b/src/backend/base/langflow/services/settings/base.py index c2e930f42..77dec1817 100644 --- a/src/backend/base/langflow/services/settings/base.py +++ b/src/backend/base/langflow/services/settings/base.py @@ -238,6 +238,9 @@ class Settings(BaseSettings): Default is 24 hours (86400 seconds). Minimum is 600 seconds (10 minutes).""" event_delivery: Literal["polling", "streaming"] = "polling" """How to deliver build events to the frontend. Can be 'polling' or 'streaming'.""" + lazy_load_components: bool = False + """If set to True, Langflow will only partially load components at startup and fully load them on demand. + This significantly reduces startup time but may cause a slight delay when a component is first used.""" @field_validator("dev") @classmethod diff --git a/src/backend/base/langflow/utils/voice_utils.py b/src/backend/base/langflow/utils/voice_utils.py new file mode 100644 index 000000000..6afa8897f --- /dev/null +++ b/src/backend/base/langflow/utils/voice_utils.py @@ -0,0 +1,92 @@ +import asyncio +import base64 +from pathlib import Path + +import numpy as np +from scipy.signal import resample + +from langflow.logging import logger + +SAMPLE_RATE_24K = 24000 +VAD_SAMPLE_RATE_16K = 16000 +FRAME_DURATION_MS = 20 +BYTES_PER_SAMPLE = 2 + +BYTES_PER_24K_FRAME = int(SAMPLE_RATE_24K * FRAME_DURATION_MS / 1000) * BYTES_PER_SAMPLE +BYTES_PER_16K_FRAME = int(VAD_SAMPLE_RATE_16K * FRAME_DURATION_MS / 1000) * BYTES_PER_SAMPLE + + +def resample_24k_to_16k(frame_24k_bytes): + """Resample a 20ms frame from 24kHz to 16kHz. + + Args: + frame_24k_bytes: A bytes object containing 20ms of 24kHz audio (960 bytes) + + Returns: + A bytes object containing 20ms of 16kHz audio (640 bytes) + + Raises: + ValueError: If the input frame is not exactly 960 bytes + """ + if len(frame_24k_bytes) != BYTES_PER_24K_FRAME: + msg = f"Expected exactly {BYTES_PER_24K_FRAME} bytes for 24kHz frame, got {len(frame_24k_bytes)}" + raise ValueError(msg) + + # Convert bytes to numpy array of int16 + frame_24k = np.frombuffer(frame_24k_bytes, dtype=np.int16) + + # Resample from 24kHz to 16kHz (2/3 ratio) + # For a 20ms frame, we go from 480 samples to 320 samples + frame_16k = resample(frame_24k, int(len(frame_24k) * 2 / 3)) + + # Convert back to int16 and then to bytes + frame_16k = frame_16k.astype(np.int16) + return frame_16k.tobytes() + + +# def resample_24k_to_16k(frame_24k_bytes: bytes) -> bytes: +# """ +# Convert one 20ms chunk (960 bytes @ 24kHz) to 20ms @ 16kHz (640 bytes). +# Raises ValueError if the frame is not exactly 960 bytes. +# """ +# if len(frame_24k_bytes) != BYTES_PER_24K_FRAME: +# raise ValueError( +# f"Expected exactly {BYTES_PER_24K_FRAME} bytes for a 20ms 24k frame, " +# f"but got {len(frame_24k_bytes)}" +# ) +# # Convert bytes -> int16 array (480 samples) +# samples_24k = np.frombuffer(frame_24k_bytes, dtype=np.int16) +# +# # Resample 24k => 16k (ratio=2/3) +# # Should get 320 samples out if the chunk was exactly 480 samples in +# samples_16k = resample_poly(samples_24k, up=2, down=3) +# +# # Round & convert to int16 +# samples_16k = np.rint(samples_16k).astype(np.int16) +# +# # Convert back to bytes +# frame_16k_bytes = samples_16k.tobytes() +# if len(frame_16k_bytes) != BYTES_PER_16K_FRAME: +# raise ValueError( +# f"Expected exactly {BYTES_PER_16K_FRAME} bytes after resampling " +# f"to 20ms@16kHz, got {len(frame_16k_bytes)}" +# ) +# return frame_16k_bytes +# + + +async def write_audio_to_file(audio_base64: str, filename: str = "output_audio.raw") -> None: + """Decode the base64-encoded audio and write (append) it to a file asynchronously.""" + try: + audio_bytes = base64.b64decode(audio_base64) + # Use asyncio.to_thread to perform file I/O without blocking the event loop + await asyncio.to_thread(_write_bytes_to_file, audio_bytes, filename) + logger.info(f"Wrote {len(audio_bytes)} bytes to {filename}") + except (OSError, base64.binascii.Error) as e: # type: ignore[attr-defined] + logger.error(f"Error writing audio to file: {e}") + + +def _write_bytes_to_file(data: bytes, filename: str) -> None: + """Helper function to write bytes to a file using a context manager.""" + with Path(filename).open("ab") as f: + f.write(data) diff --git a/src/backend/base/pyproject.toml b/src/backend/base/pyproject.toml index af26c5a91..2b554477e 100644 --- a/src/backend/base/pyproject.toml +++ b/src/backend/base/pyproject.toml @@ -83,6 +83,7 @@ dependencies = [ "greenlet>=3.1.1", "jsonquerylang>=1.1.1", "sqlalchemy[aiosqlite]>=2.0.38,<3.0.0", + "elevenlabs>=1.54.0", ] [dependency-groups] diff --git a/src/backend/tests/data/debug_incoming_24k.raw b/src/backend/tests/data/debug_incoming_24k.raw new file mode 100644 index 000000000..efc09ee4e Binary files /dev/null and b/src/backend/tests/data/debug_incoming_24k.raw differ diff --git a/src/backend/tests/unit/test_voice_mode.py b/src/backend/tests/unit/test_voice_mode.py new file mode 100644 index 000000000..320eb18a5 --- /dev/null +++ b/src/backend/tests/unit/test_voice_mode.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest +import webrtcvad +from langflow.utils.voice_utils import ( + BYTES_PER_16K_FRAME, + BYTES_PER_24K_FRAME, + SAMPLE_RATE_24K, + VAD_SAMPLE_RATE_16K, + resample_24k_to_16k, +) + + +def test_resample_24k_to_16k_valid_frame(): + """Test that valid 960-byte frames (20ms @ 24kHz) resample to 640 bytes (20ms @ 16kHz).""" + # Generate a fake 20ms @ 24kHz frame (960 bytes) + duration_samples_24k = int(0.02 * SAMPLE_RATE_24K) # 480 samples + # Use the newer numpy random Generator + rng = np.random.default_rng() + fake_frame_24k = (rng.random(duration_samples_24k) * 32767).astype(np.int16) + frame_24k_bytes = fake_frame_24k.tobytes() + + assert len(frame_24k_bytes) == BYTES_PER_24K_FRAME # 960 + + # Resample + frame_16k_bytes = resample_24k_to_16k(frame_24k_bytes) + + # Check length after resampling + assert len(frame_16k_bytes) == BYTES_PER_16K_FRAME # 640 + + +def test_resample_24k_to_16k_invalid_frame(): + """Test that passing an invalid size frame raises a ValueError.""" + invalid_frame = b"\x00\x01" * 100 # only 200 bytes, not 960 + with pytest.raises(ValueError, match="Expected exactly"): + _ = resample_24k_to_16k(invalid_frame) + + +def test_webrtcvad_silence_detection(): + """Make sure that passing all-zero frames leads to is_speech == False.""" + vad = webrtcvad.Vad(mode=0) + + # Generate 1 second of silence @16k, chunk it in 20ms frames + num_samples = VAD_SAMPLE_RATE_16K # 1 second + silent_audio = np.zeros(num_samples, dtype=np.int16).tobytes() + + frame_size = BYTES_PER_16K_FRAME # 640 + num_frames = len(silent_audio) // frame_size + + speech_frames = 0 + for i in range(num_frames): + frame_16k = silent_audio[i * frame_size : (i + 1) * frame_size] + + is_speech = vad.is_speech(frame_16k, VAD_SAMPLE_RATE_16K) + if is_speech: + speech_frames += 1 + + # Expect zero frames labeled as speech + assert speech_frames == 0 + + +def test_webrtcvad_with_real_data(): + """End-to-end test. + + - Generate synthetic 24kHz audio + - Break into 20ms frames + - Resample to 16k + - Check how many frames VAD detects as speech. + This test is approximate, since random audio won't always be "speech." + """ + # Instead of reading from a file, generate synthetic audio + # Create 1 second of random audio data at 24kHz + num_samples = SAMPLE_RATE_24K # 1 second + rng = np.random.default_rng(seed=42) # Use a fixed seed for reproducibility + + # Generate random audio (this won't be detected as speech, but that's fine for testing) + raw_data_24k = (rng.random(num_samples) * 32767).astype(np.int16).tobytes() + + # We'll chunk into 20ms frames (960 bytes each) + frame_size_24k = BYTES_PER_24K_FRAME # 960 + total_frames = len(raw_data_24k) // frame_size_24k + + vad = webrtcvad.Vad(mode=2) + + resampled_all = bytearray() + speech_count = 0 + for i in range(total_frames): + frame_24k = raw_data_24k[i * frame_size_24k : (i + 1) * frame_size_24k] + frame_16k = resample_24k_to_16k(frame_24k) + + resampled_all.extend(frame_16k) # Append to our buffer + + is_speech = vad.is_speech(frame_16k, VAD_SAMPLE_RATE_16K) + if is_speech: + speech_count += 1 + + # For random noise, we expect very few frames to be detected as speech + # We're not making a strict assertion, just verifying the process works + assert len(resampled_all) == (total_frames * BYTES_PER_16K_FRAME) + + # Log the speech detection rate + speech_count / total_frames if total_frames > 0 else 0 diff --git a/src/frontend/package-lock.json b/src/frontend/package-lock.json index 1ac06ef0a..0d1a22df8 100644 --- a/src/frontend/package-lock.json +++ b/src/frontend/package-lock.json @@ -1005,14 +1005,14 @@ }, "node_modules/@esbuild/darwin-arm64": { "version": "0.21.5", - "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.21.5.tgz", - "integrity": "sha512-DwqXqZyuk5AiWWf3UfLiRDJ5EDd49zg6O9wclZ7kUMv2WRFr4HKjXp/5t8JZ11QbQfUS6/cRCKGwYhtNAY88kQ==", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.21.5.tgz", + "integrity": "sha512-1rYdTpyv03iycF1+BhzrzQJCdOuAOtaqHTWJZCWvijKD2N5Xu0TtVC8/+1faWqcP9iBCWOmjmhoH94dH82BxPQ==", "cpu": [ - "arm64" + "x64" ], "optional": true, "os": [ - "darwin" + "linux" ], "engines": { "node": ">=12" @@ -2007,14 +2007,14 @@ }, "node_modules/@million/lint/node_modules/@esbuild/darwin-arm64": { "version": "0.20.2", - "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.20.2.tgz", - "integrity": "sha512-4J6IRT+10J3aJH3l1yzEg9y3wkTDgDk7TSDFX+wKFiWjqWp/iCfLIYzGyasx9l0SAFPT1HwSCR+0w/h1ES/MjA==", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.20.2.tgz", + "integrity": "sha512-1MdwI6OOTsfQfek8sLwgyjOXAu+wKhLEoaOLTjbijk6E2WONYpH9ZU2mNtR+lZ2B4uwr+usqGuVfFT9tMtGvGw==", "cpu": [ - "arm64" + "x64" ], "optional": true, "os": [ - "darwin" + "linux" ], "engines": { "node": ">=12" @@ -2415,14 +2415,14 @@ }, "node_modules/@napi-rs/nice-darwin-arm64": { "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@napi-rs/nice-darwin-arm64/-/nice-darwin-arm64-1.0.1.tgz", - "integrity": "sha512-91k3HEqUl2fsrz/sKkuEkscj6EAj3/eZNCLqzD2AA0TtVbkQi8nqxZCZDMkfklULmxLkMxuUdKe7RvG/T6s2AA==", + "resolved": "https://registry.npmjs.org/@napi-rs/nice-linux-x64-gnu/-/nice-linux-x64-gnu-1.0.1.tgz", + "integrity": "sha512-XQAJs7DRN2GpLN6Fb+ZdGFeYZDdGl2Fn3TmFlqEL5JorgWKrQGRUrpGKbgZ25UeZPILuTKJ+OowG2avN8mThBA==", "cpu": [ - "arm64" + "x64" ], "optional": true, "os": [ - "darwin" + "linux" ], "engines": { "node": ">= 10" @@ -4122,11 +4122,11 @@ "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.36.0.tgz", "integrity": "sha512-JQ1Jk5G4bGrD4pWJQzWsD8I1n1mgPXq33+/vP4sk8j/z/C2siRuxZtaUA7yMTf71TCZTZl/4e1bfzwUmFb3+rw==", "cpu": [ - "arm64" + "x64" ], "optional": true, "os": [ - "darwin" + "linux" ] }, "node_modules/@rollup/rollup-darwin-x64": { @@ -4699,12 +4699,12 @@ "resolved": "https://registry.npmjs.org/@swc/core-darwin-arm64/-/core-darwin-arm64-1.11.11.tgz", "integrity": "sha512-vJcjGVDB8cZH7zyOkC0AfpFYI/7GHKG0NSsH3tpuKrmoAXJyCYspKPGid7FT53EAlWreN7+Pew+bukYf5j+Fmg==", "cpu": [ - "arm64" + "x64" ], "dev": true, "optional": true, "os": [ - "darwin" + "linux" ], "engines": { "node": ">=10" @@ -8449,19 +8449,6 @@ "optional": true, "peer": true }, - "node_modules/fsevents": { - "version": "2.3.2", - "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", - "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", - "hasInstallScript": true, - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": "^8.16.0 || ^10.6.0 || >=11.0.0" - } - }, "node_modules/function-bind": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", @@ -15644,19 +15631,6 @@ } } }, - "node_modules/vite/node_modules/fsevents": { - "version": "2.3.3", - "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", - "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", - "hasInstallScript": true, - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": "^8.16.0 || ^10.6.0 || >=11.0.0" - } - }, "node_modules/w3c-keyname": { "version": "2.2.8", "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", diff --git a/src/frontend/src/components/core/GlobalVariableModal/GlobalVariableModal.tsx b/src/frontend/src/components/core/GlobalVariableModal/GlobalVariableModal.tsx index f1da51631..6932af455 100644 --- a/src/frontend/src/components/core/GlobalVariableModal/GlobalVariableModal.tsx +++ b/src/frontend/src/components/core/GlobalVariableModal/GlobalVariableModal.tsx @@ -166,7 +166,7 @@ export default function GlobalVariableModal({ -
+
{option}
- - onRemove(e as unknown as React.MouseEvent) - } - data-testid="remove-icon-badge" - /> +
+ + onRemove(e as unknown as React.MouseEvent) + } + data-testid="remove-icon-badge" + /> +
); @@ -71,17 +73,24 @@ const CommandItemContent = ({ isSelected, optionButton, nodeStyle, + commandWidth, }: { option: string; isSelected: boolean; optionButton: (option: string) => ReactNode; nodeStyle?: string; + commandWidth?: string; }) => (
-
+
{option}
@@ -119,6 +128,7 @@ const getInputClassName = ( disabled: boolean, password: boolean, selectedOptions: string[], + blockAddNewGlobalVariable: boolean = false, ) => { return cn( "popover-input nodrag w-full truncate px-1 pr-4", @@ -127,6 +137,7 @@ const getInputClassName = ( disabled && "disabled:text-muted disabled:opacity-100 placeholder:disabled:text-muted-foreground", password && "text-clip pr-14", + blockAddNewGlobalVariable && "text-clip pr-8", selectedOptions?.length >= 0 && "cursor-default", ); }; @@ -173,6 +184,8 @@ const CustomInputPopover = ({ optionButton, autoFocus, popoverWidth, + commandWidth, + blockAddNewGlobalVariable, }) => { const [isFocused, setIsFocused] = useState(false); const memoizedOptions = useMemo(() => new Set(options), [options]); @@ -230,7 +243,11 @@ const CustomInputPopover = ({
) : selectedOption?.length > 0 ? ( -
+
handleRemoveOption(selectedOption, e)} @@ -266,6 +283,7 @@ const CustomInputPopover = ({ disabled, password, selectedOptions, + blockAddNewGlobalVariable, )} placeholder={ selectedOptions?.length > 0 || selectedOption ? "" : placeholder @@ -318,6 +336,7 @@ const CustomInputPopover = ({ } optionButton={optionButton} nodeStyle={nodeStyle} + commandWidth={commandWidth} /> ))} diff --git a/src/frontend/src/components/core/parameterRenderComponent/components/inputComponent/index.tsx b/src/frontend/src/components/core/parameterRenderComponent/components/inputComponent/index.tsx index 8d9ac3db4..ec7302279 100644 --- a/src/frontend/src/components/core/parameterRenderComponent/components/inputComponent/index.tsx +++ b/src/frontend/src/components/core/parameterRenderComponent/components/inputComponent/index.tsx @@ -40,6 +40,8 @@ export default function InputComponent({ nodeStyle, isToolMode, popoverWidth, + commandWidth, + blockAddNewGlobalVariable = false, }: InputComponentType): JSX.Element { const [pwdVisible, setPwdVisible] = useState(false); const refInput = useRef(null); @@ -151,54 +153,57 @@ export default function InputComponent({ optionsPlaceholder={optionsPlaceholder} nodeStyle={nodeStyle} popoverWidth={popoverWidth} + commandWidth={commandWidth} + blockAddNewGlobalVariable={blockAddNewGlobalVariable} /> )} )} - {(setSelectedOption || setSelectedOptions) && ( - - - - )} + > +