feat: lmprove MCP langflow port selection and error handling (#7327)

* langflow port search and error handling

* [autofix.ci] apply automated fixes

* Update mcp_component.py

* Update util.py

*  (test_mcp_component.py): add support for validating URL and setting max retries to improve connection handling
🐛 (test_mcp_component.py): fix incorrect exception type in test_connect_timeout method to match expected behavior

* [autofix.ci] apply automated fixes

*  (test_mcp_component.py): refactor test_connect_to_server method for better readability and maintainability
🔧 (test_mcp_component.py): refactor test_connect_timeout method for improved error message formatting

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
Edwin Jose 2025-03-31 10:43:17 -04:00 committed by GitHub
commit aea98a4019
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 193 additions and 40 deletions

View file

@ -3,9 +3,12 @@ import os
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from typing import Any
from urllib.parse import urlparse
from uuid import UUID
import httpx
from httpx import codes as httpx_codes
from loguru import logger
from mcp import ClientSession, StdioServerParameters, stdio_client
from mcp.client.sse import sse_client
from pydantic import Field, create_model
@ -14,6 +17,8 @@ from sqlmodel import select
from langflow.helpers.base_model import BaseModel
from langflow.services.database.models import Flow
HTTP_ERROR_STATUS_CODE = httpx_codes.BAD_REQUEST # HTTP status code for client errors
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., Awaitable]:
async def tool_coroutine(*args, **kwargs):
@ -139,40 +144,110 @@ class MCPSseClient:
self.sse = None
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
self.max_retries = 3
self.retry_delay = 1.0 # seconds
async def pre_check_redirect(self, url: str):
async with httpx.AsyncClient(follow_redirects=False) as client:
response = await client.request("HEAD", url)
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
return response.headers.get("Location")
async def validate_url(self, url: str | None) -> tuple[bool, str]:
"""Validate the SSE URL before attempting connection."""
try:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False, "Invalid URL format. Must include scheme (http/https) and host."
async with httpx.AsyncClient() as client:
try:
# First try a HEAD request to check if server is reachable
response = await client.head(url, timeout=5.0)
if response.status_code >= HTTP_ERROR_STATUS_CODE:
return False, f"Server returned error status: {response.status_code}"
except httpx.TimeoutException:
return False, "Connection timed out. Server may be down or unreachable."
except httpx.NetworkError:
return False, "Network error. Could not reach the server."
else:
return True, ""
except (httpx.HTTPError, ValueError, OSError) as e:
return False, f"URL validation error: {e!s}"
async def pre_check_redirect(self, url: str | None) -> str | None:
"""Check for redirects and return the final URL."""
if url is None:
return url
try:
async with httpx.AsyncClient(follow_redirects=False) as client:
response = await client.request("HEAD", url)
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
return response.headers.get("Location", url)
except (httpx.RequestError, httpx.HTTPError) as e:
logger.warning(f"Error checking redirects: {e}")
return url
async def _connect_with_timeout(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
self, url: str | None, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
)
self.sse, self.write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
await self.session.initialize()
"""Attempt to connect with timeout."""
try:
if url is None:
return
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
)
self.sse, self.write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
await self.session.initialize()
except Exception as e:
msg = f"Failed to establish SSE connection: {e!s}"
raise ConnectionError(msg) from e
async def connect_to_server(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
self,
url: str | None,
headers: dict[str, str] | None,
timeout_seconds: int = 30,
sse_read_timeout_seconds: int = 30,
):
"""Connect to server with retries and improved error handling."""
if headers is None:
headers = {}
# First validate the URL
is_valid, error_msg = await self.validate_url(url)
if not is_valid:
msg = f"Invalid SSE URL ({url}): {error_msg}"
raise ValueError(msg)
url = await self.pre_check_redirect(url)
try:
await asyncio.wait_for(
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
timeout=timeout_seconds,
)
if self.session is None:
msg = "Session not initialized"
raise ValueError(msg)
response = await self.session.list_tools()
except asyncio.TimeoutError as err:
msg = f"Connection to {url} timed out after {timeout_seconds} seconds"
raise TimeoutError(msg) from err
return response.tools
last_error = None
for attempt in range(self.max_retries):
try:
await asyncio.wait_for(
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
timeout=timeout_seconds,
)
if self.session is None:
msg = "Session not initialized"
raise ValueError(msg)
response = await self.session.list_tools()
except asyncio.TimeoutError:
last_error = f"Connection to {url} timed out after {timeout_seconds} seconds"
logger.warning(f"Connection attempt {attempt + 1} failed: {last_error}")
except ConnectionError as err:
last_error = str(err)
logger.warning(f"Connection attempt {attempt + 1} failed: {last_error}")
except (ValueError, httpx.HTTPError, OSError) as err:
last_error = f"Connection error: {err!s}"
logger.warning(f"Connection attempt {attempt + 1} failed: {last_error}")
else:
return response.tools
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1))
msg = f"Failed to connect after {self.max_retries} attempts. Last error: {last_error}"
raise ConnectionError(msg)

