fix: support mcp env vars and nested inputs (#7772)

* env vars for mcp

* support nested mcp schemas

* [autofix.ci] apply automated fixes

* Update mcp_component.py

* [autofix.ci] apply automated fixes

* Update mcp_component.py

* Update mcp_component.py

* fix lint and mypy

* fix tests

---------

Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Sebastián Estévez 2025-04-24 00:40:16 -04:00 committed by GitHub
commit c161a2e68d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 426 additions and 125 deletions

View file

@ -11,10 +11,9 @@ from httpx import codes as httpx_codes
from loguru import logger
from mcp import ClientSession, StdioServerParameters, stdio_client
from mcp.client.sse import sse_client
from pydantic import Field, create_model
from pydantic import BaseModel, Field, create_model
from sqlmodel import select
from langflow.helpers.base_model import BaseModel
from langflow.services.database.models import Flow
HTTP_ERROR_STATUS_CODE = httpx_codes.BAD_REQUEST # HTTP status code for client errors
@ -78,44 +77,117 @@ async def get_flow_snake_case(flow_name: str, user_id: str, session) -> Flow | N
def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]:
"""Converts a JSON schema into a Pydantic model dynamically.
"""Dynamically build a Pydantic model from a JSON schema (with $defs).
Fields not listed as required are wrapped in Optional[...] and default to None if not provided.
:param schema: The JSON schema as a dictionary.
:return: A Pydantic model class.
Non-required fields become Optional[...] with default=None.
"""
if schema.get("type") != "object":
msg = "JSON schema must be of type 'object' at the root level."
msg = "Root schema must be type 'object'"
raise ValueError(msg)
fields = {}
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
defs: dict[str, dict[str, Any]] = schema.get("$defs", {})
model_cache: dict[str, type[BaseModel]] = {}
for field_name, field_def in properties.items():
# Determine the base type from the JSON schema type string.
field_type_str = field_def.get("type", "str") # Defaults to string if not specified.
base_type = {
def resolve_ref(s: dict[str, Any] | None) -> dict[str, Any]:
"""Follow a $ref chain until you land on a real subschema."""
if s is None:
return {}
while "$ref" in s:
ref_name = s["$ref"].split("/")[-1]
s = defs.get(ref_name)
if s is None:
msg = f"Definition '{ref_name}' not found"
raise ValueError(msg)
return s
def parse_type(s: dict[str, Any] | None) -> Any:
"""Map a JSON Schema subschema to a Python type (possibly nested)."""
if s is None:
return None
s = resolve_ref(s)
if "anyOf" in s:
subtypes = [parse_type(sub) for sub in s["anyOf"]]
return tuple | subtypes
t = s.get("type", Any)
if t == "array":
item_schema = s.get("items", {})
# schema_type: type[Any]
schema_type: Any = parse_type(item_schema)
return list[schema_type]
if t == "object":
# inline object not in $defs ⇒ anonymous nested model
return _build_model(f"AnonModel{len(model_cache)}", s)
# primitive fallback
return {
"string": str,
"str": str,
"integer": int,
"int": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}.get(field_type_str, Any)
"array": list,
}.get(t, Any)
# if result == "":
# return Any
# return result
field_metadata = {"description": field_def.get("description", "")}
def _build_model(name: str, subschema: dict[str, Any]) -> type[BaseModel]:
"""Create (or fetch) a BaseModel subclass for the given object schema."""
# If this came via a named $ref, use that name
if "$ref" in subschema:
refname = subschema["$ref"].split("/")[-1]
if refname in model_cache:
return model_cache[refname]
target = defs.get(refname)
if not target:
msg = f"Definition '{refname}' not found"
raise ValueError(msg)
cls = _build_model(refname, target)
model_cache[refname] = cls
return cls
# For non-required fields, wrap the type in Optional[...] and set a default value.
if field_name not in required_fields:
field_metadata["default"] = field_def.get("default", None)
# Named anonymous or inline: avoid clashes by name
if name in model_cache:
return model_cache[name]
fields[field_name] = (base_type, Field(**field_metadata))
props = subschema.get("properties", {})
reqs = set(subschema.get("required", []))
fields: dict[str, Any] = {}
return create_model("InputSchema", **fields)
for prop_name, prop_schema in props.items():
py_type = parse_type(prop_schema)
is_required = prop_name in reqs
if not is_required:
py_type = py_type | None
default = prop_schema.get("default", None)
else:
default = ... # required by Pydantic
fields[prop_name] = (py_type, Field(default, description=prop_schema.get("description")))
model_cls = create_model(name, **fields)
model_cache[name] = model_cls
return model_cls
# build the top - level “InputSchema” from the root properties
top_props = schema.get("properties", {})
top_reqs = set(schema.get("required", []))
top_fields: dict[str, Any] = {}
for fname, fdef in top_props.items():
py_type = parse_type(fdef)
if fname not in top_reqs:
py_type = py_type | None
default = fdef.get("default", None)
else:
default = ...
top_fields[fname] = (py_type, Field(default, description=fdef.get("description")))
return create_model("InputSchema", **top_fields)
class MCPStdioClient:
@ -123,12 +195,20 @@ class MCPStdioClient:
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def connect_to_server(self, command_str: str):
async def connect_to_server(self, command_str: str, env: list[str] | None = None):
env_dict: dict[str, str] = {}
if env is None:
env = []
for var in env:
if "=" not in var:
msg = f"Invalid env var format: {var}. Must be in the format 'VAR_NAME=VAR_VALUE'"
raise ValueError(msg)
env_dict[var.split("=")[0]] = var.split("=")[1]
command = command_str.split(" ")
server_params = StdioServerParameters(
command=command[0],
args=command[1:],
env={"DEBUG": "true", "PATH": os.environ["PATH"]},
env={"DEBUG": "true", "PATH": os.environ["PATH"], **(env_dict or {})},
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport

View file

@ -1,11 +1,9 @@
import os
import re
from typing import Any
import httpx
from langchain_core.tools import StructuredTool
from langflow.base.mcp.util import (
HTTP_ERROR_STATUS_CODE,
MCPSseClient,
MCPStdioClient,
create_input_schema_from_json_schema,
@ -13,14 +11,52 @@ from langflow.base.mcp.util import (
create_tool_func,
)
from langflow.custom import Component
from langflow.inputs import DropdownInput
from langflow.inputs import DropdownInput, TableInput
from langflow.inputs.inputs import InputTypes
from langflow.io import MessageTextInput, MultilineInput, Output, TabInput
from langflow.io.schema import schema_to_langflow_inputs
from langflow.io.schema import flatten_schema, schema_to_langflow_inputs
from langflow.logging import logger
from langflow.schema import Message
def maybe_unflatten_dict(flat: dict[str, Any]) -> dict[str, Any]:
"""If any key looks nested (contains a dot or “[index]”), rebuild the.
full nested structure; otherwise return flat as is.
"""
# Quick check: do we have any nested keys?
if not any(re.search(r"\.|\[\d+\]", key) for key in flat):
return flat
# Otherwise, unflatten into dicts/lists
nested: dict[str, Any] = {}
array_re = re.compile(r"^(.+)\[(\d+)\]$")
for key, val in flat.items():
parts = key.split(".")
cur = nested
for i, part in enumerate(parts):
m = array_re.match(part)
# Array segment?
if m:
name, idx = m.group(1), int(m.group(2))
lst = cur.setdefault(name, [])
# Ensure list is big enough
while len(lst) <= idx:
lst.append({})
if i == len(parts) - 1:
lst[idx] = val
else:
cur = lst[idx]
# Normal object key
elif i == len(parts) - 1:
cur[part] = val
else:
cur = cur.setdefault(part, {})
return nested
class MCPToolsComponent(Component):
schema_inputs: list[InputTypes] = []
stdio_client: MCPStdioClient = MCPStdioClient()
@ -28,7 +64,18 @@ class MCPToolsComponent(Component):
tools: list = []
tool_names: list[str] = []
_tool_cache: dict = {} # Cache for tool objects
default_keys: list[str] = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
default_keys: list[str] = [
"code",
"_type",
"mode",
"command",
"env",
"sse_url",
"tool_placeholder",
"tool_mode",
"tool",
"headers_input",
]
display_name = "MCP Server"
description = "Connect to an MCP server and expose tools."
@ -52,12 +99,45 @@ class MCPToolsComponent(Component):
show=True,
refresh_button=True,
),
MessageTextInput(
name="env",
display_name="Env",
info="Env vars to include in mcp stdio connection (i.e. DEBUG=true)",
value="",
is_list=True,
show=True,
tool_mode=False,
),
MultilineInput(
name="sse_url",
display_name="MCP SSE URL",
info="URL for MCP SSE connection",
show=False,
refresh_button=True,
value="MCP_SSE",
real_time_refresh=True,
),
TableInput(
name="headers_input",
display_name="Headers",
info="Headers to include in the tool",
show=False,
real_time_refresh=True,
table_schema=[
{
"name": "key",
"display_name": "Header",
"type": "str",
"description": "Header name",
},
{
"name": "value",
"display_name": "Value",
"type": "str",
"description": "Header value",
},
],
value=[],
),
DropdownInput(
name="tool",
@ -83,21 +163,6 @@ class MCPToolsComponent(Component):
Output(display_name="Response", name="response", method="build_output"),
]
async def find_langflow_instance(self) -> tuple[bool, int | None, str]:
"""Find Langflow instance by checking env variable first, then scanning common ports."""
# First check environment variable
env_port = os.getenv("LANGFLOW_PORT")
port = int(env_port) if env_port else 7860
try:
url = f"http://localhost:{port}/api/v1/mcp/sse"
async with httpx.AsyncClient() as client:
response = await client.head(url, timeout=2.0)
if response.status_code < HTTP_ERROR_STATUS_CODE:
return True, port, f"Langflow instance found at configured port {port}"
except (ValueError, httpx.TimeoutException, httpx.NetworkError, httpx.HTTPError):
logger.warning(f"Could not connect to Langflow at configured port {env_port}")
return False, None, "No Langflow instance found on configured port or common ports"
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
"""Validate connection parameters based on mode."""
if mode not in ["Stdio", "SSE"]:
@ -111,6 +176,33 @@ class MCPToolsComponent(Component):
msg = "URL is required for SSE mode"
raise ValueError(msg)
def _process_headers(self, headers: Any) -> dict:
"""Process the headers input into a valid dictionary.
Args:
headers: The headers to process, can be dict, str, or list
Returns:
Processed dictionary
"""
if headers is None:
return {}
if isinstance(headers, dict):
return headers
if isinstance(headers, list):
processed_headers = {}
try:
for item in headers:
if not self._is_valid_key_value_item(item):
continue
key = item["key"]
value = item["value"]
processed_headers[key] = value
except (KeyError, TypeError, ValueError) as e:
self.log(f"Failed to process headers list: {e}")
return {} # Return empty dictionary instead of None
return processed_headers
return {}
async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:
"""Validate and process schema inputs for a tool."""
try:
@ -118,7 +210,8 @@ class MCPToolsComponent(Component):
msg = "Invalid tool object or missing input schema"
raise ValueError(msg)
input_schema = create_input_schema_from_json_schema(tool_obj.inputSchema)
flat_schema = flatten_schema(tool_obj.inputSchema)
input_schema = create_input_schema_from_json_schema(flat_schema)
if not input_schema:
msg = f"Empty input schema for tool '{tool_obj.name}'"
raise ValueError(msg)
@ -143,43 +236,24 @@ class MCPToolsComponent(Component):
self.remove_non_default_keys(build_config)
if field_value == "Stdio":
build_config["command"]["show"] = True
build_config["env"]["show"] = True
build_config["headers_input"]["show"] = False
build_config["sse_url"]["show"] = False
elif field_value == "SSE":
build_config["command"]["show"] = False
build_config["env"]["show"] = False
build_config["sse_url"]["show"] = True
build_config["sse_url"]["value"] = "MCP_SSE"
build_config["headers_input"]["show"] = True
return build_config
if field_name in ("command", "sse_url", "mode"):
try:
# If SSE mode and localhost URL is not valid, try to find correct port
if build_config["mode"]["value"] == "SSE" and (
"localhost" in str(build_config["sse_url"]["value"])
or "127.0.0.1" in str(build_config["sse_url"]["value"])
):
is_valid, _ = await self.sse_client.validate_url(build_config["sse_url"]["value"])
if not is_valid:
found, port, message = await self.find_langflow_instance()
if found:
new_url = f"http://localhost:{port}/api/v1/mcp/sse"
logger.info(f"Original URL {build_config['sse_url']['value']} not valid. {message}")
build_config["sse_url"]["value"] = new_url
elif build_config["mode"]["value"] == "SSE":
if len(build_config["sse_url"]["value"]) > 0:
is_valid, _ = await self.sse_client.validate_url(build_config["sse_url"]["value"])
if not is_valid:
msg = (
f"Invalid SSE URL configuration: {build_config['sse_url']['value']}. "
"Please check the SSE URL and try again."
)
raise ValueError(msg)
else:
build_config["tool"]["options"] = []
return build_config
await self.update_tools(
mode=build_config["mode"]["value"],
command=build_config["command"]["value"],
url=build_config["sse_url"]["value"],
env=build_config["env"]["value"],
headers=build_config["headers_input"]["value"],
)
if "tool" in build_config:
build_config["tool"]["options"] = self.tool_names
@ -195,6 +269,8 @@ class MCPToolsComponent(Component):
mode=build_config["mode"]["value"],
command=build_config["command"]["value"],
url=build_config["sse_url"]["value"],
env=build_config["env"]["value"],
headers=build_config["headers_input"]["value"],
)
if self.tool is None:
return build_config
@ -229,8 +305,10 @@ class MCPToolsComponent(Component):
if not tool or not hasattr(tool, "name"):
continue
try:
input_schema = schema_to_langflow_inputs(create_input_schema_from_json_schema(tool.inputSchema))
inputs[tool.name] = input_schema
flat_schema = flatten_schema(tool.inputSchema)
input_schema = create_input_schema_from_json_schema(flat_schema)
langflow_inputs = schema_to_langflow_inputs(input_schema)
inputs[tool.name] = langflow_inputs
except (AttributeError, ValueError, TypeError, KeyError) as e:
msg = f"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}"
logger.exception(msg)
@ -262,6 +340,8 @@ class MCPToolsComponent(Component):
mode=build_config["mode"]["value"],
command=build_config["command"]["value"],
url=build_config["sse_url"]["value"],
env=build_config["env"]["value"],
headers=build_config["headers_input"]["value"],
)
if not tool_name:
@ -302,7 +382,6 @@ class MCPToolsComponent(Component):
msg = f"Error processing schema input {schema_input}: {e!s}"
logger.exception(msg)
continue
except ValueError as e:
msg = f"Schema validation error for tool {tool_name}: {e!s}"
logger.exception(msg)
@ -325,7 +404,11 @@ class MCPToolsComponent(Component):
value = getattr(self, arg.name, None)
if value:
kwargs[arg.name] = value
output = await exec_tool.coroutine(**kwargs)
unflattened_kwargs = maybe_unflatten_dict(kwargs)
output = await exec_tool.coroutine(**unflattened_kwargs)
return Message(text=output.content[len(output.content) - 1].text)
return Message(text="You must select a tool", error=True)
except Exception as e:
@ -334,7 +417,12 @@ class MCPToolsComponent(Component):
raise ValueError(msg) from e
async def update_tools(
self, mode: str | None = None, command: str | None = None, url: str | None = None
self,
mode: str | None = None,
command: str | None = None,
url: str | None = None,
env: list[str] | None = None,
headers: dict[str, str] | None = None,
) -> list[StructuredTool]:
"""Connect to the MCP server and update available tools with improved error handling."""
try:
@ -342,21 +430,21 @@ class MCPToolsComponent(Component):
mode = self.mode
if command is None:
command = self.command
if env is None:
env = self.env
if url is None:
url = self.sse_url
if headers is None:
headers = self.headers_input
headers = self._process_headers(headers)
await self._validate_connection_params(mode, command, url)
if mode == "Stdio":
if not self.stdio_client.session:
self.tools = await self.stdio_client.connect_to_server(command)
self.tools = await self.stdio_client.connect_to_server(command, env)
elif mode == "SSE" and not self.sse_client.session:
try:
is_valid, _ = await self.sse_client.validate_url(url)
if not is_valid:
msg = f"Invalid SSE URL configuration: {url}. Please check the SSE URL and try again."
logger.error(msg)
return []
self.tools = await self.sse_client.connect_to_server(url, {})
self.tools = await self.sse_client.connect_to_server(url, headers)
except ValueError as e:
# URL validation error
logger.error(f"SSE URL validation error: {e}")

View file

@ -1,8 +1,18 @@
from typing import Literal, Union, get_args, get_origin
from types import UnionType
from typing import Any, Literal, Union, get_args, get_origin
from pydantic import BaseModel, Field, create_model
from langflow.inputs.inputs import BoolInput, DictInput, FieldTypes, FloatInput, InputTypes, IntInput, MessageTextInput
from langflow.inputs.inputs import (
BoolInput,
DictInput,
DropdownInput,
FieldTypes,
FloatInput,
InputTypes,
IntInput,
MessageTextInput,
)
from langflow.schema.dotdict import dotdict
_convert_field_type_to_type: dict[FieldTypes, type] = {
@ -32,53 +42,174 @@ _convert_type_to_field_type = {
}
def schema_to_langflow_inputs(schema: type[BaseModel]) -> list["InputTypes"]:
"""Given a Pydantic schema, convert its fields to Langflow input definitions."""
inputs = []
def flatten_schema(root_schema: dict[str, Any]) -> dict[str, Any]:
"""Flatten a JSON RPC style schema into a single level JSON Schema.
If the input schema is already flat (no $defs / $ref / nested objects or arrays)
the function simply returns the original i.e. a noop.
"""
defs = root_schema.get("$defs", {})
# --- Fast path: schema is already flat ---------------------------------
props = root_schema.get("properties", {})
if not defs and all("$ref" not in v and v.get("type") not in ("object", "array") for v in props.values()):
return root_schema
# -----------------------------------------------------------------------
flat_props: dict[str, dict[str, Any]] = {}
required_list: list[str] = []
def _resolve_if_ref(schema: dict[str, Any]) -> dict[str, Any]:
while "$ref" in schema:
ref_name = schema["$ref"].split("/")[-1]
schema = defs.get(ref_name, {})
return schema
def _walk(name: str, schema: dict[str, Any], *, inherited_req: bool) -> None:
schema = _resolve_if_ref(schema)
t = schema.get("type")
# ── objects ─────────────────────────────────────────────────────────
if t == "object":
req_here = set(schema.get("required", []))
for k, subschema in schema.get("properties", {}).items():
child_name = f"{name}.{k}" if name else k
_walk(name=child_name, schema=subschema, inherited_req=inherited_req and k in req_here)
return
# ── arrays (always recurse into the first item as “[0]”) ───────────
if t == "array":
items = schema.get("items", {})
_walk(name=f"{name}[0]", schema=items, inherited_req=inherited_req)
return
leaf: dict[str, Any] = {
k: v
for k, v in schema.items()
if k
in (
"type",
"description",
"pattern",
"format",
"enum",
"default",
"minLength",
"maxLength",
"minimum",
"maximum",
"exclusiveMinimum",
"exclusiveMaximum",
"additionalProperties",
"examples",
)
}
flat_props[name] = leaf
if inherited_req:
required_list.append(name)
# kick things off at the true root
root_required = set(root_schema.get("required", []))
for k, subschema in props.items():
_walk(k, subschema, inherited_req=k in root_required)
# build the flattened schema; keep any descriptive metadata
result: dict[str, Any] = {
"type": "object",
"properties": flat_props,
**{k: v for k, v in root_schema.items() if k not in ("properties", "$defs")},
}
if required_list:
result["required"] = required_list
return result
def schema_to_langflow_inputs(schema: type[BaseModel]) -> list[InputTypes]:
inputs: list[InputTypes] = []
for field_name, model_field in schema.model_fields.items():
# Start with the field's annotation type
field_type = model_field.annotation
ann = model_field.annotation
if isinstance(ann, UnionType):
# Extract non-None types from Union
non_none_types = [t for t in get_args(ann) if t is not type(None)]
if len(non_none_types) == 1:
ann = non_none_types[0]
is_list = False
options = None
# If the field is a list, record that and extract its inner type.
if get_origin(field_type) is list:
if get_origin(ann) is list:
is_list = True
field_type = get_args(field_type)[0]
ann = get_args(ann)[0]
# If the field type is a Literal, extract its allowed values.
if get_origin(field_type) is Literal:
options = list(get_args(field_type))
# Optionally, set field_type to the type of the literal values.
options: list[Any] | None = None
if get_origin(ann) is Literal:
options = list(get_args(ann))
if options:
field_type = type(options[0])
ann = type(options[0])
# Handle Union types (e.g., Optional fields)
if get_origin(field_type) is Union:
# Get the first non-None type from the Union
field_type = next(t for t in get_args(field_type) if t is not type(None))
if get_origin(ann) is Union:
non_none = [t for t in get_args(ann) if t is not type(None)]
if len(non_none) == 1:
ann = non_none[0]
# Convert the Python type to the Langflow field type using our reverse mapping.
# 1) Nested Pydantic model?
# if isinstance(ann, type) and issubclass(ann, BaseModel):
# nested = schema_to_langflow_inputs(ann)
# inputs.append(
# ObjectInput(
# display_name=model_field.title or field_name.replace("_", " ").title(),
# name=field_name,
# info=model_field.description or "",
# required=model_field.is_required(),
# is_list=is_list,
# inputs=nested,
# )
# )
# continue
# 2) Enumerated choices
if options is not None:
inputs.append(
DropdownInput(
display_name=model_field.title or field_name.replace("_", " ").title(),
name=field_name,
info=model_field.description or "",
required=model_field.is_required(),
is_list=is_list,
options=options,
)
)
continue
# 3) “Any” fallback → text
if ann is Any:
inputs.append(
MessageTextInput(
display_name=model_field.title or field_name.replace("_", " ").title(),
name=field_name,
info=model_field.description or "",
required=model_field.is_required(),
is_list=is_list,
)
)
continue
# 4) Primitive via your mapping
try:
langflow_field_type = _convert_type_to_field_type[field_type]
except KeyError as e:
msg = f"Unsupported field type: {field_type}"
raise TypeError(msg) from e
# Get metadata from the Pydantic Field.
title = model_field.title or field_name.replace("_", " ").title()
description = model_field.description or ""
required = model_field.is_required()
# Construct the Langflow input.
input_obj = langflow_field_type(
display_name=title,
name=field_name,
info=description,
required=required,
is_list=is_list,
lf_cls = _convert_type_to_field_type[ann]
except KeyError as err:
msg = f"Unsupported field type: {ann}"
raise TypeError(msg) from err
inputs.append(
lf_cls(
display_name=model_field.title or field_name.replace("_", " ").title(),
name=field_name,
info=model_field.description or "",
required=model_field.is_required(),
is_list=is_list,
)
)
inputs.append(input_obj)
return inputs

View file

@ -82,6 +82,8 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
"sse_url": {"show": True, "value": "http://localhost:7860/api/v1/mcp/sse"},
"tool": {"options": [], "show": True},
"mode": {"value": "Stdio"},
"env": {"show": True, "value": []},
"headers_input": {"show": False, "value": []},
}
# Test switching to Stdio mode