fix: support mcp env vars and nested inputs (#7772)
* env vars for mcp * support nested mcp schemas * [autofix.ci] apply automated fixes * Update mcp_component.py * [autofix.ci] apply automated fixes * Update mcp_component.py * Update mcp_component.py * fix lint and mypy * fix tests --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
60de34074b
commit
c161a2e68d
4 changed files with 426 additions and 125 deletions
|
|
@ -11,10 +11,9 @@ from httpx import codes as httpx_codes
|
|||
from loguru import logger
|
||||
from mcp import ClientSession, StdioServerParameters, stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from pydantic import Field, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.helpers.base_model import BaseModel
|
||||
from langflow.services.database.models import Flow
|
||||
|
||||
HTTP_ERROR_STATUS_CODE = httpx_codes.BAD_REQUEST # HTTP status code for client errors
|
||||
|
|
@ -78,44 +77,117 @@ async def get_flow_snake_case(flow_name: str, user_id: str, session) -> Flow | N
|
|||
|
||||
|
||||
def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]:
|
||||
"""Converts a JSON schema into a Pydantic model dynamically.
|
||||
"""Dynamically build a Pydantic model from a JSON schema (with $defs).
|
||||
|
||||
Fields not listed as required are wrapped in Optional[...] and default to None if not provided.
|
||||
|
||||
:param schema: The JSON schema as a dictionary.
|
||||
:return: A Pydantic model class.
|
||||
Non-required fields become Optional[...] with default=None.
|
||||
"""
|
||||
if schema.get("type") != "object":
|
||||
msg = "JSON schema must be of type 'object' at the root level."
|
||||
msg = "Root schema must be type 'object'"
|
||||
raise ValueError(msg)
|
||||
|
||||
fields = {}
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
defs: dict[str, dict[str, Any]] = schema.get("$defs", {})
|
||||
model_cache: dict[str, type[BaseModel]] = {}
|
||||
|
||||
for field_name, field_def in properties.items():
|
||||
# Determine the base type from the JSON schema type string.
|
||||
field_type_str = field_def.get("type", "str") # Defaults to string if not specified.
|
||||
base_type = {
|
||||
def resolve_ref(s: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Follow a $ref chain until you land on a real subschema."""
|
||||
if s is None:
|
||||
return {}
|
||||
while "$ref" in s:
|
||||
ref_name = s["$ref"].split("/")[-1]
|
||||
s = defs.get(ref_name)
|
||||
if s is None:
|
||||
msg = f"Definition '{ref_name}' not found"
|
||||
raise ValueError(msg)
|
||||
return s
|
||||
|
||||
def parse_type(s: dict[str, Any] | None) -> Any:
|
||||
"""Map a JSON Schema subschema to a Python type (possibly nested)."""
|
||||
if s is None:
|
||||
return None
|
||||
s = resolve_ref(s)
|
||||
|
||||
if "anyOf" in s:
|
||||
subtypes = [parse_type(sub) for sub in s["anyOf"]]
|
||||
return tuple | subtypes
|
||||
|
||||
t = s.get("type", Any)
|
||||
if t == "array":
|
||||
item_schema = s.get("items", {})
|
||||
# schema_type: type[Any]
|
||||
schema_type: Any = parse_type(item_schema)
|
||||
|
||||
return list[schema_type]
|
||||
|
||||
if t == "object":
|
||||
# inline object not in $defs ⇒ anonymous nested model
|
||||
return _build_model(f"AnonModel{len(model_cache)}", s)
|
||||
|
||||
# primitive fallback
|
||||
return {
|
||||
"string": str,
|
||||
"str": str,
|
||||
"integer": int,
|
||||
"int": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}.get(field_type_str, Any)
|
||||
"array": list,
|
||||
}.get(t, Any)
|
||||
# if result == "":
|
||||
# return Any
|
||||
# return result
|
||||
|
||||
field_metadata = {"description": field_def.get("description", "")}
|
||||
def _build_model(name: str, subschema: dict[str, Any]) -> type[BaseModel]:
|
||||
"""Create (or fetch) a BaseModel subclass for the given object schema."""
|
||||
# If this came via a named $ref, use that name
|
||||
if "$ref" in subschema:
|
||||
refname = subschema["$ref"].split("/")[-1]
|
||||
if refname in model_cache:
|
||||
return model_cache[refname]
|
||||
target = defs.get(refname)
|
||||
if not target:
|
||||
msg = f"Definition '{refname}' not found"
|
||||
raise ValueError(msg)
|
||||
cls = _build_model(refname, target)
|
||||
model_cache[refname] = cls
|
||||
return cls
|
||||
|
||||
# For non-required fields, wrap the type in Optional[...] and set a default value.
|
||||
if field_name not in required_fields:
|
||||
field_metadata["default"] = field_def.get("default", None)
|
||||
# Named anonymous or inline: avoid clashes by name
|
||||
if name in model_cache:
|
||||
return model_cache[name]
|
||||
|
||||
fields[field_name] = (base_type, Field(**field_metadata))
|
||||
props = subschema.get("properties", {})
|
||||
reqs = set(subschema.get("required", []))
|
||||
fields: dict[str, Any] = {}
|
||||
|
||||
return create_model("InputSchema", **fields)
|
||||
for prop_name, prop_schema in props.items():
|
||||
py_type = parse_type(prop_schema)
|
||||
is_required = prop_name in reqs
|
||||
if not is_required:
|
||||
py_type = py_type | None
|
||||
default = prop_schema.get("default", None)
|
||||
else:
|
||||
default = ... # required by Pydantic
|
||||
|
||||
fields[prop_name] = (py_type, Field(default, description=prop_schema.get("description")))
|
||||
|
||||
model_cls = create_model(name, **fields)
|
||||
model_cache[name] = model_cls
|
||||
return model_cls
|
||||
|
||||
# build the top - level “InputSchema” from the root properties
|
||||
top_props = schema.get("properties", {})
|
||||
top_reqs = set(schema.get("required", []))
|
||||
top_fields: dict[str, Any] = {}
|
||||
|
||||
for fname, fdef in top_props.items():
|
||||
py_type = parse_type(fdef)
|
||||
if fname not in top_reqs:
|
||||
py_type = py_type | None
|
||||
default = fdef.get("default", None)
|
||||
else:
|
||||
default = ...
|
||||
top_fields[fname] = (py_type, Field(default, description=fdef.get("description")))
|
||||
|
||||
return create_model("InputSchema", **top_fields)
|
||||
|
||||
|
||||
class MCPStdioClient:
|
||||
|
|
@ -123,12 +195,20 @@ class MCPStdioClient:
|
|||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_to_server(self, command_str: str):
|
||||
async def connect_to_server(self, command_str: str, env: list[str] | None = None):
|
||||
env_dict: dict[str, str] = {}
|
||||
if env is None:
|
||||
env = []
|
||||
for var in env:
|
||||
if "=" not in var:
|
||||
msg = f"Invalid env var format: {var}. Must be in the format 'VAR_NAME=VAR_VALUE'"
|
||||
raise ValueError(msg)
|
||||
env_dict[var.split("=")[0]] = var.split("=")[1]
|
||||
command = command_str.split(" ")
|
||||
server_params = StdioServerParameters(
|
||||
command=command[0],
|
||||
args=command[1:],
|
||||
env={"DEBUG": "true", "PATH": os.environ["PATH"]},
|
||||
env={"DEBUG": "true", "PATH": os.environ["PATH"], **(env_dict or {})},
|
||||
)
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from langflow.base.mcp.util import (
|
||||
HTTP_ERROR_STATUS_CODE,
|
||||
MCPSseClient,
|
||||
MCPStdioClient,
|
||||
create_input_schema_from_json_schema,
|
||||
|
|
@ -13,14 +11,52 @@ from langflow.base.mcp.util import (
|
|||
create_tool_func,
|
||||
)
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs import DropdownInput
|
||||
from langflow.inputs import DropdownInput, TableInput
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.io import MessageTextInput, MultilineInput, Output, TabInput
|
||||
from langflow.io.schema import schema_to_langflow_inputs
|
||||
from langflow.io.schema import flatten_schema, schema_to_langflow_inputs
|
||||
from langflow.logging import logger
|
||||
from langflow.schema import Message
|
||||
|
||||
|
||||
def maybe_unflatten_dict(flat: dict[str, Any]) -> dict[str, Any]:
|
||||
"""If any key looks nested (contains a dot or “[index]”), rebuild the.
|
||||
|
||||
full nested structure; otherwise return flat as is.
|
||||
"""
|
||||
# Quick check: do we have any nested keys?
|
||||
if not any(re.search(r"\.|\[\d+\]", key) for key in flat):
|
||||
return flat
|
||||
|
||||
# Otherwise, unflatten into dicts/lists
|
||||
nested: dict[str, Any] = {}
|
||||
array_re = re.compile(r"^(.+)\[(\d+)\]$")
|
||||
|
||||
for key, val in flat.items():
|
||||
parts = key.split(".")
|
||||
cur = nested
|
||||
for i, part in enumerate(parts):
|
||||
m = array_re.match(part)
|
||||
# Array segment?
|
||||
if m:
|
||||
name, idx = m.group(1), int(m.group(2))
|
||||
lst = cur.setdefault(name, [])
|
||||
# Ensure list is big enough
|
||||
while len(lst) <= idx:
|
||||
lst.append({})
|
||||
if i == len(parts) - 1:
|
||||
lst[idx] = val
|
||||
else:
|
||||
cur = lst[idx]
|
||||
# Normal object key
|
||||
elif i == len(parts) - 1:
|
||||
cur[part] = val
|
||||
else:
|
||||
cur = cur.setdefault(part, {})
|
||||
|
||||
return nested
|
||||
|
||||
|
||||
class MCPToolsComponent(Component):
|
||||
schema_inputs: list[InputTypes] = []
|
||||
stdio_client: MCPStdioClient = MCPStdioClient()
|
||||
|
|
@ -28,7 +64,18 @@ class MCPToolsComponent(Component):
|
|||
tools: list = []
|
||||
tool_names: list[str] = []
|
||||
_tool_cache: dict = {} # Cache for tool objects
|
||||
default_keys: list[str] = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
|
||||
default_keys: list[str] = [
|
||||
"code",
|
||||
"_type",
|
||||
"mode",
|
||||
"command",
|
||||
"env",
|
||||
"sse_url",
|
||||
"tool_placeholder",
|
||||
"tool_mode",
|
||||
"tool",
|
||||
"headers_input",
|
||||
]
|
||||
|
||||
display_name = "MCP Server"
|
||||
description = "Connect to an MCP server and expose tools."
|
||||
|
|
@ -52,12 +99,45 @@ class MCPToolsComponent(Component):
|
|||
show=True,
|
||||
refresh_button=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="env",
|
||||
display_name="Env",
|
||||
info="Env vars to include in mcp stdio connection (i.e. DEBUG=true)",
|
||||
value="",
|
||||
is_list=True,
|
||||
show=True,
|
||||
tool_mode=False,
|
||||
),
|
||||
MultilineInput(
|
||||
name="sse_url",
|
||||
display_name="MCP SSE URL",
|
||||
info="URL for MCP SSE connection",
|
||||
show=False,
|
||||
refresh_button=True,
|
||||
value="MCP_SSE",
|
||||
real_time_refresh=True,
|
||||
),
|
||||
TableInput(
|
||||
name="headers_input",
|
||||
display_name="Headers",
|
||||
info="Headers to include in the tool",
|
||||
show=False,
|
||||
real_time_refresh=True,
|
||||
table_schema=[
|
||||
{
|
||||
"name": "key",
|
||||
"display_name": "Header",
|
||||
"type": "str",
|
||||
"description": "Header name",
|
||||
},
|
||||
{
|
||||
"name": "value",
|
||||
"display_name": "Value",
|
||||
"type": "str",
|
||||
"description": "Header value",
|
||||
},
|
||||
],
|
||||
value=[],
|
||||
),
|
||||
DropdownInput(
|
||||
name="tool",
|
||||
|
|
@ -83,21 +163,6 @@ class MCPToolsComponent(Component):
|
|||
Output(display_name="Response", name="response", method="build_output"),
|
||||
]
|
||||
|
||||
async def find_langflow_instance(self) -> tuple[bool, int | None, str]:
|
||||
"""Find Langflow instance by checking env variable first, then scanning common ports."""
|
||||
# First check environment variable
|
||||
env_port = os.getenv("LANGFLOW_PORT")
|
||||
port = int(env_port) if env_port else 7860
|
||||
try:
|
||||
url = f"http://localhost:{port}/api/v1/mcp/sse"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.head(url, timeout=2.0)
|
||||
if response.status_code < HTTP_ERROR_STATUS_CODE:
|
||||
return True, port, f"Langflow instance found at configured port {port}"
|
||||
except (ValueError, httpx.TimeoutException, httpx.NetworkError, httpx.HTTPError):
|
||||
logger.warning(f"Could not connect to Langflow at configured port {env_port}")
|
||||
return False, None, "No Langflow instance found on configured port or common ports"
|
||||
|
||||
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
|
||||
"""Validate connection parameters based on mode."""
|
||||
if mode not in ["Stdio", "SSE"]:
|
||||
|
|
@ -111,6 +176,33 @@ class MCPToolsComponent(Component):
|
|||
msg = "URL is required for SSE mode"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _process_headers(self, headers: Any) -> dict:
|
||||
"""Process the headers input into a valid dictionary.
|
||||
|
||||
Args:
|
||||
headers: The headers to process, can be dict, str, or list
|
||||
Returns:
|
||||
Processed dictionary
|
||||
"""
|
||||
if headers is None:
|
||||
return {}
|
||||
if isinstance(headers, dict):
|
||||
return headers
|
||||
if isinstance(headers, list):
|
||||
processed_headers = {}
|
||||
try:
|
||||
for item in headers:
|
||||
if not self._is_valid_key_value_item(item):
|
||||
continue
|
||||
key = item["key"]
|
||||
value = item["value"]
|
||||
processed_headers[key] = value
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
self.log(f"Failed to process headers list: {e}")
|
||||
return {} # Return empty dictionary instead of None
|
||||
return processed_headers
|
||||
return {}
|
||||
|
||||
async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:
|
||||
"""Validate and process schema inputs for a tool."""
|
||||
try:
|
||||
|
|
@ -118,7 +210,8 @@ class MCPToolsComponent(Component):
|
|||
msg = "Invalid tool object or missing input schema"
|
||||
raise ValueError(msg)
|
||||
|
||||
input_schema = create_input_schema_from_json_schema(tool_obj.inputSchema)
|
||||
flat_schema = flatten_schema(tool_obj.inputSchema)
|
||||
input_schema = create_input_schema_from_json_schema(flat_schema)
|
||||
if not input_schema:
|
||||
msg = f"Empty input schema for tool '{tool_obj.name}'"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -143,43 +236,24 @@ class MCPToolsComponent(Component):
|
|||
self.remove_non_default_keys(build_config)
|
||||
if field_value == "Stdio":
|
||||
build_config["command"]["show"] = True
|
||||
build_config["env"]["show"] = True
|
||||
build_config["headers_input"]["show"] = False
|
||||
build_config["sse_url"]["show"] = False
|
||||
elif field_value == "SSE":
|
||||
build_config["command"]["show"] = False
|
||||
build_config["env"]["show"] = False
|
||||
build_config["sse_url"]["show"] = True
|
||||
build_config["sse_url"]["value"] = "MCP_SSE"
|
||||
build_config["headers_input"]["show"] = True
|
||||
return build_config
|
||||
if field_name in ("command", "sse_url", "mode"):
|
||||
try:
|
||||
# If SSE mode and localhost URL is not valid, try to find correct port
|
||||
if build_config["mode"]["value"] == "SSE" and (
|
||||
"localhost" in str(build_config["sse_url"]["value"])
|
||||
or "127.0.0.1" in str(build_config["sse_url"]["value"])
|
||||
):
|
||||
is_valid, _ = await self.sse_client.validate_url(build_config["sse_url"]["value"])
|
||||
if not is_valid:
|
||||
found, port, message = await self.find_langflow_instance()
|
||||
if found:
|
||||
new_url = f"http://localhost:{port}/api/v1/mcp/sse"
|
||||
logger.info(f"Original URL {build_config['sse_url']['value']} not valid. {message}")
|
||||
build_config["sse_url"]["value"] = new_url
|
||||
elif build_config["mode"]["value"] == "SSE":
|
||||
if len(build_config["sse_url"]["value"]) > 0:
|
||||
is_valid, _ = await self.sse_client.validate_url(build_config["sse_url"]["value"])
|
||||
if not is_valid:
|
||||
msg = (
|
||||
f"Invalid SSE URL configuration: {build_config['sse_url']['value']}. "
|
||||
"Please check the SSE URL and try again."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
build_config["tool"]["options"] = []
|
||||
return build_config
|
||||
|
||||
await self.update_tools(
|
||||
mode=build_config["mode"]["value"],
|
||||
command=build_config["command"]["value"],
|
||||
url=build_config["sse_url"]["value"],
|
||||
env=build_config["env"]["value"],
|
||||
headers=build_config["headers_input"]["value"],
|
||||
)
|
||||
if "tool" in build_config:
|
||||
build_config["tool"]["options"] = self.tool_names
|
||||
|
|
@ -195,6 +269,8 @@ class MCPToolsComponent(Component):
|
|||
mode=build_config["mode"]["value"],
|
||||
command=build_config["command"]["value"],
|
||||
url=build_config["sse_url"]["value"],
|
||||
env=build_config["env"]["value"],
|
||||
headers=build_config["headers_input"]["value"],
|
||||
)
|
||||
if self.tool is None:
|
||||
return build_config
|
||||
|
|
@ -229,8 +305,10 @@ class MCPToolsComponent(Component):
|
|||
if not tool or not hasattr(tool, "name"):
|
||||
continue
|
||||
try:
|
||||
input_schema = schema_to_langflow_inputs(create_input_schema_from_json_schema(tool.inputSchema))
|
||||
inputs[tool.name] = input_schema
|
||||
flat_schema = flatten_schema(tool.inputSchema)
|
||||
input_schema = create_input_schema_from_json_schema(flat_schema)
|
||||
langflow_inputs = schema_to_langflow_inputs(input_schema)
|
||||
inputs[tool.name] = langflow_inputs
|
||||
except (AttributeError, ValueError, TypeError, KeyError) as e:
|
||||
msg = f"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}"
|
||||
logger.exception(msg)
|
||||
|
|
@ -262,6 +340,8 @@ class MCPToolsComponent(Component):
|
|||
mode=build_config["mode"]["value"],
|
||||
command=build_config["command"]["value"],
|
||||
url=build_config["sse_url"]["value"],
|
||||
env=build_config["env"]["value"],
|
||||
headers=build_config["headers_input"]["value"],
|
||||
)
|
||||
|
||||
if not tool_name:
|
||||
|
|
@ -302,7 +382,6 @@ class MCPToolsComponent(Component):
|
|||
msg = f"Error processing schema input {schema_input}: {e!s}"
|
||||
logger.exception(msg)
|
||||
continue
|
||||
|
||||
except ValueError as e:
|
||||
msg = f"Schema validation error for tool {tool_name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
|
|
@ -325,7 +404,11 @@ class MCPToolsComponent(Component):
|
|||
value = getattr(self, arg.name, None)
|
||||
if value:
|
||||
kwargs[arg.name] = value
|
||||
output = await exec_tool.coroutine(**kwargs)
|
||||
|
||||
unflattened_kwargs = maybe_unflatten_dict(kwargs)
|
||||
|
||||
output = await exec_tool.coroutine(**unflattened_kwargs)
|
||||
|
||||
return Message(text=output.content[len(output.content) - 1].text)
|
||||
return Message(text="You must select a tool", error=True)
|
||||
except Exception as e:
|
||||
|
|
@ -334,7 +417,12 @@ class MCPToolsComponent(Component):
|
|||
raise ValueError(msg) from e
|
||||
|
||||
async def update_tools(
|
||||
self, mode: str | None = None, command: str | None = None, url: str | None = None
|
||||
self,
|
||||
mode: str | None = None,
|
||||
command: str | None = None,
|
||||
url: str | None = None,
|
||||
env: list[str] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Connect to the MCP server and update available tools with improved error handling."""
|
||||
try:
|
||||
|
|
@ -342,21 +430,21 @@ class MCPToolsComponent(Component):
|
|||
mode = self.mode
|
||||
if command is None:
|
||||
command = self.command
|
||||
if env is None:
|
||||
env = self.env
|
||||
if url is None:
|
||||
url = self.sse_url
|
||||
if headers is None:
|
||||
headers = self.headers_input
|
||||
headers = self._process_headers(headers)
|
||||
await self._validate_connection_params(mode, command, url)
|
||||
|
||||
if mode == "Stdio":
|
||||
if not self.stdio_client.session:
|
||||
self.tools = await self.stdio_client.connect_to_server(command)
|
||||
self.tools = await self.stdio_client.connect_to_server(command, env)
|
||||
elif mode == "SSE" and not self.sse_client.session:
|
||||
try:
|
||||
is_valid, _ = await self.sse_client.validate_url(url)
|
||||
if not is_valid:
|
||||
msg = f"Invalid SSE URL configuration: {url}. Please check the SSE URL and try again."
|
||||
logger.error(msg)
|
||||
return []
|
||||
self.tools = await self.sse_client.connect_to_server(url, {})
|
||||
self.tools = await self.sse_client.connect_to_server(url, headers)
|
||||
except ValueError as e:
|
||||
# URL validation error
|
||||
logger.error(f"SSE URL validation error: {e}")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
from typing import Literal, Union, get_args, get_origin
|
||||
from types import UnionType
|
||||
from typing import Any, Literal, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langflow.inputs.inputs import BoolInput, DictInput, FieldTypes, FloatInput, InputTypes, IntInput, MessageTextInput
|
||||
from langflow.inputs.inputs import (
|
||||
BoolInput,
|
||||
DictInput,
|
||||
DropdownInput,
|
||||
FieldTypes,
|
||||
FloatInput,
|
||||
InputTypes,
|
||||
IntInput,
|
||||
MessageTextInput,
|
||||
)
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
_convert_field_type_to_type: dict[FieldTypes, type] = {
|
||||
|
|
@ -32,53 +42,174 @@ _convert_type_to_field_type = {
|
|||
}
|
||||
|
||||
|
||||
def schema_to_langflow_inputs(schema: type[BaseModel]) -> list["InputTypes"]:
|
||||
"""Given a Pydantic schema, convert its fields to Langflow input definitions."""
|
||||
inputs = []
|
||||
def flatten_schema(root_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Flatten a JSON RPC style schema into a single level JSON Schema.
|
||||
|
||||
If the input schema is already flat (no $defs / $ref / nested objects or arrays)
|
||||
the function simply returns the original i.e. a noop.
|
||||
"""
|
||||
defs = root_schema.get("$defs", {})
|
||||
|
||||
# --- Fast path: schema is already flat ---------------------------------
|
||||
props = root_schema.get("properties", {})
|
||||
if not defs and all("$ref" not in v and v.get("type") not in ("object", "array") for v in props.values()):
|
||||
return root_schema
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
flat_props: dict[str, dict[str, Any]] = {}
|
||||
required_list: list[str] = []
|
||||
|
||||
def _resolve_if_ref(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
while "$ref" in schema:
|
||||
ref_name = schema["$ref"].split("/")[-1]
|
||||
schema = defs.get(ref_name, {})
|
||||
return schema
|
||||
|
||||
def _walk(name: str, schema: dict[str, Any], *, inherited_req: bool) -> None:
|
||||
schema = _resolve_if_ref(schema)
|
||||
t = schema.get("type")
|
||||
|
||||
# ── objects ─────────────────────────────────────────────────────────
|
||||
if t == "object":
|
||||
req_here = set(schema.get("required", []))
|
||||
for k, subschema in schema.get("properties", {}).items():
|
||||
child_name = f"{name}.{k}" if name else k
|
||||
_walk(name=child_name, schema=subschema, inherited_req=inherited_req and k in req_here)
|
||||
return
|
||||
|
||||
# ── arrays (always recurse into the first item as “[0]”) ───────────
|
||||
if t == "array":
|
||||
items = schema.get("items", {})
|
||||
_walk(name=f"{name}[0]", schema=items, inherited_req=inherited_req)
|
||||
return
|
||||
|
||||
leaf: dict[str, Any] = {
|
||||
k: v
|
||||
for k, v in schema.items()
|
||||
if k
|
||||
in (
|
||||
"type",
|
||||
"description",
|
||||
"pattern",
|
||||
"format",
|
||||
"enum",
|
||||
"default",
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"exclusiveMinimum",
|
||||
"exclusiveMaximum",
|
||||
"additionalProperties",
|
||||
"examples",
|
||||
)
|
||||
}
|
||||
flat_props[name] = leaf
|
||||
if inherited_req:
|
||||
required_list.append(name)
|
||||
|
||||
# kick things off at the true root
|
||||
root_required = set(root_schema.get("required", []))
|
||||
for k, subschema in props.items():
|
||||
_walk(k, subschema, inherited_req=k in root_required)
|
||||
|
||||
# build the flattened schema; keep any descriptive metadata
|
||||
result: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": flat_props,
|
||||
**{k: v for k, v in root_schema.items() if k not in ("properties", "$defs")},
|
||||
}
|
||||
if required_list:
|
||||
result["required"] = required_list
|
||||
return result
|
||||
|
||||
|
||||
def schema_to_langflow_inputs(schema: type[BaseModel]) -> list[InputTypes]:
|
||||
inputs: list[InputTypes] = []
|
||||
|
||||
for field_name, model_field in schema.model_fields.items():
|
||||
# Start with the field's annotation type
|
||||
field_type = model_field.annotation
|
||||
ann = model_field.annotation
|
||||
if isinstance(ann, UnionType):
|
||||
# Extract non-None types from Union
|
||||
non_none_types = [t for t in get_args(ann) if t is not type(None)]
|
||||
if len(non_none_types) == 1:
|
||||
ann = non_none_types[0]
|
||||
|
||||
is_list = False
|
||||
options = None
|
||||
|
||||
# If the field is a list, record that and extract its inner type.
|
||||
if get_origin(field_type) is list:
|
||||
if get_origin(ann) is list:
|
||||
is_list = True
|
||||
field_type = get_args(field_type)[0]
|
||||
ann = get_args(ann)[0]
|
||||
|
||||
# If the field type is a Literal, extract its allowed values.
|
||||
if get_origin(field_type) is Literal:
|
||||
options = list(get_args(field_type))
|
||||
# Optionally, set field_type to the type of the literal values.
|
||||
options: list[Any] | None = None
|
||||
if get_origin(ann) is Literal:
|
||||
options = list(get_args(ann))
|
||||
if options:
|
||||
field_type = type(options[0])
|
||||
ann = type(options[0])
|
||||
|
||||
# Handle Union types (e.g., Optional fields)
|
||||
if get_origin(field_type) is Union:
|
||||
# Get the first non-None type from the Union
|
||||
field_type = next(t for t in get_args(field_type) if t is not type(None))
|
||||
if get_origin(ann) is Union:
|
||||
non_none = [t for t in get_args(ann) if t is not type(None)]
|
||||
if len(non_none) == 1:
|
||||
ann = non_none[0]
|
||||
|
||||
# Convert the Python type to the Langflow field type using our reverse mapping.
|
||||
# 1) Nested Pydantic model?
|
||||
# if isinstance(ann, type) and issubclass(ann, BaseModel):
|
||||
# nested = schema_to_langflow_inputs(ann)
|
||||
# inputs.append(
|
||||
# ObjectInput(
|
||||
# display_name=model_field.title or field_name.replace("_", " ").title(),
|
||||
# name=field_name,
|
||||
# info=model_field.description or "",
|
||||
# required=model_field.is_required(),
|
||||
# is_list=is_list,
|
||||
# inputs=nested,
|
||||
# )
|
||||
# )
|
||||
# continue
|
||||
|
||||
# 2) Enumerated choices
|
||||
if options is not None:
|
||||
inputs.append(
|
||||
DropdownInput(
|
||||
display_name=model_field.title or field_name.replace("_", " ").title(),
|
||||
name=field_name,
|
||||
info=model_field.description or "",
|
||||
required=model_field.is_required(),
|
||||
is_list=is_list,
|
||||
options=options,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# 3) “Any” fallback → text
|
||||
if ann is Any:
|
||||
inputs.append(
|
||||
MessageTextInput(
|
||||
display_name=model_field.title or field_name.replace("_", " ").title(),
|
||||
name=field_name,
|
||||
info=model_field.description or "",
|
||||
required=model_field.is_required(),
|
||||
is_list=is_list,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# 4) Primitive via your mapping
|
||||
try:
|
||||
langflow_field_type = _convert_type_to_field_type[field_type]
|
||||
except KeyError as e:
|
||||
msg = f"Unsupported field type: {field_type}"
|
||||
raise TypeError(msg) from e
|
||||
|
||||
# Get metadata from the Pydantic Field.
|
||||
title = model_field.title or field_name.replace("_", " ").title()
|
||||
description = model_field.description or ""
|
||||
required = model_field.is_required()
|
||||
|
||||
# Construct the Langflow input.
|
||||
input_obj = langflow_field_type(
|
||||
display_name=title,
|
||||
name=field_name,
|
||||
info=description,
|
||||
required=required,
|
||||
is_list=is_list,
|
||||
lf_cls = _convert_type_to_field_type[ann]
|
||||
except KeyError as err:
|
||||
msg = f"Unsupported field type: {ann}"
|
||||
raise TypeError(msg) from err
|
||||
inputs.append(
|
||||
lf_cls(
|
||||
display_name=model_field.title or field_name.replace("_", " ").title(),
|
||||
name=field_name,
|
||||
info=model_field.description or "",
|
||||
required=model_field.is_required(),
|
||||
is_list=is_list,
|
||||
)
|
||||
)
|
||||
inputs.append(input_obj)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,8 @@ class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
|||
"sse_url": {"show": True, "value": "http://localhost:7860/api/v1/mcp/sse"},
|
||||
"tool": {"options": [], "show": True},
|
||||
"mode": {"value": "Stdio"},
|
||||
"env": {"show": True, "value": []},
|
||||
"headers_input": {"show": False, "value": []},
|
||||
}
|
||||
|
||||
# Test switching to Stdio mode
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue