diff --git a/src/backend/base/langflow/base/mcp/util.py b/src/backend/base/langflow/base/mcp/util.py index 05c356c60..2f35b312f 100644 --- a/src/backend/base/langflow/base/mcp/util.py +++ b/src/backend/base/langflow/base/mcp/util.py @@ -10,10 +10,12 @@ from urllib.parse import urlparse from uuid import UUID import httpx +from anyio import ClosedResourceError from httpx import codes as httpx_codes from langchain_core.tools import StructuredTool from loguru import logger from mcp import ClientSession +from mcp.shared.exceptions import McpError from pydantic import BaseModel, Field, create_model from sqlmodel import select @@ -23,6 +25,11 @@ from langflow.services.deps import get_settings_service HTTP_ERROR_STATUS_CODE = httpx_codes.BAD_REQUEST # HTTP status code for client errors NULLABLE_TYPE_LENGTH = 2 # Number of types in a nullable union (the type itself + null) +# HTTP status codes used in validation +HTTP_NOT_FOUND = 404 +HTTP_BAD_REQUEST = 400 +HTTP_INTERNAL_SERVER_ERROR = 500 + def sanitize_mcp_name(name: str, max_length: int = 46) -> str: """Sanitize a name for MCP usage by removing emojis, diacritics, and special characters. @@ -378,9 +385,71 @@ class MCPSessionManager: def __init__(self): self.sessions = {} # context_id -> session_info self._background_tasks = set() # Keep references to background tasks + self._last_server_by_session = {} # context_id -> server_name for tracking switches + + async def _validate_session_connectivity(self, session) -> bool: + """Validate that the session is actually usable by testing a simple operation.""" + try: + # Try to list tools as a connectivity test (this is a lightweight operation) + # Use a shorter timeout for the connectivity test to fail fast + response = await asyncio.wait_for(session.list_tools(), timeout=3.0) + except (asyncio.TimeoutError, ConnectionError, OSError, ValueError) as e: + logger.debug(f"Session connectivity test failed (standard error): {e}") + return False + except Exception as e: + # Handle MCP-specific errors that might not be in the standard list + error_str = str(e) + if ( + "ClosedResourceError" in str(type(e)) + or "Connection closed" in error_str + or "Connection lost" in error_str + or "Transport closed" in error_str + or "Stream closed" in error_str + ): + logger.debug(f"Session connectivity test failed (MCP connection error): {e}") + return False + # Re-raise unexpected errors + logger.warning(f"Unexpected error in connectivity test: {e}") + raise + else: + # Validate that we got a meaningful response + if response is None: + logger.debug("Session connectivity test failed: received None response") + return False + try: + # Check if we can access the tools list (even if empty) + tools = getattr(response, "tools", None) + if tools is None: + logger.debug("Session connectivity test failed: no tools attribute in response") + return False + except (AttributeError, TypeError) as e: + logger.debug(f"Session connectivity test failed while validating response: {e}") + return False + else: + logger.debug(f"Session connectivity test passed: found {len(tools)} tools") + return True async def get_session(self, context_id: str, connection_params, transport_type: str): """Get or create a persistent session.""" + # Extract server identifier from connection params for tracking + server_identifier = None + if transport_type == "stdio" and hasattr(connection_params, "command"): + server_identifier = f"stdio_{connection_params.command}" + elif transport_type == "sse" and isinstance(connection_params, dict) and "url" in connection_params: + server_identifier = f"sse_{connection_params['url']}" + + # Check if we're switching servers for this context + server_switched = False + if context_id in self._last_server_by_session: + last_server = self._last_server_by_session[context_id] + if last_server != server_identifier: + server_switched = True + logger.info(f"Detected server switch for context {context_id}: {last_server} -> {server_identifier}") + + # Update server tracking + if server_identifier: + self._last_server_by_session[context_id] = server_identifier + if context_id in self.sessions: session_info = self.sessions[context_id] # Check if session and background task are still alive @@ -391,19 +460,71 @@ class MCPSessionManager: # Break down the health check to understand why cleanup is triggered task_not_done = not task.done() + # Additional check for stream health + stream_is_healthy = True + try: + # Check if the session's write stream is still open + if hasattr(session, "_write_stream"): + write_stream = session._write_stream + + # Check for explicit closed state + if hasattr(write_stream, "_closed") and write_stream._closed: + stream_is_healthy = False + # Check anyio stream state for send channels + elif hasattr(write_stream, "_state") and hasattr(write_stream._state, "open_send_channels"): + # Stream is healthy if there are open send channels + stream_is_healthy = write_stream._state.open_send_channels > 0 + # Check for other stream closed indicators + elif hasattr(write_stream, "is_closing") and callable(write_stream.is_closing): + stream_is_healthy = not write_stream.is_closing() + # If we can't determine state definitively, try a simple write test + else: + # For streams we can't easily check, assume healthy unless proven otherwise + # The actual tool call will reveal if the stream is truly dead + stream_is_healthy = True + + except (AttributeError, TypeError) as e: + # If we can't check stream health due to missing attributes, + # assume it's healthy and let the tool call fail if it's not + logger.debug(f"Could not check stream health for context_id {context_id}: {e}") + stream_is_healthy = True + logger.debug(f"Session health check for context_id {context_id}:") logger.debug(f" - task_not_done: {task_not_done}") + logger.debug(f" - stream_is_healthy: {stream_is_healthy}") - # For MCP ClientSession, we rely on the background task being alive - session_is_healthy = task_not_done + # For MCP ClientSession, we need both task and stream to be healthy + session_is_healthy = task_not_done and stream_is_healthy logger.debug(f" - session_is_healthy: {session_is_healthy}") + # If we switched servers, always recreate the session to avoid cross-server contamination + if server_switched: + logger.info(f"Server switch detected for context_id {context_id}, forcing session recreation") + session_is_healthy = False + + # Always run connectivity test for sessions to ensure they're truly responsive + # This is especially important when switching between servers + elif session_is_healthy: + logger.debug(f"Running connectivity test for context_id {context_id}") + connectivity_ok = await self._validate_session_connectivity(session) + logger.debug(f" - connectivity_ok: {connectivity_ok}") + if not connectivity_ok: + session_is_healthy = False + logger.info( + f"Session for context_id {context_id} failed connectivity test, marking as unhealthy" + ) + if session_is_healthy: - logger.debug(f"Session for context_id {context_id} is healthy, reusing") + logger.debug(f"Session for context_id {context_id} is healthy and responsive, reusing") return session - msg = f"Session for context_id {context_id} failed health check: background task is done" - logger.info(msg) + + if not task_not_done: + msg = f"Session for context_id {context_id} failed health check: background task is done" + logger.info(msg) + elif not stream_is_healthy: + msg = f"Session for context_id {context_id} failed health check: stream is closed" + logger.info(msg) except Exception as e: # noqa: BLE001 msg = f"Session for context_id {context_id} is dead due to exception: {e}" @@ -458,12 +579,24 @@ class MCPSessionManager: task.add_done_callback(self._background_tasks.discard) # Wait for session to be ready - session = await session_future + try: + session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation + except asyncio.TimeoutError as timeout_err: + # Clean up the failed task + if not task.done(): + task.cancel() + import contextlib - # Store session info - self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"} - - return session + with contextlib.suppress(asyncio.CancelledError): + await task + self._background_tasks.discard(task) + msg = f"Timeout waiting for STDIO session to initialize for context {context_id}" + logger.error(msg) + raise ValueError(msg) from timeout_err + else: + # Store session info + self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"} + return session async def _create_sse_session(self, context_id: str, connection_params): """Create a new SSE session as a background task to avoid context issues.""" @@ -509,12 +642,24 @@ class MCPSessionManager: task.add_done_callback(self._background_tasks.discard) # Wait for session to be ready - session = await session_future + try: + session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation + except asyncio.TimeoutError as timeout_err: + # Clean up the failed task + if not task.done(): + task.cancel() + import contextlib - # Store session info - self.sessions[context_id] = {"session": session, "task": task, "type": "sse"} - - return session + with contextlib.suppress(asyncio.CancelledError): + await task + self._background_tasks.discard(task) + msg = f"Timeout waiting for SSE session to initialize for context {context_id}" + logger.error(msg) + raise ValueError(msg) from timeout_err + else: + # Store session info + self.sessions[context_id] = {"session": session, "task": task, "type": "sse"} + return session async def _cleanup_session(self, context_id: str): """Clean up a session by cancelling its background task.""" @@ -536,6 +681,9 @@ class MCPSessionManager: logger.info(f"issue cleaning up mcp session: {e}") finally: del self.sessions[context_id] + # Also clean up server tracking + if context_id in self._last_server_by_session: + del self._last_server_by_session[context_id] async def cleanup_all(self): """Clean up all sessions.""" @@ -652,20 +800,98 @@ class MCPStdioClient: param_hash = uuid.uuid4().hex[:8] self._session_context = f"default_{param_hash}" - try: - # Get or create persistent session - session = await self._get_or_create_session() - return await session.call_tool(tool_name, arguments=arguments) + max_retries = 2 + last_error_type = None - except (ConnectionError, TimeoutError, OSError, ValueError) as e: - msg = f"Failed to run tool '{tool_name}': {e}" - logger.error(msg) - # Clean up failed session from cache - if self._session_context and self._component_cache: - cache_key = f"mcp_session_stdio_{self._session_context}" - self._component_cache.delete(cache_key) - self._connected = False - raise ValueError(msg) from e + for attempt in range(max_retries): + try: + logger.debug(f"Attempting to run tool '{tool_name}' (attempt {attempt + 1}/{max_retries})") + # Get or create persistent session + session = await self._get_or_create_session() + + result = await asyncio.wait_for( + session.call_tool(tool_name, arguments=arguments), + timeout=30.0, # 30 second timeout + ) + except Exception as e: + current_error_type = type(e).__name__ + logger.warning(f"Tool '{tool_name}' failed on attempt {attempt + 1}: {current_error_type} - {e}") + + # Import specific MCP error types for detection + try: + is_closed_resource_error = isinstance(e, ClosedResourceError) + is_mcp_connection_error = isinstance(e, McpError) and "Connection closed" in str(e) + except ImportError: + is_closed_resource_error = "ClosedResourceError" in str(type(e)) + is_mcp_connection_error = "Connection closed" in str(e) + + # Detect timeout errors + is_timeout_error = isinstance(e, asyncio.TimeoutError | TimeoutError) + + # If we're getting the same error type repeatedly, don't retry + if last_error_type == current_error_type and attempt > 0: + logger.error(f"Repeated {current_error_type} error for tool '{tool_name}', not retrying") + break + + last_error_type = current_error_type + + # If it's a connection error (ClosedResourceError or MCP connection closed) and we have retries left + if (is_closed_resource_error or is_mcp_connection_error) and attempt < max_retries - 1: + logger.warning( + f"MCP session connection issue for tool '{tool_name}', retrying with fresh session..." + ) + # Clean up the dead session + if self._session_context: + session_manager = self._get_session_manager() + await session_manager._cleanup_session(self._session_context) + # Add a small delay before retry + await asyncio.sleep(0.5) + continue + + # If it's a timeout error and we have retries left, try once more + if is_timeout_error and attempt < max_retries - 1: + logger.warning(f"Tool '{tool_name}' timed out, retrying...") + # Don't clean up session for timeouts, might just be a slow response + await asyncio.sleep(1.0) + continue + + # For other errors or no retries left, handle as before + if ( + isinstance(e, ConnectionError | TimeoutError | OSError | ValueError) + or is_closed_resource_error + or is_mcp_connection_error + or is_timeout_error + ): + msg = f"Failed to run tool '{tool_name}' after {attempt + 1} attempts: {e}" + logger.error(msg) + # Clean up failed session from cache + if self._session_context and self._component_cache: + cache_key = f"mcp_session_stdio_{self._session_context}" + self._component_cache.delete(cache_key) + self._connected = False + raise ValueError(msg) from e + # Re-raise unexpected errors + raise + else: + logger.debug(f"Tool '{tool_name}' completed successfully") + return result + + # This should never be reached due to the exception handling above + msg = f"Failed to run tool '{tool_name}': Maximum retries exceeded with repeated {last_error_type} errors" + logger.error(msg) + raise ValueError(msg) + + async def disconnect(self): + """Properly close the connection and clean up resources.""" + # Clean up session using session manager + if self._session_context: + session_manager = self._get_session_manager() + await session_manager._cleanup_session(self._session_context) + + self.session = None + self._connection_params = None + self._connected = False + self._session_context = None async def __aenter__(self): return self @@ -707,13 +933,32 @@ class MCPSseClient: async with httpx.AsyncClient() as client: try: - # First try a HEAD request to check if server is reachable - response = await client.head(url, timeout=5.0) - if response.status_code >= HTTP_ERROR_STATUS_CODE: - return False, f"Server returned error status: {response.status_code}" + # For SSE endpoints, try a GET request with short timeout + # Many SSE servers don't support HEAD requests and return 404 + response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"}) + + # For SSE, we expect the server to either: + # 1. Start streaming (200) + # 2. Return 404 if HEAD/GET without proper SSE handshake is not supported + # 3. Return other status codes that we should handle gracefully + + # Don't fail on 404 since many SSE endpoints return this for non-SSE requests + if response.status_code == HTTP_NOT_FOUND: + # This is likely an SSE endpoint that doesn't support regular GET + # Let the actual SSE connection attempt handle this + return True, "" + + # Fail on client errors except 404, but allow server errors and redirects + if ( + HTTP_BAD_REQUEST <= response.status_code < HTTP_INTERNAL_SERVER_ERROR + and response.status_code != HTTP_NOT_FOUND + ): + return False, f"Server returned client error status: {response.status_code}" except httpx.TimeoutException: - return False, "Connection timed out. Server may be down or unreachable." + # Timeout on a short request might indicate the server is trying to stream + # This is actually expected behavior for SSE endpoints + return True, "" except httpx.NetworkError: return False, "Network error. Could not reach the server." else: @@ -728,9 +973,11 @@ class MCPSseClient: return url try: async with httpx.AsyncClient(follow_redirects=False) as client: - response = await client.request("HEAD", url) + # Use GET with SSE headers instead of HEAD since many SSE servers don't support HEAD + response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"}) if response.status_code == httpx.codes.TEMPORARY_REDIRECT: return response.headers.get("Location", url) + # Don't treat 404 as an error here - let the main connection handle it except (httpx.RequestError, httpx.HTTPError) as e: logger.warning(f"Error checking redirects: {e}") return url @@ -834,20 +1081,92 @@ class MCPSseClient: param_hash = uuid.uuid4().hex[:8] self._session_context = f"default_sse_{param_hash}" - try: - # Get or create persistent session - session = await self._get_or_create_session() - return await session.call_tool(tool_name, arguments=arguments) + max_retries = 2 + last_error_type = None - except (ConnectionError, TimeoutError, OSError, ValueError) as e: - msg = f"Failed to run tool '{tool_name}': {e}" - logger.error(msg) - # Clean up failed session from cache - if self._session_context and self._component_cache: - cache_key = f"mcp_session_sse_{self._session_context}" - self._component_cache.delete(cache_key) - self._connected = False - raise ValueError(msg) from e + for attempt in range(max_retries): + try: + logger.debug(f"Attempting to run tool '{tool_name}' (attempt {attempt + 1}/{max_retries})") + # Get or create persistent session + session = await self._get_or_create_session() + + # Add timeout to prevent hanging + import asyncio + + result = await asyncio.wait_for( + session.call_tool(tool_name, arguments=arguments), + timeout=30.0, # 30 second timeout + ) + except Exception as e: + current_error_type = type(e).__name__ + logger.warning(f"Tool '{tool_name}' failed on attempt {attempt + 1}: {current_error_type} - {e}") + + # Import specific MCP error types for detection + try: + from anyio import ClosedResourceError + from mcp.shared.exceptions import McpError + + is_closed_resource_error = isinstance(e, ClosedResourceError) + is_mcp_connection_error = isinstance(e, McpError) and "Connection closed" in str(e) + except ImportError: + is_closed_resource_error = "ClosedResourceError" in str(type(e)) + is_mcp_connection_error = "Connection closed" in str(e) + + # Detect timeout errors + is_timeout_error = isinstance(e, asyncio.TimeoutError | TimeoutError) + + # If we're getting the same error type repeatedly, don't retry + if last_error_type == current_error_type and attempt > 0: + logger.error(f"Repeated {current_error_type} error for tool '{tool_name}', not retrying") + break + + last_error_type = current_error_type + + # If it's a connection error (ClosedResourceError or MCP connection closed) and we have retries left + if (is_closed_resource_error or is_mcp_connection_error) and attempt < max_retries - 1: + logger.warning( + f"MCP session connection issue for tool '{tool_name}', retrying with fresh session..." + ) + # Clean up the dead session + if self._session_context: + session_manager = self._get_session_manager() + await session_manager._cleanup_session(self._session_context) + # Add a small delay before retry + await asyncio.sleep(0.5) + continue + + # If it's a timeout error and we have retries left, try once more + if is_timeout_error and attempt < max_retries - 1: + logger.warning(f"Tool '{tool_name}' timed out, retrying...") + # Don't clean up session for timeouts, might just be a slow response + await asyncio.sleep(1.0) + continue + + # For other errors or no retries left, handle as before + if ( + isinstance(e, ConnectionError | TimeoutError | OSError | ValueError) + or is_closed_resource_error + or is_mcp_connection_error + or is_timeout_error + ): + msg = f"Failed to run tool '{tool_name}' after {attempt + 1} attempts: {e}" + logger.error(msg) + # Clean up failed session from cache + if self._session_context and self._component_cache: + cache_key = f"mcp_session_sse_{self._session_context}" + self._component_cache.delete(cache_key) + self._connected = False + raise ValueError(msg) from e + # Re-raise unexpected errors + raise + else: + logger.debug(f"Tool '{tool_name}' completed successfully") + return result + + # This should never be reached due to the exception handling above + msg = f"Failed to run tool '{tool_name}': Maximum retries exceeded with repeated {last_error_type} errors" + logger.error(msg) + raise ValueError(msg) async def __aenter__(self): return self diff --git a/src/frontend/src/utils/mcpUtils.ts b/src/frontend/src/utils/mcpUtils.ts index 57953095a..9095ce886 100644 --- a/src/frontend/src/utils/mcpUtils.ts +++ b/src/frontend/src/utils/mcpUtils.ts @@ -1,82 +1,5 @@ import { MCPServerType } from "@/types/mcp"; -/** - * Extracts the first MCP server from a JSON string or object. - * Supports: - * 1. { mcpServers: { ... } } - * 2. { ... } (object with server keys) - * 3. a single server object - * Returns: { name, server } or throws an error. - */ -export function extractFirstMcpServerFromJson(json: string | object): { - name: string; - server: Omit; -} { - let parsed: any = json; - if (typeof json === "string") { - try { - parsed = JSON.parse(json); - } catch (e) { - throw new Error("Invalid JSON format."); - } - } - - let serverEntries: [string, Omit][] = []; - - // Case 1: { mcpServers: { ... } } - if ( - parsed && - typeof parsed === "object" && - parsed.mcpServers && - typeof parsed.mcpServers === "object" - ) { - serverEntries = Object.entries(parsed.mcpServers) as [ - string, - Omit, - ][]; - } - // Case 2: { ... } (object with server keys) - else if ( - parsed && - typeof parsed === "object" && - Object.values(parsed).some( - (v) => - v && - typeof v === "object" && - "command" in v && - Array.isArray((v as any).args), - ) - ) { - serverEntries = Object.entries(parsed).filter( - ([, v]) => - v && - typeof v === "object" && - "command" in v && - Array.isArray((v as any).args), - ) as [string, Omit][]; - } - // Case 3: single server object - else if ( - parsed && - typeof parsed === "object" && - "command" in parsed && - Array.isArray((parsed as any).args) - ) { - serverEntries = [["server", parsed]]; - } - - if (serverEntries.length === 0) { - throw new Error("No valid MCP server found in the input."); - } - const [name, server] = serverEntries[0]; - if (!server.command || !Array.isArray(server.args)) { - throw new Error( - "Each MCP server must have a 'command' and an 'args' array.", - ); - } - return { name, server }; -} - /** * Extracts all MCP servers from a JSON string or object. * Supports: @@ -117,27 +40,18 @@ export function extractMcpServersFromJson( parsed && typeof parsed === "object" && Object.values(parsed).some( - (v) => - v && - typeof v === "object" && - "command" in v && - Array.isArray((v as any).args), + (v) => v && typeof v === "object" && ("command" in v || "url" in v), ) ) { serverEntries = Object.entries(parsed).filter( - ([, v]) => - v && - typeof v === "object" && - "command" in v && - Array.isArray((v as any).args), + ([, v]) => v && typeof v === "object" && ("command" in v || "url" in v), ); } // Case 3: single server object else if ( parsed && typeof parsed === "object" && - "command" in parsed && - Array.isArray((parsed as any).args) + ("command" in parsed || "url" in parsed) ) { serverEntries = [["server", parsed]]; } @@ -147,7 +61,7 @@ export function extractMcpServersFromJson( } // Validate and map all servers const validServers = serverEntries.filter( - ([, server]) => server.command && Array.isArray(server.args), + ([, server]) => server.command || server.url, ); if (validServers.length === 0) { throw new Error("No valid MCP server found in the input."); @@ -155,7 +69,7 @@ export function extractMcpServersFromJson( return validServers.map(([name, server]) => ({ name: name.slice(0, 30), command: server.command, - args: server.args, + args: server.args || [], env: server.env && typeof server.env === "object" ? server.env : {}, url: server.url, })); diff --git a/src/frontend/tests/extended/features/mcp-server-tab.spec.ts b/src/frontend/tests/extended/features/mcp-server-tab.spec.ts index ebcf7357c..2de7d0ad2 100644 --- a/src/frontend/tests/extended/features/mcp-server-tab.spec.ts +++ b/src/frontend/tests/extended/features/mcp-server-tab.spec.ts @@ -165,7 +165,9 @@ test( await expect(page.getByTestId("icon-check")).toBeVisible(); // Get the SSE URL from the configuration - const configJson = await page.locator("pre").textContent(); + const configJson = await page.evaluate(() => { + return navigator.clipboard.readText(); + }); expect(configJson).toContain("mcpServers"); expect(configJson).toContain("mcp-proxy"); expect(configJson).toContain("uvx"); @@ -180,8 +182,13 @@ test( await page.getByText("macOS/Linux", { exact: true }).click(); await page.waitForSelector("pre", { state: "visible", timeout: 3000 }); + // Copy configuration + await page.getByTestId("icon-copy").click(); + await expect(page.getByTestId("icon-check")).toBeVisible(); - const configJsonLinux = await page.locator("pre").textContent(); + const configJsonLinux = await page.evaluate(() => { + return navigator.clipboard.readText(); + }); const sseUrlMatchLinux = configJsonLinux?.match( /"args":\s*\[\s*"mcp-proxy"\s*,\s*"([^"]+)"/,