feat: add anthropic mcp endpoint (#5148)

* mcp WIP

* [autofix.ci] apply automated fixes

* logging and flow user check

* mcp stdio client component

* handle disconnect better

* initialization

* session fix and type fix

* [autofix.ci] apply automated fixes

* defensive against mcp server bugs

* [autofix.ci] apply automated fixes

* notifications and sse component

* enabled flags and resource support

* remove unneeded print

* extract json schema util

* [autofix.ci] apply automated fixes

* ruff

* fix tools [] bug and db asysnc session api change

* Tool instead of StructuredTool

* ruff fixes

* ruff

* validation optimization

* fix frontend test

* another playwright fix

* Update src/frontend/tests/extended/features/notifications.spec.ts

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* mcp component descriptions

* mypy fixes

* fix setup_database_url test

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Sebastián Estévez 2025-01-02 23:45:54 -05:00 committed by GitHub
commit 63d649b0f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1229 additions and 525 deletions

View file

@ -113,8 +113,10 @@ dependencies = [
"aiofile>=3.9.0,<4.0.0",
"sseclient-py==1.8.0",
"arize-phoenix-otel>=0.6.1",
"openinference-instrumentation-langchain==0.1.29",
"openinference-instrumentation-langchain>=0.1.29",
"crewai~=0.86.0",
"mcp>=0.9.1",
"uv>=0.5.7",
"ag2",
"pydantic-ai>=0.0.12",
]

View file

@ -9,6 +9,7 @@ from langflow.api.v1 import (
flows_router,
folders_router,
login_router,
mcp_router,
monitor_router,
starter_projects_router,
store_router,
@ -16,6 +17,7 @@ from langflow.api.v1 import (
validate_router,
variables_router,
)
from langflow.services.deps import get_settings_service
router = APIRouter(
prefix="/api/v1",
@ -33,3 +35,6 @@ router.include_router(files_router)
router.include_router(monitor_router)
router.include_router(folders_router)
router.include_router(starter_projects_router)
if get_settings_service().settings.mcp_server_enabled:
router.include_router(mcp_router)

View file

@ -5,6 +5,7 @@ from langflow.api.v1.files import router as files_router
from langflow.api.v1.flows import router as flows_router
from langflow.api.v1.folders import router as folders_router
from langflow.api.v1.login import router as login_router
from langflow.api.v1.mcp import router as mcp_router
from langflow.api.v1.monitor import router as monitor_router
from langflow.api.v1.starter_projects import router as starter_projects_router
from langflow.api.v1.store import router as store_router
@ -20,6 +21,7 @@ __all__ = [
"flows_router",
"folders_router",
"login_router",
"mcp_router",
"monitor_router",
"starter_projects_router",
"store_router",

View file

@ -0,0 +1,343 @@
import asyncio
import base64
import json
import logging
import traceback
from contextlib import suppress
from contextvars import ContextVar
from typing import Annotated
from urllib.parse import quote, unquote, urlparse
from uuid import UUID, uuid4
import pydantic
from anyio import BrokenResourceError
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from mcp import types
from mcp.server import NotificationOptions, Server
from mcp.server.sse import SseServerTransport
from sqlmodel import select
from starlette.background import BackgroundTasks
from langflow.api.v1.chat import build_flow
from langflow.api.v1.schemas import InputValueRequest
from langflow.helpers.flow import json_schema_from_flow
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models import Flow, User
from langflow.services.deps import get_db_service, get_session, get_settings_service, get_storage_service
from langflow.services.storage.utils import build_content_type_from_extension
logger = logging.getLogger(__name__)
if False:
logger.setLevel(logging.DEBUG)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
# Enable debug logging for MCP package
mcp_logger = logging.getLogger("mcp")
mcp_logger.setLevel(logging.DEBUG)
if not mcp_logger.handlers:
mcp_logger.addHandler(handler)
logger.debug("MCP module loaded - debug logging enabled")
enable_progress_notifications = get_settings_service().settings.mcp_server_enable_progress_notifications
router = APIRouter(prefix="/mcp", tags=["mcp"])
server = Server("langflow-mcp-server")
# Create a context variable to store the current user
current_user_ctx: ContextVar[User] = ContextVar("current_user_ctx")
# Define constants
MAX_RETRIES = 2
@server.list_prompts()
async def handle_list_prompts():
return []
@server.list_resources()
async def handle_list_resources():
resources = []
try:
session = await anext(get_session())
storage_service = get_storage_service()
settings_service = get_settings_service()
# Build full URL from settings
host = getattr(settings_service.settings, "holst", "localhost")
port = getattr(settings_service.settings, "port", 3000)
base_url = f"http://{host}:{port}".rstrip("/")
flows = (await session.exec(select(Flow))).all()
for flow in flows:
if flow.id:
try:
files = await storage_service.list_files(flow_id=str(flow.id))
for file_name in files:
# URL encode the filename
safe_filename = quote(file_name)
resource = types.Resource(
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
name=file_name,
description=f"File in flow: {flow.name}",
mimeType=build_content_type_from_extension(file_name),
)
resources.append(resource)
except FileNotFoundError as e:
msg = f"Error listing files for flow {flow.id}: {e}"
logger.debug(msg)
continue
except Exception as e:
msg = f"Error in listing resources: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
return resources
@server.read_resource()
async def handle_read_resource(uri: str) -> bytes:
"""Handle resource read requests."""
try:
# Parse the URI properly
parsed_uri = urlparse(str(uri))
# Path will be like /api/v1/files/{flow_id}/{filename}
path_parts = parsed_uri.path.split("/")
# Remove empty strings from split
path_parts = [p for p in path_parts if p]
# The flow_id and filename should be the last two parts
two = 2
if len(path_parts) < two:
msg = f"Invalid URI format: {uri}"
raise ValueError(msg)
flow_id = path_parts[-2]
filename = unquote(path_parts[-1]) # URL decode the filename
storage_service = get_storage_service()
# Read the file content
content = await storage_service.get_file(flow_id=flow_id, file_name=filename)
if not content:
msg = f"File {filename} not found in flow {flow_id}"
raise ValueError(msg)
# Ensure content is base64 encoded
if isinstance(content, str):
content = content.encode()
return base64.b64encode(content)
except Exception as e:
msg = f"Error reading resource {uri}: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
@server.list_tools()
async def handle_list_tools():
tools = []
try:
session = await anext(get_session())
flows = (await session.exec(select(Flow))).all()
for flow in flows:
if flow.user_id is None:
continue
tool = types.Tool(
name=str(flow.id), # Use flow.id instead of name
description=f"{flow.name}: {flow.description}"
if flow.description
else f"Tool generated from flow: {flow.name}",
inputSchema=json_schema_from_flow(flow),
)
tools.append(tool)
except Exception as e:
msg = f"Error in listing tools: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
return tools
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
"""Handle tool execution requests."""
try:
session = await anext(get_session())
background_tasks = BackgroundTasks()
current_user = current_user_ctx.get()
flow = (await session.exec(select(Flow).where(Flow.id == UUID(name)))).first()
if not flow:
msg = f"Flow with id '{name}' not found"
raise ValueError(msg)
# Process inputs
processed_inputs = dict(arguments)
# Initial progress notification
if enable_progress_notifications and (progress_token := server.request_context.meta.progressToken):
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=0.0, total=1.0
)
conversation_id = str(uuid4())
input_request = InputValueRequest(
input_value=processed_inputs.get("input_value", ""), components=[], type="chat", session=conversation_id
)
async def send_progress_updates():
if not (enable_progress_notifications and server.request_context.meta.progressToken):
return
try:
progress = 0.0
while True:
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=min(0.9, progress), total=1.0
)
progress += 0.1
await asyncio.sleep(1.0)
except asyncio.CancelledError:
# Send final 100% progress
if enable_progress_notifications:
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=1.0, total=1.0
)
raise
db_service = get_db_service()
collected_results = []
async with db_service.with_session() as async_session:
try:
progress_task = asyncio.create_task(send_progress_updates())
try:
response = await build_flow(
flow_id=UUID(name),
inputs=input_request,
background_tasks=background_tasks,
current_user=current_user,
session=async_session,
)
async for line in response.body_iterator:
if not line:
continue
try:
event_data = json.loads(line)
if event_data.get("event") == "end_vertex":
message = (
event_data.get("data", {})
.get("build_data", {})
.get("data", {})
.get("results", {})
.get("message", {})
.get("text", "")
)
if message:
collected_results.append(types.TextContent(type="text", text=str(message)))
except json.JSONDecodeError:
msg = f"Failed to parse event data: {line}"
logger.warning(msg)
continue
return collected_results
finally:
progress_task.cancel()
with suppress(asyncio.CancelledError):
await progress_task
except Exception as e:
msg = f"Error in async session: {e}"
logger.exception(msg)
raise
except Exception as e:
context = server.request_context
# Send error progress if there's an exception
if enable_progress_notifications and (progress_token := context.meta.progressToken):
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=1.0, total=1.0
)
msg = f"Error executing tool {name}: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
sse = SseServerTransport("/api/v1/mcp/")
def find_validation_error(exc):
"""Searches for a pydantic.ValidationError in the exception chain."""
while exc:
if isinstance(exc, pydantic.ValidationError):
return exc
exc = getattr(exc, "__cause__", None) or getattr(exc, "__context__", None)
return None
@router.get("/sse", response_class=StreamingResponse)
async def handle_sse(request: Request, current_user: Annotated[User, Depends(get_current_active_user)]):
token = current_user_ctx.set(current_user)
try:
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
try:
msg = "Starting SSE connection"
logger.debug(msg)
msg = f"Stream types: read={type(streams[0])}, write={type(streams[1])}"
logger.debug(msg)
notification_options = NotificationOptions(
prompts_changed=True, resources_changed=True, tools_changed=True
)
init_options = server.create_initialization_options(notification_options)
msg = f"Initialization options: {init_options}"
logger.debug(msg)
try:
await server.run(streams[0], streams[1], init_options)
except Exception as exc: # noqa: BLE001
validation_error = find_validation_error(exc)
if validation_error:
msg = "Validation error in MCP:" + str(validation_error)
logger.debug(msg)
else:
msg = f"Error in MCP: {exc!s}"
logger.debug(msg)
return
except BrokenResourceError:
# Handle gracefully when client disconnects
logger.info("Client disconnected from SSE connection")
except asyncio.CancelledError:
logger.info("SSE connection was cancelled")
except Exception as e:
msg = f"Error in MCP: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
finally:
current_user_ctx.reset(token)
@router.post("/")
async def handle_messages(request: Request):
await sse.handle_post_message(request.scope, request.receive, request._send)

