ref: refactor MCP-related functionality centralizing common utilities (#9059)
* reactor to have common mcp codes in mcp_support * [autofix.ci] apply automated fixes * Refactor MCP API argument passing and function signatures Updated function calls in mcp_projects.py to use explicit keyword arguments for clarity. Refactored mcp_support.py to use more concise query assignment and added keyword-only arguments to handle_call_tool and handle_list_tools for improved code readability and maintainability. * Rename mcp_support.py to mcp_utils.py and update imports Renamed mcp_support.py to mcp_utils.py for clarity and updated all relevant import statements in mcp.py and mcp_projects.py to reflect the new module name. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
607cadfe17
commit
e20ae03c38
3 changed files with 395 additions and 518 deletions
|
|
@ -1,11 +1,4 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
from urllib.parse import quote, unquote, urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import pydantic
|
||||
from anyio import BrokenResourceError
|
||||
|
|
@ -15,70 +8,22 @@ from loguru import logger
|
|||
from mcp import types
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import CurrentActiveMCPUser
|
||||
from langflow.api.v1.endpoints import simple_run_flow
|
||||
from langflow.api.v1.schemas import SimplifiedAPIRequest
|
||||
from langflow.base.mcp.constants import MAX_MCP_TOOL_NAME_LENGTH
|
||||
from langflow.base.mcp.util import get_flow_snake_case, sanitize_mcp_name
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import (
|
||||
get_db_service,
|
||||
get_settings_service,
|
||||
get_storage_service,
|
||||
session_scope,
|
||||
from langflow.api.v1.mcp_utils import (
|
||||
current_user_ctx,
|
||||
handle_call_tool,
|
||||
handle_list_resources,
|
||||
handle_list_tools,
|
||||
handle_mcp_errors,
|
||||
handle_read_resource,
|
||||
)
|
||||
from langflow.services.storage.utils import build_content_type_from_extension
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""Decorator to handle MCP endpoint errors consistently."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T:
|
||||
"""Execute an operation within a database session context."""
|
||||
async with session_scope() as session:
|
||||
return await operation(session)
|
||||
|
||||
|
||||
class MCPConfig:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.enable_progress_notifications = None
|
||||
return cls._instance
|
||||
|
||||
|
||||
def get_mcp_config():
|
||||
return MCPConfig()
|
||||
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
router = APIRouter(prefix="/mcp", tags=["mcp"])
|
||||
|
||||
server = Server("langflow-mcp-server")
|
||||
|
||||
# Create a context variable to store the current user
|
||||
current_user_ctx: ContextVar[User] = ContextVar("current_user_ctx")
|
||||
|
||||
# Define constants
|
||||
MAX_RETRIES = 2
|
||||
|
|
@ -94,237 +39,28 @@ async def handle_list_prompts():
|
|||
|
||||
|
||||
@server.list_resources()
|
||||
async def handle_list_resources():
|
||||
resources = []
|
||||
try:
|
||||
db_service = get_db_service()
|
||||
storage_service = get_storage_service()
|
||||
settings_service = get_settings_service()
|
||||
|
||||
# Build full URL from settings
|
||||
host = getattr(settings_service.settings, "host", "localhost")
|
||||
port = getattr(settings_service.settings, "port", 3000)
|
||||
|
||||
base_url = f"http://{host}:{port}".rstrip("/")
|
||||
|
||||
async with db_service.with_session() as session:
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.id:
|
||||
try:
|
||||
files = await storage_service.list_files(flow_id=str(flow.id))
|
||||
for file_name in files:
|
||||
# URL encode the filename
|
||||
safe_filename = quote(file_name)
|
||||
resource = types.Resource(
|
||||
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
|
||||
name=file_name,
|
||||
description=f"File in flow: {flow.name}",
|
||||
mimeType=build_content_type_from_extension(file_name),
|
||||
)
|
||||
resources.append(resource)
|
||||
except FileNotFoundError as e:
|
||||
msg = f"Error listing files for flow {flow.id}: {e}"
|
||||
logger.debug(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing resources: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
return resources
|
||||
async def handle_global_resources():
|
||||
"""Handle listing resources for global MCP server."""
|
||||
return await handle_list_resources()
|
||||
|
||||
|
||||
@server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> bytes:
|
||||
"""Handle resource read requests."""
|
||||
try:
|
||||
# Parse the URI properly
|
||||
parsed_uri = urlparse(str(uri))
|
||||
# Path will be like /api/v1/files/{flow_id}/{filename}
|
||||
path_parts = parsed_uri.path.split("/")
|
||||
# Remove empty strings from split
|
||||
path_parts = [p for p in path_parts if p]
|
||||
|
||||
# The flow_id and filename should be the last two parts
|
||||
two = 2
|
||||
if len(path_parts) < two:
|
||||
msg = f"Invalid URI format: {uri}"
|
||||
raise ValueError(msg)
|
||||
|
||||
flow_id = path_parts[-2]
|
||||
filename = unquote(path_parts[-1]) # URL decode the filename
|
||||
|
||||
storage_service = get_storage_service()
|
||||
|
||||
# Read the file content
|
||||
content = await storage_service.get_file(flow_id=flow_id, file_name=filename)
|
||||
if not content:
|
||||
msg = f"File {filename} not found in flow {flow_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Ensure content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode()
|
||||
return base64.b64encode(content)
|
||||
except Exception as e:
|
||||
msg = f"Error reading resource {uri}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
async def handle_global_read_resource(uri: str) -> bytes:
|
||||
"""Handle resource read requests for global MCP server."""
|
||||
return await handle_read_resource(uri)
|
||||
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools():
|
||||
tools = []
|
||||
try:
|
||||
db_service = get_db_service()
|
||||
async with db_service.with_session() as session:
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
existing_names = set()
|
||||
for flow in flows:
|
||||
if flow.user_id is None:
|
||||
continue
|
||||
|
||||
base_name = sanitize_mcp_name(flow.name)
|
||||
name = base_name[:MAX_MCP_TOOL_NAME_LENGTH]
|
||||
if name in existing_names:
|
||||
i = 1
|
||||
while True:
|
||||
suffix = f"_{i}"
|
||||
truncated_base = base_name[: MAX_MCP_TOOL_NAME_LENGTH - len(suffix)]
|
||||
candidate = f"{truncated_base}{suffix}"
|
||||
if candidate not in existing_names:
|
||||
name = candidate
|
||||
break
|
||||
i += 1
|
||||
try:
|
||||
tool = types.Tool(
|
||||
name=name,
|
||||
description=f"{flow.id}: {flow.description}"
|
||||
if flow.description
|
||||
else f"Tool generated from flow: {name}",
|
||||
inputSchema=json_schema_from_flow(flow),
|
||||
)
|
||||
tools.append(tool)
|
||||
existing_names.add(name)
|
||||
except Exception as e: # noqa: BLE001
|
||||
msg = f"Error in listing tools: {e!s} from flow: {base_name}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing tools: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
return tools
|
||||
async def handle_global_tools():
|
||||
"""Handle listing tools for global MCP server."""
|
||||
return await handle_list_tools()
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
@handle_mcp_errors
|
||||
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests."""
|
||||
mcp_config = get_mcp_config()
|
||||
if mcp_config.enable_progress_notifications is None:
|
||||
settings_service = get_settings_service()
|
||||
mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications
|
||||
|
||||
current_user = current_user_ctx.get()
|
||||
|
||||
async def execute_tool(session):
|
||||
# get flow id from name
|
||||
flow = await get_flow_snake_case(name, current_user.id, session)
|
||||
if not flow:
|
||||
msg = f"Flow with name '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
||||
# Initial progress notification
|
||||
if mcp_config.enable_progress_notifications and (progress_token := server.request_context.meta.progressToken):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=0.0, total=1.0
|
||||
)
|
||||
|
||||
conversation_id = str(uuid4())
|
||||
input_request = SimplifiedAPIRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), session_id=conversation_id
|
||||
)
|
||||
|
||||
async def send_progress_updates(progress_token):
|
||||
try:
|
||||
progress = 0.0
|
||||
while True:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=min(0.9, progress), total=1.0
|
||||
)
|
||||
progress += 0.1
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
if mcp_config.enable_progress_notifications:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
collected_results = []
|
||||
try:
|
||||
progress_task = None
|
||||
if mcp_config.enable_progress_notifications and server.request_context.meta.progressToken:
|
||||
progress_task = asyncio.create_task(send_progress_updates(server.request_context.meta.progressToken))
|
||||
|
||||
try:
|
||||
try:
|
||||
result = await simple_run_flow(
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
stream=False,
|
||||
api_key_user=current_user,
|
||||
)
|
||||
# Process all outputs and messages, ensuring no duplicates
|
||||
processed_texts = set()
|
||||
|
||||
def add_result(text: str):
|
||||
if text not in processed_texts:
|
||||
processed_texts.add(text)
|
||||
collected_results.append(types.TextContent(type="text", text=text))
|
||||
|
||||
for run_output in result.outputs:
|
||||
for component_output in run_output.outputs:
|
||||
# Handle messages
|
||||
for msg in component_output.messages or []:
|
||||
add_result(msg.message)
|
||||
# Handle results
|
||||
for value in (component_output.results or {}).values():
|
||||
if isinstance(value, Message):
|
||||
add_result(value.get_text())
|
||||
else:
|
||||
add_result(str(value))
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}"
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
await asyncio.gather(progress_task, return_exceptions=True)
|
||||
|
||||
except Exception:
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
progress_token := server.request_context.meta.progressToken
|
||||
):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
return await with_db_session(execute_tool)
|
||||
except Exception as e:
|
||||
msg = f"Error executing tool {name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
async def handle_global_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests for global MCP server."""
|
||||
return await handle_call_tool(name, arguments, server)
|
||||
|
||||
|
||||
sse = SseServerTransport("/api/v1/mcp/")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -10,8 +9,7 @@ from datetime import datetime, timezone
|
|||
from ipaddress import ip_address
|
||||
from pathlib import Path
|
||||
from subprocess import CalledProcessError
|
||||
from urllib.parse import quote, unquote, urlparse
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
|
|
@ -23,27 +21,25 @@ from sqlalchemy.orm import selectinload
|
|||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import CurrentActiveMCPUser
|
||||
from langflow.api.v1.endpoints import simple_run_flow
|
||||
from langflow.api.v1.mcp import (
|
||||
from langflow.api.v1.mcp_utils import (
|
||||
current_user_ctx,
|
||||
get_mcp_config,
|
||||
handle_call_tool,
|
||||
handle_list_resources,
|
||||
handle_list_tools,
|
||||
handle_mcp_errors,
|
||||
with_db_session,
|
||||
handle_read_resource,
|
||||
)
|
||||
from langflow.api.v1.schemas import MCPInstallRequest, MCPSettings, SimplifiedAPIRequest
|
||||
from langflow.base.mcp.constants import MAX_MCP_SERVER_NAME_LENGTH, MAX_MCP_TOOL_NAME_LENGTH
|
||||
from langflow.base.mcp.util import get_flow_snake_case, get_unique_name, sanitize_mcp_name
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.schema.message import Message
|
||||
from langflow.api.v1.schemas import MCPInstallRequest, MCPSettings
|
||||
from langflow.base.mcp.constants import MAX_MCP_SERVER_NAME_LENGTH
|
||||
from langflow.base.mcp.util import sanitize_mcp_name
|
||||
from langflow.services.database.models import Flow, Folder
|
||||
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
|
||||
from langflow.services.storage.utils import build_content_type_from_extension
|
||||
from langflow.services.deps import get_settings_service, session_scope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mcp/project", tags=["mcp_projects"])
|
||||
|
||||
# Create a context variable to store the current project
|
||||
# Create project-specific context variable
|
||||
current_project_ctx: ContextVar[UUID | None] = ContextVar("current_project_ctx", default=None)
|
||||
|
||||
# Create a mapping of project-specific SSE transports
|
||||
|
|
@ -652,236 +648,33 @@ class ProjectMCPServer:
|
|||
@handle_mcp_errors
|
||||
async def handle_list_project_tools():
|
||||
"""Handle listing tools for this specific project."""
|
||||
tools = []
|
||||
try:
|
||||
async with session_scope() as session:
|
||||
# Get flows with mcp_enabled flag set to True and in this project
|
||||
flows = (
|
||||
await session.exec(
|
||||
select(Flow).where(Flow.mcp_enabled == True, Flow.folder_id == self.project_id) # noqa: E712
|
||||
)
|
||||
).all()
|
||||
existing_names = set()
|
||||
for flow in flows:
|
||||
if flow.user_id is None:
|
||||
continue
|
||||
|
||||
# Use action_name if available, otherwise construct from flow name
|
||||
base_name = (
|
||||
sanitize_mcp_name(flow.action_name) if flow.action_name else sanitize_mcp_name(flow.name)
|
||||
)
|
||||
name = get_unique_name(base_name, MAX_MCP_TOOL_NAME_LENGTH, existing_names)
|
||||
|
||||
# Use action_description if available, otherwise use defaults
|
||||
description = flow.action_description or (
|
||||
flow.description if flow.description else f"Tool generated from flow: {name}"
|
||||
)
|
||||
|
||||
tool = types.Tool(
|
||||
name=name,
|
||||
description=description,
|
||||
inputSchema=json_schema_from_flow(flow),
|
||||
)
|
||||
tools.append(tool)
|
||||
existing_names.add(name)
|
||||
except Exception as e: # noqa: BLE001
|
||||
msg = f"Error in listing project tools: {e!s} from flow: {name}"
|
||||
logger.warning(msg)
|
||||
return tools
|
||||
return await handle_list_tools(project_id=self.project_id, mcp_enabled_only=True)
|
||||
|
||||
@self.server.list_prompts()
|
||||
async def handle_list_prompts():
|
||||
return []
|
||||
|
||||
@self.server.list_resources()
|
||||
async def handle_list_resources():
|
||||
resources = []
|
||||
try:
|
||||
storage_service = get_storage_service()
|
||||
settings_service = get_settings_service()
|
||||
|
||||
# Build full URL from settings
|
||||
host = getattr(settings_service.settings, "host", "localhost")
|
||||
port = getattr(settings_service.settings, "port", 3000)
|
||||
|
||||
base_url = f"http://{host}:{port}".rstrip("/")
|
||||
|
||||
async with session_scope() as session:
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.id:
|
||||
try:
|
||||
files = await storage_service.list_files(flow_id=str(flow.id))
|
||||
for file_name in files:
|
||||
# URL encode the filename
|
||||
safe_filename = quote(file_name)
|
||||
resource = types.Resource(
|
||||
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
|
||||
name=file_name,
|
||||
description=f"File in flow: {flow.name}",
|
||||
mimeType=build_content_type_from_extension(file_name),
|
||||
)
|
||||
resources.append(resource)
|
||||
except FileNotFoundError as e:
|
||||
msg = f"Error listing files for flow {flow.id}: {e}"
|
||||
logger.debug(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing resources: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
return resources
|
||||
async def handle_list_project_resources():
|
||||
"""Handle listing resources for this specific project."""
|
||||
return await handle_list_resources(project_id=self.project_id)
|
||||
|
||||
@self.server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> bytes:
|
||||
"""Handle resource read requests."""
|
||||
try:
|
||||
# Parse the URI properly
|
||||
parsed_uri = urlparse(str(uri))
|
||||
# Path will be like /api/v1/files/{flow_id}/{filename}
|
||||
path_parts = parsed_uri.path.split("/")
|
||||
# Remove empty strings from split
|
||||
path_parts = [p for p in path_parts if p]
|
||||
|
||||
# The flow_id and filename should be the last two parts
|
||||
two = 2
|
||||
if len(path_parts) < two:
|
||||
msg = f"Invalid URI format: {uri}"
|
||||
raise ValueError(msg)
|
||||
|
||||
flow_id = path_parts[-2]
|
||||
filename = unquote(path_parts[-1]) # URL decode the filename
|
||||
|
||||
storage_service = get_storage_service()
|
||||
|
||||
# Read the file content
|
||||
content = await storage_service.get_file(flow_id=flow_id, file_name=filename)
|
||||
if not content:
|
||||
msg = f"File {filename} not found in flow {flow_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Ensure content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode()
|
||||
return base64.b64encode(content)
|
||||
except Exception as e:
|
||||
msg = f"Error reading resource {uri}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
async def handle_read_project_resource(uri: str) -> bytes:
|
||||
"""Handle resource read requests for this specific project."""
|
||||
return await handle_read_resource(uri=uri)
|
||||
|
||||
@self.server.call_tool()
|
||||
@handle_mcp_errors
|
||||
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests."""
|
||||
mcp_config = get_mcp_config()
|
||||
if mcp_config.enable_progress_notifications is None:
|
||||
settings_service = get_settings_service()
|
||||
mcp_config.enable_progress_notifications = (
|
||||
settings_service.settings.mcp_server_enable_progress_notifications
|
||||
)
|
||||
|
||||
current_user = current_user_ctx.get()
|
||||
|
||||
async def execute_tool(session):
|
||||
# get flow id from name
|
||||
flow = await get_flow_snake_case(name, current_user.id, session, is_action=True)
|
||||
if not flow:
|
||||
msg = f"Flow with name '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
||||
# Initial progress notification
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
progress_token := self.server.request_context.meta.progressToken
|
||||
):
|
||||
await self.server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=0.0, total=1.0
|
||||
)
|
||||
|
||||
conversation_id = str(uuid4())
|
||||
input_request = SimplifiedAPIRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), session_id=conversation_id
|
||||
)
|
||||
|
||||
async def send_progress_updates(progress_token):
|
||||
try:
|
||||
progress = 0.0
|
||||
while True:
|
||||
await self.server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=min(0.9, progress), total=1.0
|
||||
)
|
||||
progress += 0.1
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
if mcp_config.enable_progress_notifications:
|
||||
await self.server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
collected_results = []
|
||||
try:
|
||||
progress_task = None
|
||||
if mcp_config.enable_progress_notifications and self.server.request_context.meta.progressToken:
|
||||
progress_task = asyncio.create_task(
|
||||
send_progress_updates(self.server.request_context.meta.progressToken)
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
result = await simple_run_flow(
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
stream=False,
|
||||
api_key_user=current_user,
|
||||
)
|
||||
# Process all outputs and messages, ensuring no duplicates
|
||||
processed_texts = set()
|
||||
|
||||
def add_result(text: str):
|
||||
if text not in processed_texts:
|
||||
processed_texts.add(text)
|
||||
collected_results.append(types.TextContent(type="text", text=text))
|
||||
|
||||
for run_output in result.outputs:
|
||||
for component_output in run_output.outputs:
|
||||
# Handle messages
|
||||
for msg in component_output.messages or []:
|
||||
add_result(msg.message)
|
||||
# Handle results
|
||||
for value in (component_output.results or {}).values():
|
||||
if isinstance(value, Message):
|
||||
add_result(value.get_text())
|
||||
else:
|
||||
add_result(str(value))
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}"
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
await asyncio.gather(progress_task, return_exceptions=True)
|
||||
|
||||
except Exception:
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
progress_token := self.server.request_context.meta.progressToken
|
||||
):
|
||||
await self.server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
return await with_db_session(execute_tool)
|
||||
except Exception as e:
|
||||
msg = f"Error executing tool {name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
async def handle_call_project_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests for this specific project."""
|
||||
return await handle_call_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
server=self.server,
|
||||
project_id=self.project_id,
|
||||
is_action=True,
|
||||
)
|
||||
|
||||
|
||||
# Cache of project MCP servers
|
||||
|
|
|
|||
348
src/backend/base/langflow/api/v1/mcp_utils.py
Normal file
348
src/backend/base/langflow/api/v1/mcp_utils.py
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
"""Common MCP handler functions shared between mcp.py and mcp_projects.py.
|
||||
|
||||
This module serves as the single source of truth for MCP functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
from urllib.parse import quote, unquote, urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
from mcp import types
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.v1.endpoints import simple_run_flow
|
||||
from langflow.api.v1.schemas import SimplifiedAPIRequest
|
||||
from langflow.base.mcp.constants import MAX_MCP_TOOL_NAME_LENGTH
|
||||
from langflow.base.mcp.util import get_flow_snake_case, get_unique_name, sanitize_mcp_name
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models import Flow
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
|
||||
from langflow.services.storage.utils import build_content_type_from_extension
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
# Create context variables
|
||||
current_user_ctx: ContextVar[User] = ContextVar("current_user_ctx")
|
||||
|
||||
|
||||
def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""Decorator to handle MCP endpoint errors consistently."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T:
|
||||
"""Execute an operation within a database session context."""
|
||||
async with session_scope() as session:
|
||||
return await operation(session)
|
||||
|
||||
|
||||
class MCPConfig:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.enable_progress_notifications = None
|
||||
return cls._instance
|
||||
|
||||
|
||||
def get_mcp_config():
|
||||
return MCPConfig()
|
||||
|
||||
|
||||
async def handle_list_resources(project_id=None):
|
||||
"""Handle listing resources for MCP.
|
||||
|
||||
Args:
|
||||
project_id: Optional project ID to filter resources by project
|
||||
"""
|
||||
resources = []
|
||||
try:
|
||||
storage_service = get_storage_service()
|
||||
settings_service = get_settings_service()
|
||||
|
||||
# Build full URL from settings
|
||||
host = getattr(settings_service.settings, "host", "localhost")
|
||||
port = getattr(settings_service.settings, "port", 3000)
|
||||
|
||||
base_url = f"http://{host}:{port}".rstrip("/")
|
||||
|
||||
async with session_scope() as session:
|
||||
# Build query based on whether project_id is provided
|
||||
flows_query = select(Flow).where(Flow.folder_id == project_id) if project_id else select(Flow)
|
||||
|
||||
flows = (await session.exec(flows_query)).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.id:
|
||||
try:
|
||||
files = await storage_service.list_files(flow_id=str(flow.id))
|
||||
for file_name in files:
|
||||
# URL encode the filename
|
||||
safe_filename = quote(file_name)
|
||||
resource = types.Resource(
|
||||
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
|
||||
name=file_name,
|
||||
description=f"File in flow: {flow.name}",
|
||||
mimeType=build_content_type_from_extension(file_name),
|
||||
)
|
||||
resources.append(resource)
|
||||
except FileNotFoundError as e:
|
||||
msg = f"Error listing files for flow {flow.id}: {e}"
|
||||
logger.debug(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing resources: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
return resources
|
||||
|
||||
|
||||
async def handle_read_resource(uri: str) -> bytes:
|
||||
"""Handle resource read requests."""
|
||||
try:
|
||||
# Parse the URI properly
|
||||
parsed_uri = urlparse(str(uri))
|
||||
# Path will be like /api/v1/files/{flow_id}/{filename}
|
||||
path_parts = parsed_uri.path.split("/")
|
||||
# Remove empty strings from split
|
||||
path_parts = [p for p in path_parts if p]
|
||||
|
||||
# The flow_id and filename should be the last two parts
|
||||
two = 2
|
||||
if len(path_parts) < two:
|
||||
msg = f"Invalid URI format: {uri}"
|
||||
raise ValueError(msg)
|
||||
|
||||
flow_id = path_parts[-2]
|
||||
filename = unquote(path_parts[-1]) # URL decode the filename
|
||||
|
||||
storage_service = get_storage_service()
|
||||
|
||||
# Read the file content
|
||||
content = await storage_service.get_file(flow_id=flow_id, file_name=filename)
|
||||
if not content:
|
||||
msg = f"File {filename} not found in flow {flow_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Ensure content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode()
|
||||
return base64.b64encode(content)
|
||||
except Exception as e:
|
||||
msg = f"Error reading resource {uri}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict, server, project_id=None, *, is_action=False
|
||||
) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool arguments
|
||||
server: MCP server instance
|
||||
project_id: Optional project ID to filter flows by project
|
||||
is_action: Whether to use action name for flow lookup
|
||||
"""
|
||||
mcp_config = get_mcp_config()
|
||||
if mcp_config.enable_progress_notifications is None:
|
||||
settings_service = get_settings_service()
|
||||
mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications
|
||||
|
||||
current_user = current_user_ctx.get()
|
||||
|
||||
async def execute_tool(session):
|
||||
# Get flow id from name
|
||||
flow = await get_flow_snake_case(name, current_user.id, session, is_action=is_action)
|
||||
if not flow:
|
||||
msg = f"Flow with name '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# If project_id is provided, verify the flow belongs to the project
|
||||
if project_id and flow.folder_id != project_id:
|
||||
msg = f"Flow '{name}' not found in project {project_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
||||
# Initial progress notification
|
||||
if mcp_config.enable_progress_notifications and (progress_token := server.request_context.meta.progressToken):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=0.0, total=1.0
|
||||
)
|
||||
|
||||
conversation_id = str(uuid4())
|
||||
input_request = SimplifiedAPIRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), session_id=conversation_id
|
||||
)
|
||||
|
||||
async def send_progress_updates(progress_token):
|
||||
try:
|
||||
progress = 0.0
|
||||
while True:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=min(0.9, progress), total=1.0
|
||||
)
|
||||
progress += 0.1
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
if mcp_config.enable_progress_notifications:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
collected_results = []
|
||||
try:
|
||||
progress_task = None
|
||||
if mcp_config.enable_progress_notifications and server.request_context.meta.progressToken:
|
||||
progress_task = asyncio.create_task(send_progress_updates(server.request_context.meta.progressToken))
|
||||
|
||||
try:
|
||||
try:
|
||||
result = await simple_run_flow(
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
stream=False,
|
||||
api_key_user=current_user,
|
||||
)
|
||||
# Process all outputs and messages, ensuring no duplicates
|
||||
processed_texts = set()
|
||||
|
||||
def add_result(text: str):
|
||||
if text not in processed_texts:
|
||||
processed_texts.add(text)
|
||||
collected_results.append(types.TextContent(type="text", text=text))
|
||||
|
||||
for run_output in result.outputs:
|
||||
for component_output in run_output.outputs:
|
||||
# Handle messages
|
||||
for msg in component_output.messages or []:
|
||||
add_result(msg.message)
|
||||
# Handle results
|
||||
for value in (component_output.results or {}).values():
|
||||
if isinstance(value, Message):
|
||||
add_result(value.get_text())
|
||||
else:
|
||||
add_result(str(value))
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}"
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
await asyncio.gather(progress_task, return_exceptions=True)
|
||||
|
||||
except Exception:
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
progress_token := server.request_context.meta.progressToken
|
||||
):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
return await with_db_session(execute_tool)
|
||||
except Exception as e:
|
||||
msg = f"Error executing tool {name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
|
||||
async def handle_list_tools(project_id=None, *, mcp_enabled_only=False):
|
||||
"""Handle listing tools for MCP.
|
||||
|
||||
Args:
|
||||
project_id: Optional project ID to filter tools by project
|
||||
mcp_enabled_only: Whether to filter for MCP-enabled flows only
|
||||
"""
|
||||
tools = []
|
||||
try:
|
||||
async with session_scope() as session:
|
||||
# Build query based on parameters
|
||||
if project_id:
|
||||
# Filter flows by project and optionally by MCP enabled status
|
||||
flows_query = select(Flow).where(Flow.folder_id == project_id, Flow.is_component == False) # noqa: E712
|
||||
if mcp_enabled_only:
|
||||
flows_query = flows_query.where(Flow.mcp_enabled == True) # noqa: E712
|
||||
else:
|
||||
# Get all flows
|
||||
flows_query = select(Flow)
|
||||
|
||||
flows = (await session.exec(flows_query)).all()
|
||||
|
||||
existing_names = set()
|
||||
for flow in flows:
|
||||
if flow.user_id is None:
|
||||
continue
|
||||
|
||||
# For project-specific tools, use action names if available
|
||||
if project_id:
|
||||
base_name = (
|
||||
sanitize_mcp_name(flow.action_name) if flow.action_name else sanitize_mcp_name(flow.name)
|
||||
)
|
||||
name = get_unique_name(base_name, MAX_MCP_TOOL_NAME_LENGTH, existing_names)
|
||||
description = flow.action_description or (
|
||||
flow.description if flow.description else f"Tool generated from flow: {name}"
|
||||
)
|
||||
else:
|
||||
# For global tools, use simple sanitized names
|
||||
base_name = sanitize_mcp_name(flow.name)
|
||||
name = base_name[:MAX_MCP_TOOL_NAME_LENGTH]
|
||||
if name in existing_names:
|
||||
i = 1
|
||||
while True:
|
||||
suffix = f"_{i}"
|
||||
truncated_base = base_name[: MAX_MCP_TOOL_NAME_LENGTH - len(suffix)]
|
||||
candidate = f"{truncated_base}{suffix}"
|
||||
if candidate not in existing_names:
|
||||
name = candidate
|
||||
break
|
||||
i += 1
|
||||
description = (
|
||||
f"{flow.id}: {flow.description}" if flow.description else f"Tool generated from flow: {name}"
|
||||
)
|
||||
|
||||
try:
|
||||
tool = types.Tool(
|
||||
name=name,
|
||||
description=description,
|
||||
inputSchema=json_schema_from_flow(flow),
|
||||
)
|
||||
tools.append(tool)
|
||||
existing_names.add(name)
|
||||
except Exception as e: # noqa: BLE001
|
||||
msg = f"Error in listing tools: {e!s} from flow: {base_name}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing tools: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
return tools
|
||||
Loading…
Add table
Add a link
Reference in a new issue