diff --git a/src/backend/base/langflow/base/mcp/util.py b/src/backend/base/langflow/base/mcp/util.py index 353b5d657..62d0ca2b9 100644 --- a/src/backend/base/langflow/base/mcp/util.py +++ b/src/backend/base/langflow/base/mcp/util.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import os import platform import re @@ -30,6 +31,13 @@ HTTP_NOT_FOUND = 404 HTTP_BAD_REQUEST = 400 HTTP_INTERNAL_SERVER_ERROR = 500 +# MCP Session Manager constants +settings = get_settings_service().settings +MAX_SESSIONS_PER_SERVER = ( + settings.mcp_max_sessions_per_server +) # Maximum number of sessions per server to prevent resource exhaustion +SESSION_IDLE_TIMEOUT = settings.mcp_session_idle_timeout # 5 minutes idle timeout for sessions +SESSION_CLEANUP_INTERVAL = settings.mcp_session_cleanup_interval # Cleanup interval in seconds # RFC 7230 compliant header name pattern: token = 1*tchar # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / # "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA @@ -460,12 +468,90 @@ async def _validate_connection_params(mode: str, command: str | None = None, url class MCPSessionManager: - """Manages persistent MCP sessions with proper context manager lifecycle.""" + """Manages persistent MCP sessions with proper context manager lifecycle. + + Fixed version that addresses the memory leak issue by: + 1. Session reuse based on server identity rather than unique context IDs + 2. Maximum session limits per server to prevent resource exhaustion + 3. Idle timeout for automatic session cleanup + 4. Periodic cleanup of stale sessions + """ def __init__(self): - self.sessions = {} # context_id -> session_info + # Structure: server_key -> {"sessions": {session_id: session_info}, "last_cleanup": timestamp} + self.sessions_by_server = {} self._background_tasks = set() # Keep references to background tasks - self._last_server_by_session = {} # context_id -> server_name for tracking switches + # Backwards-compatibility maps: which context_id uses which (server_key, session_id) + self._context_to_session: dict[str, tuple[str, str]] = {} + # Reference count for each active (server_key, session_id) + self._session_refcount: dict[tuple[str, str], int] = {} + self._cleanup_task = None + self._start_cleanup_task() + + def _start_cleanup_task(self): + """Start the periodic cleanup task.""" + if self._cleanup_task is None or self._cleanup_task.done(): + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + self._background_tasks.add(self._cleanup_task) + self._cleanup_task.add_done_callback(self._background_tasks.discard) + + async def _periodic_cleanup(self): + """Periodically clean up idle sessions.""" + while True: + try: + await asyncio.sleep(SESSION_CLEANUP_INTERVAL) + await self._cleanup_idle_sessions() + except asyncio.CancelledError: + break + except (RuntimeError, KeyError, ClosedResourceError, ValueError, asyncio.TimeoutError) as e: + # Handle common recoverable errors without stopping the cleanup loop + logger.warning(f"Error in periodic cleanup: {e}") + + async def _cleanup_idle_sessions(self): + """Clean up sessions that have been idle for too long.""" + current_time = asyncio.get_event_loop().time() + servers_to_remove = [] + + for server_key, server_data in self.sessions_by_server.items(): + sessions = server_data.get("sessions", {}) + sessions_to_remove = [] + + for session_id, session_info in sessions.items(): + if current_time - session_info["last_used"] > SESSION_IDLE_TIMEOUT: + sessions_to_remove.append(session_id) + + # Clean up idle sessions + for session_id in sessions_to_remove: + logger.info(f"Cleaning up idle session {session_id} for server {server_key}") + await self._cleanup_session_by_id(server_key, session_id) + + # Remove server entry if no sessions left + if not sessions: + servers_to_remove.append(server_key) + + # Clean up empty server entries + for server_key in servers_to_remove: + del self.sessions_by_server[server_key] + + def _get_server_key(self, connection_params, transport_type: str) -> str: + """Generate a consistent server key based on connection parameters.""" + if transport_type == "stdio": + if hasattr(connection_params, "command"): + # Include command, args, and environment for uniqueness + command_str = f"{connection_params.command} {' '.join(connection_params.args or [])}" + env_str = str(sorted((connection_params.env or {}).items())) + key_input = f"{command_str}|{env_str}" + return f"stdio_{hash(key_input)}" + elif transport_type == "sse" and (isinstance(connection_params, dict) and "url" in connection_params): + # Include URL and headers for uniqueness + url = connection_params["url"] + headers = str(sorted((connection_params.get("headers", {})).items())) + key_input = f"{url}|{headers}" + return f"sse_{hash(key_input)}" + + # Fallback to a generic key + # TODO: add option for streamable HTTP in future. + return f"{transport_type}_{hash(str(connection_params))}" async def _validate_session_connectivity(self, session) -> bool: """Validate that the session is actually usable by testing a simple operation.""" @@ -483,6 +569,7 @@ class MCPSessionManager: "ClosedResourceError" in str(type(e)) or "Connection closed" in error_str or "Connection lost" in error_str + or "Connection failed" in error_str or "Transport closed" in error_str or "Stream closed" in error_str ): @@ -510,117 +597,83 @@ class MCPSessionManager: return True async def get_session(self, context_id: str, connection_params, transport_type: str): - """Get or create a persistent session.""" - # Extract server identifier from connection params for tracking - server_identifier = None - if transport_type == "stdio" and hasattr(connection_params, "command"): - server_identifier = f"stdio_{connection_params.command}" - elif transport_type == "sse" and isinstance(connection_params, dict) and "url" in connection_params: - server_identifier = f"sse_{connection_params['url']}" + """Get or create a session with improved reuse strategy. - # Check if we're switching servers for this context - server_switched = False - if context_id in self._last_server_by_session: - last_server = self._last_server_by_session[context_id] - if last_server != server_identifier: - server_switched = True - logger.info(f"Detected server switch for context {context_id}: {last_server} -> {server_identifier}") + The key insight is that we should reuse sessions based on the server + identity (command + args for stdio, URL for SSE) rather than the context_id. + This prevents creating a new subprocess for each unique context. + """ + server_key = self._get_server_key(connection_params, transport_type) - # Update server tracking - if server_identifier: - self._last_server_by_session[context_id] = server_identifier + # Ensure server entry exists + if server_key not in self.sessions_by_server: + self.sessions_by_server[server_key] = {"sessions": {}, "last_cleanup": asyncio.get_event_loop().time()} - if context_id in self.sessions: - session_info = self.sessions[context_id] - # Check if session and background task are still alive - try: - session = session_info["session"] - task = session_info["task"] + server_data = self.sessions_by_server[server_key] + sessions = server_data["sessions"] - # Break down the health check to understand why cleanup is triggered - task_not_done = not task.done() + # Try to find a healthy existing session + for session_id, session_info in sessions.items(): + session = session_info["session"] + task = session_info["task"] - # Additional check for stream health - stream_is_healthy = True - try: - # Check if the session's write stream is still open - if hasattr(session, "_write_stream"): - write_stream = session._write_stream + # Check if session is still alive + if not task.done(): + # Update last used time + session_info["last_used"] = asyncio.get_event_loop().time() - # Check for explicit closed state - if hasattr(write_stream, "_closed") and write_stream._closed: - stream_is_healthy = False - # Check anyio stream state for send channels - elif hasattr(write_stream, "_state") and hasattr(write_stream._state, "open_send_channels"): - # Stream is healthy if there are open send channels - stream_is_healthy = write_stream._state.open_send_channels > 0 - # Check for other stream closed indicators - elif hasattr(write_stream, "is_closing") and callable(write_stream.is_closing): - stream_is_healthy = not write_stream.is_closing() - # If we can't determine state definitively, try a simple write test - else: - # For streams we can't easily check, assume healthy unless proven otherwise - # The actual tool call will reveal if the stream is truly dead - stream_is_healthy = True - - except (AttributeError, TypeError) as e: - # If we can't check stream health due to missing attributes, - # assume it's healthy and let the tool call fail if it's not - logger.debug(f"Could not check stream health for context_id {context_id}: {e}") - stream_is_healthy = True - - logger.debug(f"Session health check for context_id {context_id}:") - logger.debug(f" - task_not_done: {task_not_done}") - logger.debug(f" - stream_is_healthy: {stream_is_healthy}") - - # For MCP ClientSession, we need both task and stream to be healthy - session_is_healthy = task_not_done and stream_is_healthy - - logger.debug(f" - session_is_healthy: {session_is_healthy}") - - # If we switched servers, always recreate the session to avoid cross-server contamination - if server_switched: - logger.info(f"Server switch detected for context_id {context_id}, forcing session recreation") - session_is_healthy = False - - # Always run connectivity test for sessions to ensure they're truly responsive - # This is especially important when switching between servers - elif session_is_healthy: - logger.debug(f"Running connectivity test for context_id {context_id}") - connectivity_ok = await self._validate_session_connectivity(session) - logger.debug(f" - connectivity_ok: {connectivity_ok}") - if not connectivity_ok: - session_is_healthy = False - logger.info( - f"Session for context_id {context_id} failed connectivity test, marking as unhealthy" - ) - - if session_is_healthy: - logger.debug(f"Session for context_id {context_id} is healthy and responsive, reusing") + # Quick health check + if await self._validate_session_connectivity(session): + logger.debug(f"Reusing existing session {session_id} for server {server_key}") + # record mapping & bump ref-count for backwards compatibility + self._context_to_session[context_id] = (server_key, session_id) + self._session_refcount[(server_key, session_id)] = ( + self._session_refcount.get((server_key, session_id), 0) + 1 + ) return session + logger.info(f"Session {session_id} for server {server_key} failed health check, cleaning up") + await self._cleanup_session_by_id(server_key, session_id) + else: + # Task is done, clean up + logger.info(f"Session {session_id} for server {server_key} task is done, cleaning up") + await self._cleanup_session_by_id(server_key, session_id) - if not task_not_done: - msg = f"Session for context_id {context_id} failed health check: background task is done" - logger.info(msg) - elif not stream_is_healthy: - msg = f"Session for context_id {context_id} failed health check: stream is closed" - logger.info(msg) - - except Exception as e: # noqa: BLE001 - msg = f"Session for context_id {context_id} is dead due to exception: {e}" - logger.info(msg) - # Session is dead, clean it up - await self._cleanup_session(context_id) + # Check if we've reached the maximum number of sessions for this server + if len(sessions) >= MAX_SESSIONS_PER_SERVER: + # Remove the oldest session + oldest_session_id = min(sessions.keys(), key=lambda x: sessions[x]["last_used"]) + logger.info( + f"Maximum sessions reached for server {server_key}, removing oldest session {oldest_session_id}" + ) + await self._cleanup_session_by_id(server_key, oldest_session_id) # Create new session - if transport_type == "stdio": - return await self._create_stdio_session(context_id, connection_params) - if transport_type == "sse": - return await self._create_sse_session(context_id, connection_params) - msg = f"Unknown transport type: {transport_type}" - raise ValueError(msg) + session_id = f"{server_key}_{len(sessions)}" + logger.info(f"Creating new session {session_id} for server {server_key}") - async def _create_stdio_session(self, context_id: str, connection_params): + if transport_type == "stdio": + session, task = await self._create_stdio_session(session_id, connection_params) + elif transport_type == "sse": + session, task = await self._create_sse_session(session_id, connection_params) + else: + msg = f"Unknown transport type: {transport_type}" + raise ValueError(msg) + + # Store session info + sessions[session_id] = { + "session": session, + "task": task, + "type": transport_type, + "last_used": asyncio.get_event_loop().time(), + } + + # register mapping & initial ref-count for the new session + self._context_to_session[context_id] = (server_key, session_id) + self._session_refcount[(server_key, session_id)] = 1 + + return session + + async def _create_stdio_session(self, session_id: str, connection_params): """Create a new stdio session as a background task to avoid context issues.""" import asyncio @@ -646,9 +699,7 @@ class MCPSessionManager: try: await event.wait() except asyncio.CancelledError: - # Session is being shut down - msg = "Message is shutting down" - logger.info(msg) + logger.info(f"Session {session_id} is shutting down") except Exception as e: # noqa: BLE001 if not session_future.done(): session_future.set_exception(e) @@ -660,7 +711,7 @@ class MCPSessionManager: # Wait for session to be ready try: - session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation + session = await asyncio.wait_for(session_future, timeout=10.0) except asyncio.TimeoutError as timeout_err: # Clean up the failed task if not task.done(): @@ -670,15 +721,13 @@ class MCPSessionManager: with contextlib.suppress(asyncio.CancelledError): await task self._background_tasks.discard(task) - msg = f"Timeout waiting for STDIO session to initialize for context {context_id}" + msg = f"Timeout waiting for STDIO session {session_id} to initialize" logger.error(msg) raise ValueError(msg) from timeout_err - else: - # Store session info - self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"} - return session - async def _create_sse_session(self, context_id: str, connection_params): + return session, task + + async def _create_sse_session(self, session_id: str, connection_params): """Create a new SSE session as a background task to avoid context issues.""" import asyncio @@ -709,9 +758,7 @@ class MCPSessionManager: try: await event.wait() except asyncio.CancelledError: - # Session is being shut down - msg = "Message is shutting down" - logger.info(msg) + logger.info(f"Session {session_id} is shutting down") except Exception as e: # noqa: BLE001 if not session_future.done(): session_future.set_exception(e) @@ -723,7 +770,7 @@ class MCPSessionManager: # Wait for session to be ready try: - session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation + session = await asyncio.wait_for(session_future, timeout=10.0) except asyncio.TimeoutError as timeout_err: # Clean up the failed task if not task.done(): @@ -733,20 +780,29 @@ class MCPSessionManager: with contextlib.suppress(asyncio.CancelledError): await task self._background_tasks.discard(task) - msg = f"Timeout waiting for SSE session to initialize for context {context_id}" + msg = f"Timeout waiting for SSE session {session_id} to initialize" logger.error(msg) raise ValueError(msg) from timeout_err - else: - # Store session info - self.sessions[context_id] = {"session": session, "task": task, "type": "sse"} - return session - async def _cleanup_session(self, context_id: str): - """Clean up a session by cancelling its background task.""" - if context_id not in self.sessions: + return session, task + + async def _cleanup_session_by_id(self, server_key: str, session_id: str): + """Clean up a specific session by server key and session ID.""" + if server_key not in self.sessions_by_server: return - session_info = self.sessions[context_id] + server_data = self.sessions_by_server[server_key] + # Handle both old and new session structure + if isinstance(server_data, dict) and "sessions" in server_data: + sessions = server_data["sessions"] + else: + # Handle old structure where sessions were stored directly + sessions = server_data + + if session_id not in sessions: + return + + session_info = sessions[session_id] try: # Cancel the background task which will properly close the session if "task" in session_info: @@ -756,19 +812,72 @@ class MCPSessionManager: try: await task except asyncio.CancelledError: - logger.info(f"Issue cancelling task for context_id {context_id}") + logger.info(f"Cancelled task for session {session_id}") except Exception as e: # noqa: BLE001 - logger.info(f"issue cleaning up mcp session: {e}") + logger.warning(f"Error cleaning up session {session_id}: {e}") finally: - del self.sessions[context_id] - # Also clean up server tracking - if context_id in self._last_server_by_session: - del self._last_server_by_session[context_id] + # Remove from sessions dict + del sessions[session_id] async def cleanup_all(self): """Clean up all sessions.""" - for context_id in list(self.sessions.keys()): - await self._cleanup_session(context_id) + # Cancel periodic cleanup task + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._cleanup_task + + # Clean up all sessions + for server_key in list(self.sessions_by_server.keys()): + server_data = self.sessions_by_server[server_key] + # Handle both old and new session structure + if isinstance(server_data, dict) and "sessions" in server_data: + sessions = server_data["sessions"] + else: + # Handle old structure where sessions were stored directly + sessions = server_data + + for session_id in list(sessions.keys()): + await self._cleanup_session_by_id(server_key, session_id) + + # Clear the sessions_by_server structure completely + self.sessions_by_server.clear() + + # Clear compatibility maps + self._context_to_session.clear() + self._session_refcount.clear() + + # Clear all background tasks + for task in list(self._background_tasks): + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + async def _cleanup_session(self, context_id: str): + """Backward-compat cleanup by context_id. + + Decrements the ref-count for the session used by *context_id* and only + tears the session down when the last context that references it goes + away. + """ + mapping = self._context_to_session.get(context_id) + if not mapping: + logger.debug(f"No session mapping found for context_id {context_id}") + return + + server_key, session_id = mapping + ref_key = (server_key, session_id) + remaining = self._session_refcount.get(ref_key, 1) - 1 + + if remaining <= 0: + await self._cleanup_session_by_id(server_key, session_id) + self._session_refcount.pop(ref_key, None) + else: + self._session_refcount[ref_key] = remaining + + # Remove the mapping for this context + self._context_to_session.pop(context_id, None) class MCPStdioClient: @@ -963,11 +1072,15 @@ class MCPStdioClient: async def disconnect(self): """Properly close the connection and clean up resources.""" - # Clean up session using session manager + # For stdio transport, there is no remote session to terminate explicitly + # The session cleanup happens when the background task is cancelled + + # Clean up local session using the session manager if self._session_context: session_manager = self._get_session_manager() await session_manager._cleanup_session(self._session_context) + # Reset local state self.session = None self._connection_params = None self._connected = False @@ -1127,19 +1240,34 @@ class MCPSseClient: # Use cached session manager to get/create persistent session session_manager = self._get_session_manager() - return await session_manager.get_session(self._session_context, self._connection_params, "sse") + # Cache session so we can access server-assigned session_id later for DELETE + self.session = await session_manager.get_session(self._session_context, self._connection_params, "sse") + return self.session - async def disconnect(self): - """Properly close the connection and clean up resources.""" - # Clean up session using session manager - if self._session_context: - session_manager = self._get_session_manager() - await session_manager._cleanup_session(self._session_context) + async def _terminate_remote_session(self) -> None: + """Attempt to explicitly terminate the remote MCP session via HTTP DELETE (best-effort).""" + # Only relevant for SSE transport + if not self._connection_params or "url" not in self._connection_params: + return - self.session = None - self._connection_params = None - self._connected = False - self._session_context = None + url: str = self._connection_params["url"] + + # Retrieve session id from the underlying SDK if exposed + session_id = None + if getattr(self, "session", None) is not None: + # Common attributes in MCP python SDK: `session_id` or `id` + session_id = getattr(self.session, "session_id", None) or getattr(self.session, "id", None) + + headers: dict[str, str] = dict(self._connection_params.get("headers", {})) + if session_id: + headers["Mcp-Session-Id"] = str(session_id) + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + await client.delete(url, headers=headers) + except Exception as e: # noqa: BLE001 + # DELETE is advisory—log and continue + logger.debug(f"Unable to send session DELETE to '{url}': {e}") async def run_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """Run a tool with the given arguments using context-specific session. @@ -1253,6 +1381,22 @@ class MCPSseClient: logger.error(msg) raise ValueError(msg) + async def disconnect(self): + """Properly close the connection and clean up resources.""" + # Attempt best-effort remote session termination first + await self._terminate_remote_session() + + # Clean up local session using the session manager + if self._session_context: + session_manager = self._get_session_manager() + await session_manager._cleanup_session(self._session_context) + + # Reset local state + self.session = None + self._connection_params = None + self._connected = False + self._session_context = None + async def __aenter__(self): return self diff --git a/src/backend/base/langflow/services/settings/base.py b/src/backend/base/langflow/services/settings/base.py index 5cb82c479..cf7668fed 100644 --- a/src/backend/base/langflow/services/settings/base.py +++ b/src/backend/base/langflow/services/settings/base.py @@ -96,6 +96,22 @@ class Settings(BaseSettings): """The number of seconds to wait before giving up on a lock to released or establishing a connection to the database.""" + # --------------------------------------------------------------------- + # MCP Session-manager tuning + # --------------------------------------------------------------------- + mcp_max_sessions_per_server: int = 10 + """Maximum number of MCP sessions to keep per unique server (command/url). + Mirrors the default constant MAX_SESSIONS_PER_SERVER in util.py. Adjust to + control resource usage or concurrency per server.""" + + mcp_session_idle_timeout: int = 400 # seconds + """How long (in seconds) an MCP session can stay idle before the background + cleanup task disposes of it. Defaults to 5 minutes.""" + + mcp_session_cleanup_interval: int = 120 # seconds + """Frequency (in seconds) at which the background cleanup task wakes up to + reap idle sessions.""" + # sqlite configuration sqlite_pragmas: dict | None = {"synchronous": "NORMAL", "journal_mode": "WAL"} """SQLite pragmas to use when connecting to the database.""" diff --git a/src/backend/tests/integration/components/mcp/test_mcp_memory_leak.py b/src/backend/tests/integration/components/mcp/test_mcp_memory_leak.py new file mode 100644 index 000000000..fc92cfb2d --- /dev/null +++ b/src/backend/tests/integration/components/mcp/test_mcp_memory_leak.py @@ -0,0 +1,365 @@ +"""Integration tests for MCP memory leak fix. + +These tests verify that the MCP session manager properly handles session reuse +and cleanup to prevent subprocess leaks. +""" + +import asyncio +import contextlib +import os +import platform +import shutil + +import psutil +import pytest +from langflow.base.mcp.util import MCPSessionManager +from loguru import logger +from mcp import StdioServerParameters + + +@pytest.fixture +def mcp_server_params(): + """Create MCP server parameters for testing.""" + command = ["npx", "-y", "@modelcontextprotocol/server-everything"] + env_data = {"DEBUG": "true", "PATH": os.environ["PATH"]} + + if platform.system() == "Windows": + return StdioServerParameters( + command="cmd", + args=["/c", f"{command[0]} {' '.join(command[1:])}"], + env=env_data, + ) + return StdioServerParameters( + command="bash", + args=["-c", f"exec {' '.join(command)}"], + env=env_data, + ) + + +@pytest.fixture +def process_tracker(): + """Track subprocess count for memory leak detection.""" + process = psutil.Process() + initial_count = len(process.children(recursive=True)) + + yield process, initial_count + + # Cleanup any remaining child processes + try: + for child in process.children(recursive=True): + try: + child.terminate() + child.wait(timeout=3) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + with contextlib.suppress(psutil.NoSuchProcess): + child.kill() + except Exception as e: + logger.exception("Error cleaning up child processes: %s", e) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") +async def test_session_reuse_prevents_subprocess_leak(mcp_server_params, process_tracker): + """Test that session reuse prevents subprocess proliferation.""" + process, initial_count = process_tracker + + session_manager = MCPSessionManager() + + try: + # Create multiple sessions with different context IDs but same server + sessions = [] + for i in range(3): + context_id = f"test_context_{i}" + session = await session_manager.get_session(context_id, mcp_server_params, "stdio") + sessions.append(session) + + # Verify session is working + tools_response = await session.list_tools() + assert len(tools_response.tools) > 0 + + # Check subprocess count after creating sessions + current_count = len(process.children(recursive=True)) + subprocess_increase = current_count - initial_count + + # With the fix, we should have minimal subprocess increase + # (ideally 2 subprocesses max for the MCP server) + assert subprocess_increase <= 4, f"Too many subprocesses created: {subprocess_increase}" + + # Verify all sessions are functional + for session in sessions: + tools_response = await session.list_tools() + assert len(tools_response.tools) > 0 + + finally: + await session_manager.cleanup_all() + await asyncio.sleep(2) # Allow cleanup to complete + + +@pytest.mark.asyncio +@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") +async def test_session_cleanup_removes_subprocesses(mcp_server_params, process_tracker): + """Test that session cleanup properly removes subprocesses.""" + process, initial_count = process_tracker + + session_manager = MCPSessionManager() + + try: + # Create a session + session = await session_manager.get_session("cleanup_test", mcp_server_params, "stdio") + tools_response = await session.list_tools() + assert len(tools_response.tools) > 0 + + # Verify subprocess was created + after_creation_count = len(process.children(recursive=True)) + assert after_creation_count > initial_count + + finally: + # Clean up session + await session_manager.cleanup_all() + await asyncio.sleep(2) # Allow cleanup to complete + + # Verify subprocess was cleaned up + after_cleanup_count = len(process.children(recursive=True)) + # Allow some tolerance for cleanup timing and system processes + assert after_cleanup_count <= initial_count + 1, ( + f"Subprocesses not cleaned up properly: {after_cleanup_count} vs {initial_count}" + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") +async def test_session_health_check_and_recovery(mcp_server_params, process_tracker): + """Test that unhealthy sessions are properly detected and recreated.""" + process, initial_count = process_tracker + + session_manager = MCPSessionManager() + + try: + # Create a session + session1 = await session_manager.get_session("health_test", mcp_server_params, "stdio") + tools_response = await session1.list_tools() + assert len(tools_response.tools) > 0 + + # Simulate session becoming unhealthy by accessing internal state + # This is a bit of a hack but necessary for testing + server_key = session_manager._get_server_key(mcp_server_params, "stdio") + if hasattr(session_manager, "sessions_by_server"): + # For the fixed version + sessions = session_manager.sessions_by_server.get(server_key, {}) + if sessions: + session_id = next(iter(sessions.keys())) + session_info = sessions[session_id] + if "task" in session_info: + task = session_info["task"] + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + elif hasattr(session_manager, "sessions"): + # For the original version + for session_info in session_manager.sessions.values(): + if "task" in session_info: + task = session_info["task"] + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + # Wait a bit for the task to be cancelled + await asyncio.sleep(1) + + # Try to get a session again - should create a new healthy one + session2 = await session_manager.get_session("health_test_2", mcp_server_params, "stdio") + tools_response = await session2.list_tools() + assert len(tools_response.tools) > 0 + + finally: + await session_manager.cleanup_all() + await asyncio.sleep(2) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") +async def test_multiple_servers_isolation(process_tracker): + """Test that different servers get separate sessions.""" + process, initial_count = process_tracker + + session_manager = MCPSessionManager() + + # Create parameters for different servers + server1_params = StdioServerParameters( + command="bash", + args=["-c", "exec npx -y @modelcontextprotocol/server-everything"], + env={"DEBUG": "true", "PATH": os.environ["PATH"]}, + ) + + server2_params = StdioServerParameters( + command="bash", + args=["-c", "exec npx -y @modelcontextprotocol/server-everything"], + env={"DEBUG": "false", "PATH": os.environ["PATH"]}, # Different env + ) + + try: + # Create sessions for different servers + session1 = await session_manager.get_session("server1_test", server1_params, "stdio") + session2 = await session_manager.get_session("server2_test", server2_params, "stdio") + + # Verify both sessions work + tools1 = await session1.list_tools() + tools2 = await session2.list_tools() + + assert len(tools1.tools) > 0 + assert len(tools2.tools) > 0 + + # Sessions should be different objects for different servers (different environments) + # Since the servers have different environments, they should get different server keys + server_key1 = session_manager._get_server_key(server1_params, "stdio") + server_key2 = session_manager._get_server_key(server2_params, "stdio") + assert server_key1 != server_key2, "Different server environments should generate different keys" + assert session1 is not session2 + + finally: + await session_manager.cleanup_all() + await asyncio.sleep(2) + + +@pytest.mark.asyncio +async def test_session_manager_server_key_generation(): + """Test that server key generation works correctly.""" + session_manager = MCPSessionManager() + + # Test stdio server key + stdio_params = StdioServerParameters( + command="test_command", + args=["arg1", "arg2"], + env={"TEST": "value"}, + ) + + key1 = session_manager._get_server_key(stdio_params, "stdio") + key2 = session_manager._get_server_key(stdio_params, "stdio") + + # Same parameters should generate same key + assert key1 == key2 + assert key1.startswith("stdio_") + + # Different parameters should generate different keys + stdio_params2 = StdioServerParameters( + command="different_command", + args=["arg1", "arg2"], + env={"TEST": "value"}, + ) + + key3 = session_manager._get_server_key(stdio_params2, "stdio") + assert key1 != key3 + + # Test SSE server key + sse_params = { + "url": "http://example.com/sse", + "headers": {"Authorization": "Bearer token"}, + "timeout_seconds": 30, + "sse_read_timeout_seconds": 30, + } + + sse_key1 = session_manager._get_server_key(sse_params, "sse") + sse_key2 = session_manager._get_server_key(sse_params, "sse") + + assert sse_key1 == sse_key2 + assert sse_key1.startswith("sse_") + + # Different URL should generate different key + sse_params2 = sse_params.copy() + sse_params2["url"] = "http://different.com/sse" + + sse_key3 = session_manager._get_server_key(sse_params2, "sse") + assert sse_key1 != sse_key3 + + +@pytest.mark.asyncio +async def test_session_manager_connectivity_validation(): + """Test session connectivity validation.""" + session_manager = MCPSessionManager() + + # Mock a session that responds to list_tools + class MockSession: + def __init__(self, should_fail=False): # noqa: FBT002 + self.should_fail = should_fail + + async def list_tools(self): + if self.should_fail: + msg = "Connection failed" + raise Exception(msg) # noqa: TRY002 + + class MockResponse: + def __init__(self): + self.tools = ["tool1", "tool2"] + + return MockResponse() + + # Test healthy session + healthy_session = MockSession(should_fail=False) + is_healthy = await session_manager._validate_session_connectivity(healthy_session) + assert is_healthy is True + + # Test unhealthy session + unhealthy_session = MockSession(should_fail=True) + is_healthy = await session_manager._validate_session_connectivity(unhealthy_session) + assert is_healthy is False + + # Test session that returns None + class MockNoneSession: + async def list_tools(self): + return None + + none_session = MockNoneSession() + is_healthy = await session_manager._validate_session_connectivity(none_session) + assert is_healthy is False + + +@pytest.mark.asyncio +async def test_session_manager_cleanup_all(): + """Test that cleanup_all properly cleans up all sessions.""" + session_manager = MCPSessionManager() + + # Mock some sessions using the correct structure + session_manager.sessions_by_server = { + "server1": { + "sessions": { + "session1": { + "session": "mock_session", + "task": asyncio.create_task(asyncio.sleep(10)), + "type": "stdio", + "last_used": asyncio.get_event_loop().time(), + } + } + }, + "server2": { + "sessions": { + "session2": { + "session": "mock_session", + "task": asyncio.create_task(asyncio.sleep(10)), + "type": "sse", + "last_used": asyncio.get_event_loop().time(), + } + } + }, + } + + # Add some background tasks + task1 = asyncio.create_task(asyncio.sleep(10)) + task2 = asyncio.create_task(asyncio.sleep(10)) + session_manager._background_tasks = {task1, task2} + + # Cleanup all + await session_manager.cleanup_all() + + # Verify cleanup + if hasattr(session_manager, "sessions_by_server"): + # For fixed version + assert len(session_manager.sessions_by_server) == 0 + elif hasattr(session_manager, "sessions"): + # For original version + assert len(session_manager.sessions) == 0 + + # Verify background tasks were cancelled + assert task1.done() + assert task2.done() diff --git a/src/backend/tests/unit/base/mcp/__init__.py b/src/backend/tests/unit/base/mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/tests/unit/base/mcp/test_mcp_util.py b/src/backend/tests/unit/base/mcp/test_mcp_util.py new file mode 100644 index 000000000..6fa047ba7 --- /dev/null +++ b/src/backend/tests/unit/base/mcp/test_mcp_util.py @@ -0,0 +1,806 @@ +"""Unit tests for MCP utility functions. + +This test suite validates the MCP utility functions including: +- Session management +- Header validation and processing +- Utility functions for name sanitization and schema conversion +""" + +import shutil +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langflow.base.mcp import util +from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers + + +class TestMCPSessionManager: + @pytest.fixture + async def session_manager(self): + """Create a session manager and clean it up after the test.""" + manager = MCPSessionManager() + yield manager + # Clean up after test + await manager.cleanup_all() + + async def test_session_caching(self, session_manager): + """Test that sessions are properly cached and reused.""" + context_id = "test_context" + connection_params = MagicMock() + transport_type = "stdio" + + # Create a mock session that will appear healthy + mock_session = AsyncMock() + mock_session._write_stream = MagicMock() + mock_session._write_stream._closed = False + + # Create a mock task that appears to be running + mock_task = AsyncMock() + mock_task.done = MagicMock(return_value=False) + + with ( + patch.object(session_manager, "_create_stdio_session") as mock_create, + patch.object(session_manager, "_validate_session_connectivity", return_value=True), + ): + mock_create.return_value = (mock_session, mock_task) + + # First call should create session + session1 = await session_manager.get_session(context_id, connection_params, transport_type) + + # Second call should return cached session without creating new one + session2 = await session_manager.get_session(context_id, connection_params, transport_type) + + assert session1 == session2 + assert session1 == mock_session + # Should only create once since the second call should use the cached session + mock_create.assert_called_once() + + async def test_session_cleanup(self, session_manager): + """Test session cleanup functionality.""" + context_id = "test_context" + server_key = "test_server" + session_id = "test_session" + + # Add a session to the manager with proper mock setup using new structure + mock_task = AsyncMock() + mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method + mock_task.cancel = MagicMock() # Use MagicMock for sync method + + # Set up the new session structure + session_manager.sessions_by_server[server_key] = { + "sessions": {session_id: {"session": AsyncMock(), "task": mock_task, "type": "stdio", "last_used": 0}}, + "last_cleanup": 0, + } + + # Set up mapping for backwards compatibility + session_manager._context_to_session[context_id] = (server_key, session_id) + + await session_manager._cleanup_session(context_id) + + # Should cancel the task and remove from sessions + mock_task.cancel.assert_called_once() + assert session_id not in session_manager.sessions_by_server[server_key]["sessions"] + + async def test_server_switch_detection(self, session_manager): + """Test that server switches are properly detected and handled.""" + context_id = "test_context" + + # First server + server1_params = MagicMock() + server1_params.command = "server1" + + # Second server + server2_params = MagicMock() + server2_params.command = "server2" + + with ( + patch.object(session_manager, "_create_stdio_session") as mock_create, + patch.object(session_manager, "_validate_session_connectivity", return_value=True), + ): + mock_session1 = AsyncMock() + mock_session2 = AsyncMock() + mock_task1 = AsyncMock() + mock_task2 = AsyncMock() + mock_create.side_effect = [(mock_session1, mock_task1), (mock_session2, mock_task2)] + + # First connection + session1 = await session_manager.get_session(context_id, server1_params, "stdio") + + # Switch to different server should create new session + session2 = await session_manager.get_session(context_id, server2_params, "stdio") + + assert session1 != session2 + assert mock_create.call_count == 2 + + +class TestHeaderValidation: + """Test the header validation functionality.""" + + def test_validate_headers_valid_input(self): + """Test header validation with valid headers.""" + headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"} + + result = validate_headers(headers) + + # Headers should be normalized to lowercase + expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"} + assert result == expected + + def test_validate_headers_empty_input(self): + """Test header validation with empty/None input.""" + assert validate_headers({}) == {} + assert validate_headers(None) == {} + + def test_validate_headers_invalid_names(self): + """Test header validation with invalid header names.""" + headers = { + "Invalid Header": "value", # spaces not allowed + "Header@Name": "value", # @ not allowed + "Header Name": "value", # spaces not allowed + "Valid-Header": "value", # this should pass + } + + result = validate_headers(headers) + + # Only the valid header should remain + assert result == {"valid-header": "value"} + + def test_validate_headers_sanitize_values(self): + """Test header value sanitization.""" + headers = { + "Authorization": "Bearer \x00token\x1f with\r\ninjection", + "Clean-Header": " clean value ", + "Empty-After-Clean": "\x00\x01\x02", + "Tab-Header": "value\twith\ttabs", # tabs should be preserved + } + + result = validate_headers(headers) + + # Control characters should be removed, whitespace trimmed + # Header with injection attempts should be skipped + expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"} + assert result == expected + + def test_validate_headers_non_string_values(self): + """Test header validation with non-string values.""" + headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]} + + result = validate_headers(headers) + + # Only string headers should remain + assert result == {"string-header": "valid"} + + def test_validate_headers_injection_attempts(self): + """Test header validation against injection attempts.""" + headers = { + "Injection1": "value\r\nInjected-Header: malicious", + "Injection2": "value\nX-Evil: attack", + "Safe-Header": "safe-value", + } + + result = validate_headers(headers) + + # Injection attempts should be filtered out + assert result == {"safe-header": "safe-value"} + + +class TestSSEHeaderIntegration: + """Integration test to verify headers are properly passed through the entire SSE flow.""" + + async def test_headers_processing(self): + """Test that headers flow properly from server config through to SSE client connection.""" + # Test the header processing function directly + headers_input = [ + {"key": "Authorization", "value": "Bearer test-token"}, + {"key": "X-API-Key", "value": "secret-key"}, + ] + + expected_headers = { + "authorization": "Bearer test-token", # normalized to lowercase + "x-api-key": "secret-key", + } + + # Test _process_headers function with validation + processed_headers = _process_headers(headers_input) + assert processed_headers == expected_headers + + # Test different input formats + # Test dict input with validation + dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"} + result = _process_headers(dict_headers) + # Invalid header should be filtered out, valid header normalized + assert result == {"authorization": "Bearer dict-token"} + + # Test None input + assert _process_headers(None) == {} + + # Test empty list + assert _process_headers([]) == {} + + # Test malformed list + malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key + assert _process_headers(malformed_headers) == {} + + # Test list with invalid header names + invalid_headers = [ + {"key": "Valid-Header", "value": "good"}, + {"key": "Invalid Header", "value": "bad"}, # spaces not allowed + ] + result = _process_headers(invalid_headers) + assert result == {"valid-header": "good"} + + async def test_sse_client_header_storage(self): + """Test that SSE client properly stores headers in connection params.""" + sse_client = MCPSseClient() + test_url = "http://test.url" + test_headers = {"Authorization": "Bearer test123", "Custom": "value"} + + # Test that headers are properly stored in connection params + # Set connection params as a dict like the implementation expects + sse_client._connection_params = { + "url": test_url, + "headers": test_headers, + "timeout_seconds": 30, + "sse_read_timeout_seconds": 30, + } + + # Verify headers are stored + assert sse_client._connection_params["url"] == test_url + assert sse_client._connection_params["headers"] == test_headers + + +class TestMCPUtilityFunctions: + """Test utility functions from util.py that don't have dedicated test classes.""" + + def test_sanitize_mcp_name(self): + """Test MCP name sanitization.""" + assert util.sanitize_mcp_name("Test Name 123") == "test_name_123" + assert util.sanitize_mcp_name(" ") == "" + assert util.sanitize_mcp_name("123abc") == "_123abc" + assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name" + assert util.sanitize_mcp_name("a" * 100) == "a" * 46 + + def test_get_unique_name(self): + """Test unique name generation.""" + names = {"foo", "foo_1"} + assert util.get_unique_name("foo", 10, names) == "foo_2" + assert util.get_unique_name("bar", 10, names) == "bar" + assert util.get_unique_name("longname", 4, {"long"}) == "lo_1" + + def test_is_valid_key_value_item(self): + """Test key-value item validation.""" + assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True + assert util._is_valid_key_value_item({"key": "a"}) is False + assert util._is_valid_key_value_item(["key", "value"]) is False + assert util._is_valid_key_value_item(None) is False + + def test_validate_node_installation(self): + """Test Node.js installation validation.""" + if shutil.which("node"): + assert util._validate_node_installation("npx something") == "npx something" + else: + with pytest.raises(ValueError, match="Node.js is not installed"): + util._validate_node_installation("npx something") + assert util._validate_node_installation("echo test") == "echo test" + + def test_create_input_schema_from_json_schema(self): + """Test JSON schema to Pydantic model conversion.""" + schema = { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "desc"}, + "bar": {"type": "integer"}, + }, + "required": ["foo"], + } + model_class = util.create_input_schema_from_json_schema(schema) + instance = model_class(foo="abc", bar=1) + assert instance.foo == "abc" + assert instance.bar == 1 + + with pytest.raises(Exception): # noqa: B017, PT011 + model_class(bar=1) # missing required field + + @pytest.mark.asyncio + async def test_validate_connection_params(self): + """Test connection parameter validation.""" + # Valid parameters + await util._validate_connection_params("Stdio", command="echo test") + await util._validate_connection_params("SSE", url="http://test") + + # Invalid parameters + with pytest.raises(ValueError, match="Command is required for Stdio mode"): + await util._validate_connection_params("Stdio", command=None) + with pytest.raises(ValueError, match="URL is required for SSE mode"): + await util._validate_connection_params("SSE", url=None) + with pytest.raises(ValueError, match="Invalid mode"): + await util._validate_connection_params("InvalidMode") + + @pytest.mark.asyncio + async def test_get_flow_snake_case_mocked(self): + """Test flow lookup by snake case name with mocked session.""" + + class DummyFlow: + def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None): + self.name = name + self.user_id = user_id + self.is_component = is_component + self.action_name = action_name + + class DummyExec: + def __init__(self, flows: list[DummyFlow]): + self._flows = flows + + def all(self): + return self._flows + + class DummySession: + def __init__(self, flows: list[DummyFlow]): + self._flows = flows + + async def exec(self, stmt): # noqa: ARG002 + return DummyExec(self._flows) + + user_id = "123e4567-e89b-12d3-a456-426614174000" + flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)] + + # Should match sanitized name + result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows)) + assert result is flows[0] + + # Should return None if not found + result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows)) + assert result is None + + +class TestMCPStdioClientWithEverythingServer: + """Test MCPStdioClient with the Everything MCP server.""" + + @pytest.fixture + def stdio_client(self): + """Create a stdio client for testing.""" + return MCPStdioClient() + + @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") + async def test_connect_to_everything_server(self, stdio_client): + """Test connecting to the Everything MCP server.""" + command = "npx -y @modelcontextprotocol/server-everything" + + try: + # Connect to the server + tools = await stdio_client.connect_to_server(command) + + # Verify tools were returned + assert len(tools) > 0 + + # Find the echo tool + echo_tool = None + for tool in tools: + if hasattr(tool, "name") and tool.name == "echo": + echo_tool = tool + break + + assert echo_tool is not None, "Echo tool not found in server tools" + assert echo_tool.description is not None + + # Verify the echo tool has the expected input schema + assert hasattr(echo_tool, "inputSchema") + assert echo_tool.inputSchema is not None + + finally: + # Clean up the connection + await stdio_client.disconnect() + + @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") + async def test_run_echo_tool(self, stdio_client): + """Test running the echo tool from the Everything server.""" + command = "npx -y @modelcontextprotocol/server-everything" + + try: + # Connect to the server + tools = await stdio_client.connect_to_server(command) + + # Find the echo tool + echo_tool = None + for tool in tools: + if hasattr(tool, "name") and tool.name == "echo": + echo_tool = tool + break + + assert echo_tool is not None, "Echo tool not found" + + # Run the echo tool + test_message = "Hello, MCP!" + result = await stdio_client.run_tool("echo", {"message": test_message}) + + # Verify the result + assert result is not None + assert hasattr(result, "content") + assert len(result.content) > 0 + + # Check that the echo worked - content should contain our message + content_text = str(result.content[0]) + assert test_message in content_text or "Echo:" in content_text + + finally: + await stdio_client.disconnect() + + @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") + async def test_list_all_tools(self, stdio_client): + """Test listing all available tools from the Everything server.""" + command = "npx -y @modelcontextprotocol/server-everything" + + try: + # Connect to the server + tools = await stdio_client.connect_to_server(command) + + # Verify we have multiple tools + assert len(tools) >= 3 # Everything server typically has several tools + + # Check that tools have the expected attributes + for tool in tools: + assert hasattr(tool, "name") + assert hasattr(tool, "description") + assert hasattr(tool, "inputSchema") + assert tool.name is not None + assert len(tool.name) > 0 + + # Common tools that should be available + expected_tools = ["echo"] # Echo is typically available + for expected_tool in expected_tools: + assert any(tool.name == expected_tool for tool in tools), f"Expected tool '{expected_tool}' not found" + + finally: + await stdio_client.disconnect() + + @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") + async def test_session_reuse(self, stdio_client): + """Test that sessions are properly reused.""" + command = "npx -y @modelcontextprotocol/server-everything" + + try: + # Set session context + stdio_client.set_session_context("test_session_reuse") + + # Connect to the server + tools1 = await stdio_client.connect_to_server(command) + + # Connect again - should reuse the session + tools2 = await stdio_client.connect_to_server(command) + + # Should have the same tools + assert len(tools1) == len(tools2) + + # Run a tool to verify the session is working + result = await stdio_client.run_tool("echo", {"message": "Session reuse test"}) + assert result is not None + + finally: + await stdio_client.disconnect() + + +class TestMCPSseClientWithDeepWikiServer: + """Test MCPSseClient with the DeepWiki MCP server.""" + + @pytest.fixture + def sse_client(self): + """Create an SSE client for testing.""" + return MCPSseClient() + + @pytest.mark.asyncio + async def test_connect_to_deepwiki_server(self, sse_client): + """Test connecting to the DeepWiki MCP server.""" + url = "https://mcp.deepwiki.com/sse" + + try: + # Connect to the server + tools = await sse_client.connect_to_server(url) + + # Verify tools were returned + assert len(tools) > 0 + + # Check for expected DeepWiki tools + expected_tools = ["read_wiki_structure", "read_wiki_contents", "ask_question"] + + # Verify we have the expected tools + for expected_tool in expected_tools: + assert any(tool.name == expected_tool for tool in tools), f"Expected tool '{expected_tool}' not found" + + except Exception as e: + # If the server is not accessible, skip the test + pytest.skip(f"DeepWiki server not accessible: {e}") + finally: + await sse_client.disconnect() + + @pytest.mark.asyncio + async def test_run_wiki_structure_tool(self, sse_client): + """Test running the read_wiki_structure tool.""" + url = "https://mcp.deepwiki.com/sse" + + try: + # Connect to the server + tools = await sse_client.connect_to_server(url) + + # Find the read_wiki_structure tool + wiki_tool = None + for tool in tools: + if hasattr(tool, "name") and tool.name == "read_wiki_structure": + wiki_tool = tool + break + + assert wiki_tool is not None, "read_wiki_structure tool not found" + + # Run the tool with a test repository (use repoName as expected by the API) + result = await sse_client.run_tool("read_wiki_structure", {"repoName": "microsoft/vscode"}) + + # Verify the result + assert result is not None + assert hasattr(result, "content") + assert len(result.content) > 0 + + except Exception as e: + # If the server is not accessible or the tool fails, skip the test + pytest.skip(f"DeepWiki server test failed: {e}") + finally: + await sse_client.disconnect() + + @pytest.mark.asyncio + async def test_ask_question_tool(self, sse_client): + """Test running the ask_question tool.""" + url = "https://mcp.deepwiki.com/sse" + + try: + # Connect to the server + tools = await sse_client.connect_to_server(url) + + # Find the ask_question tool + ask_tool = None + for tool in tools: + if hasattr(tool, "name") and tool.name == "ask_question": + ask_tool = tool + break + + assert ask_tool is not None, "ask_question tool not found" + + # Run the tool with a test question (use repoName as expected by the API) + result = await sse_client.run_tool( + "ask_question", {"repoName": "microsoft/vscode", "question": "What is VS Code?"} + ) + + # Verify the result + assert result is not None + assert hasattr(result, "content") + assert len(result.content) > 0 + + except Exception as e: + # If the server is not accessible or the tool fails, skip the test + pytest.skip(f"DeepWiki server test failed: {e}") + finally: + await sse_client.disconnect() + + @pytest.mark.asyncio + async def test_url_validation(self, sse_client): + """Test URL validation for SSE connections.""" + # Test valid URL + valid_url = "https://mcp.deepwiki.com/sse" + is_valid, error = await sse_client.validate_url(valid_url) + assert is_valid or error == "" # Either valid or accessible + + # Test invalid URL + invalid_url = "not_a_url" + is_valid, error = await sse_client.validate_url(invalid_url) + assert not is_valid + assert error != "" + + @pytest.mark.asyncio + async def test_redirect_handling(self, sse_client): + """Test redirect handling for SSE connections.""" + # Test with the DeepWiki URL + url = "https://mcp.deepwiki.com/sse" + + try: + # Check for redirects + final_url = await sse_client.pre_check_redirect(url) + + # Should return a URL (either original or redirected) + assert final_url is not None + assert isinstance(final_url, str) + assert final_url.startswith("http") + + except Exception as e: + # If the server is not accessible, skip the test + pytest.skip(f"DeepWiki server not accessible for redirect test: {e}") + + @pytest.fixture + def mock_tool(self): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + tool.description = "Test tool description" + tool.inputSchema = { + "type": "object", + "properties": {"test_param": {"type": "string", "description": "Test parameter"}}, + "required": ["test_param"], + } + return tool + + @pytest.fixture + def mock_session(self, mock_tool): + """Create a mock ClientSession.""" + session = AsyncMock() + session.initialize = AsyncMock() + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool] + session.list_tools = AsyncMock(return_value=list_tools_result) + session.call_tool = AsyncMock( + return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})]) + ) + return session + + +class TestMCPSseClientUnit: + """Unit tests for MCPSseClient functionality.""" + + @pytest.fixture + def sse_client(self): + return MCPSseClient() + + @pytest.mark.asyncio + async def test_client_initialization(self, sse_client): + """Test that SSE client initializes correctly.""" + # Client should initialize with default values + assert sse_client.session is None + assert sse_client._connection_params is None + assert sse_client._connected is False + assert sse_client._session_context is None + + async def test_validate_url_valid(self, sse_client): + """Test URL validation with valid URL.""" + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.return_value.__aenter__.return_value.get.return_value = mock_response + + is_valid, error_msg = await sse_client.validate_url("http://test.url", {}) + + assert is_valid is True + assert error_msg == "" + + async def test_validate_url_invalid_format(self, sse_client): + """Test URL validation with invalid format.""" + is_valid, error_msg = await sse_client.validate_url("invalid-url", {}) + + assert is_valid is False + assert "Invalid URL format" in error_msg + + async def test_validate_url_with_404_response(self, sse_client): + """Test URL validation with 404 response (should be valid for SSE).""" + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 404 + mock_client.return_value.__aenter__.return_value.get.return_value = mock_response + + is_valid, error_msg = await sse_client.validate_url("http://test.url", {}) + + assert is_valid is True + assert error_msg == "" + + async def test_connect_to_server_with_headers(self, sse_client): + """Test connecting to server via SSE with custom headers.""" + test_url = "http://test.url" + test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"} + expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized + + with ( + patch.object(sse_client, "validate_url", return_value=(True, "")), + patch.object(sse_client, "pre_check_redirect", return_value=test_url), + patch.object(sse_client, "_get_or_create_session") as mock_get_session, + ): + # Mock session + mock_session = AsyncMock() + mock_tool = MagicMock() + mock_tool.name = "test_tool" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_get_session.return_value = mock_session + + tools = await sse_client.connect_to_server(test_url, test_headers) + + assert len(tools) == 1 + assert tools[0].name == "test_tool" + assert sse_client._connected is True + + # Verify headers are stored in connection params (normalized) + assert sse_client._connection_params is not None + assert sse_client._connection_params["headers"] == expected_headers + assert sse_client._connection_params["url"] == test_url + + async def test_headers_passed_to_session_manager(self, sse_client): + """Test that headers are properly passed to the session manager.""" + test_url = "http://test.url" + expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized + + sse_client._session_context = "test_context" + sse_client._connection_params = { + "url": test_url, + "headers": expected_headers, # Use normalized headers + "timeout_seconds": 30, + "sse_read_timeout_seconds": 30, + } + + with patch.object(sse_client, "_get_session_manager") as mock_get_manager: + mock_manager = AsyncMock() + mock_session = AsyncMock() + mock_manager.get_session = AsyncMock(return_value=mock_session) + mock_get_manager.return_value = mock_manager + + result_session = await sse_client._get_or_create_session() + + # Verify session manager was called with correct parameters including normalized headers + mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse") + assert result_session == mock_session + + async def test_pre_check_redirect_with_headers(self, sse_client): + """Test pre-check redirect functionality with custom headers.""" + test_url = "http://test.url" + redirect_url = "http://redirect.url" + # Use pre-validated headers since pre_check_redirect expects already validated headers + test_headers = {"authorization": "Bearer token123"} # already normalized + + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 307 + mock_response.headers.get.return_value = redirect_url + mock_client.return_value.__aenter__.return_value.get.return_value = mock_response + + result = await sse_client.pre_check_redirect(test_url, test_headers) + + assert result == redirect_url + # Verify validated headers were passed to the request + mock_client.return_value.__aenter__.return_value.get.assert_called_with( + test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers} + ) + + async def test_run_tool_with_retry_on_connection_error(self, sse_client): + """Test that run_tool retries on connection errors.""" + # Setup connection state + sse_client._connected = True + sse_client._connection_params = {"url": "http://test.url", "headers": {}} + sse_client._session_context = "test_context" + + call_count = 0 + + async def mock_get_session_side_effect(): + nonlocal call_count + call_count += 1 + session = AsyncMock() + if call_count == 1: + # First call fails with connection error + from anyio import ClosedResourceError + + session.call_tool = AsyncMock(side_effect=ClosedResourceError()) + else: + # Second call succeeds + mock_result = MagicMock() + session.call_tool = AsyncMock(return_value=mock_result) + return session + + with ( + patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect), + patch.object(sse_client, "_get_session_manager") as mock_get_manager, + ): + mock_manager = AsyncMock() + mock_get_manager.return_value = mock_manager + + result = await sse_client.run_tool("test_tool", {"param": "value"}) + + # Should have retried and succeeded on second attempt + assert call_count == 2 + assert result is not None + # Should have cleaned up the failed session + mock_manager._cleanup_session.assert_called_once_with("test_context") diff --git a/src/backend/tests/unit/components/data/test_mcp_component.py b/src/backend/tests/unit/components/data/test_mcp_component.py index 339eec4a6..4b4e2d099 100644 --- a/src/backend/tests/unit/components/data/test_mcp_component.py +++ b/src/backend/tests/unit/components/data/test_mcp_component.py @@ -1,8 +1,15 @@ +"""Unit tests for MCP component with actual MCP servers. + +This test suite validates the MCP component functionality using real MCP servers: +- Everything server (stdio mode) - provides echo and other tools +- DeepWiki server (SSE mode) - provides wiki-related tools +""" + +import shutil from unittest.mock import AsyncMock, MagicMock, patch import pytest -from langflow.base.mcp import util -from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers +from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient from langflow.components.agents.mcp_component import MCPToolsComponent from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping @@ -18,8 +25,11 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient): def default_kwargs(self): """Return the default kwargs for the component.""" return { + "mode": "Stdio", + "command": "npx -y @modelcontextprotocol/server-everything", + "sse_url": "https://mcp.deepwiki.com/sse", + "tool": "echo", "mcp_server": {"name": "test_server", "config": {"command": "uvx mcp-server-fetch"}}, - "tool": "", } @pytest.fixture @@ -27,34 +37,106 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient): """Return the file names mapping for different versions.""" return [] - @pytest.fixture - def mock_tool(self): - """Create a mock MCP tool.""" - tool = MagicMock() - tool.name = "test_tool" - tool.description = "Test tool description" - tool.inputSchema = { - "type": "object", - "properties": {"test_param": {"type": "string", "description": "Test parameter"}}, - "required": ["test_param"], - } - return tool + @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") + async def test_component_initialization(self, component_class, default_kwargs): + """Test that the component initializes correctly.""" + component = component_class(**default_kwargs) + + # Check that the component has the expected attributes + assert hasattr(component, "stdio_client") + assert hasattr(component, "sse_client") + assert isinstance(component.stdio_client, MCPStdioClient) + assert isinstance(component.sse_client, MCPSseClient) + + # Check that the component has a session manager + session_manager = component.stdio_client._get_session_manager() + assert isinstance(session_manager, MCPSessionManager) + + +class TestMCPToolsComponentIntegration: + """Integration tests for the MCPToolsComponent.""" @pytest.fixture - def mock_session(self, mock_tool): - """Create a mock ClientSession.""" - session = AsyncMock() - session.initialize = AsyncMock() - list_tools_result = MagicMock() - list_tools_result.tools = [mock_tool] - session.list_tools = AsyncMock(return_value=list_tools_result) - session.call_tool = AsyncMock( - return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})]) - ) - return session + def component(self): + """Create a component for testing.""" + return MCPToolsComponent() + + @pytest.mark.asyncio + @pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available") + async def test_stdio_mode_integration(self, component): + """Test the component in stdio mode with Everything server.""" + # Configure for stdio mode + component.mode = "Stdio" + component.command = "npx -y @modelcontextprotocol/server-everything" + component.tool = "echo" + + try: + # Mock the update_tool_list method to simulate server connection + tools, server_info = await component.update_tool_list() + + # Should have tools + assert len(tools) > 0 + + # Should have server info + assert server_info is not None + assert isinstance(server_info, dict) + + except Exception as e: + # If the server is not accessible, skip the test + pytest.skip(f"Everything server not accessible: {e}") + + @pytest.mark.asyncio + async def test_sse_mode_integration(self, component): + """Test the component in SSE mode with DeepWiki server.""" + # Configure for SSE mode + component.mode = "SSE" + component.sse_url = "https://mcp.deepwiki.com/sse" + + try: + # Mock the update_tool_list method to simulate server connection + tools, server_info = await component.update_tool_list() + + # Should have tools + assert len(tools) > 0 + + # Should have server info + assert server_info is not None + assert isinstance(server_info, dict) + + except Exception as e: + # If the server is not accessible, skip the test + pytest.skip(f"DeepWiki server not accessible: {e}") + + @pytest.mark.asyncio + async def test_session_context_setting(self, component): + """Test that session context is properly set.""" + # Set session context + component.stdio_client.set_session_context("test_context") + component.sse_client.set_session_context("test_context") + + # Verify context was set + assert component.stdio_client._session_context == "test_context" + assert component.sse_client._session_context == "test_context" + + @pytest.mark.asyncio + async def test_session_manager_sharing(self, component): + """Test that session managers are shared through component cache.""" + # Get session managers + stdio_manager = component.stdio_client._get_session_manager() + sse_manager = component.sse_client._get_session_manager() + + # Both should be MCPSessionManager instances + assert isinstance(stdio_manager, MCPSessionManager) + assert isinstance(sse_manager, MCPSessionManager) + + # They should be the same instance (shared through cache) + assert stdio_manager is sse_manager -class TestMCPStdioClient: +class TestMCPComponentErrorHandling: + """Test error handling in MCP components.""" + @pytest.fixture def stdio_client(self): return MCPStdioClient() @@ -122,494 +204,3 @@ class TestMCPStdioClient: mock_manager._cleanup_session.assert_called_once_with("test_context") assert stdio_client.session is None assert stdio_client._connected is False - - -class TestMCPSseClient: - @pytest.fixture - def sse_client(self): - return MCPSseClient() - - async def test_validate_url_valid(self, sse_client): - """Test URL validation with valid URL.""" - with patch("httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_client.return_value.__aenter__.return_value.get.return_value = mock_response - - is_valid, error_msg = await sse_client.validate_url("http://test.url", {}) - - assert is_valid is True - assert error_msg == "" - - async def test_validate_url_invalid_format(self, sse_client): - """Test URL validation with invalid format.""" - is_valid, error_msg = await sse_client.validate_url("invalid-url", {}) - - assert is_valid is False - assert "Invalid URL format" in error_msg - - async def test_validate_url_with_404_response(self, sse_client): - """Test URL validation with 404 response (should be valid for SSE).""" - with patch("httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 404 - mock_client.return_value.__aenter__.return_value.get.return_value = mock_response - - is_valid, error_msg = await sse_client.validate_url("http://test.url", {}) - - assert is_valid is True - assert error_msg == "" - - async def test_connect_to_server_with_headers(self, sse_client): - """Test connecting to server via SSE with custom headers.""" - test_url = "http://test.url" - test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"} - expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized - - with ( - patch.object(sse_client, "validate_url", return_value=(True, "")), - patch.object(sse_client, "pre_check_redirect", return_value=test_url), - patch.object(sse_client, "_get_or_create_session") as mock_get_session, - ): - # Mock session - mock_session = AsyncMock() - mock_tool = MagicMock() - mock_tool.name = "test_tool" - list_tools_result = MagicMock() - list_tools_result.tools = [mock_tool] - mock_session.list_tools = AsyncMock(return_value=list_tools_result) - mock_get_session.return_value = mock_session - - tools = await sse_client.connect_to_server(test_url, test_headers) - - assert len(tools) == 1 - assert tools[0].name == "test_tool" - assert sse_client._connected is True - - # Verify headers are stored in connection params (normalized) - assert sse_client._connection_params is not None - assert sse_client._connection_params["headers"] == expected_headers - assert sse_client._connection_params["url"] == test_url - - async def test_headers_passed_to_session_manager(self, sse_client): - """Test that headers are properly passed to the session manager.""" - test_url = "http://test.url" - expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized - - sse_client._session_context = "test_context" - sse_client._connection_params = { - "url": test_url, - "headers": expected_headers, # Use normalized headers - "timeout_seconds": 30, - "sse_read_timeout_seconds": 30, - } - - with patch.object(sse_client, "_get_session_manager") as mock_get_manager: - mock_manager = AsyncMock() - mock_session = AsyncMock() - mock_manager.get_session = AsyncMock(return_value=mock_session) - mock_get_manager.return_value = mock_manager - - result_session = await sse_client._get_or_create_session() - - # Verify session manager was called with correct parameters including normalized headers - mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse") - assert result_session == mock_session - - async def test_pre_check_redirect_with_headers(self, sse_client): - """Test pre-check redirect functionality with custom headers.""" - test_url = "http://test.url" - redirect_url = "http://redirect.url" - # Use pre-validated headers since pre_check_redirect expects already validated headers - test_headers = {"authorization": "Bearer token123"} # already normalized - - with patch("httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 307 - mock_response.headers.get.return_value = redirect_url - mock_client.return_value.__aenter__.return_value.get.return_value = mock_response - - result = await sse_client.pre_check_redirect(test_url, test_headers) - - assert result == redirect_url - # Verify validated headers were passed to the request - mock_client.return_value.__aenter__.return_value.get.assert_called_with( - test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers} - ) - - async def test_run_tool_with_retry_on_connection_error(self, sse_client): - """Test that run_tool retries on connection errors.""" - # Setup connection state - sse_client._connected = True - sse_client._connection_params = {"url": "http://test.url", "headers": {}} - sse_client._session_context = "test_context" - - call_count = 0 - - async def mock_get_session_side_effect(): - nonlocal call_count - call_count += 1 - session = AsyncMock() - if call_count == 1: - # First call fails with connection error - from anyio import ClosedResourceError - - session.call_tool = AsyncMock(side_effect=ClosedResourceError()) - else: - # Second call succeeds - mock_result = MagicMock() - session.call_tool = AsyncMock(return_value=mock_result) - return session - - with ( - patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect), - patch.object(sse_client, "_get_session_manager") as mock_get_manager, - ): - mock_manager = AsyncMock() - mock_get_manager.return_value = mock_manager - - result = await sse_client.run_tool("test_tool", {"param": "value"}) - - # Should have retried and succeeded on second attempt - assert call_count == 2 - assert result is not None - # Should have cleaned up the failed session - mock_manager._cleanup_session.assert_called_once_with("test_context") - - -class TestMCPSessionManager: - @pytest.fixture - def session_manager(self): - return MCPSessionManager() - - async def test_session_caching(self, session_manager): - """Test that sessions are properly cached and reused.""" - context_id = "test_context" - connection_params = MagicMock() - transport_type = "stdio" - - # Create a mock session that will appear healthy - mock_session = AsyncMock() - mock_session._write_stream = MagicMock() - mock_session._write_stream._closed = False - - # Create a mock task that appears to be running - mock_task = AsyncMock() - mock_task.done = MagicMock(return_value=False) - - with ( - patch.object(session_manager, "_create_stdio_session") as mock_create, - patch.object(session_manager, "_validate_session_connectivity", return_value=True), - ): - mock_create.return_value = mock_session - - # First call should create session - session1 = await session_manager.get_session(context_id, connection_params, transport_type) - - # Manually populate the sessions cache as if the session was created properly - session_manager.sessions[context_id] = {"session": mock_session, "task": mock_task, "type": transport_type} - - # Second call should return cached session without creating new one - session2 = await session_manager.get_session(context_id, connection_params, transport_type) - - assert session1 == session2 - assert session1 == mock_session - # Should only create once since the second call should use the cached session - mock_create.assert_called_once() - - async def test_session_cleanup(self, session_manager): - """Test session cleanup functionality.""" - context_id = "test_context" - - # Add a session to the manager with proper mock setup - mock_task = AsyncMock() - mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method - mock_task.cancel = MagicMock() # Use MagicMock for sync method - - session_manager.sessions[context_id] = {"session": AsyncMock(), "task": mock_task, "type": "stdio"} - - await session_manager._cleanup_session(context_id) - - # Should cancel the task and remove from sessions - mock_task.cancel.assert_called_once() - assert context_id not in session_manager.sessions - - async def test_server_switch_detection(self, session_manager): - """Test that server switches are properly detected and handled.""" - context_id = "test_context" - - # First server - server1_params = MagicMock() - server1_params.command = "server1" - - # Second server - server2_params = MagicMock() - server2_params.command = "server2" - - with ( - patch.object(session_manager, "_create_stdio_session") as mock_create, - patch.object(session_manager, "_validate_session_connectivity", return_value=True), - ): - mock_session1 = AsyncMock() - mock_session2 = AsyncMock() - mock_create.side_effect = [mock_session1, mock_session2] - - # First connection - session1 = await session_manager.get_session(context_id, server1_params, "stdio") - - # Switch to different server should create new session - session2 = await session_manager.get_session(context_id, server2_params, "stdio") - - assert session1 != session2 - assert mock_create.call_count == 2 - - -# Integration test for header functionality -class TestHeaderValidation: - """Test the header validation functionality.""" - - def test_validate_headers_valid_input(self): - """Test header validation with valid headers.""" - headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"} - - result = validate_headers(headers) - - # Headers should be normalized to lowercase - expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"} - assert result == expected - - def test_validate_headers_empty_input(self): - """Test header validation with empty/None input.""" - assert validate_headers({}) == {} - assert validate_headers(None) == {} - - def test_validate_headers_invalid_names(self): - """Test header validation with invalid header names.""" - headers = { - "Invalid Header": "value", # spaces not allowed - "Header@Name": "value", # @ not allowed - "Header Name": "value", # spaces not allowed - "Valid-Header": "value", # this should pass - } - - result = validate_headers(headers) - - # Only the valid header should remain - assert result == {"valid-header": "value"} - - def test_validate_headers_sanitize_values(self): - """Test header value sanitization.""" - headers = { - "Authorization": "Bearer \x00token\x1f with\r\ninjection", - "Clean-Header": " clean value ", - "Empty-After-Clean": "\x00\x01\x02", - "Tab-Header": "value\twith\ttabs", # tabs should be preserved - } - - result = validate_headers(headers) - - # Control characters should be removed, whitespace trimmed - # Header with injection attempts should be skipped - expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"} - assert result == expected - - def test_validate_headers_non_string_values(self): - """Test header validation with non-string values.""" - headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]} - - result = validate_headers(headers) - - # Only string headers should remain - assert result == {"string-header": "valid"} - - def test_validate_headers_injection_attempts(self): - """Test header validation against injection attempts.""" - headers = { - "Injection1": "value\r\nInjected-Header: malicious", - "Injection2": "value\nX-Evil: attack", - "Safe-Header": "safe-value", - } - - result = validate_headers(headers) - - # Injection attempts should be filtered out - assert result == {"safe-header": "safe-value"} - - -class TestSSEHeaderIntegration: - """Integration test to verify headers are properly passed through the entire SSE flow.""" - - async def test_headers_processing(self): - """Test that headers flow properly from server config through to SSE client connection.""" - # Test the header processing function directly - headers_input = [ - {"key": "Authorization", "value": "Bearer test-token"}, - {"key": "X-API-Key", "value": "secret-key"}, - ] - - expected_headers = { - "authorization": "Bearer test-token", # normalized to lowercase - "x-api-key": "secret-key", - } - - # Test _process_headers function with validation - processed_headers = _process_headers(headers_input) - assert processed_headers == expected_headers - - # Test different input formats - # Test dict input with validation - dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"} - result = _process_headers(dict_headers) - # Invalid header should be filtered out, valid header normalized - assert result == {"authorization": "Bearer dict-token"} - - # Test None input - assert _process_headers(None) == {} - - # Test empty list - assert _process_headers([]) == {} - - # Test malformed list - malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key - assert _process_headers(malformed_headers) == {} - - # Test list with invalid header names - invalid_headers = [ - {"key": "Valid-Header", "value": "good"}, - {"key": "Invalid Header", "value": "bad"}, # spaces not allowed - ] - result = _process_headers(invalid_headers) - assert result == {"valid-header": "good"} - - async def test_sse_client_header_storage(self): - """Test that SSE client properly stores headers in connection params.""" - sse_client = MCPSseClient() - test_url = "http://test.url" - test_headers = {"Authorization": "Bearer test123", "Custom": "value"} - expected_headers = {"authorization": "Bearer test123", "custom": "value"} # normalized - - with ( - patch.object(sse_client, "validate_url", return_value=(True, "")), - patch.object(sse_client, "pre_check_redirect", return_value=test_url), - patch.object(sse_client, "_get_or_create_session") as mock_get_session, - ): - mock_session = AsyncMock() - mock_tool = MagicMock() - mock_tool.name = "test_tool" - list_tools_result = MagicMock() - list_tools_result.tools = [mock_tool] - mock_session.list_tools = AsyncMock(return_value=list_tools_result) - mock_get_session.return_value = mock_session - - await sse_client.connect_to_server(test_url, test_headers) - - # Verify headers are stored correctly in connection params (normalized) - assert sse_client._connection_params is not None - assert sse_client._connection_params["headers"] == expected_headers - assert sse_client._connection_params["url"] == test_url - - -class TestMCPUtilityFunctions: - """Test utility functions from util.py that don't have dedicated test classes.""" - - def test_sanitize_mcp_name(self): - """Test MCP name sanitization.""" - assert util.sanitize_mcp_name("Test Name 123") == "test_name_123" - assert util.sanitize_mcp_name(" ") == "" - assert util.sanitize_mcp_name("123abc") == "_123abc" - assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name" - assert util.sanitize_mcp_name("a" * 100) == "a" * 46 - - def test_get_unique_name(self): - """Test unique name generation.""" - names = {"foo", "foo_1"} - assert util.get_unique_name("foo", 10, names) == "foo_2" - assert util.get_unique_name("bar", 10, names) == "bar" - assert util.get_unique_name("longname", 4, {"long"}) == "lo_1" - - def test_is_valid_key_value_item(self): - """Test key-value item validation.""" - assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True - assert util._is_valid_key_value_item({"key": "a"}) is False - assert util._is_valid_key_value_item(["key", "value"]) is False - assert util._is_valid_key_value_item(None) is False - - def test_validate_node_installation(self): - """Test Node.js installation validation.""" - import shutil - - if shutil.which("node"): - assert util._validate_node_installation("npx something") == "npx something" - else: - with pytest.raises(ValueError, match="Node.js is not installed"): - util._validate_node_installation("npx something") - assert util._validate_node_installation("echo test") == "echo test" - - def test_create_input_schema_from_json_schema(self): - """Test JSON schema to Pydantic model conversion.""" - schema = { - "type": "object", - "properties": { - "foo": {"type": "string", "description": "desc"}, - "bar": {"type": "integer"}, - }, - "required": ["foo"], - } - model_class = util.create_input_schema_from_json_schema(schema) - instance = model_class(foo="abc", bar=1) - assert instance.foo == "abc" - assert instance.bar == 1 - - with pytest.raises(Exception): # noqa: B017, PT011 - model_class(bar=1) # missing required field - - @pytest.mark.asyncio - async def test_validate_connection_params(self): - """Test connection parameter validation.""" - # Valid parameters - await util._validate_connection_params("Stdio", command="echo test") - await util._validate_connection_params("SSE", url="http://test") - - # Invalid parameters - with pytest.raises(ValueError, match="Command is required for Stdio mode"): - await util._validate_connection_params("Stdio", command=None) - with pytest.raises(ValueError, match="URL is required for SSE mode"): - await util._validate_connection_params("SSE", url=None) - with pytest.raises(ValueError, match="Invalid mode"): - await util._validate_connection_params("InvalidMode") - - @pytest.mark.asyncio - async def test_get_flow_snake_case_mocked(self): - """Test flow lookup by snake case name with mocked session.""" - - class DummyFlow: - def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None): - self.name = name - self.user_id = user_id - self.is_component = is_component - self.action_name = action_name - - class DummyExec: - def __init__(self, flows: list[DummyFlow]): - self._flows = flows - - def all(self): - return self._flows - - class DummySession: - def __init__(self, flows: list[DummyFlow]): - self._flows = flows - - async def exec(self, stmt): # noqa: ARG002 - return DummyExec(self._flows) - - user_id = "123e4567-e89b-12d3-a456-426614174000" - flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)] - - # Should match sanitized name - result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows)) - assert result is flows[0] - - # Should return None if not found - result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows)) - assert result is None