View file

@ -1,9 +1,11 @@
import asyncio
import os
from typing import Any
import httpx
from langchain_core.tools import StructuredTool
from langflow.base.mcp.util import (
HTTP_ERROR_STATUS_CODE,
MCPSseClient,
MCPStdioClient,
create_input_schema_from_json_schema,
@ -21,12 +23,14 @@ from langflow.schema import Message
class MCPToolsComponent(Component):
schema_inputs: list[InputTypes] = []
stdio_client = MCPStdioClient()
sse_client = MCPSseClient()
stdio_client: MCPStdioClient = MCPStdioClient()
sse_client: MCPSseClient = MCPSseClient()
tools: list = []
tool_names: list[str] = []
_tool_cache: dict = {} # Cache for tool objects
default_keys = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
default_keys: list[str] = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
sse_url: str | None = None
display_name = "MCP Server"
description = "Connect to an MCP server and expose tools."
@ -54,7 +58,6 @@ class MCPToolsComponent(Component):
name="sse_url",
display_name="MCP SSE URL",
info="URL for MCP SSE connection",
value="http://localhost:7860/api/v1/mcp/sse",
show=False,
refresh_button=True,
),
@ -82,6 +85,21 @@ class MCPToolsComponent(Component):
Output(display_name="Response", name="response", method="build_output"),
]
async def find_langflow_instance(self) -> tuple[bool, int | None, str]:
"""Find Langflow instance by checking env variable first, then scanning common ports."""
# First check environment variable
env_port = os.getenv("LANGFLOW_PORT")
port = int(env_port) if env_port else 7860
try:
url = f"http://localhost:{port}/api/v1/mcp/sse"
async with httpx.AsyncClient() as client:
response = await client.head(url, timeout=2.0)
if response.status_code < HTTP_ERROR_STATUS_CODE:
return True, port, f"Langflow instance found at configured port {port}"
except (ValueError, httpx.TimeoutException, httpx.NetworkError, httpx.HTTPError):
logger.warning(f"Could not connect to Langflow at configured port {env_port}")
return False, None, "No Langflow instance found on configured port or common ports"
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
"""Validate connection parameters based on mode."""
if mode not in ["Stdio", "SSE"]:
@ -131,8 +149,26 @@ class MCPToolsComponent(Component):
elif field_value == "SSE":
build_config["command"]["show"] = False
build_config["sse_url"]["show"] = True
_, port, _ = await self.find_langflow_instance()
if port:
build_config["sse_url"]["value"] = f"http://localhost:{port}/api/v1/mcp/sse"
self.sse_url = build_config["sse_url"]["value"]
return build_config
if field_name in ("command", "sse_url", "mode"):
try:
# If SSE mode and localhost URL is not valid, try to find correct port
if field_name == "sse_url":
self.sse_url = field_value
elif self.mode == "SSE" and ("localhost" in str(self.sse_url) or "127.0.0.1" in str(self.sse_url)):
is_valid, _ = await self.sse_client.validate_url(self.sse_url)
if not is_valid:
found, port, message = await self.find_langflow_instance()
if found:
new_url = f"http://localhost:{port}/api/v1/mcp/sse"
logger.info(f"Original URL {self.sse_url} not valid. {message}")
build_config["sse_url"]["value"] = new_url
self.sse_url = new_url
await self.update_tools()
if "tool" in build_config:
build_config["tool"]["options"] = self.tool_names
@ -285,7 +321,33 @@ class MCPToolsComponent(Component):
if not self.stdio_client.session:
self.tools = await self.stdio_client.connect_to_server(self.command)
elif self.mode == "SSE" and not self.sse_client.session:
self.tools = await self.sse_client.connect_to_server(self.sse_url, {})
try:
is_valid, _ = await self.sse_client.validate_url(self.sse_url)
if not is_valid:
msg = f"Invalid SSE URL configuration: {self.sse_url}. Please check the SSE URL and try again."
logger.error(msg)
return []
self.tools = await self.sse_client.connect_to_server(self.sse_url, {})
except ValueError as e:
# URL validation error
logger.error(f"SSE URL validation error: {e}")
msg = f"Invalid SSE URL configuration: {e}. Please check your Langflow deployment URL and port."
raise ValueError(msg) from e
except ConnectionError as e:
# Connection failed after retries
logger.error(f"SSE connection error: {e}")
msg = (
f"Could not connect to Langflow SSE endpoint: {e}. "
"Please verify:\n"
"1. Langflow server is running\n"
"2. The SSE URL matches your Langflow deployment port\n"
"3. There are no network issues preventing the connection"
)
raise ValueError(msg) from e
except Exception as e:
logger.error(f"Unexpected SSE error: {e}")
msg = f"Unexpected error connecting to SSE endpoint: {e}"
raise ValueError(msg) from e
if not self.tools:
logger.warning("No tools returned from server")
@ -300,8 +362,7 @@ class MCPToolsComponent(Component):
try:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
if not args_schema:
msg = f"Empty schema for tool '{tool.name}', skipping"
logger.warning(msg)
logger.warning(f"Empty schema for tool '{tool.name}', skipping")
continue
client = self.stdio_client if self.mode == "Stdio" else self.sse_client
@ -320,15 +381,18 @@ class MCPToolsComponent(Component):
tool_list.append(tool_obj)
self._tool_cache[tool.name] = tool_obj
except (AttributeError, ValueError, TypeError, KeyError) as e:
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e!s}"
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e}"
logger.exception(msg)
continue
self.tool_names = [tool.name for tool in self.tools if hasattr(tool, "name")]
except (ValueError, RuntimeError, asyncio.TimeoutError) as e:
msg = f"Error updating tools: {e!s}"
logger.exception(msg)
except ValueError as e:
# Re-raise validation errors with clear messages
raise ValueError(str(e)) from e
except Exception as e:
logger.exception("Error updating tools")
msg = f"Failed to update tools: {e!s}"
raise ValueError(msg) from e
else:
return tool_list

View file

@ -209,7 +209,10 @@ class TestMCPSseClient:
async def test_connect_to_server(self, sse_client):
"""Test connecting to server via SSE."""
# Mock the pre_check_redirect first
with patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"):
with (
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
patch.object(sse_client, "validate_url", return_value=(True, "")),
):
# Create mock for sse_client context manager
mock_sse = AsyncMock()
mock_write = AsyncMock()
@ -245,11 +248,22 @@ class TestMCPSseClient:
async def test_connect_timeout(self, sse_client):
"""Test connection timeout handling."""
# Set max_retries to 1 to avoid multiple retry attempts
sse_client.max_retries = 1
with (
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
patch.object(sse_client, "validate_url", return_value=(True, "")), # Mock URL validation
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
):
mock_connect.side_effect = asyncio.TimeoutError()
with pytest.raises(TimeoutError, match="Connection to http://test.url timed out after 1 seconds"):
# Expect ConnectionError instead of TimeoutError
with pytest.raises(
ConnectionError,
match=(
"Failed to connect after 1 attempts. "
"Last error: Connection to http://test.url timed out after 1 seconds"
),
):
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)