fix: store mcp sse headers and use them on connection (#9148)

* Store mcp server headers

* Add headers on pre check url and is valid url

* adds validation of headers according to RFC 7230

* Fixed sanitized value

* Added backend tests for mcp util.py to increase coverage

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* make key pair input use flatmap id on data test ids

* added testids

* added random test names and added tests for persistence

* fix ruff lint

* [autofix.ci] apply automated fixes

* Fix mypy lint errors

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Lucas Oliveira 2025-07-25 22:16:18 -03:00 committed by GitHub
commit 41e101499a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1101 additions and 197 deletions

View file

@ -30,6 +30,86 @@ HTTP_NOT_FOUND = 404
HTTP_BAD_REQUEST = 400
HTTP_INTERNAL_SERVER_ERROR = 500
# RFC 7230 compliant header name pattern: token = 1*tchar
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
HEADER_NAME_PATTERN = re.compile(r"^[!#$%&\'*+\-.0-9A-Z^_`a-z|~]+$")
# Common allowed headers for MCP connections
ALLOWED_HEADERS = {
"authorization",
"accept",
"accept-encoding",
"accept-language",
"cache-control",
"content-type",
"user-agent",
"x-api-key",
"x-auth-token",
"x-custom-header",
"x-langflow-session",
"x-mcp-client",
"x-requested-with",
}
def validate_headers(headers: dict[str, str]) -> dict[str, str]:
"""Validate and sanitize HTTP headers according to RFC 7230.
Args:
headers: Dictionary of header name-value pairs
Returns:
Dictionary of validated and sanitized headers
Raises:
ValueError: If headers contain invalid names or values
"""
if not headers:
return {}
sanitized_headers = {}
for name, value in headers.items():
if not isinstance(name, str) or not isinstance(value, str):
logger.warning(f"Skipping non-string header: {name}={value}")
continue
# Validate header name according to RFC 7230
if not HEADER_NAME_PATTERN.match(name):
logger.warning(f"Invalid header name '{name}', skipping")
continue
# Normalize header name to lowercase (HTTP headers are case-insensitive)
normalized_name = name.lower()
# Optional: Check against whitelist of allowed headers
if normalized_name not in ALLOWED_HEADERS:
# For MCP, we'll be permissive and allow non-standard headers
# but log a warning for security awareness
logger.debug(f"Using non-standard header: {normalized_name}")
# Check for potential header injection attempts BEFORE sanitizing
if "\r" in value or "\n" in value:
logger.warning(f"Potential header injection detected in '{name}', skipping")
continue
# Sanitize header value - remove control characters and newlines
# RFC 7230: field-value = *( field-content / obs-fold )
# We'll remove control characters (0x00-0x1F, 0x7F) except tab (0x09) and space (0x20)
sanitized_value = re.sub(r"[\x00-\x08\x0A-\x1F\x7F]", "", value)
# Remove leading/trailing whitespace
sanitized_value = sanitized_value.strip()
if not sanitized_value:
logger.warning(f"Header '{name}' has empty value after sanitization, skipping")
continue
sanitized_headers[normalized_name] = sanitized_value
return sanitized_headers
def sanitize_mcp_name(name: str, max_length: int = 46) -> str:
"""Sanitize a name for MCP usage by removing emojis, diacritics, and special characters.
@ -334,12 +414,12 @@ def _process_headers(headers: Any) -> dict:
Args:
headers: The headers to process, can be dict, str, or list
Returns:
Processed dictionary
Processed and validated dictionary
"""
if headers is None:
return {}
if isinstance(headers, dict):
return headers
return validate_headers(headers)
if isinstance(headers, list):
processed_headers = {}
try:
@ -351,7 +431,7 @@ def _process_headers(headers: Any) -> dict:
processed_headers[key] = value
except (KeyError, TypeError, ValueError):
return {} # Return empty dictionary instead of None
return processed_headers
return validate_headers(processed_headers)
return {}
@ -924,7 +1004,7 @@ class MCPSseClient:
self._component_cache.set("mcp_session_manager", session_manager)
return session_manager
async def validate_url(self, url: str | None) -> tuple[bool, str]:
async def validate_url(self, url: str | None, headers: dict[str, str] | None = None) -> tuple[bool, str]:
"""Validate the SSE URL before attempting connection."""
try:
parsed = urlparse(url)
@ -935,7 +1015,9 @@ class MCPSseClient:
try:
# For SSE endpoints, try a GET request with short timeout
# Many SSE servers don't support HEAD requests and return 404
response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"})
response = await client.get(
url, timeout=2.0, headers={"Accept": "text/event-stream", **(headers or {})}
)
# For SSE, we expect the server to either:
# 1. Start streaming (200)
@ -967,14 +1049,16 @@ class MCPSseClient:
except (httpx.HTTPError, ValueError, OSError) as e:
return False, f"URL validation error: {e!s}"
async def pre_check_redirect(self, url: str | None) -> str | None:
async def pre_check_redirect(self, url: str | None, headers: dict[str, str] | None = None) -> str | None:
"""Check for redirects and return the final URL."""
if url is None:
return url
try:
async with httpx.AsyncClient(follow_redirects=False) as client:
# Use GET with SSE headers instead of HEAD since many SSE servers don't support HEAD
response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"})
response = await client.get(
url, timeout=2.0, headers={"Accept": "text/event-stream", **(headers or {})}
)
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
return response.headers.get("Location", url)
# Don't treat 404 as an error here - let the main connection handle it
@ -990,22 +1074,23 @@ class MCPSseClient:
sse_read_timeout_seconds: int = 30,
) -> list[StructuredTool]:
"""Connect to MCP server using SSE transport (SDK style)."""
if headers is None:
headers = {}
# Validate and sanitize headers early
validated_headers = _process_headers(headers)
if url is None:
msg = "URL is required for SSE mode"
raise ValueError(msg)
is_valid, error_msg = await self.validate_url(url)
is_valid, error_msg = await self.validate_url(url, validated_headers)
if not is_valid:
msg = f"Invalid SSE URL ({url}): {error_msg}"
raise ValueError(msg)
url = await self.pre_check_redirect(url)
url = await self.pre_check_redirect(url, validated_headers)
# Store connection parameters for later use in run_tool
self._connection_params = {
"url": url,
"headers": headers,
"headers": validated_headers,
"timeout_seconds": timeout_seconds,
"sse_read_timeout_seconds": sse_read_timeout_seconds,
}

