From 41e101499a5f885fffcbdc305e25f890c7ee2a67 Mon Sep 17 00:00:00 2001 From: Lucas Oliveira <62335616+lucaseduoli@users.noreply.github.com> Date: Fri, 25 Jul 2025 22:16:18 -0300 Subject: [PATCH] fix: store mcp sse headers and use them on connection (#9148) * Store mcp server headers * Add headers on pre check url and is valid url * adds validation of headers according to RFC 7230 * Fixed sanitized value * Added backend tests for mcp util.py to increase coverage * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * make key pair input use flatmap id on data test ids * added testids * added random test names and added tests for persistence * fix ruff lint * [autofix.ci] apply automated fixes * Fix mypy lint errors --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/backend/base/langflow/base/mcp/util.py | 109 ++- .../components/data/test_mcp_component.py | 644 +++++++++++++++--- .../API/queries/mcp/use-add-mcp-server.ts | 3 + .../API/queries/mcp/use-patch-mcp-server.ts | 3 + .../IOFieldView/components/key-pair-input.tsx | 163 +++-- .../src/modals/addMcpServerModal/index.tsx | 6 + src/frontend/src/utils/mcpUtils.ts | 4 + .../extended/features/mcp-server.spec.ts | 366 +++++++++- 8 files changed, 1101 insertions(+), 197 deletions(-) diff --git a/src/backend/base/langflow/base/mcp/util.py b/src/backend/base/langflow/base/mcp/util.py index 2f35b312f..353b5d657 100644 --- a/src/backend/base/langflow/base/mcp/util.py +++ b/src/backend/base/langflow/base/mcp/util.py @@ -30,6 +30,86 @@ HTTP_NOT_FOUND = 404 HTTP_BAD_REQUEST = 400 HTTP_INTERNAL_SERVER_ERROR = 500 +# RFC 7230 compliant header name pattern: token = 1*tchar +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +HEADER_NAME_PATTERN = re.compile(r"^[!#$%&\'*+\-.0-9A-Z^_`a-z|~]+$") + +# Common allowed headers for MCP connections +ALLOWED_HEADERS = { + "authorization", + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "content-type", + "user-agent", + "x-api-key", + "x-auth-token", + "x-custom-header", + "x-langflow-session", + "x-mcp-client", + "x-requested-with", +} + + +def validate_headers(headers: dict[str, str]) -> dict[str, str]: + """Validate and sanitize HTTP headers according to RFC 7230. + + Args: + headers: Dictionary of header name-value pairs + + Returns: + Dictionary of validated and sanitized headers + + Raises: + ValueError: If headers contain invalid names or values + """ + if not headers: + return {} + + sanitized_headers = {} + + for name, value in headers.items(): + if not isinstance(name, str) or not isinstance(value, str): + logger.warning(f"Skipping non-string header: {name}={value}") + continue + + # Validate header name according to RFC 7230 + if not HEADER_NAME_PATTERN.match(name): + logger.warning(f"Invalid header name '{name}', skipping") + continue + + # Normalize header name to lowercase (HTTP headers are case-insensitive) + normalized_name = name.lower() + + # Optional: Check against whitelist of allowed headers + if normalized_name not in ALLOWED_HEADERS: + # For MCP, we'll be permissive and allow non-standard headers + # but log a warning for security awareness + logger.debug(f"Using non-standard header: {normalized_name}") + + # Check for potential header injection attempts BEFORE sanitizing + if "\r" in value or "\n" in value: + logger.warning(f"Potential header injection detected in '{name}', skipping") + continue + + # Sanitize header value - remove control characters and newlines + # RFC 7230: field-value = *( field-content / obs-fold ) + # We'll remove control characters (0x00-0x1F, 0x7F) except tab (0x09) and space (0x20) + sanitized_value = re.sub(r"[\x00-\x08\x0A-\x1F\x7F]", "", value) + + # Remove leading/trailing whitespace + sanitized_value = sanitized_value.strip() + + if not sanitized_value: + logger.warning(f"Header '{name}' has empty value after sanitization, skipping") + continue + + sanitized_headers[normalized_name] = sanitized_value + + return sanitized_headers + def sanitize_mcp_name(name: str, max_length: int = 46) -> str: """Sanitize a name for MCP usage by removing emojis, diacritics, and special characters. @@ -334,12 +414,12 @@ def _process_headers(headers: Any) -> dict: Args: headers: The headers to process, can be dict, str, or list Returns: - Processed dictionary + Processed and validated dictionary """ if headers is None: return {} if isinstance(headers, dict): - return headers + return validate_headers(headers) if isinstance(headers, list): processed_headers = {} try: @@ -351,7 +431,7 @@ def _process_headers(headers: Any) -> dict: processed_headers[key] = value except (KeyError, TypeError, ValueError): return {} # Return empty dictionary instead of None - return processed_headers + return validate_headers(processed_headers) return {} @@ -924,7 +1004,7 @@ class MCPSseClient: self._component_cache.set("mcp_session_manager", session_manager) return session_manager - async def validate_url(self, url: str | None) -> tuple[bool, str]: + async def validate_url(self, url: str | None, headers: dict[str, str] | None = None) -> tuple[bool, str]: """Validate the SSE URL before attempting connection.""" try: parsed = urlparse(url) @@ -935,7 +1015,9 @@ class MCPSseClient: try: # For SSE endpoints, try a GET request with short timeout # Many SSE servers don't support HEAD requests and return 404 - response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"}) + response = await client.get( + url, timeout=2.0, headers={"Accept": "text/event-stream", **(headers or {})} + ) # For SSE, we expect the server to either: # 1. Start streaming (200) @@ -967,14 +1049,16 @@ class MCPSseClient: except (httpx.HTTPError, ValueError, OSError) as e: return False, f"URL validation error: {e!s}" - async def pre_check_redirect(self, url: str | None) -> str | None: + async def pre_check_redirect(self, url: str | None, headers: dict[str, str] | None = None) -> str | None: """Check for redirects and return the final URL.""" if url is None: return url try: async with httpx.AsyncClient(follow_redirects=False) as client: # Use GET with SSE headers instead of HEAD since many SSE servers don't support HEAD - response = await client.get(url, timeout=2.0, headers={"Accept": "text/event-stream"}) + response = await client.get( + url, timeout=2.0, headers={"Accept": "text/event-stream", **(headers or {})} + ) if response.status_code == httpx.codes.TEMPORARY_REDIRECT: return response.headers.get("Location", url) # Don't treat 404 as an error here - let the main connection handle it @@ -990,22 +1074,23 @@ class MCPSseClient: sse_read_timeout_seconds: int = 30, ) -> list[StructuredTool]: """Connect to MCP server using SSE transport (SDK style).""" - if headers is None: - headers = {} + # Validate and sanitize headers early + validated_headers = _process_headers(headers) + if url is None: msg = "URL is required for SSE mode" raise ValueError(msg) - is_valid, error_msg = await self.validate_url(url) + is_valid, error_msg = await self.validate_url(url, validated_headers) if not is_valid: msg = f"Invalid SSE URL ({url}): {error_msg}" raise ValueError(msg) - url = await self.pre_check_redirect(url) + url = await self.pre_check_redirect(url, validated_headers) # Store connection parameters for later use in run_tool self._connection_params = { "url": url, - "headers": headers, + "headers": validated_headers, "timeout_seconds": timeout_seconds, "sse_read_timeout_seconds": sse_read_timeout_seconds, } diff --git a/src/backend/tests/unit/components/data/test_mcp_component.py b/src/backend/tests/unit/components/data/test_mcp_component.py index 00529ca67..339eec4a6 100644 --- a/src/backend/tests/unit/components/data/test_mcp_component.py +++ b/src/backend/tests/unit/components/data/test_mcp_component.py @@ -1,14 +1,12 @@ -import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest -from langflow.components.agents.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent +from langflow.base.mcp import util +from langflow.base.mcp.util import MCPSessionManager, MCPSseClient, MCPStdioClient, _process_headers, validate_headers +from langflow.components.agents.mcp_component import MCPToolsComponent from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping -# TODO: This test suite is incomplete and is in need of an update to handle the latest MCP component changes. -pytestmark = pytest.mark.skip(reason="Skipping entire file") - class TestMCPToolsComponent(ComponentTestBaseWithoutClient): @pytest.fixture @@ -20,9 +18,7 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient): def default_kwargs(self): """Return the default kwargs for the component.""" return { - "mode": "Stdio", - "command": "uvx mcp-server-fetch", - "sse_url": "http://localhost:7860/api/v1/mcp/sse", + "mcp_server": {"name": "test_server", "config": {"command": "uvx mcp-server-fetch"}}, "tool": "", } @@ -40,24 +36,22 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient): tool.inputSchema = { "type": "object", "properties": {"test_param": {"type": "string", "description": "Test parameter"}}, + "required": ["test_param"], } return tool @pytest.fixture - def mock_stdio_client(self, mock_tool): - """Create a mock stdio client.""" - stdio_client = AsyncMock() - stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool]) - stdio_client.session = AsyncMock() - return stdio_client - - @pytest.fixture - def mock_sse_client(self, mock_tool): - """Create a mock SSE client.""" - sse_client = AsyncMock() - sse_client.connect_to_server = AsyncMock(return_value=[mock_tool]) - sse_client.session = AsyncMock() - return sse_client + def mock_session(self, mock_tool): + """Create a mock ClientSession.""" + session = AsyncMock() + session.initialize = AsyncMock() + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool] + session.list_tools = AsyncMock(return_value=list_tools_result) + session.call_tool = AsyncMock( + return_value=MagicMock(content=[MagicMock(model_dump=lambda: {"result": "success"})]) + ) + return session class TestMCPStdioClient: @@ -65,40 +59,69 @@ class TestMCPStdioClient: def stdio_client(self): return MCPStdioClient() - async def test_connect_to_server(self, stdio_client): - """Test connecting to server via Stdio.""" - # Create mock for stdio transport - mock_stdio = AsyncMock() - mock_write = AsyncMock() - mock_stdio_transport = (mock_stdio, mock_write) - mock_stdio_cm = AsyncMock() - mock_stdio_cm.__aenter__.return_value = mock_stdio_transport + @pytest.fixture + def mock_session_manager(self): + """Create a mock session manager.""" + return AsyncMock(spec=MCPSessionManager) - # Mock the stdio_client function to return our mock context manager - with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm): - # Mock ClientSession + async def test_connect_to_server_with_command(self, stdio_client): + """Test connecting to server via Stdio with command.""" + with patch.object(stdio_client, "_get_or_create_session") as mock_get_session: + # Mock session mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools.return_value.tools = [MagicMock()] + mock_tool = MagicMock() + mock_tool.name = "test_tool" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_get_session.return_value = mock_session - # Mock the AsyncExitStack - mock_exit_stack = AsyncMock() - mock_exit_stack.enter_async_context = AsyncMock() - mock_exit_stack.enter_async_context.side_effect = [ - mock_stdio_transport, # For stdio_client - mock_session, # For ClientSession - ] - stdio_client.exit_stack = mock_exit_stack - - tools = await stdio_client.connect_to_server("test_command") + tools = await stdio_client.connect_to_server("uvx test-command") assert len(tools) == 1 - assert stdio_client.session is not None - # Verify the exit stack was used correctly - assert mock_exit_stack.enter_async_context.call_count == 2 - # Verify the stdio transport was properly set - assert stdio_client.stdio == mock_stdio - assert stdio_client.write == mock_write + assert tools[0].name == "test_tool" + assert stdio_client._connected is True + assert stdio_client._connection_params is not None + + async def test_run_tool_success(self, stdio_client): + """Test successfully running a tool.""" + # Setup connection state + stdio_client._connected = True + stdio_client._connection_params = MagicMock() + stdio_client._session_context = "test_context" + + with patch.object(stdio_client, "_get_or_create_session") as mock_get_session: + mock_session = AsyncMock() + mock_result = MagicMock() + mock_session.call_tool = AsyncMock(return_value=mock_result) + mock_get_session.return_value = mock_session + + result = await stdio_client.run_tool("test_tool", {"param": "value"}) + + assert result == mock_result + mock_session.call_tool.assert_called_once_with("test_tool", arguments={"param": "value"}) + + async def test_run_tool_without_connection(self, stdio_client): + """Test running a tool without being connected.""" + stdio_client._connected = False + + with pytest.raises(ValueError, match="Session not initialized"): + await stdio_client.run_tool("test_tool", {}) + + async def test_disconnect_cleanup(self, stdio_client): + """Test that disconnect properly cleans up resources.""" + stdio_client._session_context = "test_context" + stdio_client._connected = True + + with patch.object(stdio_client, "_get_session_manager") as mock_get_manager: + mock_manager = AsyncMock() + mock_get_manager.return_value = mock_manager + + await stdio_client.disconnect() + + mock_manager._cleanup_session.assert_called_once_with("test_context") + assert stdio_client.session is None + assert stdio_client._connected is False class TestMCPSseClient: @@ -106,78 +129,487 @@ class TestMCPSseClient: def sse_client(self): return MCPSseClient() - async def test_pre_check_redirect(self, sse_client): - """Test pre-checking URL for redirects.""" + async def test_validate_url_valid(self, sse_client): + """Test URL validation with valid URL.""" + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.return_value.__aenter__.return_value.get.return_value = mock_response + + is_valid, error_msg = await sse_client.validate_url("http://test.url", {}) + + assert is_valid is True + assert error_msg == "" + + async def test_validate_url_invalid_format(self, sse_client): + """Test URL validation with invalid format.""" + is_valid, error_msg = await sse_client.validate_url("invalid-url", {}) + + assert is_valid is False + assert "Invalid URL format" in error_msg + + async def test_validate_url_with_404_response(self, sse_client): + """Test URL validation with 404 response (should be valid for SSE).""" + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 404 + mock_client.return_value.__aenter__.return_value.get.return_value = mock_response + + is_valid, error_msg = await sse_client.validate_url("http://test.url", {}) + + assert is_valid is True + assert error_msg == "" + + async def test_connect_to_server_with_headers(self, sse_client): + """Test connecting to server via SSE with custom headers.""" + test_url = "http://test.url" + test_headers = {"Authorization": "Bearer token123", "Custom-Header": "value"} + expected_headers = {"authorization": "Bearer token123", "custom-header": "value"} # normalized + + with ( + patch.object(sse_client, "validate_url", return_value=(True, "")), + patch.object(sse_client, "pre_check_redirect", return_value=test_url), + patch.object(sse_client, "_get_or_create_session") as mock_get_session, + ): + # Mock session + mock_session = AsyncMock() + mock_tool = MagicMock() + mock_tool.name = "test_tool" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_get_session.return_value = mock_session + + tools = await sse_client.connect_to_server(test_url, test_headers) + + assert len(tools) == 1 + assert tools[0].name == "test_tool" + assert sse_client._connected is True + + # Verify headers are stored in connection params (normalized) + assert sse_client._connection_params is not None + assert sse_client._connection_params["headers"] == expected_headers + assert sse_client._connection_params["url"] == test_url + + async def test_headers_passed_to_session_manager(self, sse_client): + """Test that headers are properly passed to the session manager.""" + test_url = "http://test.url" + expected_headers = {"authorization": "Bearer token123", "x-api-key": "secret"} # normalized + + sse_client._session_context = "test_context" + sse_client._connection_params = { + "url": test_url, + "headers": expected_headers, # Use normalized headers + "timeout_seconds": 30, + "sse_read_timeout_seconds": 30, + } + + with patch.object(sse_client, "_get_session_manager") as mock_get_manager: + mock_manager = AsyncMock() + mock_session = AsyncMock() + mock_manager.get_session = AsyncMock(return_value=mock_session) + mock_get_manager.return_value = mock_manager + + result_session = await sse_client._get_or_create_session() + + # Verify session manager was called with correct parameters including normalized headers + mock_manager.get_session.assert_called_once_with("test_context", sse_client._connection_params, "sse") + assert result_session == mock_session + + async def test_pre_check_redirect_with_headers(self, sse_client): + """Test pre-check redirect functionality with custom headers.""" test_url = "http://test.url" redirect_url = "http://redirect.url" + # Use pre-validated headers since pre_check_redirect expects already validated headers + test_headers = {"authorization": "Bearer token123"} # already normalized with patch("httpx.AsyncClient") as mock_client: mock_response = MagicMock() mock_response.status_code = 307 mock_response.headers.get.return_value = redirect_url - mock_client.return_value.__aenter__.return_value.request.return_value = mock_response + mock_client.return_value.__aenter__.return_value.get.return_value = mock_response + + result = await sse_client.pre_check_redirect(test_url, test_headers) - result = await sse_client.pre_check_redirect(test_url) assert result == redirect_url + # Verify validated headers were passed to the request + mock_client.return_value.__aenter__.return_value.get.assert_called_with( + test_url, timeout=2.0, headers={"Accept": "text/event-stream", **test_headers} + ) + + async def test_run_tool_with_retry_on_connection_error(self, sse_client): + """Test that run_tool retries on connection errors.""" + # Setup connection state + sse_client._connected = True + sse_client._connection_params = {"url": "http://test.url", "headers": {}} + sse_client._session_context = "test_context" + + call_count = 0 + + async def mock_get_session_side_effect(): + nonlocal call_count + call_count += 1 + session = AsyncMock() + if call_count == 1: + # First call fails with connection error + from anyio import ClosedResourceError + + session.call_tool = AsyncMock(side_effect=ClosedResourceError()) + else: + # Second call succeeds + mock_result = MagicMock() + session.call_tool = AsyncMock(return_value=mock_result) + return session + + with ( + patch.object(sse_client, "_get_or_create_session", side_effect=mock_get_session_side_effect), + patch.object(sse_client, "_get_session_manager") as mock_get_manager, + ): + mock_manager = AsyncMock() + mock_get_manager.return_value = mock_manager + + result = await sse_client.run_tool("test_tool", {"param": "value"}) + + # Should have retried and succeeded on second attempt + assert call_count == 2 + assert result is not None + # Should have cleaned up the failed session + mock_manager._cleanup_session.assert_called_once_with("test_context") + + +class TestMCPSessionManager: + @pytest.fixture + def session_manager(self): + return MCPSessionManager() + + async def test_session_caching(self, session_manager): + """Test that sessions are properly cached and reused.""" + context_id = "test_context" + connection_params = MagicMock() + transport_type = "stdio" + + # Create a mock session that will appear healthy + mock_session = AsyncMock() + mock_session._write_stream = MagicMock() + mock_session._write_stream._closed = False + + # Create a mock task that appears to be running + mock_task = AsyncMock() + mock_task.done = MagicMock(return_value=False) + + with ( + patch.object(session_manager, "_create_stdio_session") as mock_create, + patch.object(session_manager, "_validate_session_connectivity", return_value=True), + ): + mock_create.return_value = mock_session + + # First call should create session + session1 = await session_manager.get_session(context_id, connection_params, transport_type) + + # Manually populate the sessions cache as if the session was created properly + session_manager.sessions[context_id] = {"session": mock_session, "task": mock_task, "type": transport_type} + + # Second call should return cached session without creating new one + session2 = await session_manager.get_session(context_id, connection_params, transport_type) + + assert session1 == session2 + assert session1 == mock_session + # Should only create once since the second call should use the cached session + mock_create.assert_called_once() + + async def test_session_cleanup(self, session_manager): + """Test session cleanup functionality.""" + context_id = "test_context" + + # Add a session to the manager with proper mock setup + mock_task = AsyncMock() + mock_task.done = MagicMock(return_value=False) # Use MagicMock for sync method + mock_task.cancel = MagicMock() # Use MagicMock for sync method + + session_manager.sessions[context_id] = {"session": AsyncMock(), "task": mock_task, "type": "stdio"} + + await session_manager._cleanup_session(context_id) + + # Should cancel the task and remove from sessions + mock_task.cancel.assert_called_once() + assert context_id not in session_manager.sessions + + async def test_server_switch_detection(self, session_manager): + """Test that server switches are properly detected and handled.""" + context_id = "test_context" + + # First server + server1_params = MagicMock() + server1_params.command = "server1" + + # Second server + server2_params = MagicMock() + server2_params.command = "server2" + + with ( + patch.object(session_manager, "_create_stdio_session") as mock_create, + patch.object(session_manager, "_validate_session_connectivity", return_value=True), + ): + mock_session1 = AsyncMock() + mock_session2 = AsyncMock() + mock_create.side_effect = [mock_session1, mock_session2] + + # First connection + session1 = await session_manager.get_session(context_id, server1_params, "stdio") + + # Switch to different server should create new session + session2 = await session_manager.get_session(context_id, server2_params, "stdio") + + assert session1 != session2 + assert mock_create.call_count == 2 + + +# Integration test for header functionality +class TestHeaderValidation: + """Test the header validation functionality.""" + + def test_validate_headers_valid_input(self): + """Test header validation with valid headers.""" + headers = {"Authorization": "Bearer token123", "Content-Type": "application/json", "X-API-Key": "secret-key"} + + result = validate_headers(headers) + + # Headers should be normalized to lowercase + expected = {"authorization": "Bearer token123", "content-type": "application/json", "x-api-key": "secret-key"} + assert result == expected + + def test_validate_headers_empty_input(self): + """Test header validation with empty/None input.""" + assert validate_headers({}) == {} + assert validate_headers(None) == {} + + def test_validate_headers_invalid_names(self): + """Test header validation with invalid header names.""" + headers = { + "Invalid Header": "value", # spaces not allowed + "Header@Name": "value", # @ not allowed + "Header Name": "value", # spaces not allowed + "Valid-Header": "value", # this should pass + } + + result = validate_headers(headers) + + # Only the valid header should remain + assert result == {"valid-header": "value"} + + def test_validate_headers_sanitize_values(self): + """Test header value sanitization.""" + headers = { + "Authorization": "Bearer \x00token\x1f with\r\ninjection", + "Clean-Header": " clean value ", + "Empty-After-Clean": "\x00\x01\x02", + "Tab-Header": "value\twith\ttabs", # tabs should be preserved + } + + result = validate_headers(headers) + + # Control characters should be removed, whitespace trimmed + # Header with injection attempts should be skipped + expected = {"clean-header": "clean value", "tab-header": "value\twith\ttabs"} + assert result == expected + + def test_validate_headers_non_string_values(self): + """Test header validation with non-string values.""" + headers = {"String-Header": "valid", "Number-Header": 123, "None-Header": None, "List-Header": ["value"]} + + result = validate_headers(headers) + + # Only string headers should remain + assert result == {"string-header": "valid"} + + def test_validate_headers_injection_attempts(self): + """Test header validation against injection attempts.""" + headers = { + "Injection1": "value\r\nInjected-Header: malicious", + "Injection2": "value\nX-Evil: attack", + "Safe-Header": "safe-value", + } + + result = validate_headers(headers) + + # Injection attempts should be filtered out + assert result == {"safe-header": "safe-value"} + + +class TestSSEHeaderIntegration: + """Integration test to verify headers are properly passed through the entire SSE flow.""" + + async def test_headers_processing(self): + """Test that headers flow properly from server config through to SSE client connection.""" + # Test the header processing function directly + headers_input = [ + {"key": "Authorization", "value": "Bearer test-token"}, + {"key": "X-API-Key", "value": "secret-key"}, + ] + + expected_headers = { + "authorization": "Bearer test-token", # normalized to lowercase + "x-api-key": "secret-key", + } + + # Test _process_headers function with validation + processed_headers = _process_headers(headers_input) + assert processed_headers == expected_headers + + # Test different input formats + # Test dict input with validation + dict_headers = {"Authorization": "Bearer dict-token", "Invalid Header": "bad"} + result = _process_headers(dict_headers) + # Invalid header should be filtered out, valid header normalized + assert result == {"authorization": "Bearer dict-token"} + + # Test None input + assert _process_headers(None) == {} + + # Test empty list + assert _process_headers([]) == {} + + # Test malformed list + malformed_headers = [{"key": "Auth"}, {"value": "token"}] # Missing value/key + assert _process_headers(malformed_headers) == {} + + # Test list with invalid header names + invalid_headers = [ + {"key": "Valid-Header", "value": "good"}, + {"key": "Invalid Header", "value": "bad"}, # spaces not allowed + ] + result = _process_headers(invalid_headers) + assert result == {"valid-header": "good"} + + async def test_sse_client_header_storage(self): + """Test that SSE client properly stores headers in connection params.""" + sse_client = MCPSseClient() + test_url = "http://test.url" + test_headers = {"Authorization": "Bearer test123", "Custom": "value"} + expected_headers = {"authorization": "Bearer test123", "custom": "value"} # normalized - async def test_connect_to_server(self, sse_client): - """Test connecting to server via SSE.""" - # Mock the pre_check_redirect first with ( - patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"), patch.object(sse_client, "validate_url", return_value=(True, "")), + patch.object(sse_client, "pre_check_redirect", return_value=test_url), + patch.object(sse_client, "_get_or_create_session") as mock_get_session, ): - # Create mock for sse_client context manager - mock_sse = AsyncMock() - mock_write = AsyncMock() - mock_sse_transport = (mock_sse, mock_write) - mock_sse_cm = AsyncMock() - mock_sse_cm.__aenter__.return_value = mock_sse_transport + mock_session = AsyncMock() + mock_tool = MagicMock() + mock_tool.name = "test_tool" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_get_session.return_value = mock_session - # Mock the sse_client function to return our mock context manager - with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm): - # Mock ClientSession - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools.return_value.tools = [MagicMock()] + await sse_client.connect_to_server(test_url, test_headers) - # Mock the AsyncExitStack - mock_exit_stack = AsyncMock() - mock_exit_stack.enter_async_context = AsyncMock() - mock_exit_stack.enter_async_context.side_effect = [ - mock_sse_transport, # For sse_client - mock_session, # For ClientSession - ] - sse_client.exit_stack = mock_exit_stack + # Verify headers are stored correctly in connection params (normalized) + assert sse_client._connection_params is not None + assert sse_client._connection_params["headers"] == expected_headers + assert sse_client._connection_params["url"] == test_url - tools = await sse_client.connect_to_server("http://test.url", {}) - assert len(tools) == 1 - assert sse_client.session is not None - # Verify the exit stack was used correctly - assert mock_exit_stack.enter_async_context.call_count == 2 - # Verify the SSE transport was properly set - assert sse_client.sse == mock_sse - assert sse_client.write == mock_write +class TestMCPUtilityFunctions: + """Test utility functions from util.py that don't have dedicated test classes.""" - async def test_connect_timeout(self, sse_client): - """Test connection timeout handling.""" - # Set max_retries to 1 to avoid multiple retry attempts - sse_client.max_retries = 1 + def test_sanitize_mcp_name(self): + """Test MCP name sanitization.""" + assert util.sanitize_mcp_name("Test Name 123") == "test_name_123" + assert util.sanitize_mcp_name(" ") == "" + assert util.sanitize_mcp_name("123abc") == "_123abc" + assert util.sanitize_mcp_name("Tést-😀-Námé") == "test_name" + assert util.sanitize_mcp_name("a" * 100) == "a" * 46 - with ( - patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"), - patch.object(sse_client, "validate_url", return_value=(True, "")), # Mock URL validation - patch.object(sse_client, "_connect_with_timeout") as mock_connect, - ): - mock_connect.side_effect = asyncio.TimeoutError() + def test_get_unique_name(self): + """Test unique name generation.""" + names = {"foo", "foo_1"} + assert util.get_unique_name("foo", 10, names) == "foo_2" + assert util.get_unique_name("bar", 10, names) == "bar" + assert util.get_unique_name("longname", 4, {"long"}) == "lo_1" - # Expect ConnectionError instead of TimeoutError - with pytest.raises( - ConnectionError, - match=( - "Failed to connect after 1 attempts. " - "Last error: Connection to http://test.url timed out after 1 seconds" - ), - ): - await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1) + def test_is_valid_key_value_item(self): + """Test key-value item validation.""" + assert util._is_valid_key_value_item({"key": "a", "value": "b"}) is True + assert util._is_valid_key_value_item({"key": "a"}) is False + assert util._is_valid_key_value_item(["key", "value"]) is False + assert util._is_valid_key_value_item(None) is False + + def test_validate_node_installation(self): + """Test Node.js installation validation.""" + import shutil + + if shutil.which("node"): + assert util._validate_node_installation("npx something") == "npx something" + else: + with pytest.raises(ValueError, match="Node.js is not installed"): + util._validate_node_installation("npx something") + assert util._validate_node_installation("echo test") == "echo test" + + def test_create_input_schema_from_json_schema(self): + """Test JSON schema to Pydantic model conversion.""" + schema = { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "desc"}, + "bar": {"type": "integer"}, + }, + "required": ["foo"], + } + model_class = util.create_input_schema_from_json_schema(schema) + instance = model_class(foo="abc", bar=1) + assert instance.foo == "abc" + assert instance.bar == 1 + + with pytest.raises(Exception): # noqa: B017, PT011 + model_class(bar=1) # missing required field + + @pytest.mark.asyncio + async def test_validate_connection_params(self): + """Test connection parameter validation.""" + # Valid parameters + await util._validate_connection_params("Stdio", command="echo test") + await util._validate_connection_params("SSE", url="http://test") + + # Invalid parameters + with pytest.raises(ValueError, match="Command is required for Stdio mode"): + await util._validate_connection_params("Stdio", command=None) + with pytest.raises(ValueError, match="URL is required for SSE mode"): + await util._validate_connection_params("SSE", url=None) + with pytest.raises(ValueError, match="Invalid mode"): + await util._validate_connection_params("InvalidMode") + + @pytest.mark.asyncio + async def test_get_flow_snake_case_mocked(self): + """Test flow lookup by snake case name with mocked session.""" + + class DummyFlow: + def __init__(self, name: str, user_id: str, *, is_component: bool = False, action_name: str | None = None): + self.name = name + self.user_id = user_id + self.is_component = is_component + self.action_name = action_name + + class DummyExec: + def __init__(self, flows: list[DummyFlow]): + self._flows = flows + + def all(self): + return self._flows + + class DummySession: + def __init__(self, flows: list[DummyFlow]): + self._flows = flows + + async def exec(self, stmt): # noqa: ARG002 + return DummyExec(self._flows) + + user_id = "123e4567-e89b-12d3-a456-426614174000" + flows = [DummyFlow("Test Flow", user_id), DummyFlow("Other", user_id)] + + # Should match sanitized name + result = await util.get_flow_snake_case(util.sanitize_mcp_name("Test Flow"), user_id, DummySession(flows)) + assert result is flows[0] + + # Should return None if not found + result = await util.get_flow_snake_case("notfound", user_id, DummySession(flows)) + assert result is None diff --git a/src/frontend/src/controllers/API/queries/mcp/use-add-mcp-server.ts b/src/frontend/src/controllers/API/queries/mcp/use-add-mcp-server.ts index 3bf5b34c2..3fe626f8e 100644 --- a/src/frontend/src/controllers/API/queries/mcp/use-add-mcp-server.ts +++ b/src/frontend/src/controllers/API/queries/mcp/use-add-mcp-server.ts @@ -34,6 +34,9 @@ export const useAddMCPServer: useMutationFunctionType< if (body.env && Object.keys(body.env).length > 0) { payload.env = body.env; } + if (body.headers && Object.keys(body.headers).length > 0) { + payload.headers = body.headers; + } const res = await api.post( `${getURL("MCP_SERVERS", undefined, true)}/${body.name}`, diff --git a/src/frontend/src/controllers/API/queries/mcp/use-patch-mcp-server.ts b/src/frontend/src/controllers/API/queries/mcp/use-patch-mcp-server.ts index 8eff3eae8..d2c4bcc75 100644 --- a/src/frontend/src/controllers/API/queries/mcp/use-patch-mcp-server.ts +++ b/src/frontend/src/controllers/API/queries/mcp/use-patch-mcp-server.ts @@ -34,6 +34,9 @@ export const usePatchMCPServer: useMutationFunctionType< if (body.env && Object.keys(body.env).length > 0) { payload.env = body.env; } + if (body.headers && Object.keys(body.headers).length > 0) { + payload.headers = body.headers; + } const res = await api.patch( `${getURL("MCP_SERVERS", undefined, true)}/${body.name}`, diff --git a/src/frontend/src/modals/IOModal/components/IOFieldView/components/key-pair-input.tsx b/src/frontend/src/modals/IOModal/components/IOFieldView/components/key-pair-input.tsx index 476259092..5cf856563 100644 --- a/src/frontend/src/modals/IOModal/components/IOFieldView/components/key-pair-input.tsx +++ b/src/frontend/src/modals/IOModal/components/IOFieldView/components/key-pair-input.tsx @@ -1,5 +1,5 @@ import _ from "lodash"; -import { useEffect, useState } from "react"; +import { useCallback, useEffect, useState } from "react"; import IconComponent from "../../../../../components/common/genericIconComponent"; import { Input } from "../../../../../components/ui/input"; import { classNames } from "../../../../../utils/utils"; @@ -10,6 +10,7 @@ export type IOKeyPairInputProps = { duplicateKey: boolean; isList: boolean; isInputField?: boolean; + testId?: string; }; const IOKeyPairInput = ({ @@ -18,10 +19,11 @@ const IOKeyPairInput = ({ duplicateKey, isList = true, isInputField, + testId, }: IOKeyPairInputProps) => { - const checkValueType = (value) => { + const checkValueType = useCallback((value) => { return Array.isArray(value) ? value : [value]; - }; + }, []); const [currentData, setCurrentData] = useState(() => { return !value || value?.length === 0 ? [{ "": "" }] : checkValueType(value); @@ -32,90 +34,105 @@ const IOKeyPairInput = ({ const newData = !value || value?.length === 0 ? [{ "": "" }] : checkValueType(value); setCurrentData(newData); - }, [value]); + }, [value, checkValueType]); - const handleChangeKey = (event, idx) => { - const oldKey = Object.keys(currentData[idx])[0]; - const updatedObj = { [event.target.value]: currentData[idx][oldKey] }; + const handleChangeKey = (event, objIndex) => { + const oldKey = Object.keys(currentData[objIndex])[0]; + const updatedObj = { [event.target.value]: currentData[objIndex][oldKey] }; const newData = [...currentData]; - newData[idx] = updatedObj; + newData[objIndex] = updatedObj; setCurrentData(newData); onChange(newData); }; - const handleChangeValue = (newValue, idx) => { - const key = Object.keys(currentData[idx])[0]; + const handleChangeValue = (newValue, objIndex) => { + const key = Object.keys(currentData[objIndex])[0]; const newData = [...currentData]; - newData[idx] = { ...newData[idx], [key]: newValue }; + newData[objIndex] = { ...newData[objIndex], [key]: newValue }; setCurrentData(newData); onChange(newData); }; + // Create flat array with additional metadata for rendering + const flattenedData = + currentData?.flatMap((obj, objIndex) => { + return Object.keys(obj).map((key) => ({ + key, + value: obj[key], + objIndex, + uniqueId: `${objIndex}-${key}`, // Create unique identifier for React key + })); + }) || []; + return ( - <> -
- {currentData?.map((obj, index) => { - return Object.keys(obj).map((key, idx) => { - return ( -
- handleChangeKey(event, index)} - disabled={!isInputField} - /> +
+ {flattenedData.map((item, idx) => { + return ( +
+ handleChangeKey(event, item.objIndex)} + disabled={!isInputField} + data-testid={testId ? `${testId}-key-${idx}` : undefined} + /> - - handleChangeValue(event.target.value, index) - } - disabled={!isInputField} - /> + + handleChangeValue(event.target.value, item.objIndex) + } + disabled={!isInputField} + data-testid={testId ? `${testId}-value-${idx}` : undefined} + /> - {isList && isInputField && index === currentData.length - 1 ? ( - - ) : isList && isInputField ? ( - - ) : ( - "" - )} -
- ); - }); - })} -
- + {isList && + isInputField && + item.objIndex === currentData.length - 1 ? ( + + ) : isList && isInputField ? ( + + ) : ( + "" + )} +
+ ); + })} +
); }; diff --git a/src/frontend/src/modals/addMcpServerModal/index.tsx b/src/frontend/src/modals/addMcpServerModal/index.tsx index d56c30bfd..bc3892a5c 100644 --- a/src/frontend/src/modals/addMcpServerModal/index.tsx +++ b/src/frontend/src/modals/addMcpServerModal/index.tsx @@ -358,6 +358,7 @@ export default function AddMcpServerModal({ listAddLabel="Add Argument" editNode={false} id="stdio-args" + data-testid="stdio-args-input" />
@@ -368,6 +369,7 @@ export default function AddMcpServerModal({ duplicateKey={false} isList={true} isInputField={true} + testId="stdio-env" />
@@ -382,6 +384,7 @@ export default function AddMcpServerModal({ value={sseName} onChange={(e) => setSseName(e.target.value)} placeholder="Name" + data-testid="sse-name-input" disabled={isPending} /> @@ -393,6 +396,7 @@ export default function AddMcpServerModal({ value={sseUrl} onChange={(e) => setSseUrl(e.target.value)} placeholder="SSE URL" + data-testid="sse-url-input" disabled={isPending} /> @@ -404,6 +408,7 @@ export default function AddMcpServerModal({ duplicateKey={false} isList={true} isInputField={true} + testId="sse-headers" />
@@ -414,6 +419,7 @@ export default function AddMcpServerModal({ duplicateKey={false} isList={true} isInputField={true} + testId="sse-env" />
diff --git a/src/frontend/src/utils/mcpUtils.ts b/src/frontend/src/utils/mcpUtils.ts index f1a109493..f14f4bf7a 100644 --- a/src/frontend/src/utils/mcpUtils.ts +++ b/src/frontend/src/utils/mcpUtils.ts @@ -95,5 +95,9 @@ export function extractMcpServersFromJson( args: server.args || [], env: server.env && typeof server.env === "object" ? server.env : {}, url: server.url, + headers: + server.headers && typeof server.headers === "object" + ? server.headers + : {}, })); } diff --git a/src/frontend/tests/extended/features/mcp-server.spec.ts b/src/frontend/tests/extended/features/mcp-server.spec.ts index e0cef158a..c46949655 100644 --- a/src/frontend/tests/extended/features/mcp-server.spec.ts +++ b/src/frontend/tests/extended/features/mcp-server.spec.ts @@ -54,7 +54,9 @@ test( timeout: 30000, }); - await page.getByTestId("stdio-name-input").fill("test server"); + const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number + const testName = `test_server_${randomSuffix}`; + await page.getByTestId("stdio-name-input").fill(testName); await page.getByTestId("stdio-command-input").fill("uvx mcp-server-fetch"); @@ -115,12 +117,12 @@ test( timeout: 3000, }); - await expect(page.getByText("test_server")).toBeVisible({ + await expect(page.getByText(testName)).toBeVisible({ timeout: 3000, }); await page - .getByTestId(`mcp-server-menu-button-test_server`) + .getByTestId(`mcp-server-menu-button-${testName}`) .click({ timeout: 3000 }); await page @@ -152,7 +154,7 @@ test( await page.getByTestId("add-mcp-server-button").click(); await page - .getByTestId(`mcp-server-menu-button-test_server`) + .getByTestId(`mcp-server-menu-button-${testName}`) .click({ timeout: 3000 }); await page @@ -177,7 +179,7 @@ test( await page.waitForTimeout(3000); - await expect(page.getByText("test_server")).not.toBeVisible({ + await expect(page.getByText(testName)).not.toBeVisible({ timeout: 3000, }); @@ -199,8 +201,360 @@ test( }); await page.getByTestId("mcp-server-dropdown").click({ timeout: 10000 }); - await expect(page.getByText("test_server")).toHaveCount(2, { + await expect(page.getByText(testName)).toHaveCount(2, { timeout: 10000, }); }, ); + +test( + "STDIO MCP server fields should persist after saving and editing", + { tag: ["@release", "@workspace", "@components"] }, + async ({ page }) => { + await awaitBootstrapTest(page); + + await page.waitForSelector('[data-testid="blank-flow"]', { + timeout: 30000, + }); + await page.getByTestId("blank-flow").click(); + await page.getByTestId("sidebar-search-input").click(); + await page.getByTestId("sidebar-search-input").fill("mcp tools"); + + await page.waitForSelector('[data-testid="agentsMCP Tools"]', { + timeout: 30000, + }); + + await page + .getByTestId("agentsMCP Tools") + .dragTo(page.locator('//*[@id="react-flow-id"]'), { + targetPosition: { x: 0, y: 0 }, + }); + + await page.getByTestId("fit_view").click(); + + await zoomOut(page, 3); + + try { + await page.getByText("Add MCP Server", { exact: true }).click({ + timeout: 5000, + }); + } catch (_error) { + await page.getByTestId("mcp-server-dropdown").click({ timeout: 3000 }); + await page.getByText("Add MCP Server", { exact: true }).click({ + timeout: 5000, + }); + } + + await page.waitForSelector('[data-testid="add-mcp-server-button"]', { + state: "visible", + timeout: 30000, + }); + + // Go to STDIO tab and fill all fields + await page.getByTestId("stdio-tab").click(); + await page.waitForSelector('[data-testid="stdio-name-input"]', { + state: "visible", + timeout: 30000, + }); + + // Test data with random suffix + const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number + const testName = `test_stdio_server_${randomSuffix}`; + const testCommand = "uvx mcp-server-test"; + const testArg1 = "--verbose"; + const testArg2 = "--port=8080"; + const testArg3 = "--config=test.json"; + const testEnvKey1 = "NODE_ENV"; + const testEnvValue1 = "production"; + const testEnvKey2 = "DEBUG_MODE"; + const testEnvValue2 = "true"; + + // Fill basic fields + await page.getByTestId("stdio-name-input").fill(testName); + await page.getByTestId("stdio-command-input").fill(testCommand); + + // Add first argument + await page.getByTestId("stdio-args_0").fill(testArg1); + + // Add second argument by clicking plus button + await page.getByTestId("input-list-plus-btn_-0").click(); + await page.getByTestId("stdio-args_1").fill(testArg2); + + // Add third argument + await page.getByTestId("input-list-plus-btn_-0").click(); + await page.getByTestId("stdio-args_2").fill(testArg3); + + // Add first environment variable + await page.getByTestId("stdio-env-key-0").fill(testEnvKey1); + await page.getByTestId("stdio-env-value-0").fill(testEnvValue1); + + // Add second environment variable + await page.getByTestId("stdio-env-plus-btn-0").click(); + await page.getByTestId("stdio-env-key-1").fill(testEnvKey2); + await page.getByTestId("stdio-env-value-1").fill(testEnvValue2); + + // Save the server + await page.getByTestId("add-mcp-server-button").click(); + + // Wait for server to be created + await page.waitForTimeout(2000); + + // Go to settings to edit the server + await page.getByTestId("user_menu_button").click({ timeout: 3000 }); + await page.getByTestId("menu_settings_button").click({ timeout: 3000 }); + + await page.waitForSelector('[data-testid="sidebar-nav-MCP Servers"]', { + timeout: 30000, + }); + await page.getByTestId("sidebar-nav-MCP Servers").click({ timeout: 3000 }); + + await page.waitForSelector('[data-testid="add-mcp-server-button-page"]', { + timeout: 3000, + }); + + // Find and edit the server + await expect(page.getByText(testName)).toBeVisible({ + timeout: 3000, + }); + + await page + .getByTestId(`mcp-server-menu-button-${testName}`) + .click({ timeout: 3000 }); + + await page + .getByText("Edit", { exact: true }) + .first() + .click({ timeout: 3000 }); + + await page.waitForSelector('[data-testid="add-mcp-server-button"]', { + state: "visible", + timeout: 30000, + }); + + // Verify all fields persisted correctly + expect(await page.getByTestId("stdio-name-input").inputValue()).toBe( + testName, + ); + expect(await page.getByTestId("stdio-command-input").inputValue()).toBe( + testCommand, + ); + expect(await page.getByTestId("stdio-args_0").inputValue()).toBe(testArg1); + expect(await page.getByTestId("stdio-args_1").inputValue()).toBe(testArg2); + expect(await page.getByTestId("stdio-args_2").inputValue()).toBe(testArg3); + expect(await page.getByTestId("stdio-env-key-0").last().inputValue()).toBe( + testEnvKey1, + ); + expect( + await page.getByTestId("stdio-env-value-0").last().inputValue(), + ).toBe(testEnvValue1); + expect(await page.getByTestId("stdio-env-key-1").last().inputValue()).toBe( + testEnvKey2, + ); + expect( + await page.getByTestId("stdio-env-value-1").last().inputValue(), + ).toBe(testEnvValue2); + + // Clean up - cancel the edit modal + await page.keyboard.press("Escape"); + + // Delete the test server + await page + .getByTestId(`mcp-server-menu-button-${testName}`) + .click({ timeout: 3000 }); + + await page + .getByText("Delete", { exact: true }) + .first() + .click({ timeout: 3000 }); + + await page.waitForSelector( + '[data-testid="btn_delete_delete_confirmation_modal"]', + { + timeout: 3000, + }, + ); + + await page + .getByTestId("btn_delete_delete_confirmation_modal") + .click({ timeout: 3000 }); + }, +); + +test( + "SSE MCP server fields should persist after saving and editing", + { tag: ["@release", "@workspace", "@components"] }, + async ({ page }) => { + await awaitBootstrapTest(page); + + await page.waitForSelector('[data-testid="blank-flow"]', { + timeout: 30000, + }); + await page.getByTestId("blank-flow").click(); + await page.getByTestId("sidebar-search-input").click(); + await page.getByTestId("sidebar-search-input").fill("mcp tools"); + + await page.waitForSelector('[data-testid="agentsMCP Tools"]', { + timeout: 30000, + }); + + await page + .getByTestId("agentsMCP Tools") + .dragTo(page.locator('//*[@id="react-flow-id"]'), { + targetPosition: { x: 0, y: 0 }, + }); + + await page.getByTestId("fit_view").click(); + + await zoomOut(page, 3); + + try { + await page.getByText("Add MCP Server", { exact: true }).click({ + timeout: 5000, + }); + } catch (_error) { + await page.getByTestId("mcp-server-dropdown").click({ timeout: 3000 }); + await page.getByText("Add MCP Server", { exact: true }).click({ + timeout: 5000, + }); + } + + await page.waitForSelector('[data-testid="add-mcp-server-button"]', { + state: "visible", + timeout: 30000, + }); + + // Go to SSE tab and fill all fields + await page.getByTestId("sse-tab").click(); + await page.waitForSelector('[data-testid="sse-name-input"]', { + state: "visible", + timeout: 30000, + }); + + // Test data with random suffix + const randomSuffix = Math.floor(Math.random() * 90000) + 10000; // 5-digit random number + const testName = `test_sse_server_${randomSuffix}`; + const testUrl = "https://api.example.com/mcp"; + const testHeaderKey1 = "Authorization"; + const testHeaderValue1 = "Bearer token123"; + const testHeaderKey2 = "Content-Type"; + const testHeaderValue2 = "application/json"; + const testEnvKey1 = "API_TIMEOUT"; + const testEnvValue1 = "30000"; + const testEnvKey2 = "RETRY_COUNT"; + const testEnvValue2 = "3"; + + // Fill basic fields + await page.getByTestId("sse-name-input").fill(testName); + await page.getByTestId("sse-url-input").fill(testUrl); + + // Add first header + await page.getByTestId("sse-headers-key-0").fill(testHeaderKey1); + await page.getByTestId("sse-headers-value-0").fill(testHeaderValue1); + + // Add second header + await page.getByTestId("sse-headers-plus-btn-0").click(); + await page.getByTestId("sse-headers-key-1").fill(testHeaderKey2); + await page.getByTestId("sse-headers-value-1").fill(testHeaderValue2); + + // Add first environment variable + await page.getByTestId("sse-env-key-0").fill(testEnvKey1); + await page.getByTestId("sse-env-value-0").fill(testEnvValue1); + + // Add second environment variable + await page.getByTestId("sse-env-plus-btn-0").click(); + await page.getByTestId("sse-env-key-1").fill(testEnvKey2); + await page.getByTestId("sse-env-value-1").fill(testEnvValue2); + + // Save the server + await page.getByTestId("add-mcp-server-button").click(); + + // Wait for server to be created + await page.waitForTimeout(2000); + + // Go to settings to edit the server + await page.getByTestId("user_menu_button").click({ timeout: 3000 }); + await page.getByTestId("menu_settings_button").click({ timeout: 3000 }); + + await page.waitForSelector('[data-testid="sidebar-nav-MCP Servers"]', { + timeout: 30000, + }); + await page.getByTestId("sidebar-nav-MCP Servers").click({ timeout: 3000 }); + + await page.waitForSelector('[data-testid="add-mcp-server-button-page"]', { + timeout: 3000, + }); + + // Find and edit the server + await expect(page.getByText(testName)).toBeVisible({ + timeout: 3000, + }); + + await page + .getByTestId(`mcp-server-menu-button-${testName}`) + .click({ timeout: 3000 }); + + await page + .getByText("Edit", { exact: true }) + .first() + .click({ timeout: 3000 }); + + await page.waitForSelector('[data-testid="add-mcp-server-button"]', { + state: "visible", + timeout: 30000, + }); + + // Verify all fields persisted correctly + expect(await page.getByTestId("sse-name-input").inputValue()).toBe( + testName, + ); + expect(await page.getByTestId("sse-url-input").inputValue()).toBe(testUrl); + expect(await page.getByTestId("sse-headers-key-0").inputValue()).toBe( + testHeaderKey1, + ); + expect(await page.getByTestId("sse-headers-value-0").inputValue()).toBe( + testHeaderValue1, + ); + expect(await page.getByTestId("sse-headers-key-1").inputValue()).toBe( + testHeaderKey2, + ); + expect(await page.getByTestId("sse-headers-value-1").inputValue()).toBe( + testHeaderValue2, + ); + expect(await page.getByTestId("sse-env-key-0").inputValue()).toBe( + testEnvKey1, + ); + expect(await page.getByTestId("sse-env-value-0").inputValue()).toBe( + testEnvValue1, + ); + expect(await page.getByTestId("sse-env-key-1").inputValue()).toBe( + testEnvKey2, + ); + expect(await page.getByTestId("sse-env-value-1").inputValue()).toBe( + testEnvValue2, + ); + + // Clean up - cancel the edit modal + await page.keyboard.press("Escape"); + + // Delete the test server + await page + .getByTestId(`mcp-server-menu-button-${testName}`) + .click({ timeout: 3000 }); + + await page + .getByText("Delete", { exact: true }) + .first() + .click({ timeout: 3000 }); + + await page.waitForSelector( + '[data-testid="btn_delete_delete_confirmation_modal"]', + { + timeout: 3000, + }, + ); + + await page + .getByTestId("btn_delete_delete_confirmation_modal") + .click({ timeout: 3000 }); + }, +);