View file

@ -0,0 +1,26 @@
import asyncio
from collections.abc import Awaitable, Callable
from langflow.helpers.base_model import BaseModel
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[[dict], Awaitable]:
async def tool_coroutine(*args):
if len(args) == 0:
msg = f"at least one positional argument is required {args}"
raise ValueError(msg)
arg_dict = dict(zip(arg_schema.model_fields.keys(), args, strict=False))
return await session.call_tool(tool_name, arguments=arg_dict)
return tool_coroutine
def create_tool_func(tool_name: str, session) -> Callable[..., str]:
def tool_func(**kwargs):
if len(kwargs) == 0:
msg = f"at least one named argument is required {kwargs}"
raise ValueError(msg)
loop = asyncio.get_event_loop()
return loop.run_until_complete(session.call_tool(tool_name, arguments=kwargs))
return tool_func

View file

@ -9,6 +9,7 @@ from .exa_search import ExaSearchToolkit
from .glean_search_api import GleanSearchAPIComponent
from .google_search_api import GoogleSearchAPIComponent
from .google_serper_api import GoogleSerperAPIComponent
from .mcp_stdio import MCPStdio
from .python_code_structured_tool import PythonCodeStructuredTool
from .python_repl import PythonREPLToolComponent
from .search_api import SearchAPIComponent
@ -36,6 +37,7 @@ __all__ = [
"GleanSearchAPIComponent",
"GoogleSearchAPIComponent",
"GoogleSerperAPIComponent",
"MCPStdio",
"PythonCodeStructuredTool",
"PythonREPLToolComponent",
"SearXNGToolComponent",

View file

@ -0,0 +1,98 @@
# from langflow.field_typing import Data
import asyncio
from contextlib import AsyncExitStack
import httpx
from mcp import ClientSession, types
from mcp.client.sse import sse_client
from langflow.base.mcp.util import create_tool_coroutine, create_tool_func
from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema
from langflow.custom import Component
from langflow.field_typing import Tool
from langflow.io import MessageTextInput, Output
# Define constant for status code
HTTP_TEMPORARY_REDIRECT = 307
class MCPSseClient:
def __init__(self):
# Initialize session and client objects
self.write = None
self.sse = None
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def pre_check_redirect(self, url: str):
"""Check if the URL responds with a 307 Redirect."""
async with httpx.AsyncClient(follow_redirects=False) as client:
response = await client.request("HEAD", url)
if response.status_code == HTTP_TEMPORARY_REDIRECT:
return response.headers.get("Location") # Return the redirect URL
return url # Return the original URL if no redirect
async def connect_to_server(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
):
if headers is None:
headers = {}
url = await self.pre_check_redirect(url)
async with asyncio.timeout(timeout_seconds):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
)
self.sse, self.write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
await self.session.initialize()
# List available tools
response = await self.session.list_tools()
return response.tools
class MCPSse(Component):
client = MCPSseClient()
tools = types.ListToolsResult
tool_names = [str]
display_name = "MCP Tools (SSE)"
description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent."
documentation: str = "http://docs.langflow.org/components/custom"
icon = "code"
name = "MCPSse"
inputs = [
MessageTextInput(
name="url",
display_name="mcp sse url",
info="sse url",
value="http://localhost:7860/api/v1/mcp/sse",
tool_mode=True,
),
]
outputs = [
Output(display_name="Tools", name="tools", method="build_output"),
]
async def build_output(self) -> list[Tool]:
if self.client.session is None:
self.tools = await self.client.connect_to_server(self.url, {})
tool_list = []
for tool in self.tools:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
tool_list.append(
Tool(
name=tool.name, # maybe format this
description=tool.description,
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
func=create_tool_func(tool.name, self.client.session),
)
)
self.tool_names = [tool.name for tool in self.tools]
return tool_list

View file

@ -0,0 +1,122 @@
# from langflow.field_typing import Data
import os
from contextlib import AsyncExitStack
from typing import Any
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
from pydantic import BaseModel, Field, create_model
from langflow.base.mcp.util import create_tool_coroutine, create_tool_func
from langflow.custom import Component
from langflow.field_typing import Tool
from langflow.io import MessageTextInput, Output
class MCPStdioClient:
def __init__(self):
# Initialize session and client objects
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def connect_to_server(self, command_str: str):
command = command_str.split(" ")
server_params = StdioServerParameters(
command=command[0], args=command[1:], env={"DEBUG": "true", "PATH": os.environ["PATH"]}
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
# List available tools
response = await self.session.list_tools()
return response.tools
def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]:
"""Converts a JSON schema into a Pydantic model dynamically.
:param schema: The JSON schema as a dictionary.
:return: A Pydantic model class.
"""
if schema.get("type") != "object":
msg = "JSON schema must be of type 'object' at the root level."
raise ValueError(msg)
fields = {}
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
for field_name, field_def in properties.items():
# Extract type
field_type_str = field_def.get("type", "str") # Default to string type if not specified
field_type = {
"string": str,
"str": str,
"integer": int,
"int": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}.get(field_type_str, Any)
# Extract description and default if present
field_metadata = {"description": field_def.get("description", "")}
if field_name not in required_fields:
field_metadata["default"] = field_def.get("default", None)
# Create Pydantic field
fields[field_name] = (field_type, Field(**field_metadata))
# Dynamically create the model
return create_model("InputSchema", **fields)
class MCPStdio(Component):
client = MCPStdioClient()
tools = types.ListToolsResult
tool_names = [str]
display_name = "MCP Tools (stdio)"
description = (
"Connects to an MCP server over stdio and exposes it's tools as langflow tools to be used by an Agent."
)
documentation: str = "http://docs.langflow.org/components/custom"
icon = "code"
name = "MCPStdio"
inputs = [
MessageTextInput(
name="command",
display_name="mcp command",
info="mcp command",
value="uvx mcp-sse-shim@latest",
tool_mode=True,
),
]
outputs = [
Output(display_name="Tools", name="tools", method="build_output"),
]
async def build_output(self) -> list[Tool]:
if self.client.session is None:
self.tools = await self.client.connect_to_server(self.command)
tool_list = []
for tool in self.tools:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
tool_list.append(
Tool(
name=tool.name,
description=tool.description,
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
func=create_tool_func(tool.name, args_schema),
)
)
self.tool_names = [tool.name for tool in self.tools]
return tool_list

View file

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from fastapi import HTTPException
from loguru import logger
from pydantic.v1 import BaseModel, Field, create_model
from sqlmodel import select
@ -314,3 +315,46 @@ async def generate_unique_flow_name(flow_name, user_id, session):
# If a flow with the name already exists, append (n) to the name and increment n
flow_name = f"{original_name} ({n})"
n += 1
def json_schema_from_flow(flow: Flow) -> dict:
"""Generate JSON schema from flow input nodes."""
from langflow.graph.graph.base import Graph
# Get the flow's data which contains the nodes and their configurations
flow_data = flow.data if flow.data else {}
graph = Graph.from_payload(flow_data)
input_nodes = [vertex for vertex in graph.vertices if vertex.is_input]
properties = {}
required = []
for node in input_nodes:
node_data = node.data["node"]
template = node_data["template"]
for field_name, field_data in template.items():
if field_data != "Component" and field_data.get("show", False) and not field_data.get("advanced", False):
field_type = field_data.get("type", "string")
properties[field_name] = {
"type": field_type,
"description": field_data.get("info", f"Input for {field_name}"),
}
# Update field_type in properties after determining the JSON Schema type
if field_type == "str":
field_type = "string"
elif field_type == "int":
field_type = "integer"
elif field_type == "float":
field_type = "number"
elif field_type == "bool":
field_type = "boolean"
else:
logger.warning(f"Unknown field type: {field_type} defaulting to string")
field_type = "string"
properties[field_name]["type"] = field_type
if field_data.get("required", False):
required.append(field_name)
return {"type": "object", "properties": properties, "required": required}