View file

@ -1,14 +1,12 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langflow.components.agents.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent
from langflow.base.mcp import util
from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers
from langflow.components.agents.mcp_component import MCPToolsComponent
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
# TODO: This test suite is incomplete and is in need of an update to handle the latest MCP component changes.
pytestmark = pytest.mark.skip(reason="Skipping entire file")
class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
@ -20,9 +18,7 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
def default_kwargs(self):
"""Return the default kwargs for the component."""
return {
"mode": "Stdio",
"command": "uvx mcp-server-fetch",
"sse_url": "http://localhost:7860/api/v1/mcp/sse",
"mcp_server": {"name": "test_server", "config": {"command": "uvx mcp-server-fetch"}},
"tool": "",
}
@ -40,24 +36,22 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
tool.inputSchema = {
"type": "object",
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
"required": ["test_param"],
}
return tool
@pytest.fixture
def mock_stdio_client(self, mock_tool):
"""Create a mock stdio client."""
stdio_client = AsyncMock()
stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool])
stdio_client.session = AsyncMock()
return stdio_client
@pytest.fixture
def mock_sse_client(self, mock_tool):
"""Create a mock SSE client."""
sse_client = AsyncMock()
sse_client.connect_to_server = AsyncMock(return_value=[mock_tool])
sse_client.session = AsyncMock()
return sse_client
def mock_session(self, mock_tool):
"""Create a mock ClientSession."""
session = AsyncMock()
session.initialize = AsyncMock()
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
session.list_tools = AsyncMock(return_value=list_tools_result)
session.call_tool = AsyncMock(
return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})])
)
return session
class TestMCPStdioClient:
@ -65,40 +59,69 @@ class TestMCPStdioClient:
def stdio_client(self):
return MCPStdioClient()
async def test_connect_to_server(self, stdio_client):
"""Test connecting to server via Stdio."""
# Create mock for stdio transport
mock_stdio = AsyncMock()
mock_write = AsyncMock()
mock_stdio_transport = (mock_stdio, mock_write)
mock_stdio_cm = AsyncMock()
mock_stdio_cm.__aenter__.return_value = mock_stdio_transport
@pytest.fixture
def mock_session_manager(self):
"""Create a mock session manager."""
return AsyncMock(spec=MCPSessionManager)
# Mock the stdio_client function to return our mock context manager
with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm):
# Mock ClientSession
async def test_connect_to_server_with_command(self, stdio_client):
"""Test connecting to server via Stdio with command."""
with patch.object(stdio_client, "_get_or_create_session") as mock_get_session:
# Mock session
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools.return_value.tools = [MagicMock()]
mock_tool = MagicMock()
mock_tool.name = "test_tool"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_get_session.return_value = mock_session
# Mock the AsyncExitStack
mock_exit_stack = AsyncMock()
mock_exit_stack.enter_async_context = AsyncMock()
mock_exit_stack.enter_async_context.side_effect = [
mock_stdio_transport, # For stdio_client
mock_session, # For ClientSession
]
stdio_client.exit_stack = mock_exit_stack
tools = await stdio_client.connect_to_server("test_command")
tools = await stdio_client.connect_to_server("uvx test-command")
assert len(tools) == 1
assert stdio_client.session is not None
# Verify the exit stack was used correctly
assert mock_exit_stack.enter_async_context.call_count == 2
# Verify the stdio transport was properly set
assert stdio_client.stdio == mock_stdio
assert stdio_client.write == mock_write
assert tools[0].name == "test_tool"
assert stdio_client._connected is True
assert stdio_client._connection_params is not None
async def test_run_tool_success(self, stdio_client):
"""Test successfully running a tool."""
# Setup connection state
stdio_client._connected = True
stdio_client._connection_params = MagicMock()
stdio_client._session_context = "test_context"
with patch.object(stdio_client, "_get_or_create_session") as mock_get_session:
mock_session = AsyncMock()
mock_result = MagicMock()
mock_session.call_tool = AsyncMock(return_value=mock_result)
mock_get_session.return_value = mock_session
result = await stdio_client.run_tool("test_tool", {"param": "value"})
assert result == mock_result
mock_session.call_tool.assert_called_once_with("test_tool", arguments={"param": "value"})
async def test_run_tool_without_connection(self, stdio_client):
"""Test running a tool without being connected."""
stdio_client._connected = False
with pytest.raises(ValueError, match="Session not initialized"):
await stdio_client.run_tool("test_tool", {})
async def test_disconnect_cleanup(self, stdio_client):
"""Test that disconnect properly cleans up resources."""
stdio_client._session_context = "test_context"
stdio_client._connected = True
with patch.object(stdio_client, "_get_session_manager") as mock_get_manager:
mock_manager = AsyncMock()
mock_get_manager.return_value = mock_manager
await stdio_client.disconnect()
mock_manager._cleanup_session.assert_called_once_with("test_context")
assert stdio_client.session is None
assert stdio_client._connected is False
class TestMCPSseClient:
@ -106,78 +129,487 @@ class TestMCPSseClient:
def sse_client(self):
return MCPSseClient()
async def test_pre_check_redirect(self, sse_client):
"""Test pre-checking URL for redirects."""
async def test_validate_url_valid(self, sse_client):
"""Test URL validation with valid URL."""
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
assert is_valid is True
assert error_msg == ""
async def test_validate_url_invalid_format(self, sse_client):
"""Test URL validation with invalid format."""
is_valid, error_msg = await sse_client.validate_url("invalid-url", {})
assert is_valid is False
assert "Invalid URL format" in error_msg
async def test_validate_url_with_404_response(self, sse_client):
"""Test URL validation with 404 response (should be valid for SSE)."""
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 404
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
is_valid, error_msg = await sse_client.validate_url("http://test.url", {})
assert is_valid is True
assert error_msg == ""
async def test_connect_to_server_with_headers(self, sse_client):
"""Test connecting to server via SSE with custom headers."""
test_url = "http://test.url"
test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"}
expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized
with (
patch.object(sse_client, "validate_url", return_value=(True, "")),
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
):
# Mock session
mock_session = AsyncMock()
mock_tool = MagicMock()
mock_tool.name = "test_tool"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_get_session.return_value = mock_session
tools = await sse_client.connect_to_server(test_url, test_headers)
assert len(tools) == 1
assert tools[0].name == "test_tool"
assert sse_client._connected is True
# Verify headers are stored in connection params (normalized)
assert sse_client._connection_params is not None
assert sse_client._connection_params["headers"] == expected_headers
assert sse_client._connection_params["url"] == test_url
async def test_headers_passed_to_session_manager(self, sse_client):
"""Test that headers are properly passed to the session manager."""
test_url = "http://test.url"
expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized
sse_client._session_context = "test_context"
sse_client._connection_params = {
"url": test_url,
"headers": expected_headers, # Use normalized headers
"timeout_seconds": 30,
"sse_read_timeout_seconds": 30,
}
with patch.object(sse_client, "_get_session_manager") as mock_get_manager:
mock_manager = AsyncMock()
mock_session = AsyncMock()
mock_manager.get_session = AsyncMock(return_value=mock_session)
mock_get_manager.return_value = mock_manager
result_session = await sse_client._get_or_create_session()
# Verify session manager was called with correct parameters including normalized headers
mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse")
assert result_session == mock_session
async def test_pre_check_redirect_with_headers(self, sse_client):
"""Test pre-check redirect functionality with custom headers."""
test_url = "http://test.url"
redirect_url = "http://redirect.url"
# Use pre-validated headers since pre_check_redirect expects already validated headers
test_headers = {"authorization": "Bearer token123"} # already normalized
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 307
mock_response.headers.get.return_value = redirect_url
mock_client.return_value.__aenter__.return_value.request.return_value = mock_response
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await sse_client.pre_check_redirect(test_url, test_headers)
result = await sse_client.pre_check_redirect(test_url)
assert result == redirect_url
# Verify validated headers were passed to the request
mock_client.return_value.__aenter__.return_value.get.assert_called_with(
test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers}
)
async def test_run_tool_with_retry_on_connection_error(self, sse_client):
"""Test that run_tool retries on connection errors."""
# Setup connection state
sse_client._connected = True
sse_client._connection_params = {"url": "http://test.url", "headers": {}}
sse_client._session_context = "test_context"
call_count = 0
async def mock_get_session_side_effect():
nonlocal call_count
call_count += 1
session = AsyncMock()
if call_count == 1:
# First call fails with connection error
from anyio import ClosedResourceError
session.call_tool = AsyncMock(side_effect=ClosedResourceError())
else:
# Second call succeeds
mock_result = MagicMock()
session.call_tool = AsyncMock(return_value=mock_result)
return session
with (
patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect),
patch.object(sse_client, "_get_session_manager") as mock_get_manager,
):
mock_manager = AsyncMock()
mock_get_manager.return_value = mock_manager
result = await sse_client.run_tool("test_tool", {"param": "value"})
# Should have retried and succeeded on second attempt
assert call_count == 2
assert result is not None
# Should have cleaned up the failed session
mock_manager._cleanup_session.assert_called_once_with("test_context")
class TestMCPSessionManager:
@pytest.fixture
def session_manager(self):
return MCPSessionManager()
async def test_session_caching(self, session_manager):
"""Test that sessions are properly cached and reused."""
context_id = "test_context"
connection_params = MagicMock()
transport_type = "stdio"
# Create a mock session that will appear healthy
mock_session = AsyncMock()
mock_session._write_stream = MagicMock()
mock_session._write_stream._closed = False
# Create a mock task that appears to be running
mock_task = AsyncMock()
mock_task.done = MagicMock(return_value=False)
with (
patch.object(session_manager, "_create_stdio_session") as mock_create,
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
):
mock_create.return_value = mock_session
# First call should create session
session1 = await session_manager.get_session(context_id, connection_params, transport_type)
# Manually populate the sessions cache as if the session was created properly
session_manager.sessions[context_id] = {"session": mock_session, "task": mock_task, "type": transport_type}
# Second call should return cached session without creating new one
session2 = await session_manager.get_session(context_id, connection_params, transport_type)
assert session1 == session2
assert session1 == mock_session
# Should only create once since the second call should use the cached session
mock_create.assert_called_once()
async def test_session_cleanup(self, session_manager):
"""Test session cleanup functionality."""
context_id = "test_context"
# Add a session to the manager with proper mock setup
mock_task = AsyncMock()
mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method
mock_task.cancel = MagicMock() # Use MagicMock for sync method
session_manager.sessions[context_id] = {"session": AsyncMock(), "task": mock_task, "type": "stdio"}
await session_manager._cleanup_session(context_id)
# Should cancel the task and remove from sessions
mock_task.cancel.assert_called_once()
assert context_id not in session_manager.sessions
async def test_server_switch_detection(self, session_manager):
"""Test that server switches are properly detected and handled."""
context_id = "test_context"
# First server
server1_params = MagicMock()
server1_params.command = "server1"
# Second server
server2_params = MagicMock()
server2_params.command = "server2"
with (
patch.object(session_manager, "_create_stdio_session") as mock_create,
patch.object(session_manager, "_validate_session_connectivity", return_value=True),
):
mock_session1 = AsyncMock()
mock_session2 = AsyncMock()
mock_create.side_effect = [mock_session1, mock_session2]
# First connection
session1 = await session_manager.get_session(context_id, server1_params, "stdio")
# Switch to different server should create new session
session2 = await session_manager.get_session(context_id, server2_params, "stdio")
assert session1 != session2
assert mock_create.call_count == 2
# Integration test for header functionality
class TestHeaderValidation:
"""Test the header validation functionality."""
def test_validate_headers_valid_input(self):
"""Test header validation with valid headers."""
headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"}
result = validate_headers(headers)
# Headers should be normalized to lowercase
expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"}
assert result == expected
def test_validate_headers_empty_input(self):
"""Test header validation with empty/None input."""
assert validate_headers({}) == {}
assert validate_headers(None) == {}
def test_validate_headers_invalid_names(self):
"""Test header validation with invalid header names."""
headers = {
"Invalid Header": "value", # spaces not allowed
"Header@Name": "value", # @ not allowed
"Header Name": "value", # spaces not allowed
"Valid-Header": "value", # this should pass
}
result = validate_headers(headers)
# Only the valid header should remain
assert result == {"valid-header": "value"}
def test_validate_headers_sanitize_values(self):
"""Test header value sanitization."""
headers = {
"Authorization": "Bearer \x00token\x1f with\r\ninjection",
"Clean-Header": " clean value ",
"Empty-After-Clean": "\x00\x01\x02",
"Tab-Header": "value\twith\ttabs", # tabs should be preserved
}
result = validate_headers(headers)
# Control characters should be removed, whitespace trimmed
# Header with injection attempts should be skipped
expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"}
assert result == expected
def test_validate_headers_non_string_values(self):
"""Test header validation with non-string values."""
headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]}
result = validate_headers(headers)
# Only string headers should remain
assert result == {"string-header": "valid"}
def test_validate_headers_injection_attempts(self):
"""Test header validation against injection attempts."""
headers = {
"Injection1": "value\r\nInjected-Header: malicious",
"Injection2": "value\nX-Evil: attack",
"Safe-Header": "safe-value",
}
result = validate_headers(headers)
# Injection attempts should be filtered out
assert result == {"safe-header": "safe-value"}
class TestSSEHeaderIntegration:
"""Integration test to verify headers are properly passed through the entire SSE flow."""
async def test_headers_processing(self):
"""Test that headers flow properly from server config through to SSE client connection."""
# Test the header processing function directly
headers_input = [
{"key": "Authorization", "value": "Bearer test-token"},
{"key": "X-API-Key", "value": "secret-key"},
]
expected_headers = {
"authorization": "Bearer test-token", # normalized to lowercase
"x-api-key": "secret-key",
}
# Test _process_headers function with validation
processed_headers = _process_headers(headers_input)
assert processed_headers == expected_headers
# Test different input formats
# Test dict input with validation
dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"}
result = _process_headers(dict_headers)
# Invalid header should be filtered out, valid header normalized
assert result == {"authorization": "Bearer dict-token"}
# Test None input
assert _process_headers(None) == {}
# Test empty list
assert _process_headers([]) == {}
# Test malformed list
malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key
assert _process_headers(malformed_headers) == {}
# Test list with invalid header names
invalid_headers = [
{"key": "Valid-Header", "value": "good"},
{"key": "Invalid Header", "value": "bad"}, # spaces not allowed
]
result = _process_headers(invalid_headers)
assert result == {"valid-header": "good"}
async def test_sse_client_header_storage(self):
"""Test that SSE client properly stores headers in connection params."""
sse_client = MCPSseClient()
test_url = "http://test.url"
test_headers = {"Authorization": "Bearer test123", "Custom": "value"}
expected_headers = {"authorization": "Bearer test123", "custom": "value"} # normalized
async def test_connect_to_server(self, sse_client):
"""Test connecting to server via SSE."""
# Mock the pre_check_redirect first
with (
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
patch.object(sse_client, "validate_url", return_value=(True, "")),
patch.object(sse_client, "pre_check_redirect", return_value=test_url),
patch.object(sse_client, "_get_or_create_session") as mock_get_session,
):
# Create mock for sse_client context manager
mock_sse = AsyncMock()
mock_write = AsyncMock()
mock_sse_transport = (mock_sse, mock_write)
mock_sse_cm = AsyncMock()
mock_sse_cm.__aenter__.return_value = mock_sse_transport
mock_session = AsyncMock()
mock_tool = MagicMock()
mock_tool.name = "test_tool"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_get_session.return_value = mock_session
# Mock the sse_client function to return our mock context manager
with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm):
# Mock ClientSession
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools.return_value.tools = [MagicMock()]
await sse_client.connect_to_server(test_url, test_headers)
# Mock the AsyncExitStack
mock_exit_stack = AsyncMock()
mock_exit_stack.enter_async_context = AsyncMock()
mock_exit_stack.enter_async_context.side_effect = [
mock_sse_transport, # For sse_client
mock_session, # For ClientSession
]
sse_client.exit_stack = mock_exit_stack
# Verify headers are stored correctly in connection params (normalized)
assert sse_client._connection_params is not None
assert sse_client._connection_params["headers"] == expected_headers
assert sse_client._connection_params["url"] == test_url
tools = await sse_client.connect_to_server("http://test.url", {})
assert len(tools) == 1
assert sse_client.session is not None
# Verify the exit stack was used correctly
assert mock_exit_stack.enter_async_context.call_count == 2
# Verify the SSE transport was properly set
assert sse_client.sse == mock_sse
assert sse_client.write == mock_write
class TestMCPUtilityFunctions:
"""Test utility functions from util.py that don't have dedicated test classes."""
async def test_connect_timeout(self, sse_client):
"""Test connection timeout handling."""
# Set max_retries to 1 to avoid multiple retry attempts
sse_client.max_retries = 1
def test_sanitize_mcp_name(self):
"""Test MCP name sanitization."""
assert util.sanitize_mcp_name("Test Name 123") == "test_name_123"
assert util.sanitize_mcp_name(" ") == ""
assert util.sanitize_mcp_name("123abc") == "_123abc"
assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name"
assert util.sanitize_mcp_name("a" * 100) == "a" * 46
with (
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
patch.object(sse_client, "validate_url", return_value=(True, "")), # Mock URL validation
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
):
mock_connect.side_effect = asyncio.TimeoutError()
def test_get_unique_name(self):
"""Test unique name generation."""
names = {"foo", "foo_1"}
assert util.get_unique_name("foo", 10, names) == "foo_2"
assert util.get_unique_name("bar", 10, names) == "bar"
assert util.get_unique_name("longname", 4, {"long"}) == "lo_1"
# Expect ConnectionError instead of TimeoutError
with pytest.raises(
ConnectionError,
match=(
"Failed to connect after 1 attempts. "
"Last error: Connection to http://test.url timed out after 1 seconds"
),
):
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)
def test_is_valid_key_value_item(self):
"""Test key-value item validation."""
assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True
assert util._is_valid_key_value_item({"key": "a"}) is False
assert util._is_valid_key_value_item(["key", "value"]) is False
assert util._is_valid_key_value_item(None) is False
def test_validate_node_installation(self):
"""Test Node.js installation validation."""
import shutil
if shutil.which("node"):
assert util._validate_node_installation("npx something") == "npx something"
else:
with pytest.raises(ValueError, match="Node.js is not installed"):
util._validate_node_installation("npx something")
assert util._validate_node_installation("echo test") == "echo test"
def test_create_input_schema_from_json_schema(self):
"""Test JSON schema to Pydantic model conversion."""
schema = {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "desc"},
"bar": {"type": "integer"},
},
"required": ["foo"],
}
model_class = util.create_input_schema_from_json_schema(schema)
instance = model_class(foo="abc", bar=1)
assert instance.foo == "abc"
assert instance.bar == 1
with pytest.raises(Exception): # noqa: B017, PT011
model_class(bar=1) # missing required field
@pytest.mark.asyncio
async def test_validate_connection_params(self):
"""Test connection parameter validation."""
# Valid parameters
await util._validate_connection_params("Stdio", command="echo test")
await util._validate_connection_params("SSE", url="http://test")
# Invalid parameters
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
await util._validate_connection_params("Stdio", command=None)
with pytest.raises(ValueError, match="URL is required for SSE mode"):
await util._validate_connection_params("SSE", url=None)
with pytest.raises(ValueError, match="Invalid mode"):
await util._validate_connection_params("InvalidMode")
@pytest.mark.asyncio
async def test_get_flow_snake_case_mocked(self):
"""Test flow lookup by snake case name with mocked session."""
class DummyFlow:
def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None):
self.name = name
self.user_id = user_id
self.is_component = is_component
self.action_name = action_name
class DummyExec:
def __init__(self, flows: list[DummyFlow]):
self._flows = flows
def all(self):
return self._flows
class DummySession:
def __init__(self, flows: list[DummyFlow]):
self._flows = flows
async def exec(self, stmt): # noqa: ARG002
return DummyExec(self._flows)
user_id = "123e4567-e89b-12d3-a456-426614174000"
flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)]
# Should match sanitized name
result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows))
assert result is flows[0]
# Should return None if not found
result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows))
assert result is None

