chore: merge mcp components (#7167)
* take1 * depreacate stdio and sse mcp components * optionals * rodrigo fixes * session management * update init * mcp component integration test * broken * [autofix.ci] apply automated fixes * fix url input name * upated MCP * Update mcp_component.py * [autofix.ci] apply automated fixes * update to the MCP component * [autofix.ci] apply automated fixes * mostly working * [autofix.ci] apply automated fixes * Update mcp_component.py * [autofix.ci] apply automated fixes * update component * [autofix.ci] apply automated fixes * Update mcp_component.py * rename component because Simon * icon and description for simon * fix integration test * fix test * Update mcp_component.py * update and basic QoL * [autofix.ci] apply automated fixes * refactor clients to util and use flow names not IDs in mcp.py * integration test * take out traces * ✨ (edit-tools.spec.ts): add test for user to be able to edit tools in the frontend application. * session fix * fix content output * ♻️ (util.py): remove redundant constant HTTP_TEMPORARY_REDIRECT and replace its usage with httpx.codes.TEMPORARY_REDIRECT for better code readability and maintainability * [autofix.ci] apply automated fixes * 🐛 (utils.ts): fix potential null pointer error when converting words to title case by adding null check before accessing properties * 🐛 (genericIconComponent/index.tsx): Fix issue with optional chaining in mapping function 🐛 (renderIconComponent/index.tsx): Fix issue with optional chaining in mapping function 🐛 (button.tsx): Fix issue with optional chaining in mapping function 🐛 (utils.ts): Fix issue with optional chaining in mapping functions * 🐛 (language-select.tsx): Fix potential null pointer error when mapping over allLanguages array * ✨ (constants.ts): add support for multiple languages in the application by defining an array of language options ♻️ (audio-settings-dialog.tsx, language-select.tsx): refactor to import the array of all languages from constants.ts instead of duplicating it in each file * ✅ (auto-login-off.spec.ts): add a 2-second delay before continuing the test to ensure proper loading and rendering of elements on the page * ⬆️ (filterEdge-shard-0.spec.ts): reduce wait time for page interactions to improve test performance ⬆️ (playground.spec.ts): optimize wait times for page interactions to enhance test efficiency --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
parent
4527c473be
commit
59b2ed7765
25 changed files with 1200 additions and 331 deletions
0
src/backend/tests/integration/components/mcp/__init__.py
Normal file
0
src/backend/tests/integration/components/mcp/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
async def test_mcp_component():
|
||||
from langflow.components.tools.mcp_component import MCPToolsComponent
|
||||
|
||||
inputs = {}
|
||||
await run_single_component(
|
||||
MCPToolsComponent,
|
||||
inputs=inputs, # test default inputs
|
||||
)
|
||||
255
src/backend/tests/unit/components/tools/test_mcp_component.py
Normal file
255
src/backend/tests/unit/components/tools/test_mcp_component.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.components.tools.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
|
||||
|
||||
|
||||
class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
"""Return the component class to test."""
|
||||
return MCPToolsComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
"""Return the default kwargs for the component."""
|
||||
return {
|
||||
"mode": "Stdio",
|
||||
"command": "uvx mcp-server-fetch",
|
||||
"sse_url": "http://localhost:7860/api/v1/mcp/sse",
|
||||
"tool": "",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self) -> list[VersionComponentMapping]:
|
||||
"""Return the file names mapping for different versions."""
|
||||
return []
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self):
|
||||
"""Create a mock MCP tool."""
|
||||
tool = MagicMock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
tool.inputSchema = {
|
||||
"type": "object",
|
||||
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
|
||||
}
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stdio_client(self, mock_tool):
|
||||
"""Create a mock stdio client."""
|
||||
stdio_client = AsyncMock()
|
||||
stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool])
|
||||
stdio_client.session = AsyncMock()
|
||||
return stdio_client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sse_client(self, mock_tool):
|
||||
"""Create a mock SSE client."""
|
||||
sse_client = AsyncMock()
|
||||
sse_client.connect_to_server = AsyncMock(return_value=[mock_tool])
|
||||
sse_client.session = AsyncMock()
|
||||
return sse_client
|
||||
|
||||
async def test_validate_connection_params_invalid_mode(self, component_class, default_kwargs):
|
||||
"""Test validation with invalid mode."""
|
||||
component = component_class(**default_kwargs)
|
||||
with pytest.raises(ValueError, match="Invalid mode: invalid. Must be either 'Stdio' or 'SSE'"):
|
||||
await component._validate_connection_params("invalid")
|
||||
|
||||
async def test_validate_connection_params_missing_command(self, component_class, default_kwargs):
|
||||
"""Test validation with missing command in Stdio mode."""
|
||||
component = component_class(**default_kwargs)
|
||||
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
|
||||
await component._validate_connection_params("Stdio", command=None)
|
||||
|
||||
async def test_validate_connection_params_missing_url(self, component_class, default_kwargs):
|
||||
"""Test validation with missing URL in SSE mode."""
|
||||
component = component_class(**default_kwargs)
|
||||
with pytest.raises(ValueError, match="URL is required for SSE mode"):
|
||||
await component._validate_connection_params("SSE", url=None)
|
||||
|
||||
async def test_update_build_config_mode_change(self, component_class, default_kwargs):
|
||||
"""Test build config updates when mode changes."""
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {
|
||||
"command": {"show": False},
|
||||
"sse_url": {"show": True},
|
||||
"tool": {"options": [], "show": True}, # Add tool field since component uses it
|
||||
}
|
||||
|
||||
# Test switching to Stdio mode
|
||||
updated_config = await component.update_build_config(build_config, "Stdio", "mode")
|
||||
assert updated_config["command"]["show"] is True
|
||||
assert updated_config["sse_url"]["show"] is False
|
||||
|
||||
# Test switching to SSE mode
|
||||
updated_config = await component.update_build_config(build_config, "SSE", "mode")
|
||||
assert updated_config["command"]["show"] is False
|
||||
assert updated_config["sse_url"]["show"] is True
|
||||
|
||||
# Test tool options are updated
|
||||
assert "options" in updated_config["tool"]
|
||||
|
||||
@patch("langflow.components.tools.mcp_component.create_tool_coroutine")
|
||||
async def test_build_output(self, mock_create_coroutine, component_class, default_kwargs, mock_tool):
|
||||
"""Test building output with a tool."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.tool = "test_tool"
|
||||
component.tools = [mock_tool]
|
||||
|
||||
# Mock the coroutine response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.content = [MagicMock(text="Test response")]
|
||||
mock_create_coroutine.return_value = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Create a mock tool and add it to the cache
|
||||
mock_structured_tool = MagicMock()
|
||||
mock_structured_tool.coroutine = mock_create_coroutine.return_value
|
||||
component._tool_cache = {"test_tool": mock_structured_tool}
|
||||
|
||||
# Set the test parameter value
|
||||
component.test_param = "test value"
|
||||
|
||||
# Mock get_inputs_for_all_tools to return our mock input
|
||||
mock_input = MagicMock()
|
||||
mock_input.name = "test_param"
|
||||
with patch.object(component, "get_inputs_for_all_tools") as mock_get_inputs:
|
||||
mock_get_inputs.return_value = {"test_tool": [mock_input]}
|
||||
output = await component.build_output()
|
||||
|
||||
assert output.text == "Test response"
|
||||
# Verify the mocks were called correctly
|
||||
mock_get_inputs.assert_called_once_with(component.tools)
|
||||
mock_structured_tool.coroutine.assert_called_once_with(test_param="test value")
|
||||
|
||||
async def test_get_inputs_for_all_tools(self, component_class, default_kwargs, mock_tool):
|
||||
"""Test getting input schemas for all tools."""
|
||||
component = component_class(**default_kwargs)
|
||||
inputs = component.get_inputs_for_all_tools([mock_tool])
|
||||
|
||||
assert "test_tool" in inputs
|
||||
assert len(inputs["test_tool"]) > 0 # Should have at least one input parameter
|
||||
|
||||
async def test_remove_non_default_keys(self, component_class, default_kwargs):
|
||||
"""Test removing non-default keys from build config."""
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {"code": {}, "mode": {}, "command": {}, "custom_key": {}}
|
||||
|
||||
component.remove_non_default_keys(build_config)
|
||||
assert "custom_key" not in build_config
|
||||
assert all(key in build_config for key in ["code", "mode", "command"])
|
||||
|
||||
|
||||
class TestMCPStdioClient:
|
||||
@pytest.fixture
|
||||
def stdio_client(self):
|
||||
return MCPStdioClient()
|
||||
|
||||
async def test_connect_to_server(self, stdio_client):
|
||||
"""Test connecting to server via Stdio."""
|
||||
# Create mock for stdio transport
|
||||
mock_stdio = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_stdio_transport = (mock_stdio, mock_write)
|
||||
mock_stdio_cm = AsyncMock()
|
||||
mock_stdio_cm.__aenter__.return_value = mock_stdio_transport
|
||||
|
||||
# Mock the stdio_client function to return our mock context manager
|
||||
with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm):
|
||||
# Mock ClientSession
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools.return_value.tools = [MagicMock()]
|
||||
|
||||
# Mock the AsyncExitStack
|
||||
mock_exit_stack = AsyncMock()
|
||||
mock_exit_stack.enter_async_context = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
mock_stdio_transport, # For stdio_client
|
||||
mock_session, # For ClientSession
|
||||
]
|
||||
stdio_client.exit_stack = mock_exit_stack
|
||||
|
||||
tools = await stdio_client.connect_to_server("test_command")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert stdio_client.session is not None
|
||||
# Verify the exit stack was used correctly
|
||||
assert mock_exit_stack.enter_async_context.call_count == 2
|
||||
# Verify the stdio transport was properly set
|
||||
assert stdio_client.stdio == mock_stdio
|
||||
assert stdio_client.write == mock_write
|
||||
|
||||
|
||||
class TestMCPSseClient:
|
||||
@pytest.fixture
|
||||
def sse_client(self):
|
||||
return MCPSseClient()
|
||||
|
||||
async def test_pre_check_redirect(self, sse_client):
|
||||
"""Test pre-checking URL for redirects."""
|
||||
test_url = "http://test.url"
|
||||
redirect_url = "http://redirect.url"
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 307
|
||||
mock_response.headers.get.return_value = redirect_url
|
||||
mock_client.return_value.__aenter__.return_value.request.return_value = mock_response
|
||||
|
||||
result = await sse_client.pre_check_redirect(test_url)
|
||||
assert result == redirect_url
|
||||
|
||||
async def test_connect_to_server(self, sse_client):
|
||||
"""Test connecting to server via SSE."""
|
||||
# Mock the pre_check_redirect first
|
||||
with patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"):
|
||||
# Create mock for sse_client context manager
|
||||
mock_sse = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_sse_transport = (mock_sse, mock_write)
|
||||
mock_sse_cm = AsyncMock()
|
||||
mock_sse_cm.__aenter__.return_value = mock_sse_transport
|
||||
|
||||
# Mock the sse_client function to return our mock context manager
|
||||
with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm):
|
||||
# Mock ClientSession
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools.return_value.tools = [MagicMock()]
|
||||
|
||||
# Mock the AsyncExitStack
|
||||
mock_exit_stack = AsyncMock()
|
||||
mock_exit_stack.enter_async_context = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
mock_sse_transport, # For sse_client
|
||||
mock_session, # For ClientSession
|
||||
]
|
||||
sse_client.exit_stack = mock_exit_stack
|
||||
|
||||
tools = await sse_client.connect_to_server("http://test.url", {})
|
||||
|
||||
assert len(tools) == 1
|
||||
assert sse_client.session is not None
|
||||
# Verify the exit stack was used correctly
|
||||
assert mock_exit_stack.enter_async_context.call_count == 2
|
||||
# Verify the SSE transport was properly set
|
||||
assert sse_client.sse == mock_sse
|
||||
assert sse_client.write == mock_write
|
||||
|
||||
async def test_connect_timeout(self, sse_client):
|
||||
"""Test connection timeout handling."""
|
||||
with (
|
||||
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
|
||||
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
|
||||
):
|
||||
mock_connect.side_effect = asyncio.TimeoutError()
|
||||
|
||||
with pytest.raises(TimeoutError, match="Connection to http://test.url timed out after 1 seconds"):
|
||||
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)
|
||||
|
|
@ -3,11 +3,13 @@ from types import NoneType
|
|||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from langflow.inputs.inputs import BoolInput, DictInput, FloatInput, InputTypes, IntInput, MessageTextInput
|
||||
from langflow.io.schema import schema_to_langflow_inputs
|
||||
from langflow.schema.data import Data
|
||||
from langflow.template import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
class TestInput:
|
||||
|
|
@ -178,3 +180,65 @@ class TestPostProcessType:
|
|||
pass
|
||||
|
||||
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007
|
||||
|
||||
|
||||
def test_schema_to_langflow_inputs():
|
||||
# Define a test Pydantic model with various field types
|
||||
class TestSchema(BaseModel):
|
||||
text_field: str = Field(title="Custom Text Title", description="A text field")
|
||||
number_field: int = Field(description="A number field")
|
||||
bool_field: bool = Field(description="A boolean field")
|
||||
dict_field: dict = Field(description="A dictionary field")
|
||||
list_field: list[str] = Field(description="A list of strings")
|
||||
|
||||
# Convert schema to Langflow inputs
|
||||
inputs = schema_to_langflow_inputs(TestSchema)
|
||||
|
||||
# Verify the number of inputs matches the schema fields
|
||||
assert len(inputs) == 5
|
||||
|
||||
# Helper function to find input by name
|
||||
def find_input(name: str) -> InputTypes | None:
|
||||
for _input in inputs:
|
||||
if _input.name == name:
|
||||
return _input
|
||||
return None
|
||||
|
||||
# Test text field
|
||||
text_input = find_input("text_field")
|
||||
assert text_input.display_name == "Custom Text Title"
|
||||
assert text_input.info == "A text field"
|
||||
assert isinstance(text_input, MessageTextInput) # Check the instance type instead of field_type
|
||||
|
||||
# Test number field
|
||||
number_input = find_input("number_field")
|
||||
assert number_input.display_name == "Number Field"
|
||||
assert number_input.info == "A number field"
|
||||
assert isinstance(number_input, IntInput | FloatInput)
|
||||
|
||||
# Test boolean field
|
||||
bool_input = find_input("bool_field")
|
||||
assert isinstance(bool_input, BoolInput)
|
||||
|
||||
# Test dictionary field
|
||||
dict_input = find_input("dict_field")
|
||||
assert isinstance(dict_input, DictInput)
|
||||
|
||||
# Test list field
|
||||
list_input = find_input("list_field")
|
||||
assert list_input.is_list is True
|
||||
assert isinstance(list_input, MessageTextInput)
|
||||
|
||||
|
||||
def test_schema_to_langflow_inputs_invalid_type():
|
||||
# Define a schema with an unsupported type
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
class InvalidSchema(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True} # Add this line
|
||||
invalid_field: CustomType
|
||||
|
||||
# Test that attempting to convert an unsupported type raises TypeError
|
||||
with pytest.raises(TypeError, match="Unsupported field type:"):
|
||||
schema_to_langflow_inputs(InvalidSchema)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue