refactor(session): migrate to server-based session management and add tests (#9077)
* update MCP Tests * [autofix.ci] apply automated fixes * Update util.py * [autofix.ci] apply automated fixes * Refactor MCP session manager for better configurability and cleanup (#9176) * Add log rotation and header validation features Introduces support for log rotation via the LANGFLOW_LOG_ROTATION environment variable and CLI/config options, with documentation updates. Adds header validation and sanitization for MCP connections, ensuring RFC 7230 compliance and security. Frontend and backend now support passing custom headers for MCP servers. Includes extensive new and updated unit tests for header handling, MCP utilities, and log rotation. * Add unit tests for MCP utilities and update disconnect logic Added comprehensive unit tests for MCP utility functions, session management, header validation, and client classes in test_mcp_util.py. Updated MCPStdioClient and MCPSseClient disconnect methods for clearer session cleanup logic. Refactored test_mcp_component.py to remove redundant and duplicated tests, consolidating coverage in the new test suite. * [autofix.ci] apply automated fixes * Update test_mcp_memory_leak.py * Update util.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
80ebe03d94
commit
b093c1fadb
6 changed files with 1586 additions and 664 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
|
@ -30,6 +31,13 @@ HTTP_NOT_FOUND = 404
|
|||
HTTP_BAD_REQUEST = 400
|
||||
HTTP_INTERNAL_SERVER_ERROR = 500
|
||||
|
||||
# MCP Session Manager constants
|
||||
settings = get_settings_service().settings
|
||||
MAX_SESSIONS_PER_SERVER = (
|
||||
settings.mcp_max_sessions_per_server
|
||||
) # Maximum number of sessions per server to prevent resource exhaustion
|
||||
SESSION_IDLE_TIMEOUT = settings.mcp_session_idle_timeout # 5 minutes idle timeout for sessions
|
||||
SESSION_CLEANUP_INTERVAL = settings.mcp_session_cleanup_interval # Cleanup interval in seconds
|
||||
# RFC 7230 compliant header name pattern: token = 1*tchar
|
||||
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
|
||||
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
|
||||
|
|
@ -460,12 +468,90 @@ async def _validate_connection_params(mode: str, command: str | None = None, url
|
|||
|
||||
|
||||
class MCPSessionManager:
|
||||
"""Manages persistent MCP sessions with proper context manager lifecycle."""
|
||||
"""Manages persistent MCP sessions with proper context manager lifecycle.
|
||||
|
||||
Fixed version that addresses the memory leak issue by:
|
||||
1. Session reuse based on server identity rather than unique context IDs
|
||||
2. Maximum session limits per server to prevent resource exhaustion
|
||||
3. Idle timeout for automatic session cleanup
|
||||
4. Periodic cleanup of stale sessions
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.sessions = {} # context_id -> session_info
|
||||
# Structure: server_key -> {"sessions": {session_id: session_info}, "last_cleanup": timestamp}
|
||||
self.sessions_by_server = {}
|
||||
self._background_tasks = set() # Keep references to background tasks
|
||||
self._last_server_by_session = {} # context_id -> server_name for tracking switches
|
||||
# Backwards-compatibility maps: which context_id uses which (server_key, session_id)
|
||||
self._context_to_session: dict[str, tuple[str, str]] = {}
|
||||
# Reference count for each active (server_key, session_id)
|
||||
self._session_refcount: dict[tuple[str, str], int] = {}
|
||||
self._cleanup_task = None
|
||||
self._start_cleanup_task()
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start the periodic cleanup task."""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
self._background_tasks.add(self._cleanup_task)
|
||||
self._cleanup_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically clean up idle sessions."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(SESSION_CLEANUP_INTERVAL)
|
||||
await self._cleanup_idle_sessions()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except (RuntimeError, KeyError, ClosedResourceError, ValueError, asyncio.TimeoutError) as e:
|
||||
# Handle common recoverable errors without stopping the cleanup loop
|
||||
logger.warning(f"Error in periodic cleanup: {e}")
|
||||
|
||||
async def _cleanup_idle_sessions(self):
|
||||
"""Clean up sessions that have been idle for too long."""
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
servers_to_remove = []
|
||||
|
||||
for server_key, server_data in self.sessions_by_server.items():
|
||||
sessions = server_data.get("sessions", {})
|
||||
sessions_to_remove = []
|
||||
|
||||
for session_id, session_info in sessions.items():
|
||||
if current_time - session_info["last_used"] > SESSION_IDLE_TIMEOUT:
|
||||
sessions_to_remove.append(session_id)
|
||||
|
||||
# Clean up idle sessions
|
||||
for session_id in sessions_to_remove:
|
||||
logger.info(f"Cleaning up idle session {session_id} for server {server_key}")
|
||||
await self._cleanup_session_by_id(server_key, session_id)
|
||||
|
||||
# Remove server entry if no sessions left
|
||||
if not sessions:
|
||||
servers_to_remove.append(server_key)
|
||||
|
||||
# Clean up empty server entries
|
||||
for server_key in servers_to_remove:
|
||||
del self.sessions_by_server[server_key]
|
||||
|
||||
def _get_server_key(self, connection_params, transport_type: str) -> str:
|
||||
"""Generate a consistent server key based on connection parameters."""
|
||||
if transport_type == "stdio":
|
||||
if hasattr(connection_params, "command"):
|
||||
# Include command, args, and environment for uniqueness
|
||||
command_str = f"{connection_params.command} {' '.join(connection_params.args or [])}"
|
||||
env_str = str(sorted((connection_params.env or {}).items()))
|
||||
key_input = f"{command_str}|{env_str}"
|
||||
return f"stdio_{hash(key_input)}"
|
||||
elif transport_type == "sse" and (isinstance(connection_params, dict) and "url" in connection_params):
|
||||
# Include URL and headers for uniqueness
|
||||
url = connection_params["url"]
|
||||
headers = str(sorted((connection_params.get("headers", {})).items()))
|
||||
key_input = f"{url}|{headers}"
|
||||
return f"sse_{hash(key_input)}"
|
||||
|
||||
# Fallback to a generic key
|
||||
# TODO: add option for streamable HTTP in future.
|
||||
return f"{transport_type}_{hash(str(connection_params))}"
|
||||
|
||||
async def _validate_session_connectivity(self, session) -> bool:
|
||||
"""Validate that the session is actually usable by testing a simple operation."""
|
||||
|
|
@ -483,6 +569,7 @@ class MCPSessionManager:
|
|||
"ClosedResourceError" in str(type(e))
|
||||
or "Connection closed" in error_str
|
||||
or "Connection lost" in error_str
|
||||
or "Connection failed" in error_str
|
||||
or "Transport closed" in error_str
|
||||
or "Stream closed" in error_str
|
||||
):
|
||||
|
|
@ -510,117 +597,83 @@ class MCPSessionManager:
|
|||
return True
|
||||
|
||||
async def get_session(self, context_id: str, connection_params, transport_type: str):
|
||||
"""Get or create a persistent session."""
|
||||
# Extract server identifier from connection params for tracking
|
||||
server_identifier = None
|
||||
if transport_type == "stdio" and hasattr(connection_params, "command"):
|
||||
server_identifier = f"stdio_{connection_params.command}"
|
||||
elif transport_type == "sse" and isinstance(connection_params, dict) and "url" in connection_params:
|
||||
server_identifier = f"sse_{connection_params['url']}"
|
||||
"""Get or create a session with improved reuse strategy.
|
||||
|
||||
# Check if we're switching servers for this context
|
||||
server_switched = False
|
||||
if context_id in self._last_server_by_session:
|
||||
last_server = self._last_server_by_session[context_id]
|
||||
if last_server != server_identifier:
|
||||
server_switched = True
|
||||
logger.info(f"Detected server switch for context {context_id}: {last_server} -> {server_identifier}")
|
||||
The key insight is that we should reuse sessions based on the server
|
||||
identity (command + args for stdio, URL for SSE) rather than the context_id.
|
||||
This prevents creating a new subprocess for each unique context.
|
||||
"""
|
||||
server_key = self._get_server_key(connection_params, transport_type)
|
||||
|
||||
# Update server tracking
|
||||
if server_identifier:
|
||||
self._last_server_by_session[context_id] = server_identifier
|
||||
# Ensure server entry exists
|
||||
if server_key not in self.sessions_by_server:
|
||||
self.sessions_by_server[server_key] = {"sessions": {}, "last_cleanup": asyncio.get_event_loop().time()}
|
||||
|
||||
if context_id in self.sessions:
|
||||
session_info = self.sessions[context_id]
|
||||
# Check if session and background task are still alive
|
||||
try:
|
||||
session = session_info["session"]
|
||||
task = session_info["task"]
|
||||
server_data = self.sessions_by_server[server_key]
|
||||
sessions = server_data["sessions"]
|
||||
|
||||
# Break down the health check to understand why cleanup is triggered
|
||||
task_not_done = not task.done()
|
||||
# Try to find a healthy existing session
|
||||
for session_id, session_info in sessions.items():
|
||||
session = session_info["session"]
|
||||
task = session_info["task"]
|
||||
|
||||
# Additional check for stream health
|
||||
stream_is_healthy = True
|
||||
try:
|
||||
# Check if the session's write stream is still open
|
||||
if hasattr(session, "_write_stream"):
|
||||
write_stream = session._write_stream
|
||||
# Check if session is still alive
|
||||
if not task.done():
|
||||
# Update last used time
|
||||
session_info["last_used"] = asyncio.get_event_loop().time()
|
||||
|
||||
# Check for explicit closed state
|
||||
if hasattr(write_stream, "_closed") and write_stream._closed:
|
||||
stream_is_healthy = False
|
||||
# Check anyio stream state for send channels
|
||||
elif hasattr(write_stream, "_state") and hasattr(write_stream._state, "open_send_channels"):
|
||||
# Stream is healthy if there are open send channels
|
||||
stream_is_healthy = write_stream._state.open_send_channels > 0
|
||||
# Check for other stream closed indicators
|
||||
elif hasattr(write_stream, "is_closing") and callable(write_stream.is_closing):
|
||||
stream_is_healthy = not write_stream.is_closing()
|
||||
# If we can't determine state definitively, try a simple write test
|
||||
else:
|
||||
# For streams we can't easily check, assume healthy unless proven otherwise
|
||||
# The actual tool call will reveal if the stream is truly dead
|
||||
stream_is_healthy = True
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
# If we can't check stream health due to missing attributes,
|
||||
# assume it's healthy and let the tool call fail if it's not
|
||||
logger.debug(f"Could not check stream health for context_id {context_id}: {e}")
|
||||
stream_is_healthy = True
|
||||
|
||||
logger.debug(f"Session health check for context_id {context_id}:")
|
||||
logger.debug(f" - task_not_done: {task_not_done}")
|
||||
logger.debug(f" - stream_is_healthy: {stream_is_healthy}")
|
||||
|
||||
# For MCP ClientSession, we need both task and stream to be healthy
|
||||
session_is_healthy = task_not_done and stream_is_healthy
|
||||
|
||||
logger.debug(f" - session_is_healthy: {session_is_healthy}")
|
||||
|
||||
# If we switched servers, always recreate the session to avoid cross-server contamination
|
||||
if server_switched:
|
||||
logger.info(f"Server switch detected for context_id {context_id}, forcing session recreation")
|
||||
session_is_healthy = False
|
||||
|
||||
# Always run connectivity test for sessions to ensure they're truly responsive
|
||||
# This is especially important when switching between servers
|
||||
elif session_is_healthy:
|
||||
logger.debug(f"Running connectivity test for context_id {context_id}")
|
||||
connectivity_ok = await self._validate_session_connectivity(session)
|
||||
logger.debug(f" - connectivity_ok: {connectivity_ok}")
|
||||
if not connectivity_ok:
|
||||
session_is_healthy = False
|
||||
logger.info(
|
||||
f"Session for context_id {context_id} failed connectivity test, marking as unhealthy"
|
||||
)
|
||||
|
||||
if session_is_healthy:
|
||||
logger.debug(f"Session for context_id {context_id} is healthy and responsive, reusing")
|
||||
# Quick health check
|
||||
if await self._validate_session_connectivity(session):
|
||||
logger.debug(f"Reusing existing session {session_id} for server {server_key}")
|
||||
# record mapping & bump ref-count for backwards compatibility
|
||||
self._context_to_session[context_id] = (server_key, session_id)
|
||||
self._session_refcount[(server_key, session_id)] = (
|
||||
self._session_refcount.get((server_key, session_id), 0) + 1
|
||||
)
|
||||
return session
|
||||
logger.info(f"Session {session_id} for server {server_key} failed health check, cleaning up")
|
||||
await self._cleanup_session_by_id(server_key, session_id)
|
||||
else:
|
||||
# Task is done, clean up
|
||||
logger.info(f"Session {session_id} for server {server_key} task is done, cleaning up")
|
||||
await self._cleanup_session_by_id(server_key, session_id)
|
||||
|
||||
if not task_not_done:
|
||||
msg = f"Session for context_id {context_id} failed health check: background task is done"
|
||||
logger.info(msg)
|
||||
elif not stream_is_healthy:
|
||||
msg = f"Session for context_id {context_id} failed health check: stream is closed"
|
||||
logger.info(msg)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
msg = f"Session for context_id {context_id} is dead due to exception: {e}"
|
||||
logger.info(msg)
|
||||
# Session is dead, clean it up
|
||||
await self._cleanup_session(context_id)
|
||||
# Check if we've reached the maximum number of sessions for this server
|
||||
if len(sessions) >= MAX_SESSIONS_PER_SERVER:
|
||||
# Remove the oldest session
|
||||
oldest_session_id = min(sessions.keys(), key=lambda x: sessions[x]["last_used"])
|
||||
logger.info(
|
||||
f"Maximum sessions reached for server {server_key}, removing oldest session {oldest_session_id}"
|
||||
)
|
||||
await self._cleanup_session_by_id(server_key, oldest_session_id)
|
||||
|
||||
# Create new session
|
||||
if transport_type == "stdio":
|
||||
return await self._create_stdio_session(context_id, connection_params)
|
||||
if transport_type == "sse":
|
||||
return await self._create_sse_session(context_id, connection_params)
|
||||
msg = f"Unknown transport type: {transport_type}"
|
||||
raise ValueError(msg)
|
||||
session_id = f"{server_key}_{len(sessions)}"
|
||||
logger.info(f"Creating new session {session_id} for server {server_key}")
|
||||
|
||||
async def _create_stdio_session(self, context_id: str, connection_params):
|
||||
if transport_type == "stdio":
|
||||
session, task = await self._create_stdio_session(session_id, connection_params)
|
||||
elif transport_type == "sse":
|
||||
session, task = await self._create_sse_session(session_id, connection_params)
|
||||
else:
|
||||
msg = f"Unknown transport type: {transport_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Store session info
|
||||
sessions[session_id] = {
|
||||
"session": session,
|
||||
"task": task,
|
||||
"type": transport_type,
|
||||
"last_used": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
# register mapping & initial ref-count for the new session
|
||||
self._context_to_session[context_id] = (server_key, session_id)
|
||||
self._session_refcount[(server_key, session_id)] = 1
|
||||
|
||||
return session
|
||||
|
||||
async def _create_stdio_session(self, session_id: str, connection_params):
|
||||
"""Create a new stdio session as a background task to avoid context issues."""
|
||||
import asyncio
|
||||
|
||||
|
|
@ -646,9 +699,7 @@ class MCPSessionManager:
|
|||
try:
|
||||
await event.wait()
|
||||
except asyncio.CancelledError:
|
||||
# Session is being shut down
|
||||
msg = "Message is shutting down"
|
||||
logger.info(msg)
|
||||
logger.info(f"Session {session_id} is shutting down")
|
||||
except Exception as e: # noqa: BLE001
|
||||
if not session_future.done():
|
||||
session_future.set_exception(e)
|
||||
|
|
@ -660,7 +711,7 @@ class MCPSessionManager:
|
|||
|
||||
# Wait for session to be ready
|
||||
try:
|
||||
session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation
|
||||
session = await asyncio.wait_for(session_future, timeout=10.0)
|
||||
except asyncio.TimeoutError as timeout_err:
|
||||
# Clean up the failed task
|
||||
if not task.done():
|
||||
|
|
@ -670,15 +721,13 @@ class MCPSessionManager:
|
|||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
self._background_tasks.discard(task)
|
||||
msg = f"Timeout waiting for STDIO session to initialize for context {context_id}"
|
||||
msg = f"Timeout waiting for STDIO session {session_id} to initialize"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg) from timeout_err
|
||||
else:
|
||||
# Store session info
|
||||
self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"}
|
||||
return session
|
||||
|
||||
async def _create_sse_session(self, context_id: str, connection_params):
|
||||
return session, task
|
||||
|
||||
async def _create_sse_session(self, session_id: str, connection_params):
|
||||
"""Create a new SSE session as a background task to avoid context issues."""
|
||||
import asyncio
|
||||
|
||||
|
|
@ -709,9 +758,7 @@ class MCPSessionManager:
|
|||
try:
|
||||
await event.wait()
|
||||
except asyncio.CancelledError:
|
||||
# Session is being shut down
|
||||
msg = "Message is shutting down"
|
||||
logger.info(msg)
|
||||
logger.info(f"Session {session_id} is shutting down")
|
||||
except Exception as e: # noqa: BLE001
|
||||
if not session_future.done():
|
||||
session_future.set_exception(e)
|
||||
|
|
@ -723,7 +770,7 @@ class MCPSessionManager:
|
|||
|
||||
# Wait for session to be ready
|
||||
try:
|
||||
session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation
|
||||
session = await asyncio.wait_for(session_future, timeout=10.0)
|
||||
except asyncio.TimeoutError as timeout_err:
|
||||
# Clean up the failed task
|
||||
if not task.done():
|
||||
|
|
@ -733,20 +780,29 @@ class MCPSessionManager:
|
|||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
self._background_tasks.discard(task)
|
||||
msg = f"Timeout waiting for SSE session to initialize for context {context_id}"
|
||||
msg = f"Timeout waiting for SSE session {session_id} to initialize"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg) from timeout_err
|
||||
else:
|
||||
# Store session info
|
||||
self.sessions[context_id] = {"session": session, "task": task, "type": "sse"}
|
||||
return session
|
||||
|
||||
async def _cleanup_session(self, context_id: str):
|
||||
"""Clean up a session by cancelling its background task."""
|
||||
if context_id not in self.sessions:
|
||||
return session, task
|
||||
|
||||
async def _cleanup_session_by_id(self, server_key: str, session_id: str):
|
||||
"""Clean up a specific session by server key and session ID."""
|
||||
if server_key not in self.sessions_by_server:
|
||||
return
|
||||
|
||||
session_info = self.sessions[context_id]
|
||||
server_data = self.sessions_by_server[server_key]
|
||||
# Handle both old and new session structure
|
||||
if isinstance(server_data, dict) and "sessions" in server_data:
|
||||
sessions = server_data["sessions"]
|
||||
else:
|
||||
# Handle old structure where sessions were stored directly
|
||||
sessions = server_data
|
||||
|
||||
if session_id not in sessions:
|
||||
return
|
||||
|
||||
session_info = sessions[session_id]
|
||||
try:
|
||||
# Cancel the background task which will properly close the session
|
||||
if "task" in session_info:
|
||||
|
|
@ -756,19 +812,72 @@ class MCPSessionManager:
|
|||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Issue cancelling task for context_id {context_id}")
|
||||
logger.info(f"Cancelled task for session {session_id}")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.info(f"issue cleaning up mcp session: {e}")
|
||||
logger.warning(f"Error cleaning up session {session_id}: {e}")
|
||||
finally:
|
||||
del self.sessions[context_id]
|
||||
# Also clean up server tracking
|
||||
if context_id in self._last_server_by_session:
|
||||
del self._last_server_by_session[context_id]
|
||||
# Remove from sessions dict
|
||||
del sessions[session_id]
|
||||
|
||||
async def cleanup_all(self):
|
||||
"""Clean up all sessions."""
|
||||
for context_id in list(self.sessions.keys()):
|
||||
await self._cleanup_session(context_id)
|
||||
# Cancel periodic cleanup task
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._cleanup_task
|
||||
|
||||
# Clean up all sessions
|
||||
for server_key in list(self.sessions_by_server.keys()):
|
||||
server_data = self.sessions_by_server[server_key]
|
||||
# Handle both old and new session structure
|
||||
if isinstance(server_data, dict) and "sessions" in server_data:
|
||||
sessions = server_data["sessions"]
|
||||
else:
|
||||
# Handle old structure where sessions were stored directly
|
||||
sessions = server_data
|
||||
|
||||
for session_id in list(sessions.keys()):
|
||||
await self._cleanup_session_by_id(server_key, session_id)
|
||||
|
||||
# Clear the sessions_by_server structure completely
|
||||
self.sessions_by_server.clear()
|
||||
|
||||
# Clear compatibility maps
|
||||
self._context_to_session.clear()
|
||||
self._session_refcount.clear()
|
||||
|
||||
# Clear all background tasks
|
||||
for task in list(self._background_tasks):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
async def _cleanup_session(self, context_id: str):
|
||||
"""Backward-compat cleanup by context_id.
|
||||
|
||||
Decrements the ref-count for the session used by *context_id* and only
|
||||
tears the session down when the last context that references it goes
|
||||
away.
|
||||
"""
|
||||
mapping = self._context_to_session.get(context_id)
|
||||
if not mapping:
|
||||
logger.debug(f"No session mapping found for context_id {context_id}")
|
||||
return
|
||||
|
||||
server_key, session_id = mapping
|
||||
ref_key = (server_key, session_id)
|
||||
remaining = self._session_refcount.get(ref_key, 1) - 1
|
||||
|
||||
if remaining <= 0:
|
||||
await self._cleanup_session_by_id(server_key, session_id)
|
||||
self._session_refcount.pop(ref_key, None)
|
||||
else:
|
||||
self._session_refcount[ref_key] = remaining
|
||||
|
||||
# Remove the mapping for this context
|
||||
self._context_to_session.pop(context_id, None)
|
||||
|
||||
|
||||
class MCPStdioClient:
|
||||
|
|
@ -963,11 +1072,15 @@ class MCPStdioClient:
|
|||
|
||||
async def disconnect(self):
|
||||
"""Properly close the connection and clean up resources."""
|
||||
# Clean up session using session manager
|
||||
# For stdio transport, there is no remote session to terminate explicitly
|
||||
# The session cleanup happens when the background task is cancelled
|
||||
|
||||
# Clean up local session using the session manager
|
||||
if self._session_context:
|
||||
session_manager = self._get_session_manager()
|
||||
await session_manager._cleanup_session(self._session_context)
|
||||
|
||||
# Reset local state
|
||||
self.session = None
|
||||
self._connection_params = None
|
||||
self._connected = False
|
||||
|
|
@ -1127,19 +1240,34 @@ class MCPSseClient:
|
|||
|
||||
# Use cached session manager to get/create persistent session
|
||||
session_manager = self._get_session_manager()
|
||||
return await session_manager.get_session(self._session_context, self._connection_params, "sse")
|
||||
# Cache session so we can access server-assigned session_id later for DELETE
|
||||
self.session = await session_manager.get_session(self._session_context, self._connection_params, "sse")
|
||||
return self.session
|
||||
|
||||
async def disconnect(self):
|
||||
"""Properly close the connection and clean up resources."""
|
||||
# Clean up session using session manager
|
||||
if self._session_context:
|
||||
session_manager = self._get_session_manager()
|
||||
await session_manager._cleanup_session(self._session_context)
|
||||
async def _terminate_remote_session(self) -> None:
|
||||
"""Attempt to explicitly terminate the remote MCP session via HTTP DELETE (best-effort)."""
|
||||
# Only relevant for SSE transport
|
||||
if not self._connection_params or "url" not in self._connection_params:
|
||||
return
|
||||
|
||||
self.session = None
|
||||
self._connection_params = None
|
||||
self._connected = False
|
||||
self._session_context = None
|
||||
url: str = self._connection_params["url"]
|
||||
|
||||
# Retrieve session id from the underlying SDK if exposed
|
||||
session_id = None
|
||||
if getattr(self, "session", None) is not None:
|
||||
# Common attributes in MCP python SDK: `session_id` or `id`
|
||||
session_id = getattr(self.session, "session_id", None) or getattr(self.session, "id", None)
|
||||
|
||||
headers: dict[str, str] = dict(self._connection_params.get("headers", {}))
|
||||
if session_id:
|
||||
headers["Mcp-Session-Id"] = str(session_id)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
await client.delete(url, headers=headers)
|
||||
except Exception as e: # noqa: BLE001
|
||||
# DELETE is advisory—log and continue
|
||||
logger.debug(f"Unable to send session DELETE to '{url}': {e}")
|
||||
|
||||
async def run_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Run a tool with the given arguments using context-specific session.
|
||||
|
|
@ -1253,6 +1381,22 @@ class MCPSseClient:
|
|||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
async def disconnect(self):
|
||||
"""Properly close the connection and clean up resources."""
|
||||
# Attempt best-effort remote session termination first
|
||||
await self._terminate_remote_session()
|
||||
|
||||
# Clean up local session using the session manager
|
||||
if self._session_context:
|
||||
session_manager = self._get_session_manager()
|
||||
await session_manager._cleanup_session(self._session_context)
|
||||
|
||||
# Reset local state
|
||||
self.session = None
|
||||
self._connection_params = None
|
||||
self._connected = False
|
||||
self._session_context = None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -96,6 +96,22 @@ class Settings(BaseSettings):
|
|||
"""The number of seconds to wait before giving up on a lock to released or establishing a connection to the
|
||||
database."""
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# MCP Session-manager tuning
|
||||
# ---------------------------------------------------------------------
|
||||
mcp_max_sessions_per_server: int = 10
|
||||
"""Maximum number of MCP sessions to keep per unique server (command/url).
|
||||
Mirrors the default constant MAX_SESSIONS_PER_SERVER in util.py. Adjust to
|
||||
control resource usage or concurrency per server."""
|
||||
|
||||
mcp_session_idle_timeout: int = 400 # seconds
|
||||
"""How long (in seconds) an MCP session can stay idle before the background
|
||||
cleanup task disposes of it. Defaults to 5 minutes."""
|
||||
|
||||
mcp_session_cleanup_interval: int = 120 # seconds
|
||||
"""Frequency (in seconds) at which the background cleanup task wakes up to
|
||||
reap idle sessions."""
|
||||
|
||||
# sqlite configuration
|
||||
sqlite_pragmas: dict | None = {"synchronous": "NORMAL", "journal_mode": "WAL"}
|
||||
"""SQLite pragmas to use when connecting to the database."""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,365 @@
|
|||
"""Integration tests for MCP memory leak fix.
|
||||
|
||||
These tests verify that the MCP session manager properly handles session reuse
|
||||
and cleanup to prevent subprocess leaks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
from langflow.base.mcp.util import MCPSessionManager
|
||||
from loguru import logger
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server_params():
|
||||
"""Create MCP server parameters for testing."""
|
||||
command = ["npx", "-y", "@modelcontextprotocol/server-everything"]
|
||||
env_data = {"DEBUG": "true", "PATH": os.environ["PATH"]}
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return StdioServerParameters(
|
||||
command="cmd",
|
||||
args=["/c", f"{command[0]} {' '.join(command[1:])}"],
|
||||
env=env_data,
|
||||
)
|
||||
return StdioServerParameters(
|
||||
command="bash",
|
||||
args=["-c", f"exec {' '.join(command)}"],
|
||||
env=env_data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def process_tracker():
|
||||
"""Track subprocess count for memory leak detection."""
|
||||
process = psutil.Process()
|
||||
initial_count = len(process.children(recursive=True))
|
||||
|
||||
yield process, initial_count
|
||||
|
||||
# Cleanup any remaining child processes
|
||||
try:
|
||||
for child in process.children(recursive=True):
|
||||
try:
|
||||
child.terminate()
|
||||
child.wait(timeout=3)
|
||||
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
||||
with contextlib.suppress(psutil.NoSuchProcess):
|
||||
child.kill()
|
||||
except Exception as e:
|
||||
logger.exception("Error cleaning up child processes: %s", e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_session_reuse_prevents_subprocess_leak(mcp_server_params, process_tracker):
|
||||
"""Test that session reuse prevents subprocess proliferation."""
|
||||
process, initial_count = process_tracker
|
||||
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
try:
|
||||
# Create multiple sessions with different context IDs but same server
|
||||
sessions = []
|
||||
for i in range(3):
|
||||
context_id = f"test_context_{i}"
|
||||
session = await session_manager.get_session(context_id, mcp_server_params, "stdio")
|
||||
sessions.append(session)
|
||||
|
||||
# Verify session is working
|
||||
tools_response = await session.list_tools()
|
||||
assert len(tools_response.tools) > 0
|
||||
|
||||
# Check subprocess count after creating sessions
|
||||
current_count = len(process.children(recursive=True))
|
||||
subprocess_increase = current_count - initial_count
|
||||
|
||||
# With the fix, we should have minimal subprocess increase
|
||||
# (ideally 2 subprocesses max for the MCP server)
|
||||
assert subprocess_increase <= 4, f"Too many subprocesses created: {subprocess_increase}"
|
||||
|
||||
# Verify all sessions are functional
|
||||
for session in sessions:
|
||||
tools_response = await session.list_tools()
|
||||
assert len(tools_response.tools) > 0
|
||||
|
||||
finally:
|
||||
await session_manager.cleanup_all()
|
||||
await asyncio.sleep(2) # Allow cleanup to complete
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_session_cleanup_removes_subprocesses(mcp_server_params, process_tracker):
|
||||
"""Test that session cleanup properly removes subprocesses."""
|
||||
process, initial_count = process_tracker
|
||||
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
try:
|
||||
# Create a session
|
||||
session = await session_manager.get_session("cleanup_test", mcp_server_params, "stdio")
|
||||
tools_response = await session.list_tools()
|
||||
assert len(tools_response.tools) > 0
|
||||
|
||||
# Verify subprocess was created
|
||||
after_creation_count = len(process.children(recursive=True))
|
||||
assert after_creation_count > initial_count
|
||||
|
||||
finally:
|
||||
# Clean up session
|
||||
await session_manager.cleanup_all()
|
||||
await asyncio.sleep(2) # Allow cleanup to complete
|
||||
|
||||
# Verify subprocess was cleaned up
|
||||
after_cleanup_count = len(process.children(recursive=True))
|
||||
# Allow some tolerance for cleanup timing and system processes
|
||||
assert after_cleanup_count <= initial_count + 1, (
|
||||
f"Subprocesses not cleaned up properly: {after_cleanup_count} vs {initial_count}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_session_health_check_and_recovery(mcp_server_params, process_tracker):
|
||||
"""Test that unhealthy sessions are properly detected and recreated."""
|
||||
process, initial_count = process_tracker
|
||||
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
try:
|
||||
# Create a session
|
||||
session1 = await session_manager.get_session("health_test", mcp_server_params, "stdio")
|
||||
tools_response = await session1.list_tools()
|
||||
assert len(tools_response.tools) > 0
|
||||
|
||||
# Simulate session becoming unhealthy by accessing internal state
|
||||
# This is a bit of a hack but necessary for testing
|
||||
server_key = session_manager._get_server_key(mcp_server_params, "stdio")
|
||||
if hasattr(session_manager, "sessions_by_server"):
|
||||
# For the fixed version
|
||||
sessions = session_manager.sessions_by_server.get(server_key, {})
|
||||
if sessions:
|
||||
session_id = next(iter(sessions.keys()))
|
||||
session_info = sessions[session_id]
|
||||
if "task" in session_info:
|
||||
task = session_info["task"]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
elif hasattr(session_manager, "sessions"):
|
||||
# For the original version
|
||||
for session_info in session_manager.sessions.values():
|
||||
if "task" in session_info:
|
||||
task = session_info["task"]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
# Wait a bit for the task to be cancelled
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Try to get a session again - should create a new healthy one
|
||||
session2 = await session_manager.get_session("health_test_2", mcp_server_params, "stdio")
|
||||
tools_response = await session2.list_tools()
|
||||
assert len(tools_response.tools) > 0
|
||||
|
||||
finally:
|
||||
await session_manager.cleanup_all()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_multiple_servers_isolation(process_tracker):
|
||||
"""Test that different servers get separate sessions."""
|
||||
process, initial_count = process_tracker
|
||||
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
# Create parameters for different servers
|
||||
server1_params = StdioServerParameters(
|
||||
command="bash",
|
||||
args=["-c", "exec npx -y @modelcontextprotocol/server-everything"],
|
||||
env={"DEBUG": "true", "PATH": os.environ["PATH"]},
|
||||
)
|
||||
|
||||
server2_params = StdioServerParameters(
|
||||
command="bash",
|
||||
args=["-c", "exec npx -y @modelcontextprotocol/server-everything"],
|
||||
env={"DEBUG": "false", "PATH": os.environ["PATH"]}, # Different env
|
||||
)
|
||||
|
||||
try:
|
||||
# Create sessions for different servers
|
||||
session1 = await session_manager.get_session("server1_test", server1_params, "stdio")
|
||||
session2 = await session_manager.get_session("server2_test", server2_params, "stdio")
|
||||
|
||||
# Verify both sessions work
|
||||
tools1 = await session1.list_tools()
|
||||
tools2 = await session2.list_tools()
|
||||
|
||||
assert len(tools1.tools) > 0
|
||||
assert len(tools2.tools) > 0
|
||||
|
||||
# Sessions should be different objects for different servers (different environments)
|
||||
# Since the servers have different environments, they should get different server keys
|
||||
server_key1 = session_manager._get_server_key(server1_params, "stdio")
|
||||
server_key2 = session_manager._get_server_key(server2_params, "stdio")
|
||||
assert server_key1 != server_key2, "Different server environments should generate different keys"
|
||||
assert session1 is not session2
|
||||
|
||||
finally:
|
||||
await session_manager.cleanup_all()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_manager_server_key_generation():
|
||||
"""Test that server key generation works correctly."""
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
# Test stdio server key
|
||||
stdio_params = StdioServerParameters(
|
||||
command="test_command",
|
||||
args=["arg1", "arg2"],
|
||||
env={"TEST": "value"},
|
||||
)
|
||||
|
||||
key1 = session_manager._get_server_key(stdio_params, "stdio")
|
||||
key2 = session_manager._get_server_key(stdio_params, "stdio")
|
||||
|
||||
# Same parameters should generate same key
|
||||
assert key1 == key2
|
||||
assert key1.startswith("stdio_")
|
||||
|
||||
# Different parameters should generate different keys
|
||||
stdio_params2 = StdioServerParameters(
|
||||
command="different_command",
|
||||
args=["arg1", "arg2"],
|
||||
env={"TEST": "value"},
|
||||
)
|
||||
|
||||
key3 = session_manager._get_server_key(stdio_params2, "stdio")
|
||||
assert key1 != key3
|
||||
|
||||
# Test SSE server key
|
||||
sse_params = {
|
||||
"url": "http://example.com/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
"timeout_seconds": 30,
|
||||
"sse_read_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
sse_key1 = session_manager._get_server_key(sse_params, "sse")
|
||||
sse_key2 = session_manager._get_server_key(sse_params, "sse")
|
||||
|
||||
assert sse_key1 == sse_key2
|
||||
assert sse_key1.startswith("sse_")
|
||||
|
||||
# Different URL should generate different key
|
||||
sse_params2 = sse_params.copy()
|
||||
sse_params2["url"] = "http://different.com/sse"
|
||||
|
||||
sse_key3 = session_manager._get_server_key(sse_params2, "sse")
|
||||
assert sse_key1 != sse_key3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_manager_connectivity_validation():
|
||||
"""Test session connectivity validation."""
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
# Mock a session that responds to list_tools
|
||||
class MockSession:
|
||||
def __init__(self, should_fail=False): # noqa: FBT002
|
||||
self.should_fail = should_fail
|
||||
|
||||
async def list_tools(self):
|
||||
if self.should_fail:
|
||||
msg = "Connection failed"
|
||||
raise Exception(msg) # noqa: TRY002
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.tools = ["tool1", "tool2"]
|
||||
|
||||
return MockResponse()
|
||||
|
||||
# Test healthy session
|
||||
healthy_session = MockSession(should_fail=False)
|
||||
is_healthy = await session_manager._validate_session_connectivity(healthy_session)
|
||||
assert is_healthy is True
|
||||
|
||||
# Test unhealthy session
|
||||
unhealthy_session = MockSession(should_fail=True)
|
||||
is_healthy = await session_manager._validate_session_connectivity(unhealthy_session)
|
||||
assert is_healthy is False
|
||||
|
||||
# Test session that returns None
|
||||
class MockNoneSession:
|
||||
async def list_tools(self):
|
||||
return None
|
||||
|
||||
none_session = MockNoneSession()
|
||||
is_healthy = await session_manager._validate_session_connectivity(none_session)
|
||||
assert is_healthy is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_manager_cleanup_all():
|
||||
"""Test that cleanup_all properly cleans up all sessions."""
|
||||
session_manager = MCPSessionManager()
|
||||
|
||||
# Mock some sessions using the correct structure
|
||||
session_manager.sessions_by_server = {
|
||||
"server1": {
|
||||
"sessions": {
|
||||
"session1": {
|
||||
"session": "mock_session",
|
||||
"task": asyncio.create_task(asyncio.sleep(10)),
|
||||
"type": "stdio",
|
||||
"last_used": asyncio.get_event_loop().time(),
|
||||
}
|
||||
}
|
||||
},
|
||||
"server2": {
|
||||
"sessions": {
|
||||
"session2": {
|
||||
"session": "mock_session",
|
||||
"task": asyncio.create_task(asyncio.sleep(10)),
|
||||
"type": "sse",
|
||||
"last_used": asyncio.get_event_loop().time(),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Add some background tasks
|
||||
task1 = asyncio.create_task(asyncio.sleep(10))
|
||||
task2 = asyncio.create_task(asyncio.sleep(10))
|
||||
session_manager._background_tasks = {task1, task2}
|
||||
|
||||
# Cleanup all
|
||||
await session_manager.cleanup_all()
|
||||
|
||||
# Verify cleanup
|
||||
if hasattr(session_manager, "sessions_by_server"):
|
||||
# For fixed version
|
||||
assert len(session_manager.sessions_by_server) == 0
|
||||
elif hasattr(session_manager, "sessions"):
|
||||
# For original version
|
||||
assert len(session_manager.sessions) == 0
|
||||
|
||||
# Verify background tasks were cancelled
|
||||
assert task1.done()
|
||||
assert task2.done()
|
||||
0
src/backend/tests/unit/base/mcp/__init__.py
Normal file
0
src/backend/tests/unit/base/mcp/__init__.py
Normal file
806
src/backend/tests/unit/base/mcp/test_mcp_util.py
Normal file
806
src/backend/tests/unit/base/mcp/test_mcp_util.py
Normal file
|
|
@ -0,0 +1,806 @@
|
|||
"""Unit tests for MCP utility functions.
|
||||
|
||||
This test suite validates the MCP utility functions including:
|
||||
- Session management
|
||||
- Header validation and processing
|
||||
- Utility functions for name sanitization and schema conversion
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.base.mcp import util
|
||||
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers
|
||||
|
||||
|
||||
class TestMCPSessionManager:
|
||||
@pytest.fixture
|
||||
async def session_manager(self):
|
||||
"""Create a session manager and clean it up after the test."""
|
||||
manager = MCPSessionManager()
|
||||
yield manager
|
||||
# Clean up after test
|
||||
await manager.cleanup_all()
|
||||
|
||||
async def test_session_caching(self, session_manager):
|
||||
"""Test that sessions are properly cached and reused."""
|
||||
context_id = "test_context"
|
||||
connection_params = MagicMock()
|
||||
transport_type = "stdio"
|
||||
|
||||
# Create a mock session that will appear healthy
|
||||
mock_session = AsyncMock()
|
||||
mock_session._write_stream = MagicMock()
|
||||
mock_session._write_stream._closed = False
|
||||
|
||||
# Create a mock task that appears to be running
|
||||
mock_task = AsyncMock()
|
||||
mock_task.done = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch.object(session_manager, "_create_stdio_session") as mock_create,
|
||||
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
|
||||
):
|
||||
mock_create.return_value = (mock_session, mock_task)
|
||||
|
||||
# First call should create session
|
||||
session1 = await session_manager.get_session(context_id, connection_params, transport_type)
|
||||
|
||||
# Second call should return cached session without creating new one
|
||||
session2 = await session_manager.get_session(context_id, connection_params, transport_type)
|
||||
|
||||
assert session1 == session2
|
||||
assert session1 == mock_session
|
||||
# Should only create once since the second call should use the cached session
|
||||
mock_create.assert_called_once()
|
||||
|
||||
async def test_session_cleanup(self, session_manager):
|
||||
"""Test session cleanup functionality."""
|
||||
context_id = "test_context"
|
||||
server_key = "test_server"
|
||||
session_id = "test_session"
|
||||
|
||||
# Add a session to the manager with proper mock setup using new structure
|
||||
mock_task = AsyncMock()
|
||||
mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method
|
||||
mock_task.cancel = MagicMock() # Use MagicMock for sync method
|
||||
|
||||
# Set up the new session structure
|
||||
session_manager.sessions_by_server[server_key] = {
|
||||
"sessions": {session_id: {"session": AsyncMock(), "task": mock_task, "type": "stdio", "last_used": 0}},
|
||||
"last_cleanup": 0,
|
||||
}
|
||||
|
||||
# Set up mapping for backwards compatibility
|
||||
session_manager._context_to_session[context_id] = (server_key, session_id)
|
||||
|
||||
await session_manager._cleanup_session(context_id)
|
||||
|
||||
# Should cancel the task and remove from sessions
|
||||
mock_task.cancel.assert_called_once()
|
||||
assert session_id not in session_manager.sessions_by_server[server_key]["sessions"]
|
||||
|
||||
async def test_server_switch_detection(self, session_manager):
|
||||
"""Test that server switches are properly detected and handled."""
|
||||
context_id = "test_context"
|
||||
|
||||
# First server
|
||||
server1_params = MagicMock()
|
||||
server1_params.command = "server1"
|
||||
|
||||
# Second server
|
||||
server2_params = MagicMock()
|
||||
server2_params.command = "server2"
|
||||
|
||||
with (
|
||||
patch.object(session_manager, "_create_stdio_session") as mock_create,
|
||||
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
|
||||
):
|
||||
mock_session1 = AsyncMock()
|
||||
mock_session2 = AsyncMock()
|
||||
mock_task1 = AsyncMock()
|
||||
mock_task2 = AsyncMock()
|
||||
mock_create.side_effect = [(mock_session1, mock_task1), (mock_session2, mock_task2)]
|
||||
|
||||
# First connection
|
||||
session1 = await session_manager.get_session(context_id, server1_params, "stdio")
|
||||
|
||||
# Switch to different server should create new session
|
||||
session2 = await session_manager.get_session(context_id, server2_params, "stdio")
|
||||
|
||||
assert session1 != session2
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
|
||||
class TestHeaderValidation:
|
||||
"""Test the header validation functionality."""
|
||||
|
||||
def test_validate_headers_valid_input(self):
|
||||
"""Test header validation with valid headers."""
|
||||
headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Headers should be normalized to lowercase
|
||||
expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"}
|
||||
assert result == expected
|
||||
|
||||
def test_validate_headers_empty_input(self):
|
||||
"""Test header validation with empty/None input."""
|
||||
assert validate_headers({}) == {}
|
||||
assert validate_headers(None) == {}
|
||||
|
||||
def test_validate_headers_invalid_names(self):
|
||||
"""Test header validation with invalid header names."""
|
||||
headers = {
|
||||
"Invalid Header": "value", # spaces not allowed
|
||||
"Header@Name": "value", # @ not allowed
|
||||
"Header Name": "value", # spaces not allowed
|
||||
"Valid-Header": "value", # this should pass
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Only the valid header should remain
|
||||
assert result == {"valid-header": "value"}
|
||||
|
||||
def test_validate_headers_sanitize_values(self):
|
||||
"""Test header value sanitization."""
|
||||
headers = {
|
||||
"Authorization": "Bearer \x00token\x1f with\r\ninjection",
|
||||
"Clean-Header": " clean value ",
|
||||
"Empty-After-Clean": "\x00\x01\x02",
|
||||
"Tab-Header": "value\twith\ttabs", # tabs should be preserved
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Control characters should be removed, whitespace trimmed
|
||||
# Header with injection attempts should be skipped
|
||||
expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"}
|
||||
assert result == expected
|
||||
|
||||
def test_validate_headers_non_string_values(self):
|
||||
"""Test header validation with non-string values."""
|
||||
headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Only string headers should remain
|
||||
assert result == {"string-header": "valid"}
|
||||
|
||||
def test_validate_headers_injection_attempts(self):
|
||||
"""Test header validation against injection attempts."""
|
||||
headers = {
|
||||
"Injection1": "value\r\nInjected-Header: malicious",
|
||||
"Injection2": "value\nX-Evil: attack",
|
||||
"Safe-Header": "safe-value",
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Injection attempts should be filtered out
|
||||
assert result == {"safe-header": "safe-value"}
|
||||
|
||||
|
||||
class TestSSEHeaderIntegration:
|
||||
"""Integration test to verify headers are properly passed through the entire SSE flow."""
|
||||
|
||||
async def test_headers_processing(self):
|
||||
"""Test that headers flow properly from server config through to SSE client connection."""
|
||||
# Test the header processing function directly
|
||||
headers_input = [
|
||||
{"key": "Authorization", "value": "Bearer test-token"},
|
||||
{"key": "X-API-Key", "value": "secret-key"},
|
||||
]
|
||||
|
||||
expected_headers = {
|
||||
"authorization": "Bearer test-token", # normalized to lowercase
|
||||
"x-api-key": "secret-key",
|
||||
}
|
||||
|
||||
# Test _process_headers function with validation
|
||||
processed_headers = _process_headers(headers_input)
|
||||
assert processed_headers == expected_headers
|
||||
|
||||
# Test different input formats
|
||||
# Test dict input with validation
|
||||
dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"}
|
||||
result = _process_headers(dict_headers)
|
||||
# Invalid header should be filtered out, valid header normalized
|
||||
assert result == {"authorization": "Bearer dict-token"}
|
||||
|
||||
# Test None input
|
||||
assert _process_headers(None) == {}
|
||||
|
||||
# Test empty list
|
||||
assert _process_headers([]) == {}
|
||||
|
||||
# Test malformed list
|
||||
malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key
|
||||
assert _process_headers(malformed_headers) == {}
|
||||
|
||||
# Test list with invalid header names
|
||||
invalid_headers = [
|
||||
{"key": "Valid-Header", "value": "good"},
|
||||
{"key": "Invalid Header", "value": "bad"}, # spaces not allowed
|
||||
]
|
||||
result = _process_headers(invalid_headers)
|
||||
assert result == {"valid-header": "good"}
|
||||
|
||||
async def test_sse_client_header_storage(self):
|
||||
"""Test that SSE client properly stores headers in connection params."""
|
||||
sse_client = MCPSseClient()
|
||||
test_url = "http://test.url"
|
||||
test_headers = {"Authorization": "Bearer test123", "Custom": "value"}
|
||||
|
||||
# Test that headers are properly stored in connection params
|
||||
# Set connection params as a dict like the implementation expects
|
||||
sse_client._connection_params = {
|
||||
"url": test_url,
|
||||
"headers": test_headers,
|
||||
"timeout_seconds": 30,
|
||||
"sse_read_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
# Verify headers are stored
|
||||
assert sse_client._connection_params["url"] == test_url
|
||||
assert sse_client._connection_params["headers"] == test_headers
|
||||
|
||||
|
||||
class TestMCPUtilityFunctions:
|
||||
"""Test utility functions from util.py that don't have dedicated test classes."""
|
||||
|
||||
def test_sanitize_mcp_name(self):
|
||||
"""Test MCP name sanitization."""
|
||||
assert util.sanitize_mcp_name("Test Name 123") == "test_name_123"
|
||||
assert util.sanitize_mcp_name(" ") == ""
|
||||
assert util.sanitize_mcp_name("123abc") == "_123abc"
|
||||
assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name"
|
||||
assert util.sanitize_mcp_name("a" * 100) == "a" * 46
|
||||
|
||||
def test_get_unique_name(self):
|
||||
"""Test unique name generation."""
|
||||
names = {"foo", "foo_1"}
|
||||
assert util.get_unique_name("foo", 10, names) == "foo_2"
|
||||
assert util.get_unique_name("bar", 10, names) == "bar"
|
||||
assert util.get_unique_name("longname", 4, {"long"}) == "lo_1"
|
||||
|
||||
def test_is_valid_key_value_item(self):
|
||||
"""Test key-value item validation."""
|
||||
assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True
|
||||
assert util._is_valid_key_value_item({"key": "a"}) is False
|
||||
assert util._is_valid_key_value_item(["key", "value"]) is False
|
||||
assert util._is_valid_key_value_item(None) is False
|
||||
|
||||
def test_validate_node_installation(self):
|
||||
"""Test Node.js installation validation."""
|
||||
if shutil.which("node"):
|
||||
assert util._validate_node_installation("npx something") == "npx something"
|
||||
else:
|
||||
with pytest.raises(ValueError, match="Node.js is not installed"):
|
||||
util._validate_node_installation("npx something")
|
||||
assert util._validate_node_installation("echo test") == "echo test"
|
||||
|
||||
def test_create_input_schema_from_json_schema(self):
|
||||
"""Test JSON schema to Pydantic model conversion."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string", "description": "desc"},
|
||||
"bar": {"type": "integer"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
}
|
||||
model_class = util.create_input_schema_from_json_schema(schema)
|
||||
instance = model_class(foo="abc", bar=1)
|
||||
assert instance.foo == "abc"
|
||||
assert instance.bar == 1
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017, PT011
|
||||
model_class(bar=1) # missing required field
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_connection_params(self):
|
||||
"""Test connection parameter validation."""
|
||||
# Valid parameters
|
||||
await util._validate_connection_params("Stdio", command="echo test")
|
||||
await util._validate_connection_params("SSE", url="http://test")
|
||||
|
||||
# Invalid parameters
|
||||
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
|
||||
await util._validate_connection_params("Stdio", command=None)
|
||||
with pytest.raises(ValueError, match="URL is required for SSE mode"):
|
||||
await util._validate_connection_params("SSE", url=None)
|
||||
with pytest.raises(ValueError, match="Invalid mode"):
|
||||
await util._validate_connection_params("InvalidMode")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_flow_snake_case_mocked(self):
|
||||
"""Test flow lookup by snake case name with mocked session."""
|
||||
|
||||
class DummyFlow:
|
||||
def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None):
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.is_component = is_component
|
||||
self.action_name = action_name
|
||||
|
||||
class DummyExec:
|
||||
def __init__(self, flows: list[DummyFlow]):
|
||||
self._flows = flows
|
||||
|
||||
def all(self):
|
||||
return self._flows
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, flows: list[DummyFlow]):
|
||||
self._flows = flows
|
||||
|
||||
async def exec(self, stmt): # noqa: ARG002
|
||||
return DummyExec(self._flows)
|
||||
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)]
|
||||
|
||||
# Should match sanitized name
|
||||
result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows))
|
||||
assert result is flows[0]
|
||||
|
||||
# Should return None if not found
|
||||
result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMCPStdioClientWithEverythingServer:
|
||||
"""Test MCPStdioClient with the Everything MCP server."""
|
||||
|
||||
@pytest.fixture
|
||||
def stdio_client(self):
|
||||
"""Create a stdio client for testing."""
|
||||
return MCPStdioClient()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_connect_to_everything_server(self, stdio_client):
|
||||
"""Test connecting to the Everything MCP server."""
|
||||
command = "npx -y @modelcontextprotocol/server-everything"
|
||||
|
||||
try:
|
||||
# Connect to the server
|
||||
tools = await stdio_client.connect_to_server(command)
|
||||
|
||||
# Verify tools were returned
|
||||
assert len(tools) > 0
|
||||
|
||||
# Find the echo tool
|
||||
echo_tool = None
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and tool.name == "echo":
|
||||
echo_tool = tool
|
||||
break
|
||||
|
||||
assert echo_tool is not None, "Echo tool not found in server tools"
|
||||
assert echo_tool.description is not None
|
||||
|
||||
# Verify the echo tool has the expected input schema
|
||||
assert hasattr(echo_tool, "inputSchema")
|
||||
assert echo_tool.inputSchema is not None
|
||||
|
||||
finally:
|
||||
# Clean up the connection
|
||||
await stdio_client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_run_echo_tool(self, stdio_client):
|
||||
"""Test running the echo tool from the Everything server."""
|
||||
command = "npx -y @modelcontextprotocol/server-everything"
|
||||
|
||||
try:
|
||||
# Connect to the server
|
||||
tools = await stdio_client.connect_to_server(command)
|
||||
|
||||
# Find the echo tool
|
||||
echo_tool = None
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and tool.name == "echo":
|
||||
echo_tool = tool
|
||||
break
|
||||
|
||||
assert echo_tool is not None, "Echo tool not found"
|
||||
|
||||
# Run the echo tool
|
||||
test_message = "Hello, MCP!"
|
||||
result = await stdio_client.run_tool("echo", {"message": test_message})
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert len(result.content) > 0
|
||||
|
||||
# Check that the echo worked - content should contain our message
|
||||
content_text = str(result.content[0])
|
||||
assert test_message in content_text or "Echo:" in content_text
|
||||
|
||||
finally:
|
||||
await stdio_client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_list_all_tools(self, stdio_client):
|
||||
"""Test listing all available tools from the Everything server."""
|
||||
command = "npx -y @modelcontextprotocol/server-everything"
|
||||
|
||||
try:
|
||||
# Connect to the server
|
||||
tools = await stdio_client.connect_to_server(command)
|
||||
|
||||
# Verify we have multiple tools
|
||||
assert len(tools) >= 3 # Everything server typically has several tools
|
||||
|
||||
# Check that tools have the expected attributes
|
||||
for tool in tools:
|
||||
assert hasattr(tool, "name")
|
||||
assert hasattr(tool, "description")
|
||||
assert hasattr(tool, "inputSchema")
|
||||
assert tool.name is not None
|
||||
assert len(tool.name) > 0
|
||||
|
||||
# Common tools that should be available
|
||||
expected_tools = ["echo"] # Echo is typically available
|
||||
for expected_tool in expected_tools:
|
||||
assert any(tool.name == expected_tool for tool in tools), f"Expected tool '{expected_tool}' not found"
|
||||
|
||||
finally:
|
||||
await stdio_client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_session_reuse(self, stdio_client):
|
||||
"""Test that sessions are properly reused."""
|
||||
command = "npx -y @modelcontextprotocol/server-everything"
|
||||
|
||||
try:
|
||||
# Set session context
|
||||
stdio_client.set_session_context("test_session_reuse")
|
||||
|
||||
# Connect to the server
|
||||
tools1 = await stdio_client.connect_to_server(command)
|
||||
|
||||
# Connect again - should reuse the session
|
||||
tools2 = await stdio_client.connect_to_server(command)
|
||||
|
||||
# Should have the same tools
|
||||
assert len(tools1) == len(tools2)
|
||||
|
||||
# Run a tool to verify the session is working
|
||||
result = await stdio_client.run_tool("echo", {"message": "Session reuse test"})
|
||||
assert result is not None
|
||||
|
||||
finally:
|
||||
await stdio_client.disconnect()
|
||||
|
||||
|
||||
class TestMCPSseClientWithDeepWikiServer:
|
||||
"""Test MCPSseClient with the DeepWiki MCP server."""
|
||||
|
||||
@pytest.fixture
|
||||
def sse_client(self):
|
||||
"""Create an SSE client for testing."""
|
||||
return MCPSseClient()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_to_deepwiki_server(self, sse_client):
|
||||
"""Test connecting to the DeepWiki MCP server."""
|
||||
url = "https://mcp.deepwiki.com/sse"
|
||||
|
||||
try:
|
||||
# Connect to the server
|
||||
tools = await sse_client.connect_to_server(url)
|
||||
|
||||
# Verify tools were returned
|
||||
assert len(tools) > 0
|
||||
|
||||
# Check for expected DeepWiki tools
|
||||
expected_tools = ["read_wiki_structure", "read_wiki_contents", "ask_question"]
|
||||
|
||||
# Verify we have the expected tools
|
||||
for expected_tool in expected_tools:
|
||||
assert any(tool.name == expected_tool for tool in tools), f"Expected tool '{expected_tool}' not found"
|
||||
|
||||
except Exception as e:
|
||||
# If the server is not accessible, skip the test
|
||||
pytest.skip(f"DeepWiki server not accessible: {e}")
|
||||
finally:
|
||||
await sse_client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_wiki_structure_tool(self, sse_client):
|
||||
"""Test running the read_wiki_structure tool."""
|
||||
url = "https://mcp.deepwiki.com/sse"
|
||||
|
||||
try:
|
||||
# Connect to the server
|
||||
tools = await sse_client.connect_to_server(url)
|
||||
|
||||
# Find the read_wiki_structure tool
|
||||
wiki_tool = None
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and tool.name == "read_wiki_structure":
|
||||
wiki_tool = tool
|
||||
break
|
||||
|
||||
assert wiki_tool is not None, "read_wiki_structure tool not found"
|
||||
|
||||
# Run the tool with a test repository (use repoName as expected by the API)
|
||||
result = await sse_client.run_tool("read_wiki_structure", {"repoName": "microsoft/vscode"})
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert len(result.content) > 0
|
||||
|
||||
except Exception as e:
|
||||
# If the server is not accessible or the tool fails, skip the test
|
||||
pytest.skip(f"DeepWiki server test failed: {e}")
|
||||
finally:
|
||||
await sse_client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_question_tool(self, sse_client):
|
||||
"""Test running the ask_question tool."""
|
||||
url = "https://mcp.deepwiki.com/sse"
|
||||
|
||||
try:
|
||||
# Connect to the server
|
||||
tools = await sse_client.connect_to_server(url)
|
||||
|
||||
# Find the ask_question tool
|
||||
ask_tool = None
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and tool.name == "ask_question":
|
||||
ask_tool = tool
|
||||
break
|
||||
|
||||
assert ask_tool is not None, "ask_question tool not found"
|
||||
|
||||
# Run the tool with a test question (use repoName as expected by the API)
|
||||
result = await sse_client.run_tool(
|
||||
"ask_question", {"repoName": "microsoft/vscode", "question": "What is VS Code?"}
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert len(result.content) > 0
|
||||
|
||||
except Exception as e:
|
||||
# If the server is not accessible or the tool fails, skip the test
|
||||
pytest.skip(f"DeepWiki server test failed: {e}")
|
||||
finally:
|
||||
await sse_client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_url_validation(self, sse_client):
|
||||
"""Test URL validation for SSE connections."""
|
||||
# Test valid URL
|
||||
valid_url = "https://mcp.deepwiki.com/sse"
|
||||
is_valid, error = await sse_client.validate_url(valid_url)
|
||||
assert is_valid or error == "" # Either valid or accessible
|
||||
|
||||
# Test invalid URL
|
||||
invalid_url = "not_a_url"
|
||||
is_valid, error = await sse_client.validate_url(invalid_url)
|
||||
assert not is_valid
|
||||
assert error != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_handling(self, sse_client):
|
||||
"""Test redirect handling for SSE connections."""
|
||||
# Test with the DeepWiki URL
|
||||
url = "https://mcp.deepwiki.com/sse"
|
||||
|
||||
try:
|
||||
# Check for redirects
|
||||
final_url = await sse_client.pre_check_redirect(url)
|
||||
|
||||
# Should return a URL (either original or redirected)
|
||||
assert final_url is not None
|
||||
assert isinstance(final_url, str)
|
||||
assert final_url.startswith("http")
|
||||
|
||||
except Exception as e:
|
||||
# If the server is not accessible, skip the test
|
||||
pytest.skip(f"DeepWiki server not accessible for redirect test: {e}")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self):
|
||||
"""Create a mock MCP tool."""
|
||||
tool = MagicMock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
tool.inputSchema = {
|
||||
"type": "object",
|
||||
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
|
||||
"required": ["test_param"],
|
||||
}
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self, mock_tool):
|
||||
"""Create a mock ClientSession."""
|
||||
session = AsyncMock()
|
||||
session.initialize = AsyncMock()
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})])
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
class TestMCPSseClientUnit:
|
||||
"""Unit tests for MCPSseClient functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def sse_client(self):
|
||||
return MCPSseClient()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_initialization(self, sse_client):
|
||||
"""Test that SSE client initializes correctly."""
|
||||
# Client should initialize with default values
|
||||
assert sse_client.session is None
|
||||
assert sse_client._connection_params is None
|
||||
assert sse_client._connected is False
|
||||
assert sse_client._session_context is None
|
||||
|
||||
async def test_validate_url_valid(self, sse_client):
|
||||
"""Test URL validation with valid URL."""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg == ""
|
||||
|
||||
async def test_validate_url_invalid_format(self, sse_client):
|
||||
"""Test URL validation with invalid format."""
|
||||
is_valid, error_msg = await sse_client.validate_url("invalid-url", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert "Invalid URL format" in error_msg
|
||||
|
||||
async def test_validate_url_with_404_response(self, sse_client):
|
||||
"""Test URL validation with 404 response (should be valid for SSE)."""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg == ""
|
||||
|
||||
async def test_connect_to_server_with_headers(self, sse_client):
|
||||
"""Test connecting to server via SSE with custom headers."""
|
||||
test_url = "http://test.url"
|
||||
test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"}
|
||||
expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")),
|
||||
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
|
||||
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
|
||||
):
|
||||
# Mock session
|
||||
mock_session = AsyncMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
tools = await sse_client.connect_to_server(test_url, test_headers)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "test_tool"
|
||||
assert sse_client._connected is True
|
||||
|
||||
# Verify headers are stored in connection params (normalized)
|
||||
assert sse_client._connection_params is not None
|
||||
assert sse_client._connection_params["headers"] == expected_headers
|
||||
assert sse_client._connection_params["url"] == test_url
|
||||
|
||||
async def test_headers_passed_to_session_manager(self, sse_client):
|
||||
"""Test that headers are properly passed to the session manager."""
|
||||
test_url = "http://test.url"
|
||||
expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized
|
||||
|
||||
sse_client._session_context = "test_context"
|
||||
sse_client._connection_params = {
|
||||
"url": test_url,
|
||||
"headers": expected_headers, # Use normalized headers
|
||||
"timeout_seconds": 30,
|
||||
"sse_read_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
with patch.object(sse_client, "_get_session_manager") as mock_get_manager:
|
||||
mock_manager = AsyncMock()
|
||||
mock_session = AsyncMock()
|
||||
mock_manager.get_session = AsyncMock(return_value=mock_session)
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
result_session = await sse_client._get_or_create_session()
|
||||
|
||||
# Verify session manager was called with correct parameters including normalized headers
|
||||
mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse")
|
||||
assert result_session == mock_session
|
||||
|
||||
async def test_pre_check_redirect_with_headers(self, sse_client):
|
||||
"""Test pre-check redirect functionality with custom headers."""
|
||||
test_url = "http://test.url"
|
||||
redirect_url = "http://redirect.url"
|
||||
# Use pre-validated headers since pre_check_redirect expects already validated headers
|
||||
test_headers = {"authorization": "Bearer token123"} # already normalized
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 307
|
||||
mock_response.headers.get.return_value = redirect_url
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await sse_client.pre_check_redirect(test_url, test_headers)
|
||||
|
||||
assert result == redirect_url
|
||||
# Verify validated headers were passed to the request
|
||||
mock_client.return_value.__aenter__.return_value.get.assert_called_with(
|
||||
test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers}
|
||||
)
|
||||
|
||||
async def test_run_tool_with_retry_on_connection_error(self, sse_client):
|
||||
"""Test that run_tool retries on connection errors."""
|
||||
# Setup connection state
|
||||
sse_client._connected = True
|
||||
sse_client._connection_params = {"url": "http://test.url", "headers": {}}
|
||||
sse_client._session_context = "test_context"
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_session_side_effect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
session = AsyncMock()
|
||||
if call_count == 1:
|
||||
# First call fails with connection error
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
session.call_tool = AsyncMock(side_effect=ClosedResourceError())
|
||||
else:
|
||||
# Second call succeeds
|
||||
mock_result = MagicMock()
|
||||
session.call_tool = AsyncMock(return_value=mock_result)
|
||||
return session
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect),
|
||||
patch.object(sse_client, "_get_session_manager") as mock_get_manager,
|
||||
):
|
||||
mock_manager = AsyncMock()
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
result = await sse_client.run_tool("test_tool", {"param": "value"})
|
||||
|
||||
# Should have retried and succeeded on second attempt
|
||||
assert call_count == 2
|
||||
assert result is not None
|
||||
# Should have cleaned up the failed session
|
||||
mock_manager._cleanup_session.assert_called_once_with("test_context")
|
||||
|
|
@ -1,8 +1,15 @@
|
|||
"""Unit tests for MCP component with actual MCP servers.
|
||||
|
||||
This test suite validates the MCP component functionality using real MCP servers:
|
||||
- Everything server (stdio mode) - provides echo and other tools
|
||||
- DeepWiki server (SSE mode) - provides wiki-related tools
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.base.mcp import util
|
||||
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers
|
||||
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient
|
||||
from langflow.components.agents.mcp_component import MCPToolsComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
|
||||
|
|
@ -18,8 +25,11 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
|||
def default_kwargs(self):
|
||||
"""Return the default kwargs for the component."""
|
||||
return {
|
||||
"mode": "Stdio",
|
||||
"command": "npx -y @modelcontextprotocol/server-everything",
|
||||
"sse_url": "https://mcp.deepwiki.com/sse",
|
||||
"tool": "echo",
|
||||
"mcp_server": {"name": "test_server", "config": {"command": "uvx mcp-server-fetch"}},
|
||||
"tool": "",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -27,34 +37,106 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
|||
"""Return the file names mapping for different versions."""
|
||||
return []
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self):
|
||||
"""Create a mock MCP tool."""
|
||||
tool = MagicMock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
tool.inputSchema = {
|
||||
"type": "object",
|
||||
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
|
||||
"required": ["test_param"],
|
||||
}
|
||||
return tool
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_component_initialization(self, component_class, default_kwargs):
|
||||
"""Test that the component initializes correctly."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Check that the component has the expected attributes
|
||||
assert hasattr(component, "stdio_client")
|
||||
assert hasattr(component, "sse_client")
|
||||
assert isinstance(component.stdio_client, MCPStdioClient)
|
||||
assert isinstance(component.sse_client, MCPSseClient)
|
||||
|
||||
# Check that the component has a session manager
|
||||
session_manager = component.stdio_client._get_session_manager()
|
||||
assert isinstance(session_manager, MCPSessionManager)
|
||||
|
||||
|
||||
class TestMCPToolsComponentIntegration:
|
||||
"""Integration tests for the MCPToolsComponent."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self, mock_tool):
|
||||
"""Create a mock ClientSession."""
|
||||
session = AsyncMock()
|
||||
session.initialize = AsyncMock()
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})])
|
||||
)
|
||||
return session
|
||||
def component(self):
|
||||
"""Create a component for testing."""
|
||||
return MCPToolsComponent()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
|
||||
async def test_stdio_mode_integration(self, component):
|
||||
"""Test the component in stdio mode with Everything server."""
|
||||
# Configure for stdio mode
|
||||
component.mode = "Stdio"
|
||||
component.command = "npx -y @modelcontextprotocol/server-everything"
|
||||
component.tool = "echo"
|
||||
|
||||
try:
|
||||
# Mock the update_tool_list method to simulate server connection
|
||||
tools, server_info = await component.update_tool_list()
|
||||
|
||||
# Should have tools
|
||||
assert len(tools) > 0
|
||||
|
||||
# Should have server info
|
||||
assert server_info is not None
|
||||
assert isinstance(server_info, dict)
|
||||
|
||||
except Exception as e:
|
||||
# If the server is not accessible, skip the test
|
||||
pytest.skip(f"Everything server not accessible: {e}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_mode_integration(self, component):
|
||||
"""Test the component in SSE mode with DeepWiki server."""
|
||||
# Configure for SSE mode
|
||||
component.mode = "SSE"
|
||||
component.sse_url = "https://mcp.deepwiki.com/sse"
|
||||
|
||||
try:
|
||||
# Mock the update_tool_list method to simulate server connection
|
||||
tools, server_info = await component.update_tool_list()
|
||||
|
||||
# Should have tools
|
||||
assert len(tools) > 0
|
||||
|
||||
# Should have server info
|
||||
assert server_info is not None
|
||||
assert isinstance(server_info, dict)
|
||||
|
||||
except Exception as e:
|
||||
# If the server is not accessible, skip the test
|
||||
pytest.skip(f"DeepWiki server not accessible: {e}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_context_setting(self, component):
|
||||
"""Test that session context is properly set."""
|
||||
# Set session context
|
||||
component.stdio_client.set_session_context("test_context")
|
||||
component.sse_client.set_session_context("test_context")
|
||||
|
||||
# Verify context was set
|
||||
assert component.stdio_client._session_context == "test_context"
|
||||
assert component.sse_client._session_context == "test_context"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_manager_sharing(self, component):
|
||||
"""Test that session managers are shared through component cache."""
|
||||
# Get session managers
|
||||
stdio_manager = component.stdio_client._get_session_manager()
|
||||
sse_manager = component.sse_client._get_session_manager()
|
||||
|
||||
# Both should be MCPSessionManager instances
|
||||
assert isinstance(stdio_manager, MCPSessionManager)
|
||||
assert isinstance(sse_manager, MCPSessionManager)
|
||||
|
||||
# They should be the same instance (shared through cache)
|
||||
assert stdio_manager is sse_manager
|
||||
|
||||
|
||||
class TestMCPStdioClient:
|
||||
class TestMCPComponentErrorHandling:
|
||||
"""Test error handling in MCP components."""
|
||||
|
||||
@pytest.fixture
|
||||
def stdio_client(self):
|
||||
return MCPStdioClient()
|
||||
|
|
@ -122,494 +204,3 @@ class TestMCPStdioClient:
|
|||
mock_manager._cleanup_session.assert_called_once_with("test_context")
|
||||
assert stdio_client.session is None
|
||||
assert stdio_client._connected is False
|
||||
|
||||
|
||||
class TestMCPSseClient:
|
||||
@pytest.fixture
|
||||
def sse_client(self):
|
||||
return MCPSseClient()
|
||||
|
||||
async def test_validate_url_valid(self, sse_client):
|
||||
"""Test URL validation with valid URL."""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg == ""
|
||||
|
||||
async def test_validate_url_invalid_format(self, sse_client):
|
||||
"""Test URL validation with invalid format."""
|
||||
is_valid, error_msg = await sse_client.validate_url("invalid-url", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert "Invalid URL format" in error_msg
|
||||
|
||||
async def test_validate_url_with_404_response(self, sse_client):
|
||||
"""Test URL validation with 404 response (should be valid for SSE)."""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg == ""
|
||||
|
||||
async def test_connect_to_server_with_headers(self, sse_client):
|
||||
"""Test connecting to server via SSE with custom headers."""
|
||||
test_url = "http://test.url"
|
||||
test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"}
|
||||
expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")),
|
||||
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
|
||||
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
|
||||
):
|
||||
# Mock session
|
||||
mock_session = AsyncMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
tools = await sse_client.connect_to_server(test_url, test_headers)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "test_tool"
|
||||
assert sse_client._connected is True
|
||||
|
||||
# Verify headers are stored in connection params (normalized)
|
||||
assert sse_client._connection_params is not None
|
||||
assert sse_client._connection_params["headers"] == expected_headers
|
||||
assert sse_client._connection_params["url"] == test_url
|
||||
|
||||
async def test_headers_passed_to_session_manager(self, sse_client):
|
||||
"""Test that headers are properly passed to the session manager."""
|
||||
test_url = "http://test.url"
|
||||
expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized
|
||||
|
||||
sse_client._session_context = "test_context"
|
||||
sse_client._connection_params = {
|
||||
"url": test_url,
|
||||
"headers": expected_headers, # Use normalized headers
|
||||
"timeout_seconds": 30,
|
||||
"sse_read_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
with patch.object(sse_client, "_get_session_manager") as mock_get_manager:
|
||||
mock_manager = AsyncMock()
|
||||
mock_session = AsyncMock()
|
||||
mock_manager.get_session = AsyncMock(return_value=mock_session)
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
result_session = await sse_client._get_or_create_session()
|
||||
|
||||
# Verify session manager was called with correct parameters including normalized headers
|
||||
mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse")
|
||||
assert result_session == mock_session
|
||||
|
||||
async def test_pre_check_redirect_with_headers(self, sse_client):
|
||||
"""Test pre-check redirect functionality with custom headers."""
|
||||
test_url = "http://test.url"
|
||||
redirect_url = "http://redirect.url"
|
||||
# Use pre-validated headers since pre_check_redirect expects already validated headers
|
||||
test_headers = {"authorization": "Bearer token123"} # already normalized
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 307
|
||||
mock_response.headers.get.return_value = redirect_url
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await sse_client.pre_check_redirect(test_url, test_headers)
|
||||
|
||||
assert result == redirect_url
|
||||
# Verify validated headers were passed to the request
|
||||
mock_client.return_value.__aenter__.return_value.get.assert_called_with(
|
||||
test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers}
|
||||
)
|
||||
|
||||
async def test_run_tool_with_retry_on_connection_error(self, sse_client):
|
||||
"""Test that run_tool retries on connection errors."""
|
||||
# Setup connection state
|
||||
sse_client._connected = True
|
||||
sse_client._connection_params = {"url": "http://test.url", "headers": {}}
|
||||
sse_client._session_context = "test_context"
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_session_side_effect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
session = AsyncMock()
|
||||
if call_count == 1:
|
||||
# First call fails with connection error
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
session.call_tool = AsyncMock(side_effect=ClosedResourceError())
|
||||
else:
|
||||
# Second call succeeds
|
||||
mock_result = MagicMock()
|
||||
session.call_tool = AsyncMock(return_value=mock_result)
|
||||
return session
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect),
|
||||
patch.object(sse_client, "_get_session_manager") as mock_get_manager,
|
||||
):
|
||||
mock_manager = AsyncMock()
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
result = await sse_client.run_tool("test_tool", {"param": "value"})
|
||||
|
||||
# Should have retried and succeeded on second attempt
|
||||
assert call_count == 2
|
||||
assert result is not None
|
||||
# Should have cleaned up the failed session
|
||||
mock_manager._cleanup_session.assert_called_once_with("test_context")
|
||||
|
||||
|
||||
class TestMCPSessionManager:
|
||||
@pytest.fixture
|
||||
def session_manager(self):
|
||||
return MCPSessionManager()
|
||||
|
||||
async def test_session_caching(self, session_manager):
|
||||
"""Test that sessions are properly cached and reused."""
|
||||
context_id = "test_context"
|
||||
connection_params = MagicMock()
|
||||
transport_type = "stdio"
|
||||
|
||||
# Create a mock session that will appear healthy
|
||||
mock_session = AsyncMock()
|
||||
mock_session._write_stream = MagicMock()
|
||||
mock_session._write_stream._closed = False
|
||||
|
||||
# Create a mock task that appears to be running
|
||||
mock_task = AsyncMock()
|
||||
mock_task.done = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch.object(session_manager, "_create_stdio_session") as mock_create,
|
||||
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
|
||||
):
|
||||
mock_create.return_value = mock_session
|
||||
|
||||
# First call should create session
|
||||
session1 = await session_manager.get_session(context_id, connection_params, transport_type)
|
||||
|
||||
# Manually populate the sessions cache as if the session was created properly
|
||||
session_manager.sessions[context_id] = {"session": mock_session, "task": mock_task, "type": transport_type}
|
||||
|
||||
# Second call should return cached session without creating new one
|
||||
session2 = await session_manager.get_session(context_id, connection_params, transport_type)
|
||||
|
||||
assert session1 == session2
|
||||
assert session1 == mock_session
|
||||
# Should only create once since the second call should use the cached session
|
||||
mock_create.assert_called_once()
|
||||
|
||||
async def test_session_cleanup(self, session_manager):
|
||||
"""Test session cleanup functionality."""
|
||||
context_id = "test_context"
|
||||
|
||||
# Add a session to the manager with proper mock setup
|
||||
mock_task = AsyncMock()
|
||||
mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method
|
||||
mock_task.cancel = MagicMock() # Use MagicMock for sync method
|
||||
|
||||
session_manager.sessions[context_id] = {"session": AsyncMock(), "task": mock_task, "type": "stdio"}
|
||||
|
||||
await session_manager._cleanup_session(context_id)
|
||||
|
||||
# Should cancel the task and remove from sessions
|
||||
mock_task.cancel.assert_called_once()
|
||||
assert context_id not in session_manager.sessions
|
||||
|
||||
async def test_server_switch_detection(self, session_manager):
|
||||
"""Test that server switches are properly detected and handled."""
|
||||
context_id = "test_context"
|
||||
|
||||
# First server
|
||||
server1_params = MagicMock()
|
||||
server1_params.command = "server1"
|
||||
|
||||
# Second server
|
||||
server2_params = MagicMock()
|
||||
server2_params.command = "server2"
|
||||
|
||||
with (
|
||||
patch.object(session_manager, "_create_stdio_session") as mock_create,
|
||||
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
|
||||
):
|
||||
mock_session1 = AsyncMock()
|
||||
mock_session2 = AsyncMock()
|
||||
mock_create.side_effect = [mock_session1, mock_session2]
|
||||
|
||||
# First connection
|
||||
session1 = await session_manager.get_session(context_id, server1_params, "stdio")
|
||||
|
||||
# Switch to different server should create new session
|
||||
session2 = await session_manager.get_session(context_id, server2_params, "stdio")
|
||||
|
||||
assert session1 != session2
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
|
||||
# Integration test for header functionality
|
||||
class TestHeaderValidation:
|
||||
"""Test the header validation functionality."""
|
||||
|
||||
def test_validate_headers_valid_input(self):
|
||||
"""Test header validation with valid headers."""
|
||||
headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Headers should be normalized to lowercase
|
||||
expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"}
|
||||
assert result == expected
|
||||
|
||||
def test_validate_headers_empty_input(self):
|
||||
"""Test header validation with empty/None input."""
|
||||
assert validate_headers({}) == {}
|
||||
assert validate_headers(None) == {}
|
||||
|
||||
def test_validate_headers_invalid_names(self):
|
||||
"""Test header validation with invalid header names."""
|
||||
headers = {
|
||||
"Invalid Header": "value", # spaces not allowed
|
||||
"Header@Name": "value", # @ not allowed
|
||||
"Header Name": "value", # spaces not allowed
|
||||
"Valid-Header": "value", # this should pass
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Only the valid header should remain
|
||||
assert result == {"valid-header": "value"}
|
||||
|
||||
def test_validate_headers_sanitize_values(self):
|
||||
"""Test header value sanitization."""
|
||||
headers = {
|
||||
"Authorization": "Bearer \x00token\x1f with\r\ninjection",
|
||||
"Clean-Header": " clean value ",
|
||||
"Empty-After-Clean": "\x00\x01\x02",
|
||||
"Tab-Header": "value\twith\ttabs", # tabs should be preserved
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Control characters should be removed, whitespace trimmed
|
||||
# Header with injection attempts should be skipped
|
||||
expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"}
|
||||
assert result == expected
|
||||
|
||||
def test_validate_headers_non_string_values(self):
|
||||
"""Test header validation with non-string values."""
|
||||
headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Only string headers should remain
|
||||
assert result == {"string-header": "valid"}
|
||||
|
||||
def test_validate_headers_injection_attempts(self):
|
||||
"""Test header validation against injection attempts."""
|
||||
headers = {
|
||||
"Injection1": "value\r\nInjected-Header: malicious",
|
||||
"Injection2": "value\nX-Evil: attack",
|
||||
"Safe-Header": "safe-value",
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Injection attempts should be filtered out
|
||||
assert result == {"safe-header": "safe-value"}
|
||||
|
||||
|
||||
class TestSSEHeaderIntegration:
|
||||
"""Integration test to verify headers are properly passed through the entire SSE flow."""
|
||||
|
||||
async def test_headers_processing(self):
|
||||
"""Test that headers flow properly from server config through to SSE client connection."""
|
||||
# Test the header processing function directly
|
||||
headers_input = [
|
||||
{"key": "Authorization", "value": "Bearer test-token"},
|
||||
{"key": "X-API-Key", "value": "secret-key"},
|
||||
]
|
||||
|
||||
expected_headers = {
|
||||
"authorization": "Bearer test-token", # normalized to lowercase
|
||||
"x-api-key": "secret-key",
|
||||
}
|
||||
|
||||
# Test _process_headers function with validation
|
||||
processed_headers = _process_headers(headers_input)
|
||||
assert processed_headers == expected_headers
|
||||
|
||||
# Test different input formats
|
||||
# Test dict input with validation
|
||||
dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"}
|
||||
result = _process_headers(dict_headers)
|
||||
# Invalid header should be filtered out, valid header normalized
|
||||
assert result == {"authorization": "Bearer dict-token"}
|
||||
|
||||
# Test None input
|
||||
assert _process_headers(None) == {}
|
||||
|
||||
# Test empty list
|
||||
assert _process_headers([]) == {}
|
||||
|
||||
# Test malformed list
|
||||
malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key
|
||||
assert _process_headers(malformed_headers) == {}
|
||||
|
||||
# Test list with invalid header names
|
||||
invalid_headers = [
|
||||
{"key": "Valid-Header", "value": "good"},
|
||||
{"key": "Invalid Header", "value": "bad"}, # spaces not allowed
|
||||
]
|
||||
result = _process_headers(invalid_headers)
|
||||
assert result == {"valid-header": "good"}
|
||||
|
||||
async def test_sse_client_header_storage(self):
|
||||
"""Test that SSE client properly stores headers in connection params."""
|
||||
sse_client = MCPSseClient()
|
||||
test_url = "http://test.url"
|
||||
test_headers = {"Authorization": "Bearer test123", "Custom": "value"}
|
||||
expected_headers = {"authorization": "Bearer test123", "custom": "value"} # normalized
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")),
|
||||
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
|
||||
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
await sse_client.connect_to_server(test_url, test_headers)
|
||||
|
||||
# Verify headers are stored correctly in connection params (normalized)
|
||||
assert sse_client._connection_params is not None
|
||||
assert sse_client._connection_params["headers"] == expected_headers
|
||||
assert sse_client._connection_params["url"] == test_url
|
||||
|
||||
|
||||
class TestMCPUtilityFunctions:
|
||||
"""Test utility functions from util.py that don't have dedicated test classes."""
|
||||
|
||||
def test_sanitize_mcp_name(self):
|
||||
"""Test MCP name sanitization."""
|
||||
assert util.sanitize_mcp_name("Test Name 123") == "test_name_123"
|
||||
assert util.sanitize_mcp_name(" ") == ""
|
||||
assert util.sanitize_mcp_name("123abc") == "_123abc"
|
||||
assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name"
|
||||
assert util.sanitize_mcp_name("a" * 100) == "a" * 46
|
||||
|
||||
def test_get_unique_name(self):
|
||||
"""Test unique name generation."""
|
||||
names = {"foo", "foo_1"}
|
||||
assert util.get_unique_name("foo", 10, names) == "foo_2"
|
||||
assert util.get_unique_name("bar", 10, names) == "bar"
|
||||
assert util.get_unique_name("longname", 4, {"long"}) == "lo_1"
|
||||
|
||||
def test_is_valid_key_value_item(self):
|
||||
"""Test key-value item validation."""
|
||||
assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True
|
||||
assert util._is_valid_key_value_item({"key": "a"}) is False
|
||||
assert util._is_valid_key_value_item(["key", "value"]) is False
|
||||
assert util._is_valid_key_value_item(None) is False
|
||||
|
||||
def test_validate_node_installation(self):
|
||||
"""Test Node.js installation validation."""
|
||||
import shutil
|
||||
|
||||
if shutil.which("node"):
|
||||
assert util._validate_node_installation("npx something") == "npx something"
|
||||
else:
|
||||
with pytest.raises(ValueError, match="Node.js is not installed"):
|
||||
util._validate_node_installation("npx something")
|
||||
assert util._validate_node_installation("echo test") == "echo test"
|
||||
|
||||
def test_create_input_schema_from_json_schema(self):
|
||||
"""Test JSON schema to Pydantic model conversion."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string", "description": "desc"},
|
||||
"bar": {"type": "integer"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
}
|
||||
model_class = util.create_input_schema_from_json_schema(schema)
|
||||
instance = model_class(foo="abc", bar=1)
|
||||
assert instance.foo == "abc"
|
||||
assert instance.bar == 1
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017, PT011
|
||||
model_class(bar=1) # missing required field
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_connection_params(self):
|
||||
"""Test connection parameter validation."""
|
||||
# Valid parameters
|
||||
await util._validate_connection_params("Stdio", command="echo test")
|
||||
await util._validate_connection_params("SSE", url="http://test")
|
||||
|
||||
# Invalid parameters
|
||||
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
|
||||
await util._validate_connection_params("Stdio", command=None)
|
||||
with pytest.raises(ValueError, match="URL is required for SSE mode"):
|
||||
await util._validate_connection_params("SSE", url=None)
|
||||
with pytest.raises(ValueError, match="Invalid mode"):
|
||||
await util._validate_connection_params("InvalidMode")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_flow_snake_case_mocked(self):
|
||||
"""Test flow lookup by snake case name with mocked session."""
|
||||
|
||||
class DummyFlow:
|
||||
def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None):
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.is_component = is_component
|
||||
self.action_name = action_name
|
||||
|
||||
class DummyExec:
|
||||
def __init__(self, flows: list[DummyFlow]):
|
||||
self._flows = flows
|
||||
|
||||
def all(self):
|
||||
return self._flows
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, flows: list[DummyFlow]):
|
||||
self._flows = flows
|
||||
|
||||
async def exec(self, stmt): # noqa: ARG002
|
||||
return DummyExec(self._flows)
|
||||
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)]
|
||||
|
||||
# Should match sanitized name
|
||||
result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows))
|
||||
assert result is flows[0]
|
||||
|
||||
# Should return None if not found
|
||||
result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows))
|
||||
assert result is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue