From d461074107fd7d878d49c7733dbedd81e4402481 Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Mon, 12 May 2025 15:51:40 -0400 Subject: [PATCH] feat: add validation for node js for npx based mcp command (#7907) * add validation for node js for npx based packages * [autofix.ci] apply automated fixes * Update mcp_component.py * Update mcp_component.py * Update mcp_component.py * Update util.py * [autofix.ci] apply automated fixes * Update util.py * fix format errors * merge * fix async and format issues * wacky idea * bug: error on match only * [autofix.ci] apply automated fixes * fix error: Name "env" already defined on line 225 [no-redef] * refactor: update MCPStdioClient to use async file handling for stderr logging - Replaced synchronous tempfile usage with async version using aiofiles. - Ensured proper type casting for the temporary file name. - Updated flush operation to be asynchronous for improved performance. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: phact Co-authored-by: Gabriel Luiz Freitas Almeida --- src/backend/base/langflow/base/mcp/util.py | 116 ++++++++++++++++-- .../components/tools/mcp_component.py | 17 ++- 2 files changed, 120 insertions(+), 13 deletions(-) diff --git a/src/backend/base/langflow/base/mcp/util.py b/src/backend/base/langflow/base/mcp/util.py index 71687856f..03bd2904f 100644 --- a/src/backend/base/langflow/base/mcp/util.py +++ b/src/backend/base/langflow/base/mcp/util.py @@ -1,11 +1,15 @@ import asyncio +import contextlib import os +import platform from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack -from typing import Any +from pathlib import Path +from typing import Any, cast from urllib.parse import urlparse from uuid import UUID +import aiofiles import httpx from httpx import codes as httpx_codes from loguru import logger @@ -213,6 +217,9 @@ class MCPStdioClient: def __init__(self): self.session: ClientSession | None = None self.exit_stack = AsyncExitStack() + self.max_retries = 1 + self.retry_delay = 1.0 # seconds + self.timeout_seconds = 30 # default timeout async def connect_to_server(self, command_str: str, env: list[str] | None = None): env_dict: dict[str, str] = {} @@ -224,17 +231,102 @@ class MCPStdioClient: raise ValueError(msg) env_dict[var.split("=")[0]] = var.split("=")[1] command = command_str.split(" ") - server_params = StdioServerParameters( - command=command[0], - args=command[1:], - env={"DEBUG": "true", "PATH": os.environ["PATH"], **(env_dict or {})}, - ) - stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) - self.stdio, self.write = stdio_transport - self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) - await self.session.initialize() - response = await self.session.list_tools() - return response.tools + server_params = None + env_data: dict[str, str] = {"DEBUG": "true", "PATH": os.environ["PATH"], **(env_dict or {})} + + # Create platform-specific command wrapper + if platform.system() == "Windows": + # For Windows, use cmd.exe with error reporting + server_params = StdioServerParameters( + command="cmd", + args=[ + "/c", + f"{command[0]} {' '.join(command[1:])} || echo Command failed with exit code %errorlevel% 1>&2", + ], + env=env_data, + ) + else: + # For Unix-like systems, use bash with error reporting + server_params = StdioServerParameters( + command="bash", + args=["-c", f"{command_str} || echo 'Command failed with exit code $?' >&2"], + env=env_data, + ) + + # Create a temporary file to capture stderr + errlog_path = "" + async with aiofiles.tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", delete=False) as tmp: + errlog_path = cast(str, tmp.name) + + try: + # Pass the temp file as errlog to capture stderr + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params, errlog=tmp)) + self.stdio, self.write = stdio_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + + # Create a watcher task to monitor stderr + async def watch_stderr(): + last_size = 0 + full_log = "" + while True: + await asyncio.sleep(0.05) + await tmp.flush() + current = Path(errlog_path).stat().st_size + if current > last_size: + async with aiofiles.open(errlog_path, encoding="utf-8") as f: + await f.seek(last_size) + data = await f.read() + full_log += data + data = data.strip() + + # Check for our specific error message pattern + if "Command failed with exit code" in data: + msg = f"MCP server command failed: {command_str}\nFull error log:\n{full_log}" + raise RuntimeError(msg) + last_size = current + + # Create tasks for both operations + watcher = asyncio.create_task(watch_stderr()) + initializer = asyncio.create_task(self.session.initialize()) + + # Race them: first to finish wins + done, pending = await asyncio.wait({watcher, initializer}, return_when=asyncio.FIRST_COMPLETED) + + if watcher in done: + # stderr watcher fired → cancel and propagate its error + initializer.cancel() + watcher.result() # This will re-raise the RuntimeError + else: + # initialize succeeded → cancel watcher + watcher.cancel() + initializer.result() # Will re-raise any initialization errors + + # If we get here, initialization succeeded + response = await self.session.list_tools() + # return response.tools + + except FileNotFoundError as e: + # Command not found, raise immediately + msg = f"Command not found: {command[0]}. Error: {e}" + raise ValueError(msg) from e + except OSError as e: + # Other OS errors (e.g., permission denied) + msg = f"Failed to start command '{command[0]}': {e}" + raise ValueError(msg) from e + except RuntimeError as e: + # This is from our stderr watcher + msg = f"MCP server error: {e}" + raise ConnectionError(msg) from e + except Exception as e: + msg = f"Failed to initialize MCP session: {e}" + logger.warning(msg) + raise ConnectionError(msg) from e + else: + return response.tools + finally: + # Clean up the temp file + with contextlib.suppress(FileNotFoundError, PermissionError): + Path(errlog_path).unlink() class MCPSseClient: diff --git a/src/backend/base/langflow/components/tools/mcp_component.py b/src/backend/base/langflow/components/tools/mcp_component.py index bf0fa301e..06ecfd46b 100644 --- a/src/backend/base/langflow/components/tools/mcp_component.py +++ b/src/backend/base/langflow/components/tools/mcp_component.py @@ -1,4 +1,5 @@ import re +import shutil from typing import Any from langchain_core.tools import StructuredTool @@ -174,10 +175,19 @@ class MCPToolsComponent(Component): if mode == "Stdio" and not command: msg = "Command is required for Stdio mode" raise ValueError(msg) + if mode == "Stdio" and command: + self._validate_node_installation(command) if mode == "SSE" and not url: msg = "URL is required for SSE mode" raise ValueError(msg) + def _validate_node_installation(self, command: str) -> str: + """Validate the npx command.""" + if "npx" in command and not shutil.which("node"): + msg = "Node.js is not installed. Please install Node.js to use npx commands." + raise ValueError(msg) + return command + def _process_headers(self, headers: Any) -> dict: """Process the headers input into a valid dictionary. @@ -448,7 +458,12 @@ class MCPToolsComponent(Component): if mode == "Stdio": if not self.stdio_client.session: - self.tools = await self.stdio_client.connect_to_server(command, env) + try: + self.tools = await self.stdio_client.connect_to_server(command, env) + except ValueError as e: + msg = f"Error connecting to MCP server: {e}" + logger.exception(msg) + raise ValueError(msg) from e elif mode == "SSE" and not self.sse_client.session: try: self.tools = await self.sse_client.connect_to_server(url, headers)