View file

@ -181,6 +181,12 @@ class Settings(BaseSettings):
max_vertex_builds_to_keep: int = 3000
"""The maximum number of vertex builds to keep in the database."""
# MCP Server
mcp_server_enabled: bool = True
"""If set to False, Langflow will not enable the MCP server."""
mcp_server_enable_progress_notifications: bool = False
"""If set to False, Langflow will not send progress notifications in the MCP server."""
@field_validator("dev")
@classmethod
def set_dev(cls, value):

View file

@ -1,3 +1,5 @@
import os
import pytest
from langflow.services.deps import get_settings_service
@ -5,14 +7,18 @@ from langflow.services.deps import get_settings_service
@pytest.fixture(autouse=True)
def setup_database_url(tmp_path, monkeypatch):
"""Setup a temporary database URL for testing."""
settings_service = get_settings_service()
db_path = tmp_path / "test_performance.db"
original_value = monkeypatch.delenv("LANGFLOW_DATABASE_URL", raising=False)
original_value = os.getenv("LANGFLOW_DATABASE_URL")
monkeypatch.delenv("LANGFLOW_DATABASE_URL", raising=False)
test_db_url = f"sqlite:///{db_path}"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", test_db_url)
settings_service.set("database_url", test_db_url)
yield
# Restore original value if it existed
if original_value is not None:
monkeypatch.setenv("LANGFLOW_DATABASE_URL", original_value)
settings_service.set("database_url", original_value)
else:
monkeypatch.delenv("LANGFLOW_DATABASE_URL", raising=False)

View file

@ -90,7 +90,7 @@ test(
}
await page.locator(".react-flow__pane").click();
await adjustScreenView(page, { numberOfZoomOut: 1 });
await visibleElementHandle.hover();
await page.mouse.down();

1092
uv.lock generated

File diff suppressed because it is too large Load diff