View file

@ -34,6 +34,9 @@ export const useAddMCPServer: useMutationFunctionType<
if (body.env && Object.keys(body.env).length > 0) {
payload.env = body.env;
}
if (body.headers && Object.keys(body.headers).length > 0) {
payload.headers = body.headers;
}
const res = await api.post(
`${getURL("MCP_SERVERS", undefined, true)}/${body.name}`,

View file

@ -34,6 +34,9 @@ export const usePatchMCPServer: useMutationFunctionType<
if (body.env && Object.keys(body.env).length > 0) {
payload.env = body.env;
}
if (body.headers && Object.keys(body.headers).length > 0) {
payload.headers = body.headers;
}
const res = await api.patch(
`${getURL("MCP_SERVERS", undefined, true)}/${body.name}`,

View file

@ -1,5 +1,5 @@
import _ from "lodash";
import { useEffect, useState } from "react";
import { useCallback, useEffect, useState } from "react";
import IconComponent from "../../../../../components/common/genericIconComponent";
import { Input } from "../../../../../components/ui/input";
import { classNames } from "../../../../../utils/utils";
@ -10,6 +10,7 @@ export type IOKeyPairInputProps = {
duplicateKey: boolean;
isList: boolean;
isInputField?: boolean;
testId?: string;
};
const IOKeyPairInput = ({
@ -18,10 +19,11 @@ const IOKeyPairInput = ({
duplicateKey,
isList = true,
isInputField,
testId,
}: IOKeyPairInputProps) => {
const checkValueType = (value) => {
const checkValueType = useCallback((value) => {
return Array.isArray(value) ? value : [value];
};
}, []);
const [currentData, setCurrentData] = useState<any[]>(() => {
return !value || value?.length === 0 ? [{ "": "" }] : checkValueType(value);
@ -32,90 +34,105 @@ const IOKeyPairInput = ({
const newData =
!value || value?.length === 0 ? [{ "": "" }] : checkValueType(value);
setCurrentData(newData);
}, [value]);
}, [value, checkValueType]);
const handleChangeKey = (event, idx) => {
const oldKey = Object.keys(currentData[idx])[0];
const updatedObj = { [event.target.value]: currentData[idx][oldKey] };
const handleChangeKey = (event, objIndex) => {
const oldKey = Object.keys(currentData[objIndex])[0];
const updatedObj = { [event.target.value]: currentData[objIndex][oldKey] };
const newData = [...currentData];
newData[idx] = updatedObj;
newData[objIndex] = updatedObj;
setCurrentData(newData);
onChange(newData);
};
const handleChangeValue = (newValue, idx) => {
const key = Object.keys(currentData[idx])[0];
const handleChangeValue = (newValue, objIndex) => {
const key = Object.keys(currentData[objIndex])[0];
const newData = [...currentData];
newData[idx] = { ...newData[idx], [key]: newValue };
newData[objIndex] = { ...newData[objIndex], [key]: newValue };
setCurrentData(newData);
onChange(newData);
};
// Create flat array with additional metadata for rendering
const flattenedData =
currentData?.flatMap((obj, objIndex) => {
return Object.keys(obj).map((key) => ({
key,
value: obj[key],
objIndex,
uniqueId: `${objIndex}-${key}`, // Create unique identifier for React key
}));
}) || [];
return (
<>
<div className={classNames("flex h-full flex-col gap-3")}>
{currentData?.map((obj, index) => {
return Object.keys(obj).map((key, idx) => {
return (
<div key={idx} className="flex w-full gap-2">
<Input
type="text"
value={key.trim()}
className={classNames(duplicateKey ? "input-invalid" : "")}
placeholder="Type key..."
onChange={(event) => handleChangeKey(event, index)}
disabled={!isInputField}
/>
<div className={classNames("flex h-full flex-col gap-3")}>
{flattenedData.map((item, idx) => {
return (
<div key={item.uniqueId} className="flex w-full gap-2">
<Input
type="text"
value={item.key.trim()}
className={classNames(duplicateKey ? "input-invalid" : "")}
placeholder="Type key..."
onChange={(event) => handleChangeKey(event, item.objIndex)}
disabled={!isInputField}
data-testid={testId ? `${testId}-key-${idx}` : undefined}
/>
<Input
type="text"
value={obj[key]}
placeholder="Type a value..."
onChange={(event) =>
handleChangeValue(event.target.value, index)
}
disabled={!isInputField}
/>
<Input
type="text"
value={item.value}
placeholder="Type a value..."
onChange={(event) =>
handleChangeValue(event.target.value, item.objIndex)
}
disabled={!isInputField}
data-testid={testId ? `${testId}-value-${idx}` : undefined}
/>
{isList && isInputField && index === currentData.length - 1 ? (
<button
type="button"
onClick={() => {
const newInputList = _.cloneDeep(currentData);
newInputList.push({ "": "" });
setCurrentData(newInputList);
onChange(newInputList);
}}
>
<IconComponent
name="Plus"
className={"h-4 w-4 hover:text-accent-foreground"}
/>
</button>
) : isList && isInputField ? (
<button
type="button"
onClick={() => {
const newInputList = _.cloneDeep(currentData);
newInputList.splice(index, 1);
setCurrentData(newInputList);
onChange(newInputList);
}}
>
<IconComponent
name="X"
className="h-4 w-4 hover:text-status-red"
/>
</button>
) : (
""
)}
</div>
);
});
})}
</div>
</>
{isList &&
isInputField &&
item.objIndex === currentData.length - 1 ? (
<button
type="button"
onClick={() => {
const newInputList = _.cloneDeep(currentData);
newInputList.push({ "": "" });
setCurrentData(newInputList);
onChange(newInputList);
}}
data-testid={testId ? `${testId}-plus-btn-0` : undefined}
>
<IconComponent
name="Plus"
className={"h-4 w-4 hover:text-accent-foreground"}
/>
</button>
) : isList && isInputField ? (
<button
type="button"
onClick={() => {
const newInputList = _.cloneDeep(currentData);
newInputList.splice(item.objIndex, 1);
setCurrentData(newInputList);
onChange(newInputList);
}}
data-testid={
testId ? `${testId}-minus-btn-${item.objIndex}` : undefined
}
>
<IconComponent
name="X"
className="h-4 w-4 hover:text-status-red"
/>
</button>
) : (
""
)}
</div>
);
})}
</div>
);
};

View file

@ -358,6 +358,7 @@ export default function AddMcpServerModal({
listAddLabel="Add Argument"
editNode={false}
id="stdio-args"
data-testid="stdio-args-input"
/>
</div>
<div className="flex flex-col gap-2">
@ -368,6 +369,7 @@ export default function AddMcpServerModal({
duplicateKey={false}
isList={true}
isInputField={true}
testId="stdio-env"
/>
</div>
</div>
@ -382,6 +384,7 @@ export default function AddMcpServerModal({
value={sseName}
onChange={(e) => setSseName(e.target.value)}
placeholder="Name"
data-testid="sse-name-input"
disabled={isPending}
/>
</div>
@ -393,6 +396,7 @@ export default function AddMcpServerModal({
value={sseUrl}
onChange={(e) => setSseUrl(e.target.value)}
placeholder="SSE URL"
data-testid="sse-url-input"
disabled={isPending}
/>
</div>
@ -404,6 +408,7 @@ export default function AddMcpServerModal({
duplicateKey={false}
isList={true}
isInputField={true}
testId="sse-headers"
/>
</div>
<div className="flex flex-col gap-2">
@ -414,6 +419,7 @@ export default function AddMcpServerModal({
duplicateKey={false}
isList={true}
isInputField={true}
testId="sse-env"
/>
</div>
</div>

View file

@ -95,5 +95,9 @@ export function extractMcpServersFromJson(
args: server.args || [],
env: server.env && typeof server.env === "object" ? server.env : {},
url: server.url,
headers:
server.headers && typeof server.headers === "object"
? server.headers
: {},
}));
}

View file

@ -54,7 +54,9 @@ test(
timeout: 30000,
});
await page.getByTestId("stdio-name-input").fill("test server");
const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number
const testName = `test_server_${randomSuffix}`;
await page.getByTestId("stdio-name-input").fill(testName);
await page.getByTestId("stdio-command-input").fill("uvx mcp-server-fetch");
@ -115,12 +117,12 @@ test(
timeout: 3000,
});
await expect(page.getByText("test_server")).toBeVisible({
await expect(page.getByText(testName)).toBeVisible({
timeout: 3000,
});
await page
.getByTestId(`mcp-server-menu-button-test_server`)
.getByTestId(`mcp-server-menu-button-${testName}`)
.click({ timeout: 3000 });
await page
@ -152,7 +154,7 @@ test(
await page.getByTestId("add-mcp-server-button").click();
await page
.getByTestId(`mcp-server-menu-button-test_server`)
.getByTestId(`mcp-server-menu-button-${testName}`)
.click({ timeout: 3000 });
await page
@ -177,7 +179,7 @@ test(
await page.waitForTimeout(3000);
await expect(page.getByText("test_server")).not.toBeVisible({
await expect(page.getByText(testName)).not.toBeVisible({
timeout: 3000,
});
@ -199,8 +201,360 @@ test(
});
await page.getByTestId("mcp-server-dropdown").click({ timeout: 10000 });
await expect(page.getByText("test_server")).toHaveCount(2, {
await expect(page.getByText(testName)).toHaveCount(2, {
timeout: 10000,
});
},
);
test(
"STDIO MCP server fields should persist after saving and editing",
{ tag: ["@release", "@workspace", "@components"] },
async ({ page }) => {
await awaitBootstrapTest(page);
await page.waitForSelector('[data-testid="blank-flow"]', {
timeout: 30000,
});
await page.getByTestId("blank-flow").click();
await page.getByTestId("sidebar-search-input").click();
await page.getByTestId("sidebar-search-input").fill("mcp tools");
await page.waitForSelector('[data-testid="agentsMCP Tools"]', {
timeout: 30000,
});
await page
.getByTestId("agentsMCP Tools")
.dragTo(page.locator('//*[@id="react-flow-id"]'), {
targetPosition: { x: 0, y: 0 },
});
await page.getByTestId("fit_view").click();
await zoomOut(page, 3);
try {
await page.getByText("Add MCP Server", { exact: true }).click({
timeout: 5000,
});
} catch (_error) {
await page.getByTestId("mcp-server-dropdown").click({ timeout: 3000 });
await page.getByText("Add MCP Server", { exact: true }).click({
timeout: 5000,
});
}
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
state: "visible",
timeout: 30000,
});
// Go to STDIO tab and fill all fields
await page.getByTestId("stdio-tab").click();
await page.waitForSelector('[data-testid="stdio-name-input"]', {
state: "visible",
timeout: 30000,
});
// Test data with random suffix
const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number
const testName = `test_stdio_server_${randomSuffix}`;
const testCommand = "uvx mcp-server-test";
const testArg1 = "--verbose";
const testArg2 = "--port=8080";
const testArg3 = "--config=test.json";
const testEnvKey1 = "NODE_ENV";
const testEnvValue1 = "production";
const testEnvKey2 = "DEBUG_MODE";
const testEnvValue2 = "true";
// Fill basic fields
await page.getByTestId("stdio-name-input").fill(testName);
await page.getByTestId("stdio-command-input").fill(testCommand);
// Add first argument
await page.getByTestId("stdio-args_0").fill(testArg1);
// Add second argument by clicking plus button
await page.getByTestId("input-list-plus-btn_-0").click();
await page.getByTestId("stdio-args_1").fill(testArg2);
// Add third argument
await page.getByTestId("input-list-plus-btn_-0").click();
await page.getByTestId("stdio-args_2").fill(testArg3);
// Add first environment variable
await page.getByTestId("stdio-env-key-0").fill(testEnvKey1);
await page.getByTestId("stdio-env-value-0").fill(testEnvValue1);
// Add second environment variable
await page.getByTestId("stdio-env-plus-btn-0").click();
await page.getByTestId("stdio-env-key-1").fill(testEnvKey2);
await page.getByTestId("stdio-env-value-1").fill(testEnvValue2);
// Save the server
await page.getByTestId("add-mcp-server-button").click();
// Wait for server to be created
await page.waitForTimeout(2000);
// Go to settings to edit the server
await page.getByTestId("user_menu_button").click({ timeout: 3000 });
await page.getByTestId("menu_settings_button").click({ timeout: 3000 });
await page.waitForSelector('[data-testid="sidebar-nav-MCP Servers"]', {
timeout: 30000,
});
await page.getByTestId("sidebar-nav-MCP Servers").click({ timeout: 3000 });
await page.waitForSelector('[data-testid="add-mcp-server-button-page"]', {
timeout: 3000,
});
// Find and edit the server
await expect(page.getByText(testName)).toBeVisible({
timeout: 3000,
});
await page
.getByTestId(`mcp-server-menu-button-${testName}`)
.click({ timeout: 3000 });
await page
.getByText("Edit", { exact: true })
.first()
.click({ timeout: 3000 });
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
state: "visible",
timeout: 30000,
});
// Verify all fields persisted correctly
expect(await page.getByTestId("stdio-name-input").inputValue()).toBe(
testName,
);
expect(await page.getByTestId("stdio-command-input").inputValue()).toBe(
testCommand,
);
expect(await page.getByTestId("stdio-args_0").inputValue()).toBe(testArg1);
expect(await page.getByTestId("stdio-args_1").inputValue()).toBe(testArg2);
expect(await page.getByTestId("stdio-args_2").inputValue()).toBe(testArg3);
expect(await page.getByTestId("stdio-env-key-0").last().inputValue()).toBe(
testEnvKey1,
);
expect(
await page.getByTestId("stdio-env-value-0").last().inputValue(),
).toBe(testEnvValue1);
expect(await page.getByTestId("stdio-env-key-1").last().inputValue()).toBe(
testEnvKey2,
);
expect(
await page.getByTestId("stdio-env-value-1").last().inputValue(),
).toBe(testEnvValue2);
// Clean up - cancel the edit modal
await page.keyboard.press("Escape");
// Delete the test server
await page
.getByTestId(`mcp-server-menu-button-${testName}`)
.click({ timeout: 3000 });
await page
.getByText("Delete", { exact: true })
.first()
.click({ timeout: 3000 });
await page.waitForSelector(
'[data-testid="btn_delete_delete_confirmation_modal"]',
{
timeout: 3000,
},
);
await page
.getByTestId("btn_delete_delete_confirmation_modal")
.click({ timeout: 3000 });
},
);
test(
"SSE MCP server fields should persist after saving and editing",
{ tag: ["@release", "@workspace", "@components"] },
async ({ page }) => {
await awaitBootstrapTest(page);
await page.waitForSelector('[data-testid="blank-flow"]', {
timeout: 30000,
});
await page.getByTestId("blank-flow").click();
await page.getByTestId("sidebar-search-input").click();
await page.getByTestId("sidebar-search-input").fill("mcp tools");
await page.waitForSelector('[data-testid="agentsMCP Tools"]', {
timeout: 30000,
});
await page
.getByTestId("agentsMCP Tools")
.dragTo(page.locator('//*[@id="react-flow-id"]'), {
targetPosition: { x: 0, y: 0 },
});
await page.getByTestId("fit_view").click();
await zoomOut(page, 3);
try {
await page.getByText("Add MCP Server", { exact: true }).click({
timeout: 5000,
});
} catch (_error) {
await page.getByTestId("mcp-server-dropdown").click({ timeout: 3000 });
await page.getByText("Add MCP Server", { exact: true }).click({
timeout: 5000,
});
}
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
state: "visible",
timeout: 30000,
});
// Go to SSE tab and fill all fields
await page.getByTestId("sse-tab").click();
await page.waitForSelector('[data-testid="sse-name-input"]', {
state: "visible",
timeout: 30000,
});
// Test data with random suffix
const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number
const testName = `test_sse_server_${randomSuffix}`;
const testUrl = "https://api.example.com/mcp";
const testHeaderKey1 = "Authorization";
const testHeaderValue1 = "Bearer token123";
const testHeaderKey2 = "Content-Type";
const testHeaderValue2 = "application/json";
const testEnvKey1 = "API_TIMEOUT";
const testEnvValue1 = "30000";
const testEnvKey2 = "RETRY_COUNT";
const testEnvValue2 = "3";
// Fill basic fields
await page.getByTestId("sse-name-input").fill(testName);
await page.getByTestId("sse-url-input").fill(testUrl);
// Add first header
await page.getByTestId("sse-headers-key-0").fill(testHeaderKey1);
await page.getByTestId("sse-headers-value-0").fill(testHeaderValue1);
// Add second header
await page.getByTestId("sse-headers-plus-btn-0").click();
await page.getByTestId("sse-headers-key-1").fill(testHeaderKey2);
await page.getByTestId("sse-headers-value-1").fill(testHeaderValue2);
// Add first environment variable
await page.getByTestId("sse-env-key-0").fill(testEnvKey1);
await page.getByTestId("sse-env-value-0").fill(testEnvValue1);
// Add second environment variable
await page.getByTestId("sse-env-plus-btn-0").click();
await page.getByTestId("sse-env-key-1").fill(testEnvKey2);
await page.getByTestId("sse-env-value-1").fill(testEnvValue2);
// Save the server
await page.getByTestId("add-mcp-server-button").click();
// Wait for server to be created
await page.waitForTimeout(2000);
// Go to settings to edit the server
await page.getByTestId("user_menu_button").click({ timeout: 3000 });
await page.getByTestId("menu_settings_button").click({ timeout: 3000 });
await page.waitForSelector('[data-testid="sidebar-nav-MCP Servers"]', {
timeout: 30000,
});
await page.getByTestId("sidebar-nav-MCP Servers").click({ timeout: 3000 });
await page.waitForSelector('[data-testid="add-mcp-server-button-page"]', {
timeout: 3000,
});
// Find and edit the server
await expect(page.getByText(testName)).toBeVisible({
timeout: 3000,
});
await page
.getByTestId(`mcp-server-menu-button-${testName}`)
.click({ timeout: 3000 });
await page
.getByText("Edit", { exact: true })
.first()
.click({ timeout: 3000 });
await page.waitForSelector('[data-testid="add-mcp-server-button"]', {
state: "visible",
timeout: 30000,
});
// Verify all fields persisted correctly
expect(await page.getByTestId("sse-name-input").inputValue()).toBe(
testName,
);
expect(await page.getByTestId("sse-url-input").inputValue()).toBe(testUrl);
expect(await page.getByTestId("sse-headers-key-0").inputValue()).toBe(
testHeaderKey1,
);
expect(await page.getByTestId("sse-headers-value-0").inputValue()).toBe(
testHeaderValue1,
);
expect(await page.getByTestId("sse-headers-key-1").inputValue()).toBe(
testHeaderKey2,
);
expect(await page.getByTestId("sse-headers-value-1").inputValue()).toBe(
testHeaderValue2,
);
expect(await page.getByTestId("sse-env-key-0").inputValue()).toBe(
testEnvKey1,
);
expect(await page.getByTestId("sse-env-value-0").inputValue()).toBe(
testEnvValue1,
);
expect(await page.getByTestId("sse-env-key-1").inputValue()).toBe(
testEnvKey2,
);
expect(await page.getByTestId("sse-env-value-1").inputValue()).toBe(
testEnvValue2,
);
// Clean up - cancel the edit modal
await page.keyboard.press("Escape");
// Delete the test server
await page
.getByTestId(`mcp-server-menu-button-${testName}`)
.click({ timeout: 3000 });
await page
.getByText("Delete", { exact: true })
.first()
.click({ timeout: 3000 });
await page.waitForSelector(
'[data-testid="btn_delete_delete_confirmation_modal"]',
{
timeout: 3000,
},
);
await page
.getByTestId("btn_delete_delete_confirmation_modal")
.click({ timeout: 3000 });
},
);