fix: store mcp sse headers and use them on connection (#9148)
* Store mcp server headers * Add headers on pre check url and is valid url * adds validation of headers according to RFC 7230 * Fixed sanitized value * Added backend tests for mcp util.py to increase coverage * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * make key pair input use flatmap id on data test ids * added testids * added random test names and added tests for persistence * fix ruff lint * [autofix.ci] apply automated fixes * Fix mypy lint errors --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
324caf486c
commit
41e101499a
8 changed files with 1101 additions and 197 deletions
|
|
@ -30,6 +30,86 @@ HTTP_NOT_FOUND = 404
|
|||
HTTP_BAD_REQUEST = 400
|
||||
HTTP_INTERNAL_SERVER_ERROR = 500
|
||||
|
||||
# RFC 7230 compliant header name pattern: token = 1*tchar
|
||||
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
|
||||
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
|
||||
HEADER_NAME_PATTERN = re.compile(r"^[!#$%&\'*+\-.0-9A-Z^_`a-z|~]+$")
|
||||
|
||||
# Common allowed headers for MCP connections
|
||||
ALLOWED_HEADERS = {
|
||||
"authorization",
|
||||
"accept",
|
||||
"accept-encoding",
|
||||
"accept-language",
|
||||
"cache-control",
|
||||
"content-type",
|
||||
"user-agent",
|
||||
"x-api-key",
|
||||
"x-auth-token",
|
||||
"x-custom-header",
|
||||
"x-langflow-session",
|
||||
"x-mcp-client",
|
||||
"x-requested-with",
|
||||
}
|
||||
|
||||
|
||||
def validate_headers(headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Validate and sanitize HTTP headers according to RFC 7230.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of header name-value pairs
|
||||
|
||||
Returns:
|
||||
Dictionary of validated and sanitized headers
|
||||
|
||||
Raises:
|
||||
ValueError: If headers contain invalid names or values
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
|
||||
sanitized_headers = {}
|
||||
|
||||
for name, value in headers.items():
|
||||
if not isinstance(name, str) or not isinstance(value, str):
|
||||
logger.warning(f"Skipping non-string header: {name}={value}")
|
||||
continue
|
||||
|
||||
# Validate header name according to RFC 7230
|
||||
if not HEADER_NAME_PATTERN.match(name):
|
||||
logger.warning(f"Invalid header name '{name}', skipping")
|
||||
continue
|
||||
|
||||
# Normalize header name to lowercase (HTTP headers are case-insensitive)
|
||||
normalized_name = name.lower()
|
||||
|
||||
# Optional: Check against whitelist of allowed headers
|
||||
if normalized_name not in ALLOWED_HEADERS:
|
||||
# For MCP, we'll be permissive and allow non-standard headers
|
||||
# but log a warning for security awareness
|
||||
logger.debug(f"Using non-standard header: {normalized_name}")
|
||||
|
||||
# Check for potential header injection attempts BEFORE sanitizing
|
||||
if "\r" in value or "\n" in value:
|
||||
logger.warning(f"Potential header injection detected in '{name}', skipping")
|
||||
continue
|
||||
|
||||
# Sanitize header value - remove control characters and newlines
|
||||
# RFC 7230: field-value = *( field-content / obs-fold )
|
||||
# We'll remove control characters (0x00-0x1F, 0x7F) except tab (0x09) and space (0x20)
|
||||
sanitized_value = re.sub(r"[\x00-\x08\x0A-\x1F\x7F]", "", value)
|
||||
|
||||
# Remove leading/trailing whitespace
|
||||
sanitized_value = sanitized_value.strip()
|
||||
|
||||
if not sanitized_value:
|
||||
logger.warning(f"Header '{name}' has empty value after sanitization, skipping")
|
||||
continue
|
||||
|
||||
sanitized_headers[normalized_name] = sanitized_value
|
||||
|
||||
return sanitized_headers
|
||||
|
||||
|
||||
def sanitize_mcp_name(name: str, max_length: int = 46) -> str:
|
||||
"""Sanitize a name for MCP usage by removing emojis, diacritics, and special characters.
|
||||
|
|
@ -334,12 +414,12 @@ def _process_headers(headers: Any) -> dict:
|
|||
Args:
|
||||
headers: The headers to process, can be dict, str, or list
|
||||
Returns:
|
||||
Processed dictionary
|
||||
Processed and validated dictionary
|
||||
"""
|
||||
if headers is None:
|
||||
return {}
|
||||
if isinstance(headers, dict):
|
||||
return headers
|
||||
return validate_headers(headers)
|
||||
if isinstance(headers, list):
|
||||
processed_headers = {}
|
||||
try:
|
||||
|
|
@ -351,7 +431,7 @@ def _process_headers(headers: Any) -> dict:
|
|||
processed_headers[key] = value
|
||||
except (KeyError, TypeError, ValueError):
|
||||
return {} # Return empty dictionary instead of None
|
||||
return processed_headers
|
||||
return validate_headers(processed_headers)
|
||||
return {}
|
||||
|
||||
|
||||
|
|
@ -924,7 +1004,7 @@ class MCPSseClient:
|
|||
self._component_cache.set("mcp_session_manager", session_manager)
|
||||
return session_manager
|
||||
|
||||
async def validate_url(self, url: str | None) -> tuple[bool, str]:
|
||||
async def validate_url(self, url: str | None, headers: dict[str, str] | None = None) -> tuple[bool, str]:
|
||||
"""Validate the SSE URL before attempting connection."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
|
@ -935,7 +1015,9 @@ class MCPSseClient:
|
|||
try:
|
||||
# For SSE endpoints, try a GET request with short timeout
|
||||
# Many SSE servers don't support HEAD requests and return 404
|
||||
response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"})
|
||||
response = await client.get(
|
||||
url, timeout=2.0, headers={"Accept": "text/event-stream", **(headers or {})}
|
||||
)
|
||||
|
||||
# For SSE, we expect the server to either:
|
||||
# 1. Start streaming (200)
|
||||
|
|
@ -967,14 +1049,16 @@ class MCPSseClient:
|
|||
except (httpx.HTTPError, ValueError, OSError) as e:
|
||||
return False, f"URL validation error: {e!s}"
|
||||
|
||||
async def pre_check_redirect(self, url: str | None) -> str | None:
|
||||
async def pre_check_redirect(self, url: str | None, headers: dict[str, str] | None = None) -> str | None:
|
||||
"""Check for redirects and return the final URL."""
|
||||
if url is None:
|
||||
return url
|
||||
try:
|
||||
async with httpx.AsyncClient(follow_redirects=False) as client:
|
||||
# Use GET with SSE headers instead of HEAD since many SSE servers don't support HEAD
|
||||
response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"})
|
||||
response = await client.get(
|
||||
url, timeout=2.0, headers={"Accept": "text/event-stream", **(headers or {})}
|
||||
)
|
||||
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
|
||||
return response.headers.get("Location", url)
|
||||
# Don't treat 404 as an error here - let the main connection handle it
|
||||
|
|
@ -990,22 +1074,23 @@ class MCPSseClient:
|
|||
sse_read_timeout_seconds: int = 30,
|
||||
) -> list[StructuredTool]:
|
||||
"""Connect to MCP server using SSE transport (SDK style)."""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
# Validate and sanitize headers early
|
||||
validated_headers = _process_headers(headers)
|
||||
|
||||
if url is None:
|
||||
msg = "URL is required for SSE mode"
|
||||
raise ValueError(msg)
|
||||
is_valid, error_msg = await self.validate_url(url)
|
||||
is_valid, error_msg = await self.validate_url(url, validated_headers)
|
||||
if not is_valid:
|
||||
msg = f"Invalid SSE URL ({url}): {error_msg}"
|
||||
raise ValueError(msg)
|
||||
|
||||
url = await self.pre_check_redirect(url)
|
||||
url = await self.pre_check_redirect(url, validated_headers)
|
||||
|
||||
# Store connection parameters for later use in run_tool
|
||||
self._connection_params = {
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"headers": validated_headers,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"sse_read_timeout_seconds": sse_read_timeout_seconds,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.components.agents.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent
|
||||
from langflow.base.mcp import util
|
||||
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers
|
||||
from langflow.components.agents.mcp_component import MCPToolsComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
|
||||
|
||||
# TODO: This test suite is incomplete and is in need of an update to handle the latest MCP component changes.
|
||||
pytestmark = pytest.mark.skip(reason="Skipping entire file")
|
||||
|
||||
|
||||
class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
|
|
@ -20,9 +18,7 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
|||
def default_kwargs(self):
|
||||
"""Return the default kwargs for the component."""
|
||||
return {
|
||||
"mode": "Stdio",
|
||||
"command": "uvx mcp-server-fetch",
|
||||
"sse_url": "http://localhost:7860/api/v1/mcp/sse",
|
||||
"mcp_server": {"name": "test_server", "config": {"command": "uvx mcp-server-fetch"}},
|
||||
"tool": "",
|
||||
}
|
||||
|
||||
|
|
@ -40,24 +36,22 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
|||
tool.inputSchema = {
|
||||
"type": "object",
|
||||
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
|
||||
"required": ["test_param"],
|
||||
}
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stdio_client(self, mock_tool):
|
||||
"""Create a mock stdio client."""
|
||||
stdio_client = AsyncMock()
|
||||
stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool])
|
||||
stdio_client.session = AsyncMock()
|
||||
return stdio_client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sse_client(self, mock_tool):
|
||||
"""Create a mock SSE client."""
|
||||
sse_client = AsyncMock()
|
||||
sse_client.connect_to_server = AsyncMock(return_value=[mock_tool])
|
||||
sse_client.session = AsyncMock()
|
||||
return sse_client
|
||||
def mock_session(self, mock_tool):
|
||||
"""Create a mock ClientSession."""
|
||||
session = AsyncMock()
|
||||
session.initialize = AsyncMock()
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})])
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
class TestMCPStdioClient:
|
||||
|
|
@ -65,40 +59,69 @@ class TestMCPStdioClient:
|
|||
def stdio_client(self):
|
||||
return MCPStdioClient()
|
||||
|
||||
async def test_connect_to_server(self, stdio_client):
|
||||
"""Test connecting to server via Stdio."""
|
||||
# Create mock for stdio transport
|
||||
mock_stdio = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_stdio_transport = (mock_stdio, mock_write)
|
||||
mock_stdio_cm = AsyncMock()
|
||||
mock_stdio_cm.__aenter__.return_value = mock_stdio_transport
|
||||
@pytest.fixture
|
||||
def mock_session_manager(self):
|
||||
"""Create a mock session manager."""
|
||||
return AsyncMock(spec=MCPSessionManager)
|
||||
|
||||
# Mock the stdio_client function to return our mock context manager
|
||||
with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm):
|
||||
# Mock ClientSession
|
||||
async def test_connect_to_server_with_command(self, stdio_client):
|
||||
"""Test connecting to server via Stdio with command."""
|
||||
with patch.object(stdio_client, "_get_or_create_session") as mock_get_session:
|
||||
# Mock session
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools.return_value.tools = [MagicMock()]
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
# Mock the AsyncExitStack
|
||||
mock_exit_stack = AsyncMock()
|
||||
mock_exit_stack.enter_async_context = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
mock_stdio_transport, # For stdio_client
|
||||
mock_session, # For ClientSession
|
||||
]
|
||||
stdio_client.exit_stack = mock_exit_stack
|
||||
|
||||
tools = await stdio_client.connect_to_server("test_command")
|
||||
tools = await stdio_client.connect_to_server("uvx test-command")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert stdio_client.session is not None
|
||||
# Verify the exit stack was used correctly
|
||||
assert mock_exit_stack.enter_async_context.call_count == 2
|
||||
# Verify the stdio transport was properly set
|
||||
assert stdio_client.stdio == mock_stdio
|
||||
assert stdio_client.write == mock_write
|
||||
assert tools[0].name == "test_tool"
|
||||
assert stdio_client._connected is True
|
||||
assert stdio_client._connection_params is not None
|
||||
|
||||
async def test_run_tool_success(self, stdio_client):
|
||||
"""Test successfully running a tool."""
|
||||
# Setup connection state
|
||||
stdio_client._connected = True
|
||||
stdio_client._connection_params = MagicMock()
|
||||
stdio_client._session_context = "test_context"
|
||||
|
||||
with patch.object(stdio_client, "_get_or_create_session") as mock_get_session:
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
result = await stdio_client.run_tool("test_tool", {"param": "value"})
|
||||
|
||||
assert result == mock_result
|
||||
mock_session.call_tool.assert_called_once_with("test_tool", arguments={"param": "value"})
|
||||
|
||||
async def test_run_tool_without_connection(self, stdio_client):
|
||||
"""Test running a tool without being connected."""
|
||||
stdio_client._connected = False
|
||||
|
||||
with pytest.raises(ValueError, match="Session not initialized"):
|
||||
await stdio_client.run_tool("test_tool", {})
|
||||
|
||||
async def test_disconnect_cleanup(self, stdio_client):
|
||||
"""Test that disconnect properly cleans up resources."""
|
||||
stdio_client._session_context = "test_context"
|
||||
stdio_client._connected = True
|
||||
|
||||
with patch.object(stdio_client, "_get_session_manager") as mock_get_manager:
|
||||
mock_manager = AsyncMock()
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
await stdio_client.disconnect()
|
||||
|
||||
mock_manager._cleanup_session.assert_called_once_with("test_context")
|
||||
assert stdio_client.session is None
|
||||
assert stdio_client._connected is False
|
||||
|
||||
|
||||
class TestMCPSseClient:
|
||||
|
|
@ -106,78 +129,487 @@ class TestMCPSseClient:
|
|||
def sse_client(self):
|
||||
return MCPSseClient()
|
||||
|
||||
async def test_pre_check_redirect(self, sse_client):
|
||||
"""Test pre-checking URL for redirects."""
|
||||
async def test_validate_url_valid(self, sse_client):
|
||||
"""Test URL validation with valid URL."""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg == ""
|
||||
|
||||
async def test_validate_url_invalid_format(self, sse_client):
|
||||
"""Test URL validation with invalid format."""
|
||||
is_valid, error_msg = await sse_client.validate_url("invalid-url", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert "Invalid URL format" in error_msg
|
||||
|
||||
async def test_validate_url_with_404_response(self, sse_client):
|
||||
"""Test URL validation with 404 response (should be valid for SSE)."""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg == ""
|
||||
|
||||
async def test_connect_to_server_with_headers(self, sse_client):
|
||||
"""Test connecting to server via SSE with custom headers."""
|
||||
test_url = "http://test.url"
|
||||
test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"}
|
||||
expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")),
|
||||
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
|
||||
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
|
||||
):
|
||||
# Mock session
|
||||
mock_session = AsyncMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
tools = await sse_client.connect_to_server(test_url, test_headers)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "test_tool"
|
||||
assert sse_client._connected is True
|
||||
|
||||
# Verify headers are stored in connection params (normalized)
|
||||
assert sse_client._connection_params is not None
|
||||
assert sse_client._connection_params["headers"] == expected_headers
|
||||
assert sse_client._connection_params["url"] == test_url
|
||||
|
||||
async def test_headers_passed_to_session_manager(self, sse_client):
|
||||
"""Test that headers are properly passed to the session manager."""
|
||||
test_url = "http://test.url"
|
||||
expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized
|
||||
|
||||
sse_client._session_context = "test_context"
|
||||
sse_client._connection_params = {
|
||||
"url": test_url,
|
||||
"headers": expected_headers, # Use normalized headers
|
||||
"timeout_seconds": 30,
|
||||
"sse_read_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
with patch.object(sse_client, "_get_session_manager") as mock_get_manager:
|
||||
mock_manager = AsyncMock()
|
||||
mock_session = AsyncMock()
|
||||
mock_manager.get_session = AsyncMock(return_value=mock_session)
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
result_session = await sse_client._get_or_create_session()
|
||||
|
||||
# Verify session manager was called with correct parameters including normalized headers
|
||||
mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse")
|
||||
assert result_session == mock_session
|
||||
|
||||
async def test_pre_check_redirect_with_headers(self, sse_client):
|
||||
"""Test pre-check redirect functionality with custom headers."""
|
||||
test_url = "http://test.url"
|
||||
redirect_url = "http://redirect.url"
|
||||
# Use pre-validated headers since pre_check_redirect expects already validated headers
|
||||
test_headers = {"authorization": "Bearer token123"} # already normalized
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 307
|
||||
mock_response.headers.get.return_value = redirect_url
|
||||
mock_client.return_value.__aenter__.return_value.request.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await sse_client.pre_check_redirect(test_url, test_headers)
|
||||
|
||||
result = await sse_client.pre_check_redirect(test_url)
|
||||
assert result == redirect_url
|
||||
# Verify validated headers were passed to the request
|
||||
mock_client.return_value.__aenter__.return_value.get.assert_called_with(
|
||||
test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers}
|
||||
)
|
||||
|
||||
async def test_run_tool_with_retry_on_connection_error(self, sse_client):
|
||||
"""Test that run_tool retries on connection errors."""
|
||||
# Setup connection state
|
||||
sse_client._connected = True
|
||||
sse_client._connection_params = {"url": "http://test.url", "headers": {}}
|
||||
sse_client._session_context = "test_context"
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_session_side_effect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
session = AsyncMock()
|
||||
if call_count == 1:
|
||||
# First call fails with connection error
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
session.call_tool = AsyncMock(side_effect=ClosedResourceError())
|
||||
else:
|
||||
# Second call succeeds
|
||||
mock_result = MagicMock()
|
||||
session.call_tool = AsyncMock(return_value=mock_result)
|
||||
return session
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect),
|
||||
patch.object(sse_client, "_get_session_manager") as mock_get_manager,
|
||||
):
|
||||
mock_manager = AsyncMock()
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
result = await sse_client.run_tool("test_tool", {"param": "value"})
|
||||
|
||||
# Should have retried and succeeded on second attempt
|
||||
assert call_count == 2
|
||||
assert result is not None
|
||||
# Should have cleaned up the failed session
|
||||
mock_manager._cleanup_session.assert_called_once_with("test_context")
|
||||
|
||||
|
||||
class TestMCPSessionManager:
|
||||
@pytest.fixture
|
||||
def session_manager(self):
|
||||
return MCPSessionManager()
|
||||
|
||||
async def test_session_caching(self, session_manager):
|
||||
"""Test that sessions are properly cached and reused."""
|
||||
context_id = "test_context"
|
||||
connection_params = MagicMock()
|
||||
transport_type = "stdio"
|
||||
|
||||
# Create a mock session that will appear healthy
|
||||
mock_session = AsyncMock()
|
||||
mock_session._write_stream = MagicMock()
|
||||
mock_session._write_stream._closed = False
|
||||
|
||||
# Create a mock task that appears to be running
|
||||
mock_task = AsyncMock()
|
||||
mock_task.done = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch.object(session_manager, "_create_stdio_session") as mock_create,
|
||||
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
|
||||
):
|
||||
mock_create.return_value = mock_session
|
||||
|
||||
# First call should create session
|
||||
session1 = await session_manager.get_session(context_id, connection_params, transport_type)
|
||||
|
||||
# Manually populate the sessions cache as if the session was created properly
|
||||
session_manager.sessions[context_id] = {"session": mock_session, "task": mock_task, "type": transport_type}
|
||||
|
||||
# Second call should return cached session without creating new one
|
||||
session2 = await session_manager.get_session(context_id, connection_params, transport_type)
|
||||
|
||||
assert session1 == session2
|
||||
assert session1 == mock_session
|
||||
# Should only create once since the second call should use the cached session
|
||||
mock_create.assert_called_once()
|
||||
|
||||
async def test_session_cleanup(self, session_manager):
|
||||
"""Test session cleanup functionality."""
|
||||
context_id = "test_context"
|
||||
|
||||
# Add a session to the manager with proper mock setup
|
||||
mock_task = AsyncMock()
|
||||
mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method
|
||||
mock_task.cancel = MagicMock() # Use MagicMock for sync method
|
||||
|
||||
session_manager.sessions[context_id] = {"session": AsyncMock(), "task": mock_task, "type": "stdio"}
|
||||
|
||||
await session_manager._cleanup_session(context_id)
|
||||
|
||||
# Should cancel the task and remove from sessions
|
||||
mock_task.cancel.assert_called_once()
|
||||
assert context_id not in session_manager.sessions
|
||||
|
||||
async def test_server_switch_detection(self, session_manager):
|
||||
"""Test that server switches are properly detected and handled."""
|
||||
context_id = "test_context"
|
||||
|
||||
# First server
|
||||
server1_params = MagicMock()
|
||||
server1_params.command = "server1"
|
||||
|
||||
# Second server
|
||||
server2_params = MagicMock()
|
||||
server2_params.command = "server2"
|
||||
|
||||
with (
|
||||
patch.object(session_manager, "_create_stdio_session") as mock_create,
|
||||
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
|
||||
):
|
||||
mock_session1 = AsyncMock()
|
||||
mock_session2 = AsyncMock()
|
||||
mock_create.side_effect = [mock_session1, mock_session2]
|
||||
|
||||
# First connection
|
||||
session1 = await session_manager.get_session(context_id, server1_params, "stdio")
|
||||
|
||||
# Switch to different server should create new session
|
||||
session2 = await session_manager.get_session(context_id, server2_params, "stdio")
|
||||
|
||||
assert session1 != session2
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
|
||||
# Integration test for header functionality
|
||||
class TestHeaderValidation:
|
||||
"""Test the header validation functionality."""
|
||||
|
||||
def test_validate_headers_valid_input(self):
|
||||
"""Test header validation with valid headers."""
|
||||
headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Headers should be normalized to lowercase
|
||||
expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"}
|
||||
assert result == expected
|
||||
|
||||
def test_validate_headers_empty_input(self):
|
||||
"""Test header validation with empty/None input."""
|
||||
assert validate_headers({}) == {}
|
||||
assert validate_headers(None) == {}
|
||||
|
||||
def test_validate_headers_invalid_names(self):
|
||||
"""Test header validation with invalid header names."""
|
||||
headers = {
|
||||
"Invalid Header": "value", # spaces not allowed
|
||||
"Header@Name": "value", # @ not allowed
|
||||
"Header Name": "value", # spaces not allowed
|
||||
"Valid-Header": "value", # this should pass
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Only the valid header should remain
|
||||
assert result == {"valid-header": "value"}
|
||||
|
||||
def test_validate_headers_sanitize_values(self):
|
||||
"""Test header value sanitization."""
|
||||
headers = {
|
||||
"Authorization": "Bearer \x00token\x1f with\r\ninjection",
|
||||
"Clean-Header": " clean value ",
|
||||
"Empty-After-Clean": "\x00\x01\x02",
|
||||
"Tab-Header": "value\twith\ttabs", # tabs should be preserved
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Control characters should be removed, whitespace trimmed
|
||||
# Header with injection attempts should be skipped
|
||||
expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"}
|
||||
assert result == expected
|
||||
|
||||
def test_validate_headers_non_string_values(self):
|
||||
"""Test header validation with non-string values."""
|
||||
headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Only string headers should remain
|
||||
assert result == {"string-header": "valid"}
|
||||
|
||||
def test_validate_headers_injection_attempts(self):
|
||||
"""Test header validation against injection attempts."""
|
||||
headers = {
|
||||
"Injection1": "value\r\nInjected-Header: malicious",
|
||||
"Injection2": "value\nX-Evil: attack",
|
||||
"Safe-Header": "safe-value",
|
||||
}
|
||||
|
||||
result = validate_headers(headers)
|
||||
|
||||
# Injection attempts should be filtered out
|
||||
assert result == {"safe-header": "safe-value"}
|
||||
|
||||
|
||||
class TestSSEHeaderIntegration:
|
||||
"""Integration test to verify headers are properly passed through the entire SSE flow."""
|
||||
|
||||
async def test_headers_processing(self):
|
||||
"""Test that headers flow properly from server config through to SSE client connection."""
|
||||
# Test the header processing function directly
|
||||
headers_input = [
|
||||
{"key": "Authorization", "value": "Bearer test-token"},
|
||||
{"key": "X-API-Key", "value": "secret-key"},
|
||||
]
|
||||
|
||||
expected_headers = {
|
||||
"authorization": "Bearer test-token", # normalized to lowercase
|
||||
"x-api-key": "secret-key",
|
||||
}
|
||||
|
||||
# Test _process_headers function with validation
|
||||
processed_headers = _process_headers(headers_input)
|
||||
assert processed_headers == expected_headers
|
||||
|
||||
# Test different input formats
|
||||
# Test dict input with validation
|
||||
dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"}
|
||||
result = _process_headers(dict_headers)
|
||||
# Invalid header should be filtered out, valid header normalized
|
||||
assert result == {"authorization": "Bearer dict-token"}
|
||||
|
||||
# Test None input
|
||||
assert _process_headers(None) == {}
|
||||
|
||||
# Test empty list
|
||||
assert _process_headers([]) == {}
|
||||
|
||||
# Test malformed list
|
||||
malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key
|
||||
assert _process_headers(malformed_headers) == {}
|
||||
|
||||
# Test list with invalid header names
|
||||
invalid_headers = [
|
||||
{"key": "Valid-Header", "value": "good"},
|
||||
{"key": "Invalid Header", "value": "bad"}, # spaces not allowed
|
||||
]
|
||||
result = _process_headers(invalid_headers)
|
||||
assert result == {"valid-header": "good"}
|
||||
|
||||
async def test_sse_client_header_storage(self):
|
||||
"""Test that SSE client properly stores headers in connection params."""
|
||||
sse_client = MCPSseClient()
|
||||
test_url = "http://test.url"
|
||||
test_headers = {"Authorization": "Bearer test123", "Custom": "value"}
|
||||
expected_headers = {"authorization": "Bearer test123", "custom": "value"} # normalized
|
||||
|
||||
async def test_connect_to_server(self, sse_client):
|
||||
"""Test connecting to server via SSE."""
|
||||
# Mock the pre_check_redirect first
|
||||
with (
|
||||
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")),
|
||||
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
|
||||
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
|
||||
):
|
||||
# Create mock for sse_client context manager
|
||||
mock_sse = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_sse_transport = (mock_sse, mock_write)
|
||||
mock_sse_cm = AsyncMock()
|
||||
mock_sse_cm.__aenter__.return_value = mock_sse_transport
|
||||
mock_session = AsyncMock()
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
list_tools_result = MagicMock()
|
||||
list_tools_result.tools = [mock_tool]
|
||||
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
# Mock the sse_client function to return our mock context manager
|
||||
with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm):
|
||||
# Mock ClientSession
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools.return_value.tools = [MagicMock()]
|
||||
await sse_client.connect_to_server(test_url, test_headers)
|
||||
|
||||
# Mock the AsyncExitStack
|
||||
mock_exit_stack = AsyncMock()
|
||||
mock_exit_stack.enter_async_context = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
mock_sse_transport, # For sse_client
|
||||
mock_session, # For ClientSession
|
||||
]
|
||||
sse_client.exit_stack = mock_exit_stack
|
||||
# Verify headers are stored correctly in connection params (normalized)
|
||||
assert sse_client._connection_params is not None
|
||||
assert sse_client._connection_params["headers"] == expected_headers
|
||||
assert sse_client._connection_params["url"] == test_url
|
||||
|
||||
tools = await sse_client.connect_to_server("http://test.url", {})
|
||||
|
||||
assert len(tools) == 1
|
||||
assert sse_client.session is not None
|
||||
# Verify the exit stack was used correctly
|
||||
assert mock_exit_stack.enter_async_context.call_count == 2
|
||||
# Verify the SSE transport was properly set
|
||||
assert sse_client.sse == mock_sse
|
||||
assert sse_client.write == mock_write
|
||||
class TestMCPUtilityFunctions:
|
||||
"""Test utility functions from util.py that don't have dedicated test classes."""
|
||||
|
||||
async def test_connect_timeout(self, sse_client):
|
||||
"""Test connection timeout handling."""
|
||||
# Set max_retries to 1 to avoid multiple retry attempts
|
||||
sse_client.max_retries = 1
|
||||
def test_sanitize_mcp_name(self):
|
||||
"""Test MCP name sanitization."""
|
||||
assert util.sanitize_mcp_name("Test Name 123") == "test_name_123"
|
||||
assert util.sanitize_mcp_name(" ") == ""
|
||||
assert util.sanitize_mcp_name("123abc") == "_123abc"
|
||||
assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name"
|
||||
assert util.sanitize_mcp_name("a" * 100) == "a" * 46
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")), # Mock URL validation
|
||||
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
|
||||
):
|
||||
mock_connect.side_effect = asyncio.TimeoutError()
|
||||
def test_get_unique_name(self):
|
||||
"""Test unique name generation."""
|
||||
names = {"foo", "foo_1"}
|
||||
assert util.get_unique_name("foo", 10, names) == "foo_2"
|
||||
assert util.get_unique_name("bar", 10, names) == "bar"
|
||||
assert util.get_unique_name("longname", 4, {"long"}) == "lo_1"
|
||||
|
||||
# Expect ConnectionError instead of TimeoutError
|
||||
with pytest.raises(
|
||||
ConnectionError,
|
||||
match=(
|
||||
"Failed to connect after 1 attempts. "
|
||||
"Last error: Connection to http://test.url timed out after 1 seconds"
|
||||
),
|
||||
):
|
||||
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)
|
||||
def test_is_valid_key_value_item(self):
|
||||
"""Test key-value item validation."""
|
||||
assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True
|
||||
assert util._is_valid_key_value_item({"key": "a"}) is False
|
||||
assert util._is_valid_key_value_item(["key", "value"]) is False
|
||||
assert util._is_valid_key_value_item(None) is False
|
||||
|
||||
def test_validate_node_installation(self):
|
||||
"""Test Node.js installation validation."""
|
||||
import shutil
|
||||
|
||||
if shutil.which("node"):
|
||||
assert util._validate_node_installation("npx something") == "npx something"
|
||||
else:
|
||||
with pytest.raises(ValueError, match="Node.js is not installed"):
|
||||
util._validate_node_installation("npx something")
|
||||
assert util._validate_node_installation("echo test") == "echo test"
|
||||
|
||||
def test_create_input_schema_from_json_schema(self):
|
||||
"""Test JSON schema to Pydantic model conversion."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string", "description": "desc"},
|
||||
"bar": {"type": "integer"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
}
|
||||
model_class = util.create_input_schema_from_json_schema(schema)
|
||||
instance = model_class(foo="abc", bar=1)
|
||||
assert instance.foo == "abc"
|
||||
assert instance.bar == 1
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017, PT011
|
||||
model_class(bar=1) # missing required field
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_connection_params(self):
|
||||
"""Test connection parameter validation."""
|
||||
# Valid parameters
|
||||
await util._validate_connection_params("Stdio", command="echo test")
|
||||
await util._validate_connection_params("SSE", url="http://test")
|
||||
|
||||
# Invalid parameters
|
||||
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
|
||||
await util._validate_connection_params("Stdio", command=None)
|
||||
with pytest.raises(ValueError, match="URL is required for SSE mode"):
|
||||
await util._validate_connection_params("SSE", url=None)
|
||||
with pytest.raises(ValueError, match="Invalid mode"):
|
||||
await util._validate_connection_params("InvalidMode")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_flow_snake_case_mocked(self):
|
||||
"""Test flow lookup by snake case name with mocked session."""
|
||||
|
||||
class DummyFlow:
|
||||
def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None):
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.is_component = is_component
|
||||
self.action_name = action_name
|
||||
|
||||
class DummyExec:
|
||||
def __init__(self, flows: list[DummyFlow]):
|
||||
self._flows = flows
|
||||
|
||||
def all(self):
|
||||
return self._flows
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, flows: list[DummyFlow]):
|
||||
self._flows = flows
|
||||
|
||||
async def exec(self, stmt): # noqa: ARG002
|
||||
return DummyExec(self._flows)
|
||||
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)]
|
||||
|
||||
# Should match sanitized name
|
||||
result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows))
|
||||
assert result is flows[0]
|
||||
|
||||
# Should return None if not found
|
||||
result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows))
|
||||
assert result is None
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ export const useAddMCPServer: useMutationFunctionType<
|
|||
if (body.env && Object.keys(body.env).length > 0) {
|
||||
payload.env = body.env;
|
||||
}
|
||||
if (body.headers && Object.keys(body.headers).length > 0) {
|
||||
payload.headers = body.headers;
|
||||
}
|
||||
|
||||
const res = await api.post(
|
||||
`${getURL("MCP_SERVERS", undefined, true)}/${body.name}`,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ export const usePatchMCPServer: useMutationFunctionType<
|
|||
if (body.env && Object.keys(body.env).length > 0) {
|
||||
payload.env = body.env;
|
||||
}
|
||||
if (body.headers && Object.keys(body.headers).length > 0) {
|
||||
payload.headers = body.headers;
|
||||
}
|
||||
|
||||
const res = await api.patch(
|
||||
`${getURL("MCP_SERVERS", undefined, true)}/${body.name}`,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import _ from "lodash";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import IconComponent from "../../../../../components/common/genericIconComponent";
|
||||
import { Input } from "../../../../../components/ui/input";
|
||||
import { classNames } from "../../../../../utils/utils";
|
||||
|
|
@ -10,6 +10,7 @@ export type IOKeyPairInputProps = {
|
|||
duplicateKey: boolean;
|
||||
isList: boolean;
|
||||
isInputField?: boolean;
|
||||
testId?: string;
|
||||
};
|
||||
|
||||
const IOKeyPairInput = ({
|
||||
|
|
@ -18,10 +19,11 @@ const IOKeyPairInput = ({
|
|||
duplicateKey,
|
||||
isList = true,
|
||||
isInputField,
|
||||
testId,
|
||||
}: IOKeyPairInputProps) => {
|
||||
const checkValueType = (value) => {
|
||||
const checkValueType = useCallback((value) => {
|
||||
return Array.isArray(value) ? value : [value];
|
||||
};
|
||||
}, []);
|
||||
|
||||
const [currentData, setCurrentData] = useState<any[]>(() => {
|
||||
return !value || value?.length === 0 ? [{ "": "" }] : checkValueType(value);
|
||||
|
|
@ -32,90 +34,105 @@ const IOKeyPairInput = ({
|
|||
const newData =
|
||||
!value || value?.length === 0 ? [{ "": "" }] : checkValueType(value);
|
||||
setCurrentData(newData);
|
||||
}, [value]);
|
||||
}, [value, checkValueType]);
|
||||
|
||||
const handleChangeKey = (event, idx) => {
|
||||
const oldKey = Object.keys(currentData[idx])[0];
|
||||
const updatedObj = { [event.target.value]: currentData[idx][oldKey] };
|
||||
const handleChangeKey = (event, objIndex) => {
|
||||
const oldKey = Object.keys(currentData[objIndex])[0];
|
||||
const updatedObj = { [event.target.value]: currentData[objIndex][oldKey] };
|
||||
const newData = [...currentData];
|
||||
newData[idx] = updatedObj;
|
||||
newData[objIndex] = updatedObj;
|
||||
setCurrentData(newData);
|
||||
onChange(newData);
|
||||
};
|
||||
|
||||
const handleChangeValue = (newValue, idx) => {
|
||||
const key = Object.keys(currentData[idx])[0];
|
||||
const handleChangeValue = (newValue, objIndex) => {
|
||||
const key = Object.keys(currentData[objIndex])[0];
|
||||
const newData = [...currentData];
|
||||
newData[idx] = { ...newData[idx], [key]: newValue };
|
||||
newData[objIndex] = { ...newData[objIndex], [key]: newValue };
|
||||
setCurrentData(newData);
|
||||
onChange(newData);
|
||||
};
|
||||
|
||||
// Create flat array with additional metadata for rendering
|
||||
const flattenedData =
|
||||
currentData?.flatMap((obj, objIndex) => {
|
||||
return Object.keys(obj).map((key) => ({
|
||||
key,
|
||||
value: obj[key],
|
||||
objIndex,
|
||||
uniqueId: `${objIndex}-${key}`, // Create unique identifier for React key
|
||||
}));
|
||||
}) || [];
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className={classNames("flex h-full flex-col gap-3")}>
|
||||
{currentData?.map((obj, index) => {
|
||||
return Object.keys(obj).map((key, idx) => {
|
||||
return (
|
||||
<div key={idx} className="flex w-full gap-2">
|
||||
<Input
|
||||
type="text"
|
||||
value={key.trim()}
|
||||
className={classNames(duplicateKey ? "input-invalid" : "")}
|
||||
placeholder="Type key..."
|
||||
onChange={(event) => handleChangeKey(event, index)}
|
||||
disabled={!isInputField}
|
||||
/>
|
||||
<div className={classNames("flex h-full flex-col gap-3")}>
|
||||
{flattenedData.map((item, idx) => {
|
||||
return (
|
||||
<div key={item.uniqueId} className="flex w-full gap-2">
|
||||
<Input
|
||||
type="text"
|
||||
value={item.key.trim()}
|
||||
className={classNames(duplicateKey ? "input-invalid" : "")}
|
||||
placeholder="Type key..."
|
||||
onChange={(event) => handleChangeKey(event, item.objIndex)}
|
||||
disabled={!isInputField}
|
||||
data-testid={testId ? `${testId}-key-${idx}` : undefined}
|
||||
/>
|
||||
|
||||
<Input
|
||||
type="text"
|
||||
value={obj[key]}
|
||||
placeholder="Type a value..."
|
||||
onChange={(event) =>
|
||||
handleChangeValue(event.target.value, index)
|
||||
}
|
||||
disabled={!isInputField}
|
||||
/>
|
||||
<Input
|
||||
type="text"
|
||||
value={item.value}
|
||||
placeholder="Type a value..."
|
||||
onChange={(event) =>
|
||||
handleChangeValue(event.target.value, item.objIndex)
|
||||
}
|
||||
disabled={!isInputField}
|
||||
data-testid={testId ? `${testId}-value-${idx}` : undefined}
|
||||
/>
|
||||
|
||||
{isList && isInputField && index === currentData.length - 1 ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
const newInputList = _.cloneDeep(currentData);
|
||||
newInputList.push({ "": "" });
|
||||
setCurrentData(newInputList);
|
||||
onChange(newInputList);
|
||||
}}
|
||||
>
|
||||
<IconComponent
|
||||
name="Plus"
|
||||
className={"h-4 w-4 hover:text-accent-foreground"}
|
||||
/>
|
||||
</button>
|
||||
) : isList && isInputField ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
const newInputList = _.cloneDeep(currentData);
|
||||
newInputList.splice(index, 1);
|
||||
setCurrentData(newInputList);
|
||||
onChange(newInputList);
|
||||
}}
|
||||
>
|
||||
<IconComponent
|
||||
name="X"
|
||||
className="h-4 w-4 hover:text-status-red"
|
||||
/>
|
||||
</button>
|
||||
) : (
|
||||
""
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
})}
|
||||
</div>
|
||||
</>
|
||||
{isList &&
|
||||
isInputField &&
|
||||
item.objIndex === currentData.length - 1 ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
const newInputList = _.cloneDeep(currentData);
|
||||
newInputList.push({ "": "" });
|
||||
setCurrentData(newInputList);
|
||||
onChange(newInputList);
|
||||
}}
|
||||
data-testid={testId ? `${testId}-plus-btn-0` : undefined}
|
||||
>
|
||||
<IconComponent
|
||||
name="Plus"
|
||||
className={"h-4 w-4 hover:text-accent-foreground"}
|
||||
/>
|
||||
</button>
|
||||
) : isList && isInputField ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
const newInputList = _.cloneDeep(currentData);
|
||||
newInputList.splice(item.objIndex, 1);
|
||||
setCurrentData(newInputList);
|
||||
onChange(newInputList);
|
||||
}}
|
||||
data-testid={
|
||||
testId ? `${testId}-minus-btn-${item.objIndex}` : undefined
|
||||
}
|
||||
>
|
||||
<IconComponent
|
||||
name="X"
|
||||
className="h-4 w-4 hover:text-status-red"
|
||||
/>
|
||||
</button>
|
||||
) : (
|
||||
""
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -358,6 +358,7 @@ export default function AddMcpServerModal({
|
|||
listAddLabel="Add Argument"
|
||||
editNode={false}
|
||||
id="stdio-args"
|
||||
data-testid="stdio-args-input"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-2">
|
||||
|
|
@ -368,6 +369,7 @@ export default function AddMcpServerModal({
|
|||
duplicateKey={false}
|
||||
isList={true}
|
||||
isInputField={true}
|
||||
testId="stdio-env"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -382,6 +384,7 @@ export default function AddMcpServerModal({
|
|||
value={sseName}
|
||||
onChange={(e) => setSseName(e.target.value)}
|
||||
placeholder="Name"
|
||||
data-testid="sse-name-input"
|
||||
disabled={isPending}
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -393,6 +396,7 @@ export default function AddMcpServerModal({
|
|||
value={sseUrl}
|
||||
onChange={(e) => setSseUrl(e.target.value)}
|
||||
placeholder="SSE URL"
|
||||
data-testid="sse-url-input"
|
||||
disabled={isPending}
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -404,6 +408,7 @@ export default function AddMcpServerModal({
|
|||
duplicateKey={false}
|
||||
isList={true}
|
||||
isInputField={true}
|
||||
testId="sse-headers"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-2">
|
||||
|
|
@ -414,6 +419,7 @@ export default function AddMcpServerModal({
|
|||
duplicateKey={false}
|
||||
isList={true}
|
||||
isInputField={true}
|
||||
testId="sse-env"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -95,5 +95,9 @@ export function extractMcpServersFromJson(
|
|||
args: server.args || [],
|
||||
env: server.env && typeof server.env === "object" ? server.env : {},
|
||||
url: server.url,
|
||||
headers:
|
||||
server.headers && typeof server.headers === "object"
|
||||
? server.headers
|
||||
: {},
|
||||
}));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ test(
|
|||
timeout: 30000,
|
||||
});
|
||||
|
||||
await page.getByTestId("stdio-name-input").fill("test server");
|
||||
const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number
|
||||
const testName = `test_server_${randomSuffix}`;
|
||||
await page.getByTestId("stdio-name-input").fill(testName);
|
||||
|
||||
await page.getByTestId("stdio-command-input").fill("uvx mcp-server-fetch");
|
||||
|
||||
|
|
@ -115,12 +117,12 @@ test(
|
|||
timeout: 3000,
|
||||
});
|
||||
|
||||
await expect(page.getByText("test_server")).toBeVisible({
|
||||
await expect(page.getByText(testName)).toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
await page
|
||||
.getByTestId(`mcp-server-menu-button-test_server`)
|
||||
.getByTestId(`mcp-server-menu-button-${testName}`)
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page
|
||||
|
|
@ -152,7 +154,7 @@ test(
|
|||
await page.getByTestId("add-mcp-server-button").click();
|
||||
|
||||
await page
|
||||
.getByTestId(`mcp-server-menu-button-test_server`)
|
||||
.getByTestId(`mcp-server-menu-button-${testName}`)
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page
|
||||
|
|
@ -177,7 +179,7 @@ test(
|
|||
|
||||
await page.waitForTimeout(3000);
|
||||
|
||||
await expect(page.getByText("test_server")).not.toBeVisible({
|
||||
await expect(page.getByText(testName)).not.toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
|
|
@ -199,8 +201,360 @@ test(
|
|||
});
|
||||
|
||||
await page.getByTestId("mcp-server-dropdown").click({ timeout: 10000 });
|
||||
await expect(page.getByText("test_server")).toHaveCount(2, {
|
||||
await expect(page.getByText(testName)).toHaveCount(2, {
|
||||
timeout: 10000,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
test(
|
||||
"STDIO MCP server fields should persist after saving and editing",
|
||||
{ tag: ["@release", "@workspace", "@components"] },
|
||||
async ({ page }) => {
|
||||
await awaitBootstrapTest(page);
|
||||
|
||||
await page.waitForSelector('[data-testid="blank-flow"]', {
|
||||
timeout: 30000,
|
||||
});
|
||||
await page.getByTestId("blank-flow").click();
|
||||
await page.getByTestId("sidebar-search-input").click();
|
||||
await page.getByTestId("sidebar-search-input").fill("mcp tools");
|
||||
|
||||
await page.waitForSelector('[data-testid="agentsMCP Tools"]', {
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
await page
|
||||
.getByTestId("agentsMCP Tools")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'), {
|
||||
targetPosition: { x: 0, y: 0 },
|
||||
});
|
||||
|
||||
await page.getByTestId("fit_view").click();
|
||||
|
||||
await zoomOut(page, 3);
|
||||
|
||||
try {
|
||||
await page.getByText("Add MCP Server", { exact: true }).click({
|
||||
timeout: 5000,
|
||||
});
|
||||
} catch (_error) {
|
||||
await page.getByTestId("mcp-server-dropdown").click({ timeout: 3000 });
|
||||
await page.getByText("Add MCP Server", { exact: true }).click({
|
||||
timeout: 5000,
|
||||
});
|
||||
}
|
||||
|
||||
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
|
||||
state: "visible",
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Go to STDIO tab and fill all fields
|
||||
await page.getByTestId("stdio-tab").click();
|
||||
await page.waitForSelector('[data-testid="stdio-name-input"]', {
|
||||
state: "visible",
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Test data with random suffix
|
||||
const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number
|
||||
const testName = `test_stdio_server_${randomSuffix}`;
|
||||
const testCommand = "uvx mcp-server-test";
|
||||
const testArg1 = "--verbose";
|
||||
const testArg2 = "--port=8080";
|
||||
const testArg3 = "--config=test.json";
|
||||
const testEnvKey1 = "NODE_ENV";
|
||||
const testEnvValue1 = "production";
|
||||
const testEnvKey2 = "DEBUG_MODE";
|
||||
const testEnvValue2 = "true";
|
||||
|
||||
// Fill basic fields
|
||||
await page.getByTestId("stdio-name-input").fill(testName);
|
||||
await page.getByTestId("stdio-command-input").fill(testCommand);
|
||||
|
||||
// Add first argument
|
||||
await page.getByTestId("stdio-args_0").fill(testArg1);
|
||||
|
||||
// Add second argument by clicking plus button
|
||||
await page.getByTestId("input-list-plus-btn_-0").click();
|
||||
await page.getByTestId("stdio-args_1").fill(testArg2);
|
||||
|
||||
// Add third argument
|
||||
await page.getByTestId("input-list-plus-btn_-0").click();
|
||||
await page.getByTestId("stdio-args_2").fill(testArg3);
|
||||
|
||||
// Add first environment variable
|
||||
await page.getByTestId("stdio-env-key-0").fill(testEnvKey1);
|
||||
await page.getByTestId("stdio-env-value-0").fill(testEnvValue1);
|
||||
|
||||
// Add second environment variable
|
||||
await page.getByTestId("stdio-env-plus-btn-0").click();
|
||||
await page.getByTestId("stdio-env-key-1").fill(testEnvKey2);
|
||||
await page.getByTestId("stdio-env-value-1").fill(testEnvValue2);
|
||||
|
||||
// Save the server
|
||||
await page.getByTestId("add-mcp-server-button").click();
|
||||
|
||||
// Wait for server to be created
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
// Go to settings to edit the server
|
||||
await page.getByTestId("user_menu_button").click({ timeout: 3000 });
|
||||
await page.getByTestId("menu_settings_button").click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector('[data-testid="sidebar-nav-MCP Servers"]', {
|
||||
timeout: 30000,
|
||||
});
|
||||
await page.getByTestId("sidebar-nav-MCP Servers").click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector('[data-testid="add-mcp-server-button-page"]', {
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
// Find and edit the server
|
||||
await expect(page.getByText(testName)).toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
await page
|
||||
.getByTestId(`mcp-server-menu-button-${testName}`)
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page
|
||||
.getByText("Edit", { exact: true })
|
||||
.first()
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
|
||||
state: "visible",
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Verify all fields persisted correctly
|
||||
expect(await page.getByTestId("stdio-name-input").inputValue()).toBe(
|
||||
testName,
|
||||
);
|
||||
expect(await page.getByTestId("stdio-command-input").inputValue()).toBe(
|
||||
testCommand,
|
||||
);
|
||||
expect(await page.getByTestId("stdio-args_0").inputValue()).toBe(testArg1);
|
||||
expect(await page.getByTestId("stdio-args_1").inputValue()).toBe(testArg2);
|
||||
expect(await page.getByTestId("stdio-args_2").inputValue()).toBe(testArg3);
|
||||
expect(await page.getByTestId("stdio-env-key-0").last().inputValue()).toBe(
|
||||
testEnvKey1,
|
||||
);
|
||||
expect(
|
||||
await page.getByTestId("stdio-env-value-0").last().inputValue(),
|
||||
).toBe(testEnvValue1);
|
||||
expect(await page.getByTestId("stdio-env-key-1").last().inputValue()).toBe(
|
||||
testEnvKey2,
|
||||
);
|
||||
expect(
|
||||
await page.getByTestId("stdio-env-value-1").last().inputValue(),
|
||||
).toBe(testEnvValue2);
|
||||
|
||||
// Clean up - cancel the edit modal
|
||||
await page.keyboard.press("Escape");
|
||||
|
||||
// Delete the test server
|
||||
await page
|
||||
.getByTestId(`mcp-server-menu-button-${testName}`)
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page
|
||||
.getByText("Delete", { exact: true })
|
||||
.first()
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector(
|
||||
'[data-testid="btn_delete_delete_confirmation_modal"]',
|
||||
{
|
||||
timeout: 3000,
|
||||
},
|
||||
);
|
||||
|
||||
await page
|
||||
.getByTestId("btn_delete_delete_confirmation_modal")
|
||||
.click({ timeout: 3000 });
|
||||
},
|
||||
);
|
||||
|
||||
test(
|
||||
"SSE MCP server fields should persist after saving and editing",
|
||||
{ tag: ["@release", "@workspace", "@components"] },
|
||||
async ({ page }) => {
|
||||
await awaitBootstrapTest(page);
|
||||
|
||||
await page.waitForSelector('[data-testid="blank-flow"]', {
|
||||
timeout: 30000,
|
||||
});
|
||||
await page.getByTestId("blank-flow").click();
|
||||
await page.getByTestId("sidebar-search-input").click();
|
||||
await page.getByTestId("sidebar-search-input").fill("mcp tools");
|
||||
|
||||
await page.waitForSelector('[data-testid="agentsMCP Tools"]', {
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
await page
|
||||
.getByTestId("agentsMCP Tools")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'), {
|
||||
targetPosition: { x: 0, y: 0 },
|
||||
});
|
||||
|
||||
await page.getByTestId("fit_view").click();
|
||||
|
||||
await zoomOut(page, 3);
|
||||
|
||||
try {
|
||||
await page.getByText("Add MCP Server", { exact: true }).click({
|
||||
timeout: 5000,
|
||||
});
|
||||
} catch (_error) {
|
||||
await page.getByTestId("mcp-server-dropdown").click({ timeout: 3000 });
|
||||
await page.getByText("Add MCP Server", { exact: true }).click({
|
||||
timeout: 5000,
|
||||
});
|
||||
}
|
||||
|
||||
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
|
||||
state: "visible",
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Go to SSE tab and fill all fields
|
||||
await page.getByTestId("sse-tab").click();
|
||||
await page.waitForSelector('[data-testid="sse-name-input"]', {
|
||||
state: "visible",
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Test data with random suffix
|
||||
const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number
|
||||
const testName = `test_sse_server_${randomSuffix}`;
|
||||
const testUrl = "https://api.example.com/mcp";
|
||||
const testHeaderKey1 = "Authorization";
|
||||
const testHeaderValue1 = "Bearer token123";
|
||||
const testHeaderKey2 = "Content-Type";
|
||||
const testHeaderValue2 = "application/json";
|
||||
const testEnvKey1 = "API_TIMEOUT";
|
||||
const testEnvValue1 = "30000";
|
||||
const testEnvKey2 = "RETRY_COUNT";
|
||||
const testEnvValue2 = "3";
|
||||
|
||||
// Fill basic fields
|
||||
await page.getByTestId("sse-name-input").fill(testName);
|
||||
await page.getByTestId("sse-url-input").fill(testUrl);
|
||||
|
||||
// Add first header
|
||||
await page.getByTestId("sse-headers-key-0").fill(testHeaderKey1);
|
||||
await page.getByTestId("sse-headers-value-0").fill(testHeaderValue1);
|
||||
|
||||
// Add second header
|
||||
await page.getByTestId("sse-headers-plus-btn-0").click();
|
||||
await page.getByTestId("sse-headers-key-1").fill(testHeaderKey2);
|
||||
await page.getByTestId("sse-headers-value-1").fill(testHeaderValue2);
|
||||
|
||||
// Add first environment variable
|
||||
await page.getByTestId("sse-env-key-0").fill(testEnvKey1);
|
||||
await page.getByTestId("sse-env-value-0").fill(testEnvValue1);
|
||||
|
||||
// Add second environment variable
|
||||
await page.getByTestId("sse-env-plus-btn-0").click();
|
||||
await page.getByTestId("sse-env-key-1").fill(testEnvKey2);
|
||||
await page.getByTestId("sse-env-value-1").fill(testEnvValue2);
|
||||
|
||||
// Save the server
|
||||
await page.getByTestId("add-mcp-server-button").click();
|
||||
|
||||
// Wait for server to be created
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
// Go to settings to edit the server
|
||||
await page.getByTestId("user_menu_button").click({ timeout: 3000 });
|
||||
await page.getByTestId("menu_settings_button").click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector('[data-testid="sidebar-nav-MCP Servers"]', {
|
||||
timeout: 30000,
|
||||
});
|
||||
await page.getByTestId("sidebar-nav-MCP Servers").click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector('[data-testid="add-mcp-server-button-page"]', {
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
// Find and edit the server
|
||||
await expect(page.getByText(testName)).toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
await page
|
||||
.getByTestId(`mcp-server-menu-button-${testName}`)
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page
|
||||
.getByText("Edit", { exact: true })
|
||||
.first()
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
|
||||
state: "visible",
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Verify all fields persisted correctly
|
||||
expect(await page.getByTestId("sse-name-input").inputValue()).toBe(
|
||||
testName,
|
||||
);
|
||||
expect(await page.getByTestId("sse-url-input").inputValue()).toBe(testUrl);
|
||||
expect(await page.getByTestId("sse-headers-key-0").inputValue()).toBe(
|
||||
testHeaderKey1,
|
||||
);
|
||||
expect(await page.getByTestId("sse-headers-value-0").inputValue()).toBe(
|
||||
testHeaderValue1,
|
||||
);
|
||||
expect(await page.getByTestId("sse-headers-key-1").inputValue()).toBe(
|
||||
testHeaderKey2,
|
||||
);
|
||||
expect(await page.getByTestId("sse-headers-value-1").inputValue()).toBe(
|
||||
testHeaderValue2,
|
||||
);
|
||||
expect(await page.getByTestId("sse-env-key-0").inputValue()).toBe(
|
||||
testEnvKey1,
|
||||
);
|
||||
expect(await page.getByTestId("sse-env-value-0").inputValue()).toBe(
|
||||
testEnvValue1,
|
||||
);
|
||||
expect(await page.getByTestId("sse-env-key-1").inputValue()).toBe(
|
||||
testEnvKey2,
|
||||
);
|
||||
expect(await page.getByTestId("sse-env-value-1").inputValue()).toBe(
|
||||
testEnvValue2,
|
||||
);
|
||||
|
||||
// Clean up - cancel the edit modal
|
||||
await page.keyboard.press("Escape");
|
||||
|
||||
// Delete the test server
|
||||
await page
|
||||
.getByTestId(`mcp-server-menu-button-${testName}`)
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page
|
||||
.getByText("Delete", { exact: true })
|
||||
.first()
|
||||
.click({ timeout: 3000 });
|
||||
|
||||
await page.waitForSelector(
|
||||
'[data-testid="btn_delete_delete_confirmation_modal"]',
|
||||
{
|
||||
timeout: 3000,
|
||||
},
|
||||
);
|
||||
|
||||
await page
|
||||
.getByTestId("btn_delete_delete_confirmation_modal")
|
||||
.click({ timeout: 3000 });
|
||||
},
|
||||
);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue