fix: improved better health check and stream URL check on MCP, improved JSON recognition (#8982)

* Improved health check and stream URL check on MCP

* Improved health check by validating session connectivity

* Changed mcp servers from json checks

* Fixed imports

* Fixed mcp server tab test
This commit is contained in:
Lucas Oliveira 2025-07-10 13:56:02 -03:00 committed by GitHub
commit 87795931f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 380 additions and 140 deletions

View file

@ -10,10 +10,12 @@ from urllib.parse import urlparse
from uuid import UUID
import httpx
from anyio import ClosedResourceError
from httpx import codes as httpx_codes
from langchain_core.tools import StructuredTool
from loguru import logger
from mcp import ClientSession
from mcp.shared.exceptions import McpError
from pydantic import BaseModel, Field, create_model
from sqlmodel import select
@ -23,6 +25,11 @@ from langflow.services.deps import get_settings_service
HTTP_ERROR_STATUS_CODE = httpx_codes.BAD_REQUEST # HTTP status code for client errors
NULLABLE_TYPE_LENGTH = 2 # Number of types in a nullable union (the type itself + null)
# HTTP status codes used in validation
HTTP_NOT_FOUND = 404
HTTP_BAD_REQUEST = 400
HTTP_INTERNAL_SERVER_ERROR = 500
def sanitize_mcp_name(name: str, max_length: int = 46) -> str:
"""Sanitize a name for MCP usage by removing emojis, diacritics, and special characters.
@ -378,9 +385,71 @@ class MCPSessionManager:
def __init__(self):
self.sessions = {} # context_id -> session_info
self._background_tasks = set() # Keep references to background tasks
self._last_server_by_session = {} # context_id -> server_name for tracking switches
async def _validate_session_connectivity(self, session) -> bool:
"""Validate that the session is actually usable by testing a simple operation."""
try:
# Try to list tools as a connectivity test (this is a lightweight operation)
# Use a shorter timeout for the connectivity test to fail fast
response = await asyncio.wait_for(session.list_tools(), timeout=3.0)
except (asyncio.TimeoutError, ConnectionError, OSError, ValueError) as e:
logger.debug(f"Session connectivity test failed (standard error): {e}")
return False
except Exception as e:
# Handle MCP-specific errors that might not be in the standard list
error_str = str(e)
if (
"ClosedResourceError" in str(type(e))
or "Connection closed" in error_str
or "Connection lost" in error_str
or "Transport closed" in error_str
or "Stream closed" in error_str
):
logger.debug(f"Session connectivity test failed (MCP connection error): {e}")
return False
# Re-raise unexpected errors
logger.warning(f"Unexpected error in connectivity test: {e}")
raise
else:
# Validate that we got a meaningful response
if response is None:
logger.debug("Session connectivity test failed: received None response")
return False
try:
# Check if we can access the tools list (even if empty)
tools = getattr(response, "tools", None)
if tools is None:
logger.debug("Session connectivity test failed: no tools attribute in response")
return False
except (AttributeError, TypeError) as e:
logger.debug(f"Session connectivity test failed while validating response: {e}")
return False
else:
logger.debug(f"Session connectivity test passed: found {len(tools)} tools")
return True
async def get_session(self, context_id: str, connection_params, transport_type: str):
"""Get or create a persistent session."""
# Extract server identifier from connection params for tracking
server_identifier = None
if transport_type == "stdio" and hasattr(connection_params, "command"):
server_identifier = f"stdio_{connection_params.command}"
elif transport_type == "sse" and isinstance(connection_params, dict) and "url" in connection_params:
server_identifier = f"sse_{connection_params['url']}"
# Check if we're switching servers for this context
server_switched = False
if context_id in self._last_server_by_session:
last_server = self._last_server_by_session[context_id]
if last_server != server_identifier:
server_switched = True
logger.info(f"Detected server switch for context {context_id}: {last_server} -> {server_identifier}")
# Update server tracking
if server_identifier:
self._last_server_by_session[context_id] = server_identifier
if context_id in self.sessions:
session_info = self.sessions[context_id]
# Check if session and background task are still alive
@ -391,19 +460,71 @@ class MCPSessionManager:
# Break down the health check to understand why cleanup is triggered
task_not_done = not task.done()
# Additional check for stream health
stream_is_healthy = True
try:
# Check if the session's write stream is still open
if hasattr(session, "_write_stream"):
write_stream = session._write_stream
# Check for explicit closed state
if hasattr(write_stream, "_closed") and write_stream._closed:
stream_is_healthy = False
# Check anyio stream state for send channels
elif hasattr(write_stream, "_state") and hasattr(write_stream._state, "open_send_channels"):
# Stream is healthy if there are open send channels
stream_is_healthy = write_stream._state.open_send_channels > 0
# Check for other stream closed indicators
elif hasattr(write_stream, "is_closing") and callable(write_stream.is_closing):
stream_is_healthy = not write_stream.is_closing()
# If we can't determine state definitively, try a simple write test
else:
# For streams we can't easily check, assume healthy unless proven otherwise
# The actual tool call will reveal if the stream is truly dead
stream_is_healthy = True
except (AttributeError, TypeError) as e:
# If we can't check stream health due to missing attributes,
# assume it's healthy and let the tool call fail if it's not
logger.debug(f"Could not check stream health for context_id {context_id}: {e}")
stream_is_healthy = True
logger.debug(f"Session health check for context_id {context_id}:")
logger.debug(f" - task_not_done: {task_not_done}")
logger.debug(f" - stream_is_healthy: {stream_is_healthy}")
# For MCP ClientSession, we rely on the background task being alive
session_is_healthy = task_not_done
# For MCP ClientSession, we need both task and stream to be healthy
session_is_healthy = task_not_done and stream_is_healthy
logger.debug(f" - session_is_healthy: {session_is_healthy}")
# If we switched servers, always recreate the session to avoid cross-server contamination
if server_switched:
logger.info(f"Server switch detected for context_id {context_id}, forcing session recreation")
session_is_healthy = False
# Always run connectivity test for sessions to ensure they're truly responsive
# This is especially important when switching between servers
elif session_is_healthy:
logger.debug(f"Running connectivity test for context_id {context_id}")
connectivity_ok = await self._validate_session_connectivity(session)
logger.debug(f" - connectivity_ok: {connectivity_ok}")
if not connectivity_ok:
session_is_healthy = False
logger.info(
f"Session for context_id {context_id} failed connectivity test, marking as unhealthy"
)
if session_is_healthy:
logger.debug(f"Session for context_id {context_id} is healthy, reusing")
logger.debug(f"Session for context_id {context_id} is healthy and responsive, reusing")
return session
msg = f"Session for context_id {context_id} failed health check: background task is done"
logger.info(msg)
if not task_not_done:
msg = f"Session for context_id {context_id} failed health check: background task is done"
logger.info(msg)
elif not stream_is_healthy:
msg = f"Session for context_id {context_id} failed health check: stream is closed"
logger.info(msg)
except Exception as e: # noqa: BLE001
msg = f"Session for context_id {context_id} is dead due to exception: {e}"
@ -458,12 +579,24 @@ class MCPSessionManager:
task.add_done_callback(self._background_tasks.discard)
# Wait for session to be ready
session = await session_future
try:
session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation
except asyncio.TimeoutError as timeout_err:
# Clean up the failed task
if not task.done():
task.cancel()
import contextlib
# Store session info
self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"}
return session
with contextlib.suppress(asyncio.CancelledError):
await task
self._background_tasks.discard(task)
msg = f"Timeout waiting for STDIO session to initialize for context {context_id}"
logger.error(msg)
raise ValueError(msg) from timeout_err
else:
# Store session info
self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"}
return session
async def _create_sse_session(self, context_id: str, connection_params):
"""Create a new SSE session as a background task to avoid context issues."""
@ -509,12 +642,24 @@ class MCPSessionManager:
task.add_done_callback(self._background_tasks.discard)
# Wait for session to be ready
session = await session_future
try:
session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation
except asyncio.TimeoutError as timeout_err:
# Clean up the failed task
if not task.done():
task.cancel()
import contextlib
# Store session info
self.sessions[context_id] = {"session": session, "task": task, "type": "sse"}
return session
with contextlib.suppress(asyncio.CancelledError):
await task
self._background_tasks.discard(task)
msg = f"Timeout waiting for SSE session to initialize for context {context_id}"
logger.error(msg)
raise ValueError(msg) from timeout_err
else:
# Store session info
self.sessions[context_id] = {"session": session, "task": task, "type": "sse"}
return session
async def _cleanup_session(self, context_id: str):
"""Clean up a session by cancelling its background task."""
@ -536,6 +681,9 @@ class MCPSessionManager:
logger.info(f"issue cleaning up mcp session: {e}")
finally:
del self.sessions[context_id]
# Also clean up server tracking
if context_id in self._last_server_by_session:
del self._last_server_by_session[context_id]
async def cleanup_all(self):
"""Clean up all sessions."""
@ -652,20 +800,98 @@ class MCPStdioClient:
param_hash = uuid.uuid4().hex[:8]
self._session_context = f"default_{param_hash}"
try:
# Get or create persistent session
session = await self._get_or_create_session()
return await session.call_tool(tool_name, arguments=arguments)
max_retries = 2
last_error_type = None
except (ConnectionError, TimeoutError, OSError, ValueError) as e:
msg = f"Failed to run tool '{tool_name}': {e}"
logger.error(msg)
# Clean up failed session from cache
if self._session_context and self._component_cache:
cache_key = f"mcp_session_stdio_{self._session_context}"
self._component_cache.delete(cache_key)
self._connected = False
raise ValueError(msg) from e
for attempt in range(max_retries):
try:
logger.debug(f"Attempting to run tool '{tool_name}' (attempt {attempt + 1}/{max_retries})")
# Get or create persistent session
session = await self._get_or_create_session()
result = await asyncio.wait_for(
session.call_tool(tool_name, arguments=arguments),
timeout=30.0, # 30 second timeout
)
except Exception as e:
current_error_type = type(e).__name__
logger.warning(f"Tool '{tool_name}' failed on attempt {attempt + 1}: {current_error_type} - {e}")
# Import specific MCP error types for detection
try:
is_closed_resource_error = isinstance(e, ClosedResourceError)
is_mcp_connection_error = isinstance(e, McpError) and "Connection closed" in str(e)
except ImportError:
is_closed_resource_error = "ClosedResourceError" in str(type(e))
is_mcp_connection_error = "Connection closed" in str(e)
# Detect timeout errors
is_timeout_error = isinstance(e, asyncio.TimeoutError | TimeoutError)
# If we're getting the same error type repeatedly, don't retry
if last_error_type == current_error_type and attempt > 0:
logger.error(f"Repeated {current_error_type} error for tool '{tool_name}', not retrying")
break
last_error_type = current_error_type
# If it's a connection error (ClosedResourceError or MCP connection closed) and we have retries left
if (is_closed_resource_error or is_mcp_connection_error) and attempt < max_retries - 1:
logger.warning(
f"MCP session connection issue for tool '{tool_name}', retrying with fresh session..."
)
# Clean up the dead session
if self._session_context:
session_manager = self._get_session_manager()
await session_manager._cleanup_session(self._session_context)
# Add a small delay before retry
await asyncio.sleep(0.5)
continue
# If it's a timeout error and we have retries left, try once more
if is_timeout_error and attempt < max_retries - 1:
logger.warning(f"Tool '{tool_name}' timed out, retrying...")
# Don't clean up session for timeouts, might just be a slow response
await asyncio.sleep(1.0)
continue
# For other errors or no retries left, handle as before
if (
isinstance(e, ConnectionError | TimeoutError | OSError | ValueError)
or is_closed_resource_error
or is_mcp_connection_error
or is_timeout_error
):
msg = f"Failed to run tool '{tool_name}' after {attempt + 1} attempts: {e}"
logger.error(msg)
# Clean up failed session from cache
if self._session_context and self._component_cache:
cache_key = f"mcp_session_stdio_{self._session_context}"
self._component_cache.delete(cache_key)
self._connected = False
raise ValueError(msg) from e
# Re-raise unexpected errors
raise
else:
logger.debug(f"Tool '{tool_name}' completed successfully")
return result
# This should never be reached due to the exception handling above
msg = f"Failed to run tool '{tool_name}': Maximum retries exceeded with repeated {last_error_type} errors"
logger.error(msg)
raise ValueError(msg)
async def disconnect(self):
"""Properly close the connection and clean up resources."""
# Clean up session using session manager
if self._session_context:
session_manager = self._get_session_manager()
await session_manager._cleanup_session(self._session_context)
self.session = None
self._connection_params = None
self._connected = False
self._session_context = None
async def __aenter__(self):
return self
@ -707,13 +933,32 @@ class MCPSseClient:
async with httpx.AsyncClient() as client:
try:
# First try a HEAD request to check if server is reachable
response = await client.head(url, timeout=5.0)
if response.status_code >= HTTP_ERROR_STATUS_CODE:
return False, f"Server returned error status: {response.status_code}"
# For SSE endpoints, try a GET request with short timeout
# Many SSE servers don't support HEAD requests and return 404
response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"})
# For SSE, we expect the server to either:
# 1. Start streaming (200)
# 2. Return 404 if HEAD/GET without proper SSE handshake is not supported
# 3. Return other status codes that we should handle gracefully
# Don't fail on 404 since many SSE endpoints return this for non-SSE requests
if response.status_code == HTTP_NOT_FOUND:
# This is likely an SSE endpoint that doesn't support regular GET
# Let the actual SSE connection attempt handle this
return True, ""
# Fail on client errors except 404, but allow server errors and redirects
if (
HTTP_BAD_REQUEST <= response.status_code < HTTP_INTERNAL_SERVER_ERROR
and response.status_code != HTTP_NOT_FOUND
):
return False, f"Server returned client error status: {response.status_code}"
except httpx.TimeoutException:
return False, "Connection timed out. Server may be down or unreachable."
# Timeout on a short request might indicate the server is trying to stream
# This is actually expected behavior for SSE endpoints
return True, ""
except httpx.NetworkError:
return False, "Network error. Could not reach the server."
else:
@ -728,9 +973,11 @@ class MCPSseClient:
return url
try:
async with httpx.AsyncClient(follow_redirects=False) as client:
response = await client.request("HEAD", url)
# Use GET with SSE headers instead of HEAD since many SSE servers don't support HEAD
response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"})
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
return response.headers.get("Location", url)
# Don't treat 404 as an error here - let the main connection handle it
except (httpx.RequestError, httpx.HTTPError) as e:
logger.warning(f"Error checking redirects: {e}")
return url
@ -834,20 +1081,92 @@ class MCPSseClient:
param_hash = uuid.uuid4().hex[:8]
self._session_context = f"default_sse_{param_hash}"
try:
# Get or create persistent session
session = await self._get_or_create_session()
return await session.call_tool(tool_name, arguments=arguments)
max_retries = 2
last_error_type = None
except (ConnectionError, TimeoutError, OSError, ValueError) as e:
msg = f"Failed to run tool '{tool_name}': {e}"
logger.error(msg)
# Clean up failed session from cache
if self._session_context and self._component_cache:
cache_key = f"mcp_session_sse_{self._session_context}"
self._component_cache.delete(cache_key)
self._connected = False
raise ValueError(msg) from e
for attempt in range(max_retries):
try:
logger.debug(f"Attempting to run tool '{tool_name}' (attempt {attempt + 1}/{max_retries})")
# Get or create persistent session
session = await self._get_or_create_session()
# Add timeout to prevent hanging
import asyncio
result = await asyncio.wait_for(
session.call_tool(tool_name, arguments=arguments),
timeout=30.0, # 30 second timeout
)
except Exception as e:
current_error_type = type(e).__name__
logger.warning(f"Tool '{tool_name}' failed on attempt {attempt + 1}: {current_error_type} - {e}")
# Import specific MCP error types for detection
try:
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
is_closed_resource_error = isinstance(e, ClosedResourceError)
is_mcp_connection_error = isinstance(e, McpError) and "Connection closed" in str(e)
except ImportError:
is_closed_resource_error = "ClosedResourceError" in str(type(e))
is_mcp_connection_error = "Connection closed" in str(e)
# Detect timeout errors
is_timeout_error = isinstance(e, asyncio.TimeoutError | TimeoutError)
# If we're getting the same error type repeatedly, don't retry
if last_error_type == current_error_type and attempt > 0:
logger.error(f"Repeated {current_error_type} error for tool '{tool_name}', not retrying")
break
last_error_type = current_error_type
# If it's a connection error (ClosedResourceError or MCP connection closed) and we have retries left
if (is_closed_resource_error or is_mcp_connection_error) and attempt < max_retries - 1:
logger.warning(
f"MCP session connection issue for tool '{tool_name}', retrying with fresh session..."
)
# Clean up the dead session
if self._session_context:
session_manager = self._get_session_manager()
await session_manager._cleanup_session(self._session_context)
# Add a small delay before retry
await asyncio.sleep(0.5)
continue
# If it's a timeout error and we have retries left, try once more
if is_timeout_error and attempt < max_retries - 1:
logger.warning(f"Tool '{tool_name}' timed out, retrying...")
# Don't clean up session for timeouts, might just be a slow response
await asyncio.sleep(1.0)
continue
# For other errors or no retries left, handle as before
if (
isinstance(e, ConnectionError | TimeoutError | OSError | ValueError)
or is_closed_resource_error
or is_mcp_connection_error
or is_timeout_error
):
msg = f"Failed to run tool '{tool_name}' after {attempt + 1} attempts: {e}"
logger.error(msg)
# Clean up failed session from cache
if self._session_context and self._component_cache:
cache_key = f"mcp_session_sse_{self._session_context}"
self._component_cache.delete(cache_key)
self._connected = False
raise ValueError(msg) from e
# Re-raise unexpected errors
raise
else:
logger.debug(f"Tool '{tool_name}' completed successfully")
return result
# This should never be reached due to the exception handling above
msg = f"Failed to run tool '{tool_name}': Maximum retries exceeded with repeated {last_error_type} errors"
logger.error(msg)
raise ValueError(msg)
async def __aenter__(self):
return self

View file

@ -1,82 +1,5 @@
import { MCPServerType } from "@/types/mcp";
/**
* Extracts the first MCP server from a JSON string or object.
* Supports:
* 1. { mcpServers: { ... } }
* 2. { ... } (object with server keys)
* 3. a single server object
* Returns: { name, server } or throws an error.
*/
export function extractFirstMcpServerFromJson(json: string | object): {
name: string;
server: Omit<MCPServerType, "name">;
} {
let parsed: any = json;
if (typeof json === "string") {
try {
parsed = JSON.parse(json);
} catch (e) {
throw new Error("Invalid JSON format.");
}
}
let serverEntries: [string, Omit<MCPServerType, "name">][] = [];
// Case 1: { mcpServers: { ... } }
if (
parsed &&
typeof parsed === "object" &&
parsed.mcpServers &&
typeof parsed.mcpServers === "object"
) {
serverEntries = Object.entries(parsed.mcpServers) as [
string,
Omit<MCPServerType, "name">,
][];
}
// Case 2: { ... } (object with server keys)
else if (
parsed &&
typeof parsed === "object" &&
Object.values(parsed).some(
(v) =>
v &&
typeof v === "object" &&
"command" in v &&
Array.isArray((v as any).args),
)
) {
serverEntries = Object.entries(parsed).filter(
([, v]) =>
v &&
typeof v === "object" &&
"command" in v &&
Array.isArray((v as any).args),
) as [string, Omit<MCPServerType, "name">][];
}
// Case 3: single server object
else if (
parsed &&
typeof parsed === "object" &&
"command" in parsed &&
Array.isArray((parsed as any).args)
) {
serverEntries = [["server", parsed]];
}
if (serverEntries.length === 0) {
throw new Error("No valid MCP server found in the input.");
}
const [name, server] = serverEntries[0];
if (!server.command || !Array.isArray(server.args)) {
throw new Error(
"Each MCP server must have a 'command' and an 'args' array.",
);
}
return { name, server };
}
/**
* Extracts all MCP servers from a JSON string or object.
* Supports:
@ -117,27 +40,18 @@ export function extractMcpServersFromJson(
parsed &&
typeof parsed === "object" &&
Object.values(parsed).some(
(v) =>
v &&
typeof v === "object" &&
"command" in v &&
Array.isArray((v as any).args),
(v) => v && typeof v === "object" && ("command" in v || "url" in v),
)
) {
serverEntries = Object.entries(parsed).filter(
([, v]) =>
v &&
typeof v === "object" &&
"command" in v &&
Array.isArray((v as any).args),
([, v]) => v && typeof v === "object" && ("command" in v || "url" in v),
);
}
// Case 3: single server object
else if (
parsed &&
typeof parsed === "object" &&
"command" in parsed &&
Array.isArray((parsed as any).args)
("command" in parsed || "url" in parsed)
) {
serverEntries = [["server", parsed]];
}
@ -147,7 +61,7 @@ export function extractMcpServersFromJson(
}
// Validate and map all servers
const validServers = serverEntries.filter(
([, server]) => server.command && Array.isArray(server.args),
([, server]) => server.command || server.url,
);
if (validServers.length === 0) {
throw new Error("No valid MCP server found in the input.");
@ -155,7 +69,7 @@ export function extractMcpServersFromJson(
return validServers.map(([name, server]) => ({
name: name.slice(0, 30),
command: server.command,
args: server.args,
args: server.args || [],
env: server.env && typeof server.env === "object" ? server.env : {},
url: server.url,
}));

View file

@ -165,7 +165,9 @@ test(
await expect(page.getByTestId("icon-check")).toBeVisible();
// Get the SSE URL from the configuration
const configJson = await page.locator("pre").textContent();
const configJson = await page.evaluate(() => {
return navigator.clipboard.readText();
});
expect(configJson).toContain("mcpServers");
expect(configJson).toContain("mcp-proxy");
expect(configJson).toContain("uvx");
@ -180,8 +182,13 @@ test(
await page.getByText("macOS/Linux", { exact: true }).click();
await page.waitForSelector("pre", { state: "visible", timeout: 3000 });
// Copy configuration
await page.getByTestId("icon-copy").click();
await expect(page.getByTestId("icon-check")).toBeVisible();
const configJsonLinux = await page.locator("pre").textContent();
const configJsonLinux = await page.evaluate(() => {
return navigator.clipboard.readText();
});
const sseUrlMatchLinux = configJsonLinux?.match(
/"args":\s*\[\s*"mcp-proxy"\s*,\s*"([^"]+)"/,