feat: add validation for node js for npx based mcp command (#7907)

* add validation for node js for npx based packages

* [autofix.ci] apply automated fixes

* Update mcp_component.py

* Update mcp_component.py

* Update mcp_component.py

* Update util.py

* [autofix.ci] apply automated fixes

* Update util.py

* fix format errors

* merge

* fix async and format issues

* wacky idea

* bug: error on match only

* [autofix.ci] apply automated fixes

* fix error: Name "env" already defined on line 225  [no-redef]

* refactor: update MCPStdioClient to use async file handling for stderr logging

- Replaced synchronous tempfile usage with async version using aiofiles.
- Ensured proper type casting for the temporary file name.
- Updated flush operation to be asynchronous for improved performance.

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: phact <estevezsebastian@gmail.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Edwin Jose 2025-05-12 15:51:40 -04:00 committed by GitHub
commit d461074107
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 120 additions and 13 deletions

View file

@ -1,11 +1,15 @@
import asyncio
import contextlib
import os
import platform
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from typing import Any
from pathlib import Path
from typing import Any, cast
from urllib.parse import urlparse
from uuid import UUID
import aiofiles
import httpx
from httpx import codes as httpx_codes
from loguru import logger
@ -213,6 +217,9 @@ class MCPStdioClient:
def __init__(self):
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
self.max_retries = 1
self.retry_delay = 1.0 # seconds
self.timeout_seconds = 30 # default timeout
async def connect_to_server(self, command_str: str, env: list[str] | None = None):
env_dict: dict[str, str] = {}
@ -224,17 +231,102 @@ class MCPStdioClient:
raise ValueError(msg)
env_dict[var.split("=")[0]] = var.split("=")[1]
command = command_str.split(" ")
server_params = StdioServerParameters(
command=command[0],
args=command[1:],
env={"DEBUG": "true", "PATH": os.environ["PATH"], **(env_dict or {})},
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
response = await self.session.list_tools()
return response.tools
server_params = None
env_data: dict[str, str] = {"DEBUG": "true", "PATH": os.environ["PATH"], **(env_dict or {})}
# Create platform-specific command wrapper
if platform.system() == "Windows":
# For Windows, use cmd.exe with error reporting
server_params = StdioServerParameters(
command="cmd",
args=[
"/c",
f"{command[0]} {' '.join(command[1:])} || echo Command failed with exit code %errorlevel% 1>&2",
],
env=env_data,
)
else:
# For Unix-like systems, use bash with error reporting
server_params = StdioServerParameters(
command="bash",
args=["-c", f"{command_str} || echo 'Command failed with exit code $?' >&2"],
env=env_data,
)
# Create a temporary file to capture stderr
errlog_path = ""
async with aiofiles.tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", delete=False) as tmp:
errlog_path = cast(str, tmp.name)
try:
# Pass the temp file as errlog to capture stderr
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params, errlog=tmp))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
# Create a watcher task to monitor stderr
async def watch_stderr():
last_size = 0
full_log = ""
while True:
await asyncio.sleep(0.05)
await tmp.flush()
current = Path(errlog_path).stat().st_size
if current > last_size:
async with aiofiles.open(errlog_path, encoding="utf-8") as f:
await f.seek(last_size)
data = await f.read()
full_log += data
data = data.strip()
# Check for our specific error message pattern
if "Command failed with exit code" in data:
msg = f"MCP server command failed: {command_str}\nFull error log:\n{full_log}"
raise RuntimeError(msg)
last_size = current
# Create tasks for both operations
watcher = asyncio.create_task(watch_stderr())
initializer = asyncio.create_task(self.session.initialize())
# Race them: first to finish wins
done, pending = await asyncio.wait({watcher, initializer}, return_when=asyncio.FIRST_COMPLETED)
if watcher in done:
# stderr watcher fired → cancel and propagate its error
initializer.cancel()
watcher.result() # This will re-raise the RuntimeError
else:
# initialize succeeded → cancel watcher
watcher.cancel()
initializer.result() # Will re-raise any initialization errors
# If we get here, initialization succeeded
response = await self.session.list_tools()
# return response.tools
except FileNotFoundError as e:
# Command not found, raise immediately
msg = f"Command not found: {command[0]}. Error: {e}"
raise ValueError(msg) from e
except OSError as e:
# Other OS errors (e.g., permission denied)
msg = f"Failed to start command '{command[0]}': {e}"
raise ValueError(msg) from e
except RuntimeError as e:
# This is from our stderr watcher
msg = f"MCP server error: {e}"
raise ConnectionError(msg) from e
except Exception as e:
msg = f"Failed to initialize MCP session: {e}"
logger.warning(msg)
raise ConnectionError(msg) from e
else:
return response.tools
finally:
# Clean up the temp file
with contextlib.suppress(FileNotFoundError, PermissionError):
Path(errlog_path).unlink()
class MCPSseClient:

View file

@ -1,4 +1,5 @@
import re
import shutil
from typing import Any
from langchain_core.tools import StructuredTool
@ -174,10 +175,19 @@ class MCPToolsComponent(Component):
if mode == "Stdio" and not command:
msg = "Command is required for Stdio mode"
raise ValueError(msg)
if mode == "Stdio" and command:
self._validate_node_installation(command)
if mode == "SSE" and not url:
msg = "URL is required for SSE mode"
raise ValueError(msg)
def _validate_node_installation(self, command: str) -> str:
"""Validate the npx command."""
if "npx" in command and not shutil.which("node"):
msg = "Node.js is not installed. Please install Node.js to use npx commands."
raise ValueError(msg)
return command
def _process_headers(self, headers: Any) -> dict:
"""Process the headers input into a valid dictionary.
@ -448,7 +458,12 @@ class MCPToolsComponent(Component):
if mode == "Stdio":
if not self.stdio_client.session:
self.tools = await self.stdio_client.connect_to_server(command, env)
try:
self.tools = await self.stdio_client.connect_to_server(command, env)
except ValueError as e:
msg = f"Error connecting to MCP server: {e}"
logger.exception(msg)
raise ValueError(msg) from e
elif mode == "SSE" and not self.sse_client.session:
try:
self.tools = await self.sse_client.connect_to_server(url, headers)