feat: lmprove MCP langflow port selection and error handling (#7327)
* langflow port search and error handling * [autofix.ci] apply automated fixes * Update mcp_component.py * Update util.py * ✨ (test_mcp_component.py): add support for validating URL and setting max retries to improve connection handling 🐛 (test_mcp_component.py): fix incorrect exception type in test_connect_timeout method to match expected behavior * [autofix.ci] apply automated fixes * ✅ (test_mcp_component.py): refactor test_connect_to_server method for better readability and maintainability 🔧 (test_mcp_component.py): refactor test_connect_timeout method for improved error message formatting * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
parent
7ba4bff956
commit
aea98a4019
3 changed files with 193 additions and 40 deletions
|
|
@ -3,9 +3,12 @@ import os
|
|||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from httpx import codes as httpx_codes
|
||||
from loguru import logger
|
||||
from mcp import ClientSession, StdioServerParameters, stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from pydantic import Field, create_model
|
||||
|
|
@ -14,6 +17,8 @@ from sqlmodel import select
|
|||
from langflow.helpers.base_model import BaseModel
|
||||
from langflow.services.database.models import Flow
|
||||
|
||||
HTTP_ERROR_STATUS_CODE = httpx_codes.BAD_REQUEST # HTTP status code for client errors
|
||||
|
||||
|
||||
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., Awaitable]:
|
||||
async def tool_coroutine(*args, **kwargs):
|
||||
|
|
@ -139,40 +144,110 @@ class MCPSseClient:
|
|||
self.sse = None
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.max_retries = 3
|
||||
self.retry_delay = 1.0 # seconds
|
||||
|
||||
async def pre_check_redirect(self, url: str):
|
||||
async with httpx.AsyncClient(follow_redirects=False) as client:
|
||||
response = await client.request("HEAD", url)
|
||||
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
|
||||
return response.headers.get("Location")
|
||||
async def validate_url(self, url: str | None) -> tuple[bool, str]:
|
||||
"""Validate the SSE URL before attempting connection."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return False, "Invalid URL format. Must include scheme (http/https) and host."
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# First try a HEAD request to check if server is reachable
|
||||
response = await client.head(url, timeout=5.0)
|
||||
if response.status_code >= HTTP_ERROR_STATUS_CODE:
|
||||
return False, f"Server returned error status: {response.status_code}"
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return False, "Connection timed out. Server may be down or unreachable."
|
||||
except httpx.NetworkError:
|
||||
return False, "Network error. Could not reach the server."
|
||||
else:
|
||||
return True, ""
|
||||
|
||||
except (httpx.HTTPError, ValueError, OSError) as e:
|
||||
return False, f"URL validation error: {e!s}"
|
||||
|
||||
async def pre_check_redirect(self, url: str | None) -> str | None:
|
||||
"""Check for redirects and return the final URL."""
|
||||
if url is None:
|
||||
return url
|
||||
try:
|
||||
async with httpx.AsyncClient(follow_redirects=False) as client:
|
||||
response = await client.request("HEAD", url)
|
||||
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
|
||||
return response.headers.get("Location", url)
|
||||
except (httpx.RequestError, httpx.HTTPError) as e:
|
||||
logger.warning(f"Error checking redirects: {e}")
|
||||
return url
|
||||
|
||||
async def _connect_with_timeout(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
|
||||
self, url: str | None, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
|
||||
):
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
|
||||
)
|
||||
self.sse, self.write = sse_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
|
||||
await self.session.initialize()
|
||||
"""Attempt to connect with timeout."""
|
||||
try:
|
||||
if url is None:
|
||||
return
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
|
||||
)
|
||||
self.sse, self.write = sse_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
|
||||
await self.session.initialize()
|
||||
except Exception as e:
|
||||
msg = f"Failed to establish SSE connection: {e!s}"
|
||||
raise ConnectionError(msg) from e
|
||||
|
||||
async def connect_to_server(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
|
||||
self,
|
||||
url: str | None,
|
||||
headers: dict[str, str] | None,
|
||||
timeout_seconds: int = 30,
|
||||
sse_read_timeout_seconds: int = 30,
|
||||
):
|
||||
"""Connect to server with retries and improved error handling."""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# First validate the URL
|
||||
is_valid, error_msg = await self.validate_url(url)
|
||||
if not is_valid:
|
||||
msg = f"Invalid SSE URL ({url}): {error_msg}"
|
||||
raise ValueError(msg)
|
||||
|
||||
url = await self.pre_check_redirect(url)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if self.session is None:
|
||||
msg = "Session not initialized"
|
||||
raise ValueError(msg)
|
||||
response = await self.session.list_tools()
|
||||
except asyncio.TimeoutError as err:
|
||||
msg = f"Connection to {url} timed out after {timeout_seconds} seconds"
|
||||
raise TimeoutError(msg) from err
|
||||
return response.tools
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
if self.session is None:
|
||||
msg = "Session not initialized"
|
||||
raise ValueError(msg)
|
||||
|
||||
response = await self.session.list_tools()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"Connection to {url} timed out after {timeout_seconds} seconds"
|
||||
logger.warning(f"Connection attempt {attempt + 1} failed: {last_error}")
|
||||
except ConnectionError as err:
|
||||
last_error = str(err)
|
||||
logger.warning(f"Connection attempt {attempt + 1} failed: {last_error}")
|
||||
except (ValueError, httpx.HTTPError, OSError) as err:
|
||||
last_error = f"Connection error: {err!s}"
|
||||
logger.warning(f"Connection attempt {attempt + 1} failed: {last_error}")
|
||||
else:
|
||||
return response.tools
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
msg = f"Failed to connect after {self.max_retries} attempts. Last error: {last_error}"
|
||||
raise ConnectionError(msg)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from langflow.base.mcp.util import (
|
||||
HTTP_ERROR_STATUS_CODE,
|
||||
MCPSseClient,
|
||||
MCPStdioClient,
|
||||
create_input_schema_from_json_schema,
|
||||
|
|
@ -21,12 +23,14 @@ from langflow.schema import Message
|
|||
|
||||
class MCPToolsComponent(Component):
|
||||
schema_inputs: list[InputTypes] = []
|
||||
stdio_client = MCPStdioClient()
|
||||
sse_client = MCPSseClient()
|
||||
stdio_client: MCPStdioClient = MCPStdioClient()
|
||||
sse_client: MCPSseClient = MCPSseClient()
|
||||
tools: list = []
|
||||
tool_names: list[str] = []
|
||||
_tool_cache: dict = {} # Cache for tool objects
|
||||
default_keys = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
|
||||
default_keys: list[str] = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
|
||||
|
||||
sse_url: str | None = None
|
||||
|
||||
display_name = "MCP Server"
|
||||
description = "Connect to an MCP server and expose tools."
|
||||
|
|
@ -54,7 +58,6 @@ class MCPToolsComponent(Component):
|
|||
name="sse_url",
|
||||
display_name="MCP SSE URL",
|
||||
info="URL for MCP SSE connection",
|
||||
value="http://localhost:7860/api/v1/mcp/sse",
|
||||
show=False,
|
||||
refresh_button=True,
|
||||
),
|
||||
|
|
@ -82,6 +85,21 @@ class MCPToolsComponent(Component):
|
|||
Output(display_name="Response", name="response", method="build_output"),
|
||||
]
|
||||
|
||||
async def find_langflow_instance(self) -> tuple[bool, int | None, str]:
|
||||
"""Find Langflow instance by checking env variable first, then scanning common ports."""
|
||||
# First check environment variable
|
||||
env_port = os.getenv("LANGFLOW_PORT")
|
||||
port = int(env_port) if env_port else 7860
|
||||
try:
|
||||
url = f"http://localhost:{port}/api/v1/mcp/sse"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.head(url, timeout=2.0)
|
||||
if response.status_code < HTTP_ERROR_STATUS_CODE:
|
||||
return True, port, f"Langflow instance found at configured port {port}"
|
||||
except (ValueError, httpx.TimeoutException, httpx.NetworkError, httpx.HTTPError):
|
||||
logger.warning(f"Could not connect to Langflow at configured port {env_port}")
|
||||
return False, None, "No Langflow instance found on configured port or common ports"
|
||||
|
||||
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
|
||||
"""Validate connection parameters based on mode."""
|
||||
if mode not in ["Stdio", "SSE"]:
|
||||
|
|
@ -131,8 +149,26 @@ class MCPToolsComponent(Component):
|
|||
elif field_value == "SSE":
|
||||
build_config["command"]["show"] = False
|
||||
build_config["sse_url"]["show"] = True
|
||||
_, port, _ = await self.find_langflow_instance()
|
||||
if port:
|
||||
build_config["sse_url"]["value"] = f"http://localhost:{port}/api/v1/mcp/sse"
|
||||
self.sse_url = build_config["sse_url"]["value"]
|
||||
return build_config
|
||||
if field_name in ("command", "sse_url", "mode"):
|
||||
try:
|
||||
# If SSE mode and localhost URL is not valid, try to find correct port
|
||||
if field_name == "sse_url":
|
||||
self.sse_url = field_value
|
||||
elif self.mode == "SSE" and ("localhost" in str(self.sse_url) or "127.0.0.1" in str(self.sse_url)):
|
||||
is_valid, _ = await self.sse_client.validate_url(self.sse_url)
|
||||
if not is_valid:
|
||||
found, port, message = await self.find_langflow_instance()
|
||||
if found:
|
||||
new_url = f"http://localhost:{port}/api/v1/mcp/sse"
|
||||
logger.info(f"Original URL {self.sse_url} not valid. {message}")
|
||||
build_config["sse_url"]["value"] = new_url
|
||||
self.sse_url = new_url
|
||||
|
||||
await self.update_tools()
|
||||
if "tool" in build_config:
|
||||
build_config["tool"]["options"] = self.tool_names
|
||||
|
|
@ -285,7 +321,33 @@ class MCPToolsComponent(Component):
|
|||
if not self.stdio_client.session:
|
||||
self.tools = await self.stdio_client.connect_to_server(self.command)
|
||||
elif self.mode == "SSE" and not self.sse_client.session:
|
||||
self.tools = await self.sse_client.connect_to_server(self.sse_url, {})
|
||||
try:
|
||||
is_valid, _ = await self.sse_client.validate_url(self.sse_url)
|
||||
if not is_valid:
|
||||
msg = f"Invalid SSE URL configuration: {self.sse_url}. Please check the SSE URL and try again."
|
||||
logger.error(msg)
|
||||
return []
|
||||
self.tools = await self.sse_client.connect_to_server(self.sse_url, {})
|
||||
except ValueError as e:
|
||||
# URL validation error
|
||||
logger.error(f"SSE URL validation error: {e}")
|
||||
msg = f"Invalid SSE URL configuration: {e}. Please check your Langflow deployment URL and port."
|
||||
raise ValueError(msg) from e
|
||||
except ConnectionError as e:
|
||||
# Connection failed after retries
|
||||
logger.error(f"SSE connection error: {e}")
|
||||
msg = (
|
||||
f"Could not connect to Langflow SSE endpoint: {e}. "
|
||||
"Please verify:\n"
|
||||
"1. Langflow server is running\n"
|
||||
"2. The SSE URL matches your Langflow deployment port\n"
|
||||
"3. There are no network issues preventing the connection"
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected SSE error: {e}")
|
||||
msg = f"Unexpected error connecting to SSE endpoint: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if not self.tools:
|
||||
logger.warning("No tools returned from server")
|
||||
|
|
@ -300,8 +362,7 @@ class MCPToolsComponent(Component):
|
|||
try:
|
||||
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
|
||||
if not args_schema:
|
||||
msg = f"Empty schema for tool '{tool.name}', skipping"
|
||||
logger.warning(msg)
|
||||
logger.warning(f"Empty schema for tool '{tool.name}', skipping")
|
||||
continue
|
||||
|
||||
client = self.stdio_client if self.mode == "Stdio" else self.sse_client
|
||||
|
|
@ -320,15 +381,18 @@ class MCPToolsComponent(Component):
|
|||
tool_list.append(tool_obj)
|
||||
self._tool_cache[tool.name] = tool_obj
|
||||
except (AttributeError, ValueError, TypeError, KeyError) as e:
|
||||
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e!s}"
|
||||
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e}"
|
||||
logger.exception(msg)
|
||||
continue
|
||||
|
||||
self.tool_names = [tool.name for tool in self.tools if hasattr(tool, "name")]
|
||||
|
||||
except (ValueError, RuntimeError, asyncio.TimeoutError) as e:
|
||||
msg = f"Error updating tools: {e!s}"
|
||||
logger.exception(msg)
|
||||
except ValueError as e:
|
||||
# Re-raise validation errors with clear messages
|
||||
raise ValueError(str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Error updating tools")
|
||||
msg = f"Failed to update tools: {e!s}"
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
return tool_list
|
||||
|
|
|
|||
|
|
@ -209,7 +209,10 @@ class TestMCPSseClient:
|
|||
async def test_connect_to_server(self, sse_client):
|
||||
"""Test connecting to server via SSE."""
|
||||
# Mock the pre_check_redirect first
|
||||
with patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"):
|
||||
with (
|
||||
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")),
|
||||
):
|
||||
# Create mock for sse_client context manager
|
||||
mock_sse = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
|
|
@ -245,11 +248,22 @@ class TestMCPSseClient:
|
|||
|
||||
async def test_connect_timeout(self, sse_client):
|
||||
"""Test connection timeout handling."""
|
||||
# Set max_retries to 1 to avoid multiple retry attempts
|
||||
sse_client.max_retries = 1
|
||||
|
||||
with (
|
||||
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
|
||||
patch.object(sse_client, "validate_url", return_value=(True, "")), # Mock URL validation
|
||||
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
|
||||
):
|
||||
mock_connect.side_effect = asyncio.TimeoutError()
|
||||
|
||||
with pytest.raises(TimeoutError, match="Connection to http://test.url timed out after 1 seconds"):
|
||||
# Expect ConnectionError instead of TimeoutError
|
||||
with pytest.raises(
|
||||
ConnectionError,
|
||||
match=(
|
||||
"Failed to connect after 1 attempts. "
|
||||
"Last error: Connection to http://test.url timed out after 1 seconds"
|
||||
),
|
||||
):
|
||||
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue