langflow/src/backend/base/langflow/components/tools/mcp_component.py
2025-05-07 20:48:45 +00:00

531 lines
21 KiB
Python

import re
from typing import Any
from langchain_core.tools import StructuredTool
from langflow.base.mcp.util import (
MCPSseClient,
MCPStdioClient,
create_input_schema_from_json_schema,
create_tool_coroutine,
create_tool_func,
)
from langflow.custom import Component
from langflow.inputs import DropdownInput, TableInput
from langflow.inputs.inputs import InputTypes
from langflow.io import MessageTextInput, MultilineInput, Output, TabInput
from langflow.io.schema import flatten_schema, schema_to_langflow_inputs
from langflow.logging import logger
from langflow.schema import Message
def maybe_unflatten_dict(flat: dict[str, Any]) -> dict[str, Any]:
"""If any key looks nested (contains a dot or “[index]”), rebuild the.
full nested structure; otherwise return flat as is.
"""
# Quick check: do we have any nested keys?
if not any(re.search(r"\.|\[\d+\]", key) for key in flat):
return flat
# Otherwise, unflatten into dicts/lists
nested: dict[str, Any] = {}
array_re = re.compile(r"^(.+)\[(\d+)\]$")
for key, val in flat.items():
parts = key.split(".")
cur = nested
for i, part in enumerate(parts):
m = array_re.match(part)
# Array segment?
if m:
name, idx = m.group(1), int(m.group(2))
lst = cur.setdefault(name, [])
# Ensure list is big enough
while len(lst) <= idx:
lst.append({})
if i == len(parts) - 1:
lst[idx] = val
else:
cur = lst[idx]
# Normal object key
elif i == len(parts) - 1:
cur[part] = val
else:
cur = cur.setdefault(part, {})
return nested
class MCPToolsComponent(Component):
schema_inputs: list[InputTypes] = []
stdio_client: MCPStdioClient = MCPStdioClient()
sse_client: MCPSseClient = MCPSseClient()
tools: list = []
tool_names: list[str] = []
_tool_cache: dict = {} # Cache for tool objects
default_keys: list[str] = [
"code",
"_type",
"mode",
"command",
"env",
"sse_url",
"tool_placeholder",
"tool_mode",
"tool",
"headers_input",
]
display_name = "MCP Connection"
description = "Connect to an MCP server to use its tools."
icon = "Mcp"
name = "MCPTools"
inputs = [
TabInput(
name="mode",
display_name="Mode",
options=["Stdio", "SSE"],
value="Stdio",
info="Select the connection mode",
real_time_refresh=True,
),
MessageTextInput(
name="command",
display_name="MCP Command",
info="Command for MCP stdio connection",
value="uvx mcp-server-fetch",
show=True,
refresh_button=True,
),
MessageTextInput(
name="env",
display_name="Env",
info="Env vars to include in mcp stdio connection (i.e. DEBUG=true)",
value="",
is_list=True,
show=True,
tool_mode=False,
advanced=True,
),
MultilineInput(
name="sse_url",
display_name="MCP SSE URL",
info="URL for MCP SSE connection",
show=False,
refresh_button=True,
value="MCP_SSE",
real_time_refresh=True,
),
TableInput(
name="headers_input",
display_name="Headers",
info="Headers to include in the tool",
show=False,
real_time_refresh=True,
table_schema=[
{
"name": "key",
"display_name": "Header",
"type": "str",
"description": "Header name",
},
{
"name": "value",
"display_name": "Value",
"type": "str",
"description": "Header value",
},
],
value=[],
advanced=True,
),
DropdownInput(
name="tool",
display_name="Tool",
options=[],
value="",
info="Select the tool to execute",
show=True,
required=True,
real_time_refresh=True,
),
MessageTextInput(
name="tool_placeholder",
display_name="Tool Placeholder",
info="Placeholder for the tool",
value="",
show=False,
tool_mode=True,
),
]
outputs = [
Output(display_name="Response", name="response", method="build_output"),
]
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
"""Validate connection parameters based on mode."""
if mode not in ["Stdio", "SSE"]:
msg = f"Invalid mode: {mode}. Must be either 'Stdio' or 'SSE'"
raise ValueError(msg)
if mode == "Stdio" and not command:
msg = "Command is required for Stdio mode"
raise ValueError(msg)
if mode == "SSE" and not url:
msg = "URL is required for SSE mode"
raise ValueError(msg)
def _process_headers(self, headers: Any) -> dict:
"""Process the headers input into a valid dictionary.
Args:
headers: The headers to process, can be dict, str, or list
Returns:
Processed dictionary
"""
if headers is None:
return {}
if isinstance(headers, dict):
return headers
if isinstance(headers, list):
processed_headers = {}
try:
for item in headers:
if not self._is_valid_key_value_item(item):
continue
key = item["key"]
value = item["value"]
processed_headers[key] = value
except (KeyError, TypeError, ValueError) as e:
self.log(f"Failed to process headers list: {e}")
return {} # Return empty dictionary instead of None
return processed_headers
return {}
def _is_valid_key_value_item(self, item: Any) -> bool:
"""Check if an item is a valid key-value dictionary."""
return isinstance(item, dict) and "key" in item and "value" in item
async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:
"""Validate and process schema inputs for a tool."""
try:
if not tool_obj or not hasattr(tool_obj, "inputSchema"):
msg = "Invalid tool object or missing input schema"
raise ValueError(msg)
flat_schema = flatten_schema(tool_obj.inputSchema)
input_schema = create_input_schema_from_json_schema(flat_schema)
if not input_schema:
msg = f"Empty input schema for tool '{tool_obj.name}'"
raise ValueError(msg)
schema_inputs = schema_to_langflow_inputs(input_schema)
if not schema_inputs:
msg = f"No input parameters defined for tool '{tool_obj.name}'"
logger.warning(msg)
return []
except Exception as e:
msg = f"Error validating schema inputs: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
else:
return schema_inputs
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
"""Toggle the visibility of connection-specific fields based on the selected mode."""
try:
if field_name == "mode":
self.remove_non_default_keys(build_config)
build_config["tool"]["options"] = []
if field_value == "Stdio":
build_config["command"]["show"] = True
build_config["env"]["show"] = True
build_config["headers_input"]["show"] = False
build_config["sse_url"]["show"] = False
elif field_value == "SSE":
build_config["command"]["show"] = False
build_config["env"]["show"] = False
build_config["sse_url"]["show"] = True
build_config["sse_url"]["value"] = "MCP_SSE"
build_config["headers_input"]["show"] = True
return build_config
if field_name in ("command", "sse_url", "mode"):
try:
await self.update_tools(
mode=build_config["mode"]["value"],
command=build_config["command"]["value"],
url=build_config["sse_url"]["value"],
env=build_config["env"]["value"],
headers=build_config["headers_input"]["value"],
)
if "tool" in build_config:
build_config["tool"]["options"] = self.tool_names
except Exception as e:
build_config["tool"]["options"] = []
msg = f"Failed to update tools: {e!s}"
raise ValueError(msg) from e
else:
return build_config
elif field_name == "tool":
if len(self.tools) == 0:
await self.update_tools(
mode=build_config["mode"]["value"],
command=build_config["command"]["value"],
url=build_config["sse_url"]["value"],
env=build_config["env"]["value"],
headers=build_config["headers_input"]["value"],
)
if self.tool is None:
return build_config
tool_obj = None
for tool in self.tools:
if tool.name == self.tool:
tool_obj = tool
break
if tool_obj is None:
msg = f"Tool {self.tool} not found in available tools: {self.tools}"
logger.warning(msg)
return build_config
self.remove_non_default_keys(build_config)
await self._update_tool_config(build_config, field_value)
elif field_name == "tool_mode":
build_config["tool"]["show"] = not field_value
for key, value in list(build_config.items()):
if key not in self.default_keys and isinstance(value, dict) and "show" in value:
build_config[key]["show"] = not field_value
except Exception as e:
msg = f"Error in update_build_config: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
else:
return build_config
def get_inputs_for_all_tools(self, tools: list) -> dict:
"""Get input schemas for all tools."""
inputs = {}
for tool in tools:
if not tool or not hasattr(tool, "name"):
continue
try:
flat_schema = flatten_schema(tool.inputSchema)
input_schema = create_input_schema_from_json_schema(flat_schema)
langflow_inputs = schema_to_langflow_inputs(input_schema)
inputs[tool.name] = langflow_inputs
except (AttributeError, ValueError, TypeError, KeyError) as e:
msg = f"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}"
logger.exception(msg)
continue
return inputs
def remove_input_schema_from_build_config(
self, build_config: dict, tool_name: str, input_schema: dict[list[InputTypes], Any]
):
"""Remove the input schema for the tool from the build config."""
# Keep only schemas that don't belong to the current tool
input_schema = {k: v for k, v in input_schema.items() if k != tool_name}
# Remove all inputs from other tools
for value in input_schema.values():
for _input in value:
if _input.name in build_config:
build_config.pop(_input.name)
def remove_non_default_keys(self, build_config: dict) -> None:
"""Remove non-default keys from the build config."""
for key in list(build_config.keys()):
if key not in self.default_keys:
build_config.pop(key)
async def _update_tool_config(self, build_config: dict, tool_name: str) -> None:
"""Update tool configuration with proper error handling."""
if not self.tools:
await self.update_tools(
mode=build_config["mode"]["value"],
command=build_config["command"]["value"],
url=build_config["sse_url"]["value"],
env=build_config["env"]["value"],
headers=build_config["headers_input"]["value"],
)
if not tool_name:
return
tool_obj = next((tool for tool in self.tools if tool.name == tool_name), None)
if not tool_obj:
msg = f"Tool {tool_name} not found in available tools: {self.tools}"
logger.warning(msg)
return
try:
# Get all tool inputs and remove old ones
input_schema_for_all_tools = self.get_inputs_for_all_tools(self.tools)
self.remove_input_schema_from_build_config(build_config, tool_name, input_schema_for_all_tools)
# Get and validate new inputs
self.schema_inputs = await self._validate_schema_inputs(tool_obj)
if not self.schema_inputs:
msg = f"No input parameters to configure for tool '{tool_name}'"
logger.info(msg)
return
# Add new inputs to build config
for schema_input in self.schema_inputs:
if not schema_input or not hasattr(schema_input, "name"):
msg = "Invalid schema input detected, skipping"
logger.warning(msg)
continue
try:
name = schema_input.name
input_dict = schema_input.to_dict()
input_dict.setdefault("value", None)
input_dict.setdefault("required", True)
build_config[name] = input_dict
except (AttributeError, KeyError, TypeError) as e:
msg = f"Error processing schema input {schema_input}: {e!s}"
logger.exception(msg)
continue
except ValueError as e:
msg = f"Schema validation error for tool {tool_name}: {e!s}"
logger.exception(msg)
self.schema_inputs = []
return
except (AttributeError, KeyError, TypeError) as e:
msg = f"Error updating tool config: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
async def build_output(self) -> Message:
"""Build output with improved error handling and validation."""
try:
await self.update_tools()
if self.tool != "":
exec_tool = self._tool_cache[self.tool]
tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool]
kwargs = {}
for arg in tool_args:
value = getattr(self, arg.name, None)
if value:
kwargs[arg.name] = value
unflattened_kwargs = maybe_unflatten_dict(kwargs)
output = await exec_tool.coroutine(**unflattened_kwargs)
return Message(text=output.content[len(output.content) - 1].text)
return Message(text="You must select a tool", error=True)
except Exception as e:
msg = f"Error in build_output: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
async def update_tools(
self,
mode: str | None = None,
command: str | None = None,
url: str | None = None,
env: list[str] | None = None,
headers: dict[str, str] | None = None,
) -> list[StructuredTool]:
"""Connect to the MCP server and update available tools with improved error handling."""
try:
if mode is None:
mode = self.mode
if command is None:
command = self.command
if env is None:
env = self.env
if url is None:
url = self.sse_url
if headers is None:
headers = self.headers_input
headers = self._process_headers(headers)
await self._validate_connection_params(mode, command, url)
if mode == "Stdio":
if not self.stdio_client.session:
self.tools = await self.stdio_client.connect_to_server(command, env)
elif mode == "SSE" and not self.sse_client.session:
try:
self.tools = await self.sse_client.connect_to_server(url, headers)
except ValueError as e:
# URL validation error
logger.error(f"SSE URL validation error: {e}")
msg = f"Invalid SSE URL configuration: {e}. Please check your Langflow deployment URL and port."
raise ValueError(msg) from e
except ConnectionError as e:
# Connection failed after retries
logger.error(f"SSE connection error: {e}")
msg = (
f"Could not connect to Langflow SSE endpoint: {e}. "
"Please verify:\n"
"1. Langflow server is running\n"
"2. The SSE URL matches your Langflow deployment port\n"
"3. There are no network issues preventing the connection"
)
raise ValueError(msg) from e
except Exception as e:
logger.error(f"Unexpected SSE error: {e}")
msg = f"Unexpected error connecting to SSE endpoint: {e}"
raise ValueError(msg) from e
if not self.tools:
logger.warning("No tools returned from server")
return []
tool_list = []
for tool in self.tools:
if not tool or not hasattr(tool, "name"):
logger.warning("Invalid tool object detected, skipping")
continue
try:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
if not args_schema:
logger.warning(f"Empty schema for tool '{tool.name}', skipping")
continue
client = self.stdio_client if self.mode == "Stdio" else self.sse_client
if not client or not client.session:
msg = f"Invalid client session for tool '{tool.name}'"
raise ValueError(msg)
tool_obj = StructuredTool(
name=tool.name,
description=tool.description or "",
args_schema=args_schema,
func=create_tool_func(tool.name, args_schema, client.session),
coroutine=create_tool_coroutine(tool.name, args_schema, client.session),
tags=[tool.name],
)
tool_list.append(tool_obj)
self._tool_cache[tool.name] = tool_obj
except (AttributeError, ValueError, TypeError, KeyError) as e:
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e}"
logger.exception(msg)
continue
self.tool_names = [tool.name for tool in self.tools if hasattr(tool, "name")]
except ValueError as e:
# Re-raise validation errors with clear messages
raise ValueError(str(e)) from e
except Exception as e:
logger.exception("Error updating tools")
msg = f"Failed to update tools: {e!s}"
raise ValueError(msg) from e
else:
return tool_list
async def _get_tools(self):
"""Get cached tools or update if necessary."""
# if not self.tools:
if self.mode == "SSE" and self.sse_url is None:
msg = "SSE URL is not set"
raise ValueError(msg)
return await self.update_tools()
# return self.tools