refactor(session): migrate to server-based session management and add tests (#9077)

* update MCP Tests

* [autofix.ci] apply automated fixes

* Update util.py

* [autofix.ci] apply automated fixes

* Refactor MCP session manager for better configurability and cleanup (#9176)

* Add log rotation and header validation features

Introduces support for log rotation via the LANGFLOW_LOG_ROTATION environment variable and CLI/config options, with documentation updates. Adds header validation and sanitization for MCP connections, ensuring RFC 7230 compliance and security. Frontend and backend now support passing custom headers for MCP servers. Includes extensive new and updated unit tests for header handling, MCP utilities, and log rotation.

* Add unit tests for MCP utilities and update disconnect logic

Added comprehensive unit tests for MCP utility functions, session management, header validation, and client classes in test_mcp_util.py. Updated MCPStdioClient and MCPSseClient disconnect methods for clearer session cleanup logic. Refactored test_mcp_component.py to remove redundant and duplicated tests, consolidating coverage in the new test suite.

* [autofix.ci] apply automated fixes

* Update test_mcp_memory_leak.py

* Update util.py

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Edwin Jose 2025-08-09 05:30:39 -04:00 committed by GitHub
commit b093c1fadb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1586 additions and 664 deletions

View file

@ -1,4 +1,5 @@
import asyncio
import contextlib
import os
import platform
import re
@ -30,6 +31,13 @@ HTTP_NOT_FOUND = 404
HTTP_BAD_REQUEST = 400
HTTP_INTERNAL_SERVER_ERROR = 500
# MCP Session Manager constants
settings = get_settings_service().settings
MAX_SESSIONS_PER_SERVER = (
settings.mcp_max_sessions_per_server
) # Maximum number of sessions per server to prevent resource exhaustion
SESSION_IDLE_TIMEOUT = settings.mcp_session_idle_timeout # 5 minutes idle timeout for sessions
SESSION_CLEANUP_INTERVAL = settings.mcp_session_cleanup_interval # Cleanup interval in seconds
# RFC 7230 compliant header name pattern: token = 1*tchar
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
@ -460,12 +468,90 @@ async def _validate_connection_params(mode: str, command: str | None = None, url
class MCPSessionManager:
"""Manages persistent MCP sessions with proper context manager lifecycle."""
"""Manages persistent MCP sessions with proper context manager lifecycle.
Fixed version that addresses the memory leak issue by:
1. Session reuse based on server identity rather than unique context IDs
2. Maximum session limits per server to prevent resource exhaustion
3. Idle timeout for automatic session cleanup
4. Periodic cleanup of stale sessions
"""
def __init__(self):
self.sessions = {} # context_id -> session_info
# Structure: server_key -> {"sessions": {session_id: session_info}, "last_cleanup": timestamp}
self.sessions_by_server = {}
self._background_tasks = set() # Keep references to background tasks
self._last_server_by_session = {} # context_id -> server_name for tracking switches
# Backwards-compatibility maps: which context_id uses which (server_key, session_id)
self._context_to_session: dict[str, tuple[str, str]] = {}
# Reference count for each active (server_key, session_id)
self._session_refcount: dict[tuple[str, str], int] = {}
self._cleanup_task = None
self._start_cleanup_task()
def _start_cleanup_task(self):
"""Start the periodic cleanup task."""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
self._background_tasks.add(self._cleanup_task)
self._cleanup_task.add_done_callback(self._background_tasks.discard)
async def _periodic_cleanup(self):
"""Periodically clean up idle sessions."""
while True:
try:
await asyncio.sleep(SESSION_CLEANUP_INTERVAL)
await self._cleanup_idle_sessions()
except asyncio.CancelledError:
break
except (RuntimeError, KeyError, ClosedResourceError, ValueError, asyncio.TimeoutError) as e:
# Handle common recoverable errors without stopping the cleanup loop
logger.warning(f"Error in periodic cleanup: {e}")
async def _cleanup_idle_sessions(self):
"""Clean up sessions that have been idle for too long."""
current_time = asyncio.get_event_loop().time()
servers_to_remove = []
for server_key, server_data in self.sessions_by_server.items():
sessions = server_data.get("sessions", {})
sessions_to_remove = []
for session_id, session_info in sessions.items():
if current_time - session_info["last_used"] > SESSION_IDLE_TIMEOUT:
sessions_to_remove.append(session_id)
# Clean up idle sessions
for session_id in sessions_to_remove:
logger.info(f"Cleaning up idle session {session_id} for server {server_key}")
await self._cleanup_session_by_id(server_key, session_id)
# Remove server entry if no sessions left
if not sessions:
servers_to_remove.append(server_key)
# Clean up empty server entries
for server_key in servers_to_remove:
del self.sessions_by_server[server_key]
def _get_server_key(self, connection_params, transport_type: str) -> str:
"""Generate a consistent server key based on connection parameters."""
if transport_type == "stdio":
if hasattr(connection_params, "command"):
# Include command, args, and environment for uniqueness
command_str = f"{connection_params.command} {' '.join(connection_params.args or [])}"
env_str = str(sorted((connection_params.env or {}).items()))
key_input = f"{command_str}|{env_str}"
return f"stdio_{hash(key_input)}"
elif transport_type == "sse" and (isinstance(connection_params, dict) and "url" in connection_params):
# Include URL and headers for uniqueness
url = connection_params["url"]
headers = str(sorted((connection_params.get("headers", {})).items()))
key_input = f"{url}|{headers}"
return f"sse_{hash(key_input)}"
# Fallback to a generic key
# TODO: add option for streamable HTTP in future.
return f"{transport_type}_{hash(str(connection_params))}"
async def _validate_session_connectivity(self, session) -> bool:
"""Validate that the session is actually usable by testing a simple operation."""
@ -483,6 +569,7 @@ class MCPSessionManager:
"ClosedResourceError" in str(type(e))
or "Connection closed" in error_str
or "Connection lost" in error_str
or "Connection failed" in error_str
or "Transport closed" in error_str
or "Stream closed" in error_str
):
@ -510,117 +597,83 @@ class MCPSessionManager:
return True
async def get_session(self, context_id: str, connection_params, transport_type: str):
"""Get or create a persistent session."""
# Extract server identifier from connection params for tracking
server_identifier = None
if transport_type == "stdio" and hasattr(connection_params, "command"):
server_identifier = f"stdio_{connection_params.command}"
elif transport_type == "sse" and isinstance(connection_params, dict) and "url" in connection_params:
server_identifier = f"sse_{connection_params['url']}"
"""Get or create a session with improved reuse strategy.
# Check if we're switching servers for this context
server_switched = False
if context_id in self._last_server_by_session:
last_server = self._last_server_by_session[context_id]
if last_server != server_identifier:
server_switched = True
logger.info(f"Detected server switch for context {context_id}: {last_server} -> {server_identifier}")
The key insight is that we should reuse sessions based on the server
identity (command + args for stdio, URL for SSE) rather than the context_id.
This prevents creating a new subprocess for each unique context.
"""
server_key = self._get_server_key(connection_params, transport_type)
# Update server tracking
if server_identifier:
self._last_server_by_session[context_id] = server_identifier
# Ensure server entry exists
if server_key not in self.sessions_by_server:
self.sessions_by_server[server_key] = {"sessions": {}, "last_cleanup": asyncio.get_event_loop().time()}
if context_id in self.sessions:
session_info = self.sessions[context_id]
# Check if session and background task are still alive
try:
session = session_info["session"]
task = session_info["task"]
server_data = self.sessions_by_server[server_key]
sessions = server_data["sessions"]
# Break down the health check to understand why cleanup is triggered
task_not_done = not task.done()
# Try to find a healthy existing session
for session_id, session_info in sessions.items():
session = session_info["session"]
task = session_info["task"]
# Additional check for stream health
stream_is_healthy = True
try:
# Check if the session's write stream is still open
if hasattr(session, "_write_stream"):
write_stream = session._write_stream
# Check if session is still alive
if not task.done():
# Update last used time
session_info["last_used"] = asyncio.get_event_loop().time()
# Check for explicit closed state
if hasattr(write_stream, "_closed") and write_stream._closed:
stream_is_healthy = False
# Check anyio stream state for send channels
elif hasattr(write_stream, "_state") and hasattr(write_stream._state, "open_send_channels"):
# Stream is healthy if there are open send channels
stream_is_healthy = write_stream._state.open_send_channels > 0
# Check for other stream closed indicators
elif hasattr(write_stream, "is_closing") and callable(write_stream.is_closing):
stream_is_healthy = not write_stream.is_closing()
# If we can't determine state definitively, try a simple write test
else:
# For streams we can't easily check, assume healthy unless proven otherwise
# The actual tool call will reveal if the stream is truly dead
stream_is_healthy = True
except (AttributeError, TypeError) as e:
# If we can't check stream health due to missing attributes,
# assume it's healthy and let the tool call fail if it's not
logger.debug(f"Could not check stream health for context_id {context_id}: {e}")
stream_is_healthy = True
logger.debug(f"Session health check for context_id {context_id}:")
logger.debug(f" - task_not_done: {task_not_done}")
logger.debug(f" - stream_is_healthy: {stream_is_healthy}")
# For MCP ClientSession, we need both task and stream to be healthy
session_is_healthy = task_not_done and stream_is_healthy
logger.debug(f" - session_is_healthy: {session_is_healthy}")
# If we switched servers, always recreate the session to avoid cross-server contamination
if server_switched:
logger.info(f"Server switch detected for context_id {context_id}, forcing session recreation")
session_is_healthy = False
# Always run connectivity test for sessions to ensure they're truly responsive
# This is especially important when switching between servers
elif session_is_healthy:
logger.debug(f"Running connectivity test for context_id {context_id}")
connectivity_ok = await self._validate_session_connectivity(session)
logger.debug(f" - connectivity_ok: {connectivity_ok}")
if not connectivity_ok:
session_is_healthy = False
logger.info(
f"Session for context_id {context_id} failed connectivity test, marking as unhealthy"
)
if session_is_healthy:
logger.debug(f"Session for context_id {context_id} is healthy and responsive, reusing")
# Quick health check
if await self._validate_session_connectivity(session):
logger.debug(f"Reusing existing session {session_id} for server {server_key}")
# record mapping & bump ref-count for backwards compatibility
self._context_to_session[context_id] = (server_key, session_id)
self._session_refcount[(server_key, session_id)] = (
self._session_refcount.get((server_key, session_id), 0) + 1
)
return session
logger.info(f"Session {session_id} for server {server_key} failed health check, cleaning up")
await self._cleanup_session_by_id(server_key, session_id)
else:
# Task is done, clean up
logger.info(f"Session {session_id} for server {server_key} task is done, cleaning up")
await self._cleanup_session_by_id(server_key, session_id)
if not task_not_done:
msg = f"Session for context_id {context_id} failed health check: background task is done"
logger.info(msg)
elif not stream_is_healthy:
msg = f"Session for context_id {context_id} failed health check: stream is closed"
logger.info(msg)
except Exception as e: # noqa: BLE001
msg = f"Session for context_id {context_id} is dead due to exception: {e}"
logger.info(msg)
# Session is dead, clean it up
await self._cleanup_session(context_id)
# Check if we've reached the maximum number of sessions for this server
if len(sessions) >= MAX_SESSIONS_PER_SERVER:
# Remove the oldest session
oldest_session_id = min(sessions.keys(), key=lambda x: sessions[x]["last_used"])
logger.info(
f"Maximum sessions reached for server {server_key}, removing oldest session {oldest_session_id}"
)
await self._cleanup_session_by_id(server_key, oldest_session_id)
# Create new session
if transport_type == "stdio":
return await self._create_stdio_session(context_id, connection_params)
if transport_type == "sse":
return await self._create_sse_session(context_id, connection_params)
msg = f"Unknown transport type: {transport_type}"
raise ValueError(msg)
session_id = f"{server_key}_{len(sessions)}"
logger.info(f"Creating new session {session_id} for server {server_key}")
async def _create_stdio_session(self, context_id: str, connection_params):
if transport_type == "stdio":
session, task = await self._create_stdio_session(session_id, connection_params)
elif transport_type == "sse":
session, task = await self._create_sse_session(session_id, connection_params)
else:
msg = f"Unknown transport type: {transport_type}"
raise ValueError(msg)
# Store session info
sessions[session_id] = {
"session": session,
"task": task,
"type": transport_type,
"last_used": asyncio.get_event_loop().time(),
}
# register mapping & initial ref-count for the new session
self._context_to_session[context_id] = (server_key, session_id)
self._session_refcount[(server_key, session_id)] = 1
return session
async def _create_stdio_session(self, session_id: str, connection_params):
"""Create a new stdio session as a background task to avoid context issues."""
import asyncio
@ -646,9 +699,7 @@ class MCPSessionManager:
try:
await event.wait()
except asyncio.CancelledError:
# Session is being shut down
msg = "Message is shutting down"
logger.info(msg)
logger.info(f"Session {session_id} is shutting down")
except Exception as e: # noqa: BLE001
if not session_future.done():
session_future.set_exception(e)
@ -660,7 +711,7 @@ class MCPSessionManager:
# Wait for session to be ready
try:
session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation
session = await asyncio.wait_for(session_future, timeout=10.0)
except asyncio.TimeoutError as timeout_err:
# Clean up the failed task
if not task.done():
@ -670,15 +721,13 @@ class MCPSessionManager:
with contextlib.suppress(asyncio.CancelledError):
await task
self._background_tasks.discard(task)
msg = f"Timeout waiting for STDIO session to initialize for context {context_id}"
msg = f"Timeout waiting for STDIO session {session_id} to initialize"
logger.error(msg)
raise ValueError(msg) from timeout_err
else:
# Store session info
self.sessions[context_id] = {"session": session, "task": task, "type": "stdio"}
return session
async def _create_sse_session(self, context_id: str, connection_params):
return session, task
async def _create_sse_session(self, session_id: str, connection_params):
"""Create a new SSE session as a background task to avoid context issues."""
import asyncio
@ -709,9 +758,7 @@ class MCPSessionManager:
try:
await event.wait()
except asyncio.CancelledError:
# Session is being shut down
msg = "Message is shutting down"
logger.info(msg)
logger.info(f"Session {session_id} is shutting down")
except Exception as e: # noqa: BLE001
if not session_future.done():
session_future.set_exception(e)
@ -723,7 +770,7 @@ class MCPSessionManager:
# Wait for session to be ready
try:
session = await asyncio.wait_for(session_future, timeout=10.0) # 10 second timeout for session creation
session = await asyncio.wait_for(session_future, timeout=10.0)
except asyncio.TimeoutError as timeout_err:
# Clean up the failed task
if not task.done():
@ -733,20 +780,29 @@ class MCPSessionManager:
with contextlib.suppress(asyncio.CancelledError):
await task
self._background_tasks.discard(task)
msg = f"Timeout waiting for SSE session to initialize for context {context_id}"
msg = f"Timeout waiting for SSE session {session_id} to initialize"
logger.error(msg)
raise ValueError(msg) from timeout_err
else:
# Store session info
self.sessions[context_id] = {"session": session, "task": task, "type": "sse"}
return session
async def _cleanup_session(self, context_id: str):
"""Clean up a session by cancelling its background task."""
if context_id not in self.sessions:
return session, task
async def _cleanup_session_by_id(self, server_key: str, session_id: str):
"""Clean up a specific session by server key and session ID."""
if server_key not in self.sessions_by_server:
return
session_info = self.sessions[context_id]
server_data = self.sessions_by_server[server_key]
# Handle both old and new session structure
if isinstance(server_data, dict) and "sessions" in server_data:
sessions = server_data["sessions"]
else:
# Handle old structure where sessions were stored directly
sessions = server_data
if session_id not in sessions:
return
session_info = sessions[session_id]
try:
# Cancel the background task which will properly close the session
if "task" in session_info:
@ -756,19 +812,72 @@ class MCPSessionManager:
try:
await task
except asyncio.CancelledError:
logger.info(f"Issue cancelling task for context_id {context_id}")
logger.info(f"Cancelled task for session {session_id}")
except Exception as e: # noqa: BLE001
logger.info(f"issue cleaning up mcp session: {e}")
logger.warning(f"Error cleaning up session {session_id}: {e}")
finally:
del self.sessions[context_id]
# Also clean up server tracking
if context_id in self._last_server_by_session:
del self._last_server_by_session[context_id]
# Remove from sessions dict
del sessions[session_id]
async def cleanup_all(self):
"""Clean up all sessions."""
for context_id in list(self.sessions.keys()):
await self._cleanup_session(context_id)
# Cancel periodic cleanup task
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._cleanup_task
# Clean up all sessions
for server_key in list(self.sessions_by_server.keys()):
server_data = self.sessions_by_server[server_key]
# Handle both old and new session structure
if isinstance(server_data, dict) and "sessions" in server_data:
sessions = server_data["sessions"]
else:
# Handle old structure where sessions were stored directly
sessions = server_data
for session_id in list(sessions.keys()):
await self._cleanup_session_by_id(server_key, session_id)
# Clear the sessions_by_server structure completely
self.sessions_by_server.clear()
# Clear compatibility maps
self._context_to_session.clear()
self._session_refcount.clear()
# Clear all background tasks
for task in list(self._background_tasks):
if not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
async def _cleanup_session(self, context_id: str):
"""Backward-compat cleanup by context_id.
Decrements the ref-count for the session used by *context_id* and only
tears the session down when the last context that references it goes
away.
"""
mapping = self._context_to_session.get(context_id)
if not mapping:
logger.debug(f"No session mapping found for context_id {context_id}")
return
server_key, session_id = mapping
ref_key = (server_key, session_id)
remaining = self._session_refcount.get(ref_key, 1) - 1
if remaining <= 0:
await self._cleanup_session_by_id(server_key, session_id)
self._session_refcount.pop(ref_key, None)
else:
self._session_refcount[ref_key] = remaining
# Remove the mapping for this context
self._context_to_session.pop(context_id, None)
class MCPStdioClient:
@ -963,11 +1072,15 @@ class MCPStdioClient:
async def disconnect(self):
"""Properly close the connection and clean up resources."""
# Clean up session using session manager
# For stdio transport, there is no remote session to terminate explicitly
# The session cleanup happens when the background task is cancelled
# Clean up local session using the session manager
if self._session_context:
session_manager = self._get_session_manager()
await session_manager._cleanup_session(self._session_context)
# Reset local state
self.session = None
self._connection_params = None
self._connected = False
@ -1127,19 +1240,34 @@ class MCPSseClient:
# Use cached session manager to get/create persistent session
session_manager = self._get_session_manager()
return await session_manager.get_session(self._session_context, self._connection_params, "sse")
# Cache session so we can access server-assigned session_id later for DELETE
self.session = await session_manager.get_session(self._session_context, self._connection_params, "sse")
return self.session
async def disconnect(self):
"""Properly close the connection and clean up resources."""
# Clean up session using session manager
if self._session_context:
session_manager = self._get_session_manager()
await session_manager._cleanup_session(self._session_context)
async def _terminate_remote_session(self) -> None:
"""Attempt to explicitly terminate the remote MCP session via HTTP DELETE (best-effort)."""
# Only relevant for SSE transport
if not self._connection_params or "url" not in self._connection_params:
return
self.session = None
self._connection_params = None
self._connected = False
self._session_context = None
url: str = self._connection_params["url"]
# Retrieve session id from the underlying SDK if exposed
session_id = None
if getattr(self, "session", None) is not None:
# Common attributes in MCP python SDK: `session_id` or `id`
session_id = getattr(self.session, "session_id", None) or getattr(self.session, "id", None)
headers: dict[str, str] = dict(self._connection_params.get("headers", {}))
if session_id:
headers["Mcp-Session-Id"] = str(session_id)
try:
async with httpx.AsyncClient(timeout=5.0) as client:
await client.delete(url, headers=headers)
except Exception as e: # noqa: BLE001
# DELETE is advisory—log and continue
logger.debug(f"Unable to send session DELETE to '{url}': {e}")
async def run_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Run a tool with the given arguments using context-specific session.
@ -1253,6 +1381,22 @@ class MCPSseClient:
logger.error(msg)
raise ValueError(msg)
async def disconnect(self):
"""Properly close the connection and clean up resources."""
# Attempt best-effort remote session termination first
await self._terminate_remote_session()
# Clean up local session using the session manager
if self._session_context:
session_manager = self._get_session_manager()
await session_manager._cleanup_session(self._session_context)
# Reset local state
self.session = None
self._connection_params = None
self._connected = False
self._session_context = None
async def __aenter__(self):
return self

View file

@ -96,6 +96,22 @@ class Settings(BaseSettings):
"""The number of seconds to wait before giving up on a lock to released or establishing a connection to the
database."""
# ---------------------------------------------------------------------
# MCP Session-manager tuning
# ---------------------------------------------------------------------
mcp_max_sessions_per_server: int = 10
"""Maximum number of MCP sessions to keep per unique server (command/url).
Mirrors the default constant MAX_SESSIONS_PER_SERVER in util.py. Adjust to
control resource usage or concurrency per server."""
mcp_session_idle_timeout: int = 400 # seconds
"""How long (in seconds) an MCP session can stay idle before the background
cleanup task disposes of it. Defaults to 5 minutes."""
mcp_session_cleanup_interval: int = 120 # seconds
"""Frequency (in seconds) at which the background cleanup task wakes up to
reap idle sessions."""
# sqlite configuration
sqlite_pragmas: dict | None = {"synchronous": "NORMAL", "journal_mode": "WAL"}
"""SQLite pragmas to use when connecting to the database."""

View file

@ -0,0 +1,365 @@
"""Integration tests for MCP memory leak fix.
These tests verify that the MCP session manager properly handles session reuse
and cleanup to prevent subprocess leaks.
"""
import asyncio
import contextlib
import os
import platform
import shutil
import psutil
import pytest
from langflow.base.mcp.util import MCPSessionManager
from loguru import logger
from mcp import StdioServerParameters
@pytest.fixture
def mcp_server_params():
"""Create MCP server parameters for testing."""
command = ["npx", "-y", "@modelcontextprotocol/server-everything"]
env_data = {"DEBUG": "true", "PATH": os.environ["PATH"]}
if platform.system() == "Windows":
return StdioServerParameters(
command="cmd",
args=["/c", f"{command[0]} {' '.join(command[1:])}"],
env=env_data,
)
return StdioServerParameters(
command="bash",
args=["-c", f"exec {' '.join(command)}"],
env=env_data,
)
@pytest.fixture
def process_tracker():
"""Track subprocess count for memory leak detection."""
process = psutil.Process()
initial_count = len(process.children(recursive=True))
yield process, initial_count
# Cleanup any remaining child processes
try:
for child in process.children(recursive=True):
try:
child.terminate()
child.wait(timeout=3)
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
with contextlib.suppress(psutil.NoSuchProcess):
child.kill()
except Exception as e:
logger.exception("Error cleaning up child processes: %s", e)
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_session_reuse_prevents_subprocess_leak(mcp_server_params, process_tracker):
"""Test that session reuse prevents subprocess proliferation."""
process, initial_count = process_tracker
session_manager = MCPSessionManager()
try:
# Create multiple sessions with different context IDs but same server
sessions = []
for i in range(3):
context_id = f"test_context_{i}"
session = await session_manager.get_session(context_id, mcp_server_params, "stdio")
sessions.append(session)
# Verify session is working
tools_response = await session.list_tools()
assert len(tools_response.tools) > 0
# Check subprocess count after creating sessions
current_count = len(process.children(recursive=True))
subprocess_increase = current_count - initial_count
# With the fix, we should have minimal subprocess increase
# (ideally 2 subprocesses max for the MCP server)
assert subprocess_increase <= 4, f"Too many subprocesses created: {subprocess_increase}"
# Verify all sessions are functional
for session in sessions:
tools_response = await session.list_tools()
assert len(tools_response.tools) > 0
finally:
await session_manager.cleanup_all()
await asyncio.sleep(2) # Allow cleanup to complete
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_session_cleanup_removes_subprocesses(mcp_server_params, process_tracker):
"""Test that session cleanup properly removes subprocesses."""
process, initial_count = process_tracker
session_manager = MCPSessionManager()
try:
# Create a session
session = await session_manager.get_session("cleanup_test", mcp_server_params, "stdio")
tools_response = await session.list_tools()
assert len(tools_response.tools) > 0
# Verify subprocess was created
after_creation_count = len(process.children(recursive=True))
assert after_creation_count > initial_count
finally:
# Clean up session
await session_manager.cleanup_all()
await asyncio.sleep(2) # Allow cleanup to complete
# Verify subprocess was cleaned up
after_cleanup_count = len(process.children(recursive=True))
# Allow some tolerance for cleanup timing and system processes
assert after_cleanup_count <= initial_count + 1, (
f"Subprocesses not cleaned up properly: {after_cleanup_count} vs {initial_count}"
)
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_session_health_check_and_recovery(mcp_server_params, process_tracker):
"""Test that unhealthy sessions are properly detected and recreated."""
process, initial_count = process_tracker
session_manager = MCPSessionManager()
try:
# Create a session
session1 = await session_manager.get_session("health_test", mcp_server_params, "stdio")
tools_response = await session1.list_tools()
assert len(tools_response.tools) > 0
# Simulate session becoming unhealthy by accessing internal state
# This is a bit of a hack but necessary for testing
server_key = session_manager._get_server_key(mcp_server_params, "stdio")
if hasattr(session_manager, "sessions_by_server"):
# For the fixed version
sessions = session_manager.sessions_by_server.get(server_key, {})
if sessions:
session_id = next(iter(sessions.keys()))
session_info = sessions[session_id]
if "task" in session_info:
task = session_info["task"]
if not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
elif hasattr(session_manager, "sessions"):
# For the original version
for session_info in session_manager.sessions.values():
if "task" in session_info:
task = session_info["task"]
if not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
# Wait a bit for the task to be cancelled
await asyncio.sleep(1)
# Try to get a session again - should create a new healthy one
session2 = await session_manager.get_session("health_test_2", mcp_server_params, "stdio")
tools_response = await session2.list_tools()
assert len(tools_response.tools) > 0
finally:
await session_manager.cleanup_all()
await asyncio.sleep(2)
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_multiple_servers_isolation(process_tracker):
"""Test that different servers get separate sessions."""
process, initial_count = process_tracker
session_manager = MCPSessionManager()
# Create parameters for different servers
server1_params = StdioServerParameters(
command="bash",
args=["-c", "exec npx -y @modelcontextprotocol/server-everything"],
env={"DEBUG": "true", "PATH": os.environ["PATH"]},
)
server2_params = StdioServerParameters(
command="bash",
args=["-c", "exec npx -y @modelcontextprotocol/server-everything"],
env={"DEBUG": "false", "PATH": os.environ["PATH"]}, # Different env
)
try:
# Create sessions for different servers
session1 = await session_manager.get_session("server1_test", server1_params, "stdio")
session2 = await session_manager.get_session("server2_test", server2_params, "stdio")
# Verify both sessions work
tools1 = await session1.list_tools()
tools2 = await session2.list_tools()
assert len(tools1.tools) > 0
assert len(tools2.tools) > 0
# Sessions should be different objects for different servers (different environments)
# Since the servers have different environments, they should get different server keys
server_key1 = session_manager._get_server_key(server1_params, "stdio")
server_key2 = session_manager._get_server_key(server2_params, "stdio")
assert server_key1 != server_key2, "Different server environments should generate different keys"
assert session1 is not session2
finally:
await session_manager.cleanup_all()
await asyncio.sleep(2)
@pytest.mark.asyncio
async def test_session_manager_server_key_generation():
"""Test that server key generation works correctly."""
session_manager = MCPSessionManager()
# Test stdio server key
stdio_params = StdioServerParameters(
command="test_command",
args=["arg1", "arg2"],
env={"TEST": "value"},
)
key1 = session_manager._get_server_key(stdio_params, "stdio")
key2 = session_manager._get_server_key(stdio_params, "stdio")
# Same parameters should generate same key
assert key1 == key2
assert key1.startswith("stdio_")
# Different parameters should generate different keys
stdio_params2 = StdioServerParameters(
command="different_command",
args=["arg1", "arg2"],
env={"TEST": "value"},
)
key3 = session_manager._get_server_key(stdio_params2, "stdio")
assert key1 != key3
# Test SSE server key
sse_params = {
"url": "http://example.com/sse",
"headers": {"Authorization": "Bearer token"},
"timeout_seconds": 30,
"sse_read_timeout_seconds": 30,
}
sse_key1 = session_manager._get_server_key(sse_params, "sse")
sse_key2 = session_manager._get_server_key(sse_params, "sse")
assert sse_key1 == sse_key2
assert sse_key1.startswith("sse_")
# Different URL should generate different key
sse_params2 = sse_params.copy()
sse_params2["url"] = "http://different.com/sse"
sse_key3 = session_manager._get_server_key(sse_params2, "sse")
assert sse_key1 != sse_key3
@pytest.mark.asyncio
async def test_session_manager_connectivity_validation():
"""Test session connectivity validation."""
session_manager = MCPSessionManager()
# Mock a session that responds to list_tools
class MockSession:
def __init__(self, should_fail=False): # noqa: FBT002
self.should_fail = should_fail
async def list_tools(self):
if self.should_fail:
msg = "Connection failed"
raise Exception(msg) # noqa: TRY002
class MockResponse:
def __init__(self):
self.tools = ["tool1", "tool2"]
return MockResponse()
# Test healthy session
healthy_session = MockSession(should_fail=False)
is_healthy = await session_manager._validate_session_connectivity(healthy_session)
assert is_healthy is True
# Test unhealthy session
unhealthy_session = MockSession(should_fail=True)
is_healthy = await session_manager._validate_session_connectivity(unhealthy_session)
assert is_healthy is False
# Test session that returns None
class MockNoneSession:
async def list_tools(self):
return None
none_session = MockNoneSession()
is_healthy = await session_manager._validate_session_connectivity(none_session)
assert is_healthy is False
@pytest.mark.asyncio
async def test_session_manager_cleanup_all():
"""Test that cleanup_all properly cleans up all sessions."""
session_manager = MCPSessionManager()
# Mock some sessions using the correct structure
session_manager.sessions_by_server = {
"server1": {
"sessions": {
"session1": {
"session": "mock_session",
"task": asyncio.create_task(asyncio.sleep(10)),
"type": "stdio",
"last_used": asyncio.get_event_loop().time(),
}
}
},
"server2": {
"sessions": {
"session2": {
"session": "mock_session",
"task": asyncio.create_task(asyncio.sleep(10)),
"type": "sse",
"last_used": asyncio.get_event_loop().time(),
}
}
},
}
# Add some background tasks
task1 = asyncio.create_task(asyncio.sleep(10))
task2 = asyncio.create_task(asyncio.sleep(10))
session_manager._background_tasks = {task1, task2}
# Cleanup all
await session_manager.cleanup_all()
# Verify cleanup
if hasattr(session_manager, "sessions_by_server"):
# For fixed version
assert len(session_manager.sessions_by_server) == 0
elif hasattr(session_manager, "sessions"):
# For original version
assert len(session_manager.sessions) == 0
# Verify background tasks were cancelled
assert task1.done()
assert task2.done()

View file

@ -0,0 +1,806 @@
"""Unit tests for MCP utility functions.
This test suite validates the MCP utility functions including:
- Session management
- Header validation and processing
- Utility functions for name sanitization and schema conversion
"""
import shutil
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langflow.base.mcp import util
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers
class TestMCPSessionManager:
@pytest.fixture
async def session_manager(self):
"""Create a session manager and clean it up after the test."""
manager = MCPSessionManager()
yield manager
# Clean up after test
await manager.cleanup_all()
async def test_session_caching(self, session_manager):
"""Test that sessions are properly cached and reused."""
context_id = "test_context"
connection_params = MagicMock()
transport_type = "stdio"
# Create a mock session that will appear healthy
mock_session = AsyncMock()
mock_session._write_stream = MagicMock()
mock_session._write_stream._closed = False
# Create a mock task that appears to be running
mock_task = AsyncMock()
mock_task.done = MagicMock(return_value=False)
with (
patch.object(session_manager, "_create_stdio_session") as mock_create,
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
):
mock_create.return_value = (mock_session, mock_task)
# First call should create session
session1 = await session_manager.get_session(context_id, connection_params, transport_type)
# Second call should return cached session without creating new one
session2 = await session_manager.get_session(context_id, connection_params, transport_type)
assert session1 == session2
assert session1 == mock_session
# Should only create once since the second call should use the cached session
mock_create.assert_called_once()
async def test_session_cleanup(self, session_manager):
"""Test session cleanup functionality."""
context_id = "test_context"
server_key = "test_server"
session_id = "test_session"
# Add a session to the manager with proper mock setup using new structure
mock_task = AsyncMock()
mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method
mock_task.cancel = MagicMock() # Use MagicMock for sync method
# Set up the new session structure
session_manager.sessions_by_server[server_key] = {
"sessions": {session_id: {"session": AsyncMock(), "task": mock_task, "type": "stdio", "last_used": 0}},
"last_cleanup": 0,
}
# Set up mapping for backwards compatibility
session_manager._context_to_session[context_id] = (server_key, session_id)
await session_manager._cleanup_session(context_id)
# Should cancel the task and remove from sessions
mock_task.cancel.assert_called_once()
assert session_id not in session_manager.sessions_by_server[server_key]["sessions"]
async def test_server_switch_detection(self, session_manager):
"""Test that server switches are properly detected and handled."""
context_id = "test_context"
# First server
server1_params = MagicMock()
server1_params.command = "server1"
# Second server
server2_params = MagicMock()
server2_params.command = "server2"
with (
patch.object(session_manager, "_create_stdio_session") as mock_create,
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
):
mock_session1 = AsyncMock()
mock_session2 = AsyncMock()
mock_task1 = AsyncMock()
mock_task2 = AsyncMock()
mock_create.side_effect = [(mock_session1, mock_task1), (mock_session2, mock_task2)]
# First connection
session1 = await session_manager.get_session(context_id, server1_params, "stdio")
# Switch to different server should create new session
session2 = await session_manager.get_session(context_id, server2_params, "stdio")
assert session1 != session2
assert mock_create.call_count == 2
class TestHeaderValidation:
"""Test the header validation functionality."""
def test_validate_headers_valid_input(self):
"""Test header validation with valid headers."""
headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"}
result = validate_headers(headers)
# Headers should be normalized to lowercase
expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"}
assert result == expected
def test_validate_headers_empty_input(self):
"""Test header validation with empty/None input."""
assert validate_headers({}) == {}
assert validate_headers(None) == {}
def test_validate_headers_invalid_names(self):
"""Test header validation with invalid header names."""
headers = {
"Invalid Header": "value", # spaces not allowed
"Header@Name": "value", # @ not allowed
"Header Name": "value", # spaces not allowed
"Valid-Header": "value", # this should pass
}
result = validate_headers(headers)
# Only the valid header should remain
assert result == {"valid-header": "value"}
def test_validate_headers_sanitize_values(self):
"""Test header value sanitization."""
headers = {
"Authorization": "Bearer \x00token\x1f with\r\ninjection",
"Clean-Header": " clean value ",
"Empty-After-Clean": "\x00\x01\x02",
"Tab-Header": "value\twith\ttabs", # tabs should be preserved
}
result = validate_headers(headers)
# Control characters should be removed, whitespace trimmed
# Header with injection attempts should be skipped
expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"}
assert result == expected
def test_validate_headers_non_string_values(self):
"""Test header validation with non-string values."""
headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]}
result = validate_headers(headers)
# Only string headers should remain
assert result == {"string-header": "valid"}
def test_validate_headers_injection_attempts(self):
"""Test header validation against injection attempts."""
headers = {
"Injection1": "value\r\nInjected-Header: malicious",
"Injection2": "value\nX-Evil: attack",
"Safe-Header": "safe-value",
}
result = validate_headers(headers)
# Injection attempts should be filtered out
assert result == {"safe-header": "safe-value"}
class TestSSEHeaderIntegration:
"""Integration test to verify headers are properly passed through the entire SSE flow."""
async def test_headers_processing(self):
"""Test that headers flow properly from server config through to SSE client connection."""
# Test the header processing function directly
headers_input = [
{"key": "Authorization", "value": "Bearer test-token"},
{"key": "X-API-Key", "value": "secret-key"},
]
expected_headers = {
"authorization": "Bearer test-token", # normalized to lowercase
"x-api-key": "secret-key",
}
# Test _process_headers function with validation
processed_headers = _process_headers(headers_input)
assert processed_headers == expected_headers
# Test different input formats
# Test dict input with validation
dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"}
result = _process_headers(dict_headers)
# Invalid header should be filtered out, valid header normalized
assert result == {"authorization": "Bearer dict-token"}
# Test None input
assert _process_headers(None) == {}
# Test empty list
assert _process_headers([]) == {}
# Test malformed list
malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key
assert _process_headers(malformed_headers) == {}
# Test list with invalid header names
invalid_headers = [
{"key": "Valid-Header", "value": "good"},
{"key": "Invalid Header", "value": "bad"}, # spaces not allowed
]
result = _process_headers(invalid_headers)
assert result == {"valid-header": "good"}
async def test_sse_client_header_storage(self):
"""Test that SSE client properly stores headers in connection params."""
sse_client = MCPSseClient()
test_url = "http://test.url"
test_headers = {"Authorization": "Bearer test123", "Custom": "value"}
# Test that headers are properly stored in connection params
# Set connection params as a dict like the implementation expects
sse_client._connection_params = {
"url": test_url,
"headers": test_headers,
"timeout_seconds": 30,
"sse_read_timeout_seconds": 30,
}
# Verify headers are stored
assert sse_client._connection_params["url"] == test_url
assert sse_client._connection_params["headers"] == test_headers
class TestMCPUtilityFunctions:
"""Test utility functions from util.py that don't have dedicated test classes."""
def test_sanitize_mcp_name(self):
"""Test MCP name sanitization."""
assert util.sanitize_mcp_name("Test Name 123") == "test_name_123"
assert util.sanitize_mcp_name(" ") == ""
assert util.sanitize_mcp_name("123abc") == "_123abc"
assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name"
assert util.sanitize_mcp_name("a" * 100) == "a" * 46
def test_get_unique_name(self):
"""Test unique name generation."""
names = {"foo", "foo_1"}
assert util.get_unique_name("foo", 10, names) == "foo_2"
assert util.get_unique_name("bar", 10, names) == "bar"
assert util.get_unique_name("longname", 4, {"long"}) == "lo_1"
def test_is_valid_key_value_item(self):
"""Test key-value item validation."""
assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True
assert util._is_valid_key_value_item({"key": "a"}) is False
assert util._is_valid_key_value_item(["key", "value"]) is False
assert util._is_valid_key_value_item(None) is False
def test_validate_node_installation(self):
"""Test Node.js installation validation."""
if shutil.which("node"):
assert util._validate_node_installation("npx something") == "npx something"
else:
with pytest.raises(ValueError, match="Node.js is not installed"):
util._validate_node_installation("npx something")
assert util._validate_node_installation("echo test") == "echo test"
def test_create_input_schema_from_json_schema(self):
"""Test JSON schema to Pydantic model conversion."""
schema = {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "desc"},
"bar": {"type": "integer"},
},
"required": ["foo"],
}
model_class = util.create_input_schema_from_json_schema(schema)
instance = model_class(foo="abc", bar=1)
assert instance.foo == "abc"
assert instance.bar == 1
with pytest.raises(Exception): # noqa: B017, PT011
model_class(bar=1) # missing required field
@pytest.mark.asyncio
async def test_validate_connection_params(self):
"""Test connection parameter validation."""
# Valid parameters
await util._validate_connection_params("Stdio", command="echo test")
await util._validate_connection_params("SSE", url="http://test")
# Invalid parameters
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
await util._validate_connection_params("Stdio", command=None)
with pytest.raises(ValueError, match="URL is required for SSE mode"):
await util._validate_connection_params("SSE", url=None)
with pytest.raises(ValueError, match="Invalid mode"):
await util._validate_connection_params("InvalidMode")
@pytest.mark.asyncio
async def test_get_flow_snake_case_mocked(self):
"""Test flow lookup by snake case name with mocked session."""
class DummyFlow:
def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None):
self.name = name
self.user_id = user_id
self.is_component = is_component
self.action_name = action_name
class DummyExec:
def __init__(self, flows: list[DummyFlow]):
self._flows = flows
def all(self):
return self._flows
class DummySession:
def __init__(self, flows: list[DummyFlow]):
self._flows = flows
async def exec(self, stmt): # noqa: ARG002
return DummyExec(self._flows)
user_id = "123e4567-e89b-12d3-a456-426614174000"
flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)]
# Should match sanitized name
result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows))
assert result is flows[0]
# Should return None if not found
result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows))
assert result is None
class TestMCPStdioClientWithEverythingServer:
"""Test MCPStdioClient with the Everything MCP server."""
@pytest.fixture
def stdio_client(self):
"""Create a stdio client for testing."""
return MCPStdioClient()
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_connect_to_everything_server(self, stdio_client):
"""Test connecting to the Everything MCP server."""
command = "npx -y @modelcontextprotocol/server-everything"
try:
# Connect to the server
tools = await stdio_client.connect_to_server(command)
# Verify tools were returned
assert len(tools) > 0
# Find the echo tool
echo_tool = None
for tool in tools:
if hasattr(tool, "name") and tool.name == "echo":
echo_tool = tool
break
assert echo_tool is not None, "Echo tool not found in server tools"
assert echo_tool.description is not None
# Verify the echo tool has the expected input schema
assert hasattr(echo_tool, "inputSchema")
assert echo_tool.inputSchema is not None
finally:
# Clean up the connection
await stdio_client.disconnect()
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_run_echo_tool(self, stdio_client):
"""Test running the echo tool from the Everything server."""
command = "npx -y @modelcontextprotocol/server-everything"
try:
# Connect to the server
tools = await stdio_client.connect_to_server(command)
# Find the echo tool
echo_tool = None
for tool in tools:
if hasattr(tool, "name") and tool.name == "echo":
echo_tool = tool
break
assert echo_tool is not None, "Echo tool not found"
# Run the echo tool
test_message = "Hello, MCP!"
result = await stdio_client.run_tool("echo", {"message": test_message})
# Verify the result
assert result is not None
assert hasattr(result, "content")
assert len(result.content) > 0
# Check that the echo worked - content should contain our message
content_text = str(result.content[0])
assert test_message in content_text or "Echo:" in content_text
finally:
await stdio_client.disconnect()
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_list_all_tools(self, stdio_client):
"""Test listing all available tools from the Everything server."""
command = "npx -y @modelcontextprotocol/server-everything"
try:
# Connect to the server
tools = await stdio_client.connect_to_server(command)
# Verify we have multiple tools
assert len(tools) >= 3 # Everything server typically has several tools
# Check that tools have the expected attributes
for tool in tools:
assert hasattr(tool, "name")
assert hasattr(tool, "description")
assert hasattr(tool, "inputSchema")
assert tool.name is not None
assert len(tool.name) > 0
# Common tools that should be available
expected_tools = ["echo"] # Echo is typically available
for expected_tool in expected_tools:
assert any(tool.name == expected_tool for tool in tools), f"Expected tool '{expected_tool}' not found"
finally:
await stdio_client.disconnect()
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_session_reuse(self, stdio_client):
"""Test that sessions are properly reused."""
command = "npx -y @modelcontextprotocol/server-everything"
try:
# Set session context
stdio_client.set_session_context("test_session_reuse")
# Connect to the server
tools1 = await stdio_client.connect_to_server(command)
# Connect again - should reuse the session
tools2 = await stdio_client.connect_to_server(command)
# Should have the same tools
assert len(tools1) == len(tools2)
# Run a tool to verify the session is working
result = await stdio_client.run_tool("echo", {"message": "Session reuse test"})
assert result is not None
finally:
await stdio_client.disconnect()
class TestMCPSseClientWithDeepWikiServer:
"""Test MCPSseClient with the DeepWiki MCP server."""
@pytest.fixture
def sse_client(self):
"""Create an SSE client for testing."""
return MCPSseClient()
@pytest.mark.asyncio
async def test_connect_to_deepwiki_server(self, sse_client):
"""Test connecting to the DeepWiki MCP server."""
url = "https://mcp.deepwiki.com/sse"
try:
# Connect to the server
tools = await sse_client.connect_to_server(url)
# Verify tools were returned
assert len(tools) > 0
# Check for expected DeepWiki tools
expected_tools = ["read_wiki_structure", "read_wiki_contents", "ask_question"]
# Verify we have the expected tools
for expected_tool in expected_tools:
assert any(tool.name == expected_tool for tool in tools), f"Expected tool '{expected_tool}' not found"
except Exception as e:
# If the server is not accessible, skip the test
pytest.skip(f"DeepWiki server not accessible: {e}")
finally:
await sse_client.disconnect()
@pytest.mark.asyncio
async def test_run_wiki_structure_tool(self, sse_client):
"""Test running the read_wiki_structure tool."""
url = "https://mcp.deepwiki.com/sse"
try:
# Connect to the server
tools = await sse_client.connect_to_server(url)
# Find the read_wiki_structure tool
wiki_tool = None
for tool in tools:
if hasattr(tool, "name") and tool.name == "read_wiki_structure":
wiki_tool = tool
break
assert wiki_tool is not None, "read_wiki_structure tool not found"
# Run the tool with a test repository (use repoName as expected by the API)
result = await sse_client.run_tool("read_wiki_structure", {"repoName": "microsoft/vscode"})
# Verify the result
assert result is not None
assert hasattr(result, "content")
assert len(result.content) > 0
except Exception as e:
# If the server is not accessible or the tool fails, skip the test
pytest.skip(f"DeepWiki server test failed: {e}")
finally:
await sse_client.disconnect()
@pytest.mark.asyncio
async def test_ask_question_tool(self, sse_client):
"""Test running the ask_question tool."""
url = "https://mcp.deepwiki.com/sse"
try:
# Connect to the server
tools = await sse_client.connect_to_server(url)
# Find the ask_question tool
ask_tool = None
for tool in tools:
if hasattr(tool, "name") and tool.name == "ask_question":
ask_tool = tool
break
assert ask_tool is not None, "ask_question tool not found"
# Run the tool with a test question (use repoName as expected by the API)
result = await sse_client.run_tool(
"ask_question", {"repoName": "microsoft/vscode", "question": "What is VS Code?"}
)
# Verify the result
assert result is not None
assert hasattr(result, "content")
assert len(result.content) > 0
except Exception as e:
# If the server is not accessible or the tool fails, skip the test
pytest.skip(f"DeepWiki server test failed: {e}")
finally:
await sse_client.disconnect()
@pytest.mark.asyncio
async def test_url_validation(self, sse_client):
"""Test URL validation for SSE connections."""
# Test valid URL
valid_url = "https://mcp.deepwiki.com/sse"
is_valid, error = await sse_client.validate_url(valid_url)
assert is_valid or error == "" # Either valid or accessible
# Test invalid URL
invalid_url = "not_a_url"
is_valid, error = await sse_client.validate_url(invalid_url)
assert not is_valid
assert error != ""
@pytest.mark.asyncio
async def test_redirect_handling(self, sse_client):
"""Test redirect handling for SSE connections."""
# Test with the DeepWiki URL
url = "https://mcp.deepwiki.com/sse"
try:
# Check for redirects
final_url = await sse_client.pre_check_redirect(url)
# Should return a URL (either original or redirected)
assert final_url is not None
assert isinstance(final_url, str)
assert final_url.startswith("http")
except Exception as e:
# If the server is not accessible, skip the test
pytest.skip(f"DeepWiki server not accessible for redirect test: {e}")
@pytest.fixture
def mock_tool(self):
"""Create a mock MCP tool."""
tool = MagicMock()
tool.name = "test_tool"
tool.description = "Test tool description"
tool.inputSchema = {
"type": "object",
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
"required": ["test_param"],
}
return tool
@pytest.fixture
def mock_session(self, mock_tool):
"""Create a mock ClientSession."""
session = AsyncMock()
session.initialize = AsyncMock()
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
session.list_tools = AsyncMock(return_value=list_tools_result)
session.call_tool = AsyncMock(
return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})])
)
return session
class TestMCPSseClientUnit:
"""Unit tests for MCPSseClient functionality."""
@pytest.fixture
def sse_client(self):
return MCPSseClient()
@pytest.mark.asyncio
async def test_client_initialization(self, sse_client):
"""Test that SSE client initializes correctly."""
# Client should initialize with default values
assert sse_client.session is None
assert sse_client._connection_params is None
assert sse_client._connected is False
assert sse_client._session_context is None
async def test_validate_url_valid(self, sse_client):
"""Test URL validation with valid URL."""
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
assert is_valid is True
assert error_msg == ""
async def test_validate_url_invalid_format(self, sse_client):
"""Test URL validation with invalid format."""
is_valid, error_msg = await sse_client.validate_url("invalid-url", {})
assert is_valid is False
assert "Invalid URL format" in error_msg
async def test_validate_url_with_404_response(self, sse_client):
"""Test URL validation with 404 response (should be valid for SSE)."""
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 404
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
assert is_valid is True
assert error_msg == ""
async def test_connect_to_server_with_headers(self, sse_client):
"""Test connecting to server via SSE with custom headers."""
test_url = "http://test.url"
test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"}
expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized
with (
patch.object(sse_client, "validate_url", return_value=(True, "")),
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
):
# Mock session
mock_session = AsyncMock()
mock_tool = MagicMock()
mock_tool.name = "test_tool"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_get_session.return_value = mock_session
tools = await sse_client.connect_to_server(test_url, test_headers)
assert len(tools) == 1
assert tools[0].name == "test_tool"
assert sse_client._connected is True
# Verify headers are stored in connection params (normalized)
assert sse_client._connection_params is not None
assert sse_client._connection_params["headers"] == expected_headers
assert sse_client._connection_params["url"] == test_url
async def test_headers_passed_to_session_manager(self, sse_client):
"""Test that headers are properly passed to the session manager."""
test_url = "http://test.url"
expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized
sse_client._session_context = "test_context"
sse_client._connection_params = {
"url": test_url,
"headers": expected_headers, # Use normalized headers
"timeout_seconds": 30,
"sse_read_timeout_seconds": 30,
}
with patch.object(sse_client, "_get_session_manager") as mock_get_manager:
mock_manager = AsyncMock()
mock_session = AsyncMock()
mock_manager.get_session = AsyncMock(return_value=mock_session)
mock_get_manager.return_value = mock_manager
result_session = await sse_client._get_or_create_session()
# Verify session manager was called with correct parameters including normalized headers
mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse")
assert result_session == mock_session
async def test_pre_check_redirect_with_headers(self, sse_client):
"""Test pre-check redirect functionality with custom headers."""
test_url = "http://test.url"
redirect_url = "http://redirect.url"
# Use pre-validated headers since pre_check_redirect expects already validated headers
test_headers = {"authorization": "Bearer token123"} # already normalized
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 307
mock_response.headers.get.return_value = redirect_url
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await sse_client.pre_check_redirect(test_url, test_headers)
assert result == redirect_url
# Verify validated headers were passed to the request
mock_client.return_value.__aenter__.return_value.get.assert_called_with(
test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers}
)
async def test_run_tool_with_retry_on_connection_error(self, sse_client):
"""Test that run_tool retries on connection errors."""
# Setup connection state
sse_client._connected = True
sse_client._connection_params = {"url": "http://test.url", "headers": {}}
sse_client._session_context = "test_context"
call_count = 0
async def mock_get_session_side_effect():
nonlocal call_count
call_count += 1
session = AsyncMock()
if call_count == 1:
# First call fails with connection error
from anyio import ClosedResourceError
session.call_tool = AsyncMock(side_effect=ClosedResourceError())
else:
# Second call succeeds
mock_result = MagicMock()
session.call_tool = AsyncMock(return_value=mock_result)
return session
with (
patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect),
patch.object(sse_client, "_get_session_manager") as mock_get_manager,
):
mock_manager = AsyncMock()
mock_get_manager.return_value = mock_manager
result = await sse_client.run_tool("test_tool", {"param": "value"})
# Should have retried and succeeded on second attempt
assert call_count == 2
assert result is not None
# Should have cleaned up the failed session
mock_manager._cleanup_session.assert_called_once_with("test_context")

View file

@ -1,8 +1,15 @@
"""Unit tests for MCP component with actual MCP servers.
This test suite validates the MCP component functionality using real MCP servers:
- Everything server (stdio mode) - provides echo and other tools
- DeepWiki server (SSE mode) - provides wiki-related tools
"""
import shutil
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langflow.base.mcp import util
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient
from langflow.components.agents.mcp_component import MCPToolsComponent
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
@ -18,8 +25,11 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
def default_kwargs(self):
"""Return the default kwargs for the component."""
return {
"mode": "Stdio",
"command": "npx -y @modelcontextprotocol/server-everything",
"sse_url": "https://mcp.deepwiki.com/sse",
"tool": "echo",
"mcp_server": {"name": "test_server", "config": {"command": "uvx mcp-server-fetch"}},
"tool": "",
}
@pytest.fixture
@ -27,34 +37,106 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
"""Return the file names mapping for different versions."""
return []
@pytest.fixture
def mock_tool(self):
"""Create a mock MCP tool."""
tool = MagicMock()
tool.name = "test_tool"
tool.description = "Test tool description"
tool.inputSchema = {
"type": "object",
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
"required": ["test_param"],
}
return tool
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_component_initialization(self, component_class, default_kwargs):
"""Test that the component initializes correctly."""
component = component_class(**default_kwargs)
# Check that the component has the expected attributes
assert hasattr(component, "stdio_client")
assert hasattr(component, "sse_client")
assert isinstance(component.stdio_client, MCPStdioClient)
assert isinstance(component.sse_client, MCPSseClient)
# Check that the component has a session manager
session_manager = component.stdio_client._get_session_manager()
assert isinstance(session_manager, MCPSessionManager)
class TestMCPToolsComponentIntegration:
"""Integration tests for the MCPToolsComponent."""
@pytest.fixture
def mock_session(self, mock_tool):
"""Create a mock ClientSession."""
session = AsyncMock()
session.initialize = AsyncMock()
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
session.list_tools = AsyncMock(return_value=list_tools_result)
session.call_tool = AsyncMock(
return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})])
)
return session
def component(self):
"""Create a component for testing."""
return MCPToolsComponent()
@pytest.mark.asyncio
@pytest.mark.skipif(not shutil.which("npx"), reason="Node.js not available")
async def test_stdio_mode_integration(self, component):
"""Test the component in stdio mode with Everything server."""
# Configure for stdio mode
component.mode = "Stdio"
component.command = "npx -y @modelcontextprotocol/server-everything"
component.tool = "echo"
try:
# Mock the update_tool_list method to simulate server connection
tools, server_info = await component.update_tool_list()
# Should have tools
assert len(tools) > 0
# Should have server info
assert server_info is not None
assert isinstance(server_info, dict)
except Exception as e:
# If the server is not accessible, skip the test
pytest.skip(f"Everything server not accessible: {e}")
@pytest.mark.asyncio
async def test_sse_mode_integration(self, component):
"""Test the component in SSE mode with DeepWiki server."""
# Configure for SSE mode
component.mode = "SSE"
component.sse_url = "https://mcp.deepwiki.com/sse"
try:
# Mock the update_tool_list method to simulate server connection
tools, server_info = await component.update_tool_list()
# Should have tools
assert len(tools) > 0
# Should have server info
assert server_info is not None
assert isinstance(server_info, dict)
except Exception as e:
# If the server is not accessible, skip the test
pytest.skip(f"DeepWiki server not accessible: {e}")
@pytest.mark.asyncio
async def test_session_context_setting(self, component):
"""Test that session context is properly set."""
# Set session context
component.stdio_client.set_session_context("test_context")
component.sse_client.set_session_context("test_context")
# Verify context was set
assert component.stdio_client._session_context == "test_context"
assert component.sse_client._session_context == "test_context"
@pytest.mark.asyncio
async def test_session_manager_sharing(self, component):
"""Test that session managers are shared through component cache."""
# Get session managers
stdio_manager = component.stdio_client._get_session_manager()
sse_manager = component.sse_client._get_session_manager()
# Both should be MCPSessionManager instances
assert isinstance(stdio_manager, MCPSessionManager)
assert isinstance(sse_manager, MCPSessionManager)
# They should be the same instance (shared through cache)
assert stdio_manager is sse_manager
class TestMCPStdioClient:
class TestMCPComponentErrorHandling:
"""Test error handling in MCP components."""
@pytest.fixture
def stdio_client(self):
return MCPStdioClient()
@ -122,494 +204,3 @@ class TestMCPStdioClient:
mock_manager._cleanup_session.assert_called_once_with("test_context")
assert stdio_client.session is None
assert stdio_client._connected is False
class TestMCPSseClient:
@pytest.fixture
def sse_client(self):
return MCPSseClient()
async def test_validate_url_valid(self, sse_client):
"""Test URL validation with valid URL."""
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
assert is_valid is True
assert error_msg == ""
async def test_validate_url_invalid_format(self, sse_client):
"""Test URL validation with invalid format."""
is_valid, error_msg = await sse_client.validate_url("invalid-url", {})
assert is_valid is False
assert "Invalid URL format" in error_msg
async def test_validate_url_with_404_response(self, sse_client):
"""Test URL validation with 404 response (should be valid for SSE)."""
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 404
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
assert is_valid is True
assert error_msg == ""
async def test_connect_to_server_with_headers(self, sse_client):
"""Test connecting to server via SSE with custom headers."""
test_url = "http://test.url"
test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"}
expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized
with (
patch.object(sse_client, "validate_url", return_value=(True, "")),
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
):
# Mock session
mock_session = AsyncMock()
mock_tool = MagicMock()
mock_tool.name = "test_tool"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_get_session.return_value = mock_session
tools = await sse_client.connect_to_server(test_url, test_headers)
assert len(tools) == 1
assert tools[0].name == "test_tool"
assert sse_client._connected is True
# Verify headers are stored in connection params (normalized)
assert sse_client._connection_params is not None
assert sse_client._connection_params["headers"] == expected_headers
assert sse_client._connection_params["url"] == test_url
async def test_headers_passed_to_session_manager(self, sse_client):
"""Test that headers are properly passed to the session manager."""
test_url = "http://test.url"
expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized
sse_client._session_context = "test_context"
sse_client._connection_params = {
"url": test_url,
"headers": expected_headers, # Use normalized headers
"timeout_seconds": 30,
"sse_read_timeout_seconds": 30,
}
with patch.object(sse_client, "_get_session_manager") as mock_get_manager:
mock_manager = AsyncMock()
mock_session = AsyncMock()
mock_manager.get_session = AsyncMock(return_value=mock_session)
mock_get_manager.return_value = mock_manager
result_session = await sse_client._get_or_create_session()
# Verify session manager was called with correct parameters including normalized headers
mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse")
assert result_session == mock_session
async def test_pre_check_redirect_with_headers(self, sse_client):
"""Test pre-check redirect functionality with custom headers."""
test_url = "http://test.url"
redirect_url = "http://redirect.url"
# Use pre-validated headers since pre_check_redirect expects already validated headers
test_headers = {"authorization": "Bearer token123"} # already normalized
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 307
mock_response.headers.get.return_value = redirect_url
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await sse_client.pre_check_redirect(test_url, test_headers)
assert result == redirect_url
# Verify validated headers were passed to the request
mock_client.return_value.__aenter__.return_value.get.assert_called_with(
test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers}
)
async def test_run_tool_with_retry_on_connection_error(self, sse_client):
"""Test that run_tool retries on connection errors."""
# Setup connection state
sse_client._connected = True
sse_client._connection_params = {"url": "http://test.url", "headers": {}}
sse_client._session_context = "test_context"
call_count = 0
async def mock_get_session_side_effect():
nonlocal call_count
call_count += 1
session = AsyncMock()
if call_count == 1:
# First call fails with connection error
from anyio import ClosedResourceError
session.call_tool = AsyncMock(side_effect=ClosedResourceError())
else:
# Second call succeeds
mock_result = MagicMock()
session.call_tool = AsyncMock(return_value=mock_result)
return session
with (
patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect),
patch.object(sse_client, "_get_session_manager") as mock_get_manager,
):
mock_manager = AsyncMock()
mock_get_manager.return_value = mock_manager
result = await sse_client.run_tool("test_tool", {"param": "value"})
# Should have retried and succeeded on second attempt
assert call_count == 2
assert result is not None
# Should have cleaned up the failed session
mock_manager._cleanup_session.assert_called_once_with("test_context")
class TestMCPSessionManager:
@pytest.fixture
def session_manager(self):
return MCPSessionManager()
async def test_session_caching(self, session_manager):
"""Test that sessions are properly cached and reused."""
context_id = "test_context"
connection_params = MagicMock()
transport_type = "stdio"
# Create a mock session that will appear healthy
mock_session = AsyncMock()
mock_session._write_stream = MagicMock()
mock_session._write_stream._closed = False
# Create a mock task that appears to be running
mock_task = AsyncMock()
mock_task.done = MagicMock(return_value=False)
with (
patch.object(session_manager, "_create_stdio_session") as mock_create,
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
):
mock_create.return_value = mock_session
# First call should create session
session1 = await session_manager.get_session(context_id, connection_params, transport_type)
# Manually populate the sessions cache as if the session was created properly
session_manager.sessions[context_id] = {"session": mock_session, "task": mock_task, "type": transport_type}
# Second call should return cached session without creating new one
session2 = await session_manager.get_session(context_id, connection_params, transport_type)
assert session1 == session2
assert session1 == mock_session
# Should only create once since the second call should use the cached session
mock_create.assert_called_once()
async def test_session_cleanup(self, session_manager):
"""Test session cleanup functionality."""
context_id = "test_context"
# Add a session to the manager with proper mock setup
mock_task = AsyncMock()
mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method
mock_task.cancel = MagicMock() # Use MagicMock for sync method
session_manager.sessions[context_id] = {"session": AsyncMock(), "task": mock_task, "type": "stdio"}
await session_manager._cleanup_session(context_id)
# Should cancel the task and remove from sessions
mock_task.cancel.assert_called_once()
assert context_id not in session_manager.sessions
async def test_server_switch_detection(self, session_manager):
"""Test that server switches are properly detected and handled."""
context_id = "test_context"
# First server
server1_params = MagicMock()
server1_params.command = "server1"
# Second server
server2_params = MagicMock()
server2_params.command = "server2"
with (
patch.object(session_manager, "_create_stdio_session") as mock_create,
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
):
mock_session1 = AsyncMock()
mock_session2 = AsyncMock()
mock_create.side_effect = [mock_session1, mock_session2]
# First connection
session1 = await session_manager.get_session(context_id, server1_params, "stdio")
# Switch to different server should create new session
session2 = await session_manager.get_session(context_id, server2_params, "stdio")
assert session1 != session2
assert mock_create.call_count == 2
# Integration test for header functionality
class TestHeaderValidation:
"""Test the header validation functionality."""
def test_validate_headers_valid_input(self):
"""Test header validation with valid headers."""
headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"}
result = validate_headers(headers)
# Headers should be normalized to lowercase
expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"}
assert result == expected
def test_validate_headers_empty_input(self):
"""Test header validation with empty/None input."""
assert validate_headers({}) == {}
assert validate_headers(None) == {}
def test_validate_headers_invalid_names(self):
"""Test header validation with invalid header names."""
headers = {
"Invalid Header": "value", # spaces not allowed
"Header@Name": "value", # @ not allowed
"Header Name": "value", # spaces not allowed
"Valid-Header": "value", # this should pass
}
result = validate_headers(headers)
# Only the valid header should remain
assert result == {"valid-header": "value"}
def test_validate_headers_sanitize_values(self):
"""Test header value sanitization."""
headers = {
"Authorization": "Bearer \x00token\x1f with\r\ninjection",
"Clean-Header": " clean value ",
"Empty-After-Clean": "\x00\x01\x02",
"Tab-Header": "value\twith\ttabs", # tabs should be preserved
}
result = validate_headers(headers)
# Control characters should be removed, whitespace trimmed
# Header with injection attempts should be skipped
expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"}
assert result == expected
def test_validate_headers_non_string_values(self):
"""Test header validation with non-string values."""
headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]}
result = validate_headers(headers)
# Only string headers should remain
assert result == {"string-header": "valid"}
def test_validate_headers_injection_attempts(self):
"""Test header validation against injection attempts."""
headers = {
"Injection1": "value\r\nInjected-Header: malicious",
"Injection2": "value\nX-Evil: attack",
"Safe-Header": "safe-value",
}
result = validate_headers(headers)
# Injection attempts should be filtered out
assert result == {"safe-header": "safe-value"}
class TestSSEHeaderIntegration:
"""Integration test to verify headers are properly passed through the entire SSE flow."""
async def test_headers_processing(self):
"""Test that headers flow properly from server config through to SSE client connection."""
# Test the header processing function directly
headers_input = [
{"key": "Authorization", "value": "Bearer test-token"},
{"key": "X-API-Key", "value": "secret-key"},
]
expected_headers = {
"authorization": "Bearer test-token", # normalized to lowercase
"x-api-key": "secret-key",
}
# Test _process_headers function with validation
processed_headers = _process_headers(headers_input)
assert processed_headers == expected_headers
# Test different input formats
# Test dict input with validation
dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"}
result = _process_headers(dict_headers)
# Invalid header should be filtered out, valid header normalized
assert result == {"authorization": "Bearer dict-token"}
# Test None input
assert _process_headers(None) == {}
# Test empty list
assert _process_headers([]) == {}
# Test malformed list
malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key
assert _process_headers(malformed_headers) == {}
# Test list with invalid header names
invalid_headers = [
{"key": "Valid-Header", "value": "good"},
{"key": "Invalid Header", "value": "bad"}, # spaces not allowed
]
result = _process_headers(invalid_headers)
assert result == {"valid-header": "good"}
async def test_sse_client_header_storage(self):
"""Test that SSE client properly stores headers in connection params."""
sse_client = MCPSseClient()
test_url = "http://test.url"
test_headers = {"Authorization": "Bearer test123", "Custom": "value"}
expected_headers = {"authorization": "Bearer test123", "custom": "value"} # normalized
with (
patch.object(sse_client, "validate_url", return_value=(True, "")),
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
):
mock_session = AsyncMock()
mock_tool = MagicMock()
mock_tool.name = "test_tool"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_get_session.return_value = mock_session
await sse_client.connect_to_server(test_url, test_headers)
# Verify headers are stored correctly in connection params (normalized)
assert sse_client._connection_params is not None
assert sse_client._connection_params["headers"] == expected_headers
assert sse_client._connection_params["url"] == test_url
class TestMCPUtilityFunctions:
"""Test utility functions from util.py that don't have dedicated test classes."""
def test_sanitize_mcp_name(self):
"""Test MCP name sanitization."""
assert util.sanitize_mcp_name("Test Name 123") == "test_name_123"
assert util.sanitize_mcp_name(" ") == ""
assert util.sanitize_mcp_name("123abc") == "_123abc"
assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name"
assert util.sanitize_mcp_name("a" * 100) == "a" * 46
def test_get_unique_name(self):
"""Test unique name generation."""
names = {"foo", "foo_1"}
assert util.get_unique_name("foo", 10, names) == "foo_2"
assert util.get_unique_name("bar", 10, names) == "bar"
assert util.get_unique_name("longname", 4, {"long"}) == "lo_1"
def test_is_valid_key_value_item(self):
"""Test key-value item validation."""
assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True
assert util._is_valid_key_value_item({"key": "a"}) is False
assert util._is_valid_key_value_item(["key", "value"]) is False
assert util._is_valid_key_value_item(None) is False
def test_validate_node_installation(self):
"""Test Node.js installation validation."""
import shutil
if shutil.which("node"):
assert util._validate_node_installation("npx something") == "npx something"
else:
with pytest.raises(ValueError, match="Node.js is not installed"):
util._validate_node_installation("npx something")
assert util._validate_node_installation("echo test") == "echo test"
def test_create_input_schema_from_json_schema(self):
"""Test JSON schema to Pydantic model conversion."""
schema = {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "desc"},
"bar": {"type": "integer"},
},
"required": ["foo"],
}
model_class = util.create_input_schema_from_json_schema(schema)
instance = model_class(foo="abc", bar=1)
assert instance.foo == "abc"
assert instance.bar == 1
with pytest.raises(Exception): # noqa: B017, PT011
model_class(bar=1) # missing required field
@pytest.mark.asyncio
async def test_validate_connection_params(self):
"""Test connection parameter validation."""
# Valid parameters
await util._validate_connection_params("Stdio", command="echo test")
await util._validate_connection_params("SSE", url="http://test")
# Invalid parameters
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
await util._validate_connection_params("Stdio", command=None)
with pytest.raises(ValueError, match="URL is required for SSE mode"):
await util._validate_connection_params("SSE", url=None)
with pytest.raises(ValueError, match="Invalid mode"):
await util._validate_connection_params("InvalidMode")
@pytest.mark.asyncio
async def test_get_flow_snake_case_mocked(self):
"""Test flow lookup by snake case name with mocked session."""
class DummyFlow:
def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None):
self.name = name
self.user_id = user_id
self.is_component = is_component
self.action_name = action_name
class DummyExec:
def __init__(self, flows: list[DummyFlow]):
self._flows = flows
def all(self):
return self._flows
class DummySession:
def __init__(self, flows: list[DummyFlow]):
self._flows = flows
async def exec(self, stmt): # noqa: ARG002
return DummyExec(self._flows)
user_id = "123e4567-e89b-12d3-a456-426614174000"
flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)]
# Should match sanitized name
result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows))
assert result is flows[0]
# Should return None if not found
result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows))
assert result is None