From e20ae03c380776e950881df425db7441078de98f Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Thu, 17 Jul 2025 14:21:04 -0500 Subject: [PATCH] ref: refactor MCP-related functionality centralizing common utilities (#9059) * reactor to have common mcp codes in mcp_support * [autofix.ci] apply automated fixes * Refactor MCP API argument passing and function signatures Updated function calls in mcp_projects.py to use explicit keyword arguments for clarity. Refactored mcp_support.py to use more concise query assignment and added keyword-only arguments to handle_call_tool and handle_list_tools for improved code readability and maintainability. * Rename mcp_support.py to mcp_utils.py and update imports Renamed mcp_support.py to mcp_utils.py for clarity and updated all relevant import statements in mcp.py and mcp_projects.py to reflect the new module name. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/backend/base/langflow/api/v1/mcp.py | 304 +-------------- .../base/langflow/api/v1/mcp_projects.py | 261 ++----------- src/backend/base/langflow/api/v1/mcp_utils.py | 348 ++++++++++++++++++ 3 files changed, 395 insertions(+), 518 deletions(-) create mode 100644 src/backend/base/langflow/api/v1/mcp_utils.py diff --git a/src/backend/base/langflow/api/v1/mcp.py b/src/backend/base/langflow/api/v1/mcp.py index 099312ab3..742eccc64 100644 --- a/src/backend/base/langflow/api/v1/mcp.py +++ b/src/backend/base/langflow/api/v1/mcp.py @@ -1,11 +1,4 @@ import asyncio -import base64 -from collections.abc import Awaitable, Callable -from contextvars import ContextVar -from functools import wraps -from typing import Any, ParamSpec, TypeVar -from urllib.parse import quote, unquote, urlparse -from uuid import uuid4 import pydantic from anyio import BrokenResourceError @@ -15,70 +8,22 @@ from loguru import logger from mcp import types from mcp.server import NotificationOptions, Server from mcp.server.sse import SseServerTransport -from sqlmodel import select from langflow.api.utils import CurrentActiveMCPUser -from langflow.api.v1.endpoints import simple_run_flow -from langflow.api.v1.schemas import SimplifiedAPIRequest -from langflow.base.mcp.constants import MAX_MCP_TOOL_NAME_LENGTH -from langflow.base.mcp.util import get_flow_snake_case, sanitize_mcp_name -from langflow.helpers.flow import json_schema_from_flow -from langflow.schema.message import Message -from langflow.services.database.models.flow.model import Flow -from langflow.services.database.models.user.model import User -from langflow.services.deps import ( - get_db_service, - get_settings_service, - get_storage_service, - session_scope, +from langflow.api.v1.mcp_utils import ( + current_user_ctx, + handle_call_tool, + handle_list_resources, + handle_list_tools, + handle_mcp_errors, + handle_read_resource, ) -from langflow.services.storage.utils import build_content_type_from_extension - -T = TypeVar("T") -P = ParamSpec("P") - - -def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: - """Decorator to handle MCP endpoint errors consistently.""" - - @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - try: - return await func(*args, **kwargs) - except Exception as e: - msg = f"Error in {func.__name__}: {e!s}" - logger.exception(msg) - raise - - return wrapper - - -async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T: - """Execute an operation within a database session context.""" - async with session_scope() as session: - return await operation(session) - - -class MCPConfig: - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance.enable_progress_notifications = None - return cls._instance - - -def get_mcp_config(): - return MCPConfig() - +from langflow.services.deps import get_settings_service router = APIRouter(prefix="/mcp", tags=["mcp"]) server = Server("langflow-mcp-server") -# Create a context variable to store the current user -current_user_ctx: ContextVar[User] = ContextVar("current_user_ctx") # Define constants MAX_RETRIES = 2 @@ -94,237 +39,28 @@ async def handle_list_prompts(): @server.list_resources() -async def handle_list_resources(): - resources = [] - try: - db_service = get_db_service() - storage_service = get_storage_service() - settings_service = get_settings_service() - - # Build full URL from settings - host = getattr(settings_service.settings, "host", "localhost") - port = getattr(settings_service.settings, "port", 3000) - - base_url = f"http://{host}:{port}".rstrip("/") - - async with db_service.with_session() as session: - flows = (await session.exec(select(Flow))).all() - - for flow in flows: - if flow.id: - try: - files = await storage_service.list_files(flow_id=str(flow.id)) - for file_name in files: - # URL encode the filename - safe_filename = quote(file_name) - resource = types.Resource( - uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}", - name=file_name, - description=f"File in flow: {flow.name}", - mimeType=build_content_type_from_extension(file_name), - ) - resources.append(resource) - except FileNotFoundError as e: - msg = f"Error listing files for flow {flow.id}: {e}" - logger.debug(msg) - continue - except Exception as e: - msg = f"Error in listing resources: {e!s}" - logger.exception(msg) - raise - return resources +async def handle_global_resources(): + """Handle listing resources for global MCP server.""" + return await handle_list_resources() @server.read_resource() -async def handle_read_resource(uri: str) -> bytes: - """Handle resource read requests.""" - try: - # Parse the URI properly - parsed_uri = urlparse(str(uri)) - # Path will be like /api/v1/files/{flow_id}/{filename} - path_parts = parsed_uri.path.split("/") - # Remove empty strings from split - path_parts = [p for p in path_parts if p] - - # The flow_id and filename should be the last two parts - two = 2 - if len(path_parts) < two: - msg = f"Invalid URI format: {uri}" - raise ValueError(msg) - - flow_id = path_parts[-2] - filename = unquote(path_parts[-1]) # URL decode the filename - - storage_service = get_storage_service() - - # Read the file content - content = await storage_service.get_file(flow_id=flow_id, file_name=filename) - if not content: - msg = f"File {filename} not found in flow {flow_id}" - raise ValueError(msg) - - # Ensure content is base64 encoded - if isinstance(content, str): - content = content.encode() - return base64.b64encode(content) - except Exception as e: - msg = f"Error reading resource {uri}: {e!s}" - logger.exception(msg) - raise +async def handle_global_read_resource(uri: str) -> bytes: + """Handle resource read requests for global MCP server.""" + return await handle_read_resource(uri) @server.list_tools() -async def handle_list_tools(): - tools = [] - try: - db_service = get_db_service() - async with db_service.with_session() as session: - flows = (await session.exec(select(Flow))).all() - - existing_names = set() - for flow in flows: - if flow.user_id is None: - continue - - base_name = sanitize_mcp_name(flow.name) - name = base_name[:MAX_MCP_TOOL_NAME_LENGTH] - if name in existing_names: - i = 1 - while True: - suffix = f"_{i}" - truncated_base = base_name[: MAX_MCP_TOOL_NAME_LENGTH - len(suffix)] - candidate = f"{truncated_base}{suffix}" - if candidate not in existing_names: - name = candidate - break - i += 1 - try: - tool = types.Tool( - name=name, - description=f"{flow.id}: {flow.description}" - if flow.description - else f"Tool generated from flow: {name}", - inputSchema=json_schema_from_flow(flow), - ) - tools.append(tool) - existing_names.add(name) - except Exception as e: # noqa: BLE001 - msg = f"Error in listing tools: {e!s} from flow: {base_name}" - logger.warning(msg) - continue - except Exception as e: - msg = f"Error in listing tools: {e!s}" - logger.exception(msg) - raise - return tools +async def handle_global_tools(): + """Handle listing tools for global MCP server.""" + return await handle_list_tools() @server.call_tool() @handle_mcp_errors -async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]: - """Handle tool execution requests.""" - mcp_config = get_mcp_config() - if mcp_config.enable_progress_notifications is None: - settings_service = get_settings_service() - mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications - - current_user = current_user_ctx.get() - - async def execute_tool(session): - # get flow id from name - flow = await get_flow_snake_case(name, current_user.id, session) - if not flow: - msg = f"Flow with name '{name}' not found" - raise ValueError(msg) - - # Process inputs - processed_inputs = dict(arguments) - - # Initial progress notification - if mcp_config.enable_progress_notifications and (progress_token := server.request_context.meta.progressToken): - await server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=0.0, total=1.0 - ) - - conversation_id = str(uuid4()) - input_request = SimplifiedAPIRequest( - input_value=processed_inputs.get("input_value", ""), session_id=conversation_id - ) - - async def send_progress_updates(progress_token): - try: - progress = 0.0 - while True: - await server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=min(0.9, progress), total=1.0 - ) - progress += 0.1 - await asyncio.sleep(1.0) - except asyncio.CancelledError: - if mcp_config.enable_progress_notifications: - await server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=1.0, total=1.0 - ) - raise - - collected_results = [] - try: - progress_task = None - if mcp_config.enable_progress_notifications and server.request_context.meta.progressToken: - progress_task = asyncio.create_task(send_progress_updates(server.request_context.meta.progressToken)) - - try: - try: - result = await simple_run_flow( - flow=flow, - input_request=input_request, - stream=False, - api_key_user=current_user, - ) - # Process all outputs and messages, ensuring no duplicates - processed_texts = set() - - def add_result(text: str): - if text not in processed_texts: - processed_texts.add(text) - collected_results.append(types.TextContent(type="text", text=text)) - - for run_output in result.outputs: - for component_output in run_output.outputs: - # Handle messages - for msg in component_output.messages or []: - add_result(msg.message) - # Handle results - for value in (component_output.results or {}).values(): - if isinstance(value, Message): - add_result(value.get_text()) - else: - add_result(str(value)) - except Exception as e: # noqa: BLE001 - error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}" - collected_results.append(types.TextContent(type="text", text=error_msg)) - - return collected_results - finally: - if progress_task: - progress_task.cancel() - await asyncio.gather(progress_task, return_exceptions=True) - - except Exception: - if mcp_config.enable_progress_notifications and ( - progress_token := server.request_context.meta.progressToken - ): - await server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=1.0, total=1.0 - ) - raise - - try: - return await with_db_session(execute_tool) - except Exception as e: - msg = f"Error executing tool {name}: {e!s}" - logger.exception(msg) - raise +async def handle_global_call_tool(name: str, arguments: dict) -> list[types.TextContent]: + """Handle tool execution requests for global MCP server.""" + return await handle_call_tool(name, arguments, server) sse = SseServerTransport("/api/v1/mcp/") diff --git a/src/backend/base/langflow/api/v1/mcp_projects.py b/src/backend/base/langflow/api/v1/mcp_projects.py index 39648dd73..0c313f38f 100644 --- a/src/backend/base/langflow/api/v1/mcp_projects.py +++ b/src/backend/base/langflow/api/v1/mcp_projects.py @@ -1,5 +1,4 @@ import asyncio -import base64 import json import logging import os @@ -10,8 +9,7 @@ from datetime import datetime, timezone from ipaddress import ip_address from pathlib import Path from subprocess import CalledProcessError -from urllib.parse import quote, unquote, urlparse -from uuid import UUID, uuid4 +from uuid import UUID from anyio import BrokenResourceError from fastapi import APIRouter, HTTPException, Request, Response @@ -23,27 +21,25 @@ from sqlalchemy.orm import selectinload from sqlmodel import select from langflow.api.utils import CurrentActiveMCPUser -from langflow.api.v1.endpoints import simple_run_flow -from langflow.api.v1.mcp import ( +from langflow.api.v1.mcp_utils import ( current_user_ctx, - get_mcp_config, + handle_call_tool, + handle_list_resources, + handle_list_tools, handle_mcp_errors, - with_db_session, + handle_read_resource, ) -from langflow.api.v1.schemas import MCPInstallRequest, MCPSettings, SimplifiedAPIRequest -from langflow.base.mcp.constants import MAX_MCP_SERVER_NAME_LENGTH, MAX_MCP_TOOL_NAME_LENGTH -from langflow.base.mcp.util import get_flow_snake_case, get_unique_name, sanitize_mcp_name -from langflow.helpers.flow import json_schema_from_flow -from langflow.schema.message import Message +from langflow.api.v1.schemas import MCPInstallRequest, MCPSettings +from langflow.base.mcp.constants import MAX_MCP_SERVER_NAME_LENGTH +from langflow.base.mcp.util import sanitize_mcp_name from langflow.services.database.models import Flow, Folder -from langflow.services.deps import get_settings_service, get_storage_service, session_scope -from langflow.services.storage.utils import build_content_type_from_extension +from langflow.services.deps import get_settings_service, session_scope logger = logging.getLogger(__name__) router = APIRouter(prefix="/mcp/project", tags=["mcp_projects"]) -# Create a context variable to store the current project +# Create project-specific context variable current_project_ctx: ContextVar[UUID | None] = ContextVar("current_project_ctx", default=None) # Create a mapping of project-specific SSE transports @@ -652,236 +648,33 @@ class ProjectMCPServer: @handle_mcp_errors async def handle_list_project_tools(): """Handle listing tools for this specific project.""" - tools = [] - try: - async with session_scope() as session: - # Get flows with mcp_enabled flag set to True and in this project - flows = ( - await session.exec( - select(Flow).where(Flow.mcp_enabled == True, Flow.folder_id == self.project_id) # noqa: E712 - ) - ).all() - existing_names = set() - for flow in flows: - if flow.user_id is None: - continue - - # Use action_name if available, otherwise construct from flow name - base_name = ( - sanitize_mcp_name(flow.action_name) if flow.action_name else sanitize_mcp_name(flow.name) - ) - name = get_unique_name(base_name, MAX_MCP_TOOL_NAME_LENGTH, existing_names) - - # Use action_description if available, otherwise use defaults - description = flow.action_description or ( - flow.description if flow.description else f"Tool generated from flow: {name}" - ) - - tool = types.Tool( - name=name, - description=description, - inputSchema=json_schema_from_flow(flow), - ) - tools.append(tool) - existing_names.add(name) - except Exception as e: # noqa: BLE001 - msg = f"Error in listing project tools: {e!s} from flow: {name}" - logger.warning(msg) - return tools + return await handle_list_tools(project_id=self.project_id, mcp_enabled_only=True) @self.server.list_prompts() async def handle_list_prompts(): return [] @self.server.list_resources() - async def handle_list_resources(): - resources = [] - try: - storage_service = get_storage_service() - settings_service = get_settings_service() - - # Build full URL from settings - host = getattr(settings_service.settings, "host", "localhost") - port = getattr(settings_service.settings, "port", 3000) - - base_url = f"http://{host}:{port}".rstrip("/") - - async with session_scope() as session: - flows = (await session.exec(select(Flow))).all() - - for flow in flows: - if flow.id: - try: - files = await storage_service.list_files(flow_id=str(flow.id)) - for file_name in files: - # URL encode the filename - safe_filename = quote(file_name) - resource = types.Resource( - uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}", - name=file_name, - description=f"File in flow: {flow.name}", - mimeType=build_content_type_from_extension(file_name), - ) - resources.append(resource) - except FileNotFoundError as e: - msg = f"Error listing files for flow {flow.id}: {e}" - logger.debug(msg) - continue - except Exception as e: - msg = f"Error in listing resources: {e!s}" - logger.exception(msg) - raise - return resources + async def handle_list_project_resources(): + """Handle listing resources for this specific project.""" + return await handle_list_resources(project_id=self.project_id) @self.server.read_resource() - async def handle_read_resource(uri: str) -> bytes: - """Handle resource read requests.""" - try: - # Parse the URI properly - parsed_uri = urlparse(str(uri)) - # Path will be like /api/v1/files/{flow_id}/{filename} - path_parts = parsed_uri.path.split("/") - # Remove empty strings from split - path_parts = [p for p in path_parts if p] - - # The flow_id and filename should be the last two parts - two = 2 - if len(path_parts) < two: - msg = f"Invalid URI format: {uri}" - raise ValueError(msg) - - flow_id = path_parts[-2] - filename = unquote(path_parts[-1]) # URL decode the filename - - storage_service = get_storage_service() - - # Read the file content - content = await storage_service.get_file(flow_id=flow_id, file_name=filename) - if not content: - msg = f"File {filename} not found in flow {flow_id}" - raise ValueError(msg) - - # Ensure content is base64 encoded - if isinstance(content, str): - content = content.encode() - return base64.b64encode(content) - except Exception as e: - msg = f"Error reading resource {uri}: {e!s}" - logger.exception(msg) - raise + async def handle_read_project_resource(uri: str) -> bytes: + """Handle resource read requests for this specific project.""" + return await handle_read_resource(uri=uri) @self.server.call_tool() @handle_mcp_errors - async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]: - """Handle tool execution requests.""" - mcp_config = get_mcp_config() - if mcp_config.enable_progress_notifications is None: - settings_service = get_settings_service() - mcp_config.enable_progress_notifications = ( - settings_service.settings.mcp_server_enable_progress_notifications - ) - - current_user = current_user_ctx.get() - - async def execute_tool(session): - # get flow id from name - flow = await get_flow_snake_case(name, current_user.id, session, is_action=True) - if not flow: - msg = f"Flow with name '{name}' not found" - raise ValueError(msg) - - # Process inputs - processed_inputs = dict(arguments) - - # Initial progress notification - if mcp_config.enable_progress_notifications and ( - progress_token := self.server.request_context.meta.progressToken - ): - await self.server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=0.0, total=1.0 - ) - - conversation_id = str(uuid4()) - input_request = SimplifiedAPIRequest( - input_value=processed_inputs.get("input_value", ""), session_id=conversation_id - ) - - async def send_progress_updates(progress_token): - try: - progress = 0.0 - while True: - await self.server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=min(0.9, progress), total=1.0 - ) - progress += 0.1 - await asyncio.sleep(1.0) - except asyncio.CancelledError: - if mcp_config.enable_progress_notifications: - await self.server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=1.0, total=1.0 - ) - raise - - collected_results = [] - try: - progress_task = None - if mcp_config.enable_progress_notifications and self.server.request_context.meta.progressToken: - progress_task = asyncio.create_task( - send_progress_updates(self.server.request_context.meta.progressToken) - ) - - try: - try: - result = await simple_run_flow( - flow=flow, - input_request=input_request, - stream=False, - api_key_user=current_user, - ) - # Process all outputs and messages, ensuring no duplicates - processed_texts = set() - - def add_result(text: str): - if text not in processed_texts: - processed_texts.add(text) - collected_results.append(types.TextContent(type="text", text=text)) - - for run_output in result.outputs: - for component_output in run_output.outputs: - # Handle messages - for msg in component_output.messages or []: - add_result(msg.message) - # Handle results - for value in (component_output.results or {}).values(): - if isinstance(value, Message): - add_result(value.get_text()) - else: - add_result(str(value)) - except Exception as e: # noqa: BLE001 - error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}" - collected_results.append(types.TextContent(type="text", text=error_msg)) - - return collected_results - finally: - if progress_task: - progress_task.cancel() - await asyncio.gather(progress_task, return_exceptions=True) - - except Exception: - if mcp_config.enable_progress_notifications and ( - progress_token := self.server.request_context.meta.progressToken - ): - await self.server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=1.0, total=1.0 - ) - raise - - try: - return await with_db_session(execute_tool) - except Exception as e: - msg = f"Error executing tool {name}: {e!s}" - logger.exception(msg) - raise + async def handle_call_project_tool(name: str, arguments: dict) -> list[types.TextContent]: + """Handle tool execution requests for this specific project.""" + return await handle_call_tool( + name=name, + arguments=arguments, + server=self.server, + project_id=self.project_id, + is_action=True, + ) # Cache of project MCP servers diff --git a/src/backend/base/langflow/api/v1/mcp_utils.py b/src/backend/base/langflow/api/v1/mcp_utils.py new file mode 100644 index 000000000..6dc0ec110 --- /dev/null +++ b/src/backend/base/langflow/api/v1/mcp_utils.py @@ -0,0 +1,348 @@ +"""Common MCP handler functions shared between mcp.py and mcp_projects.py. + +This module serves as the single source of truth for MCP functionality. +""" + +import asyncio +import base64 +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from functools import wraps +from typing import Any, ParamSpec, TypeVar +from urllib.parse import quote, unquote, urlparse +from uuid import uuid4 + +from loguru import logger +from mcp import types +from sqlmodel import select + +from langflow.api.v1.endpoints import simple_run_flow +from langflow.api.v1.schemas import SimplifiedAPIRequest +from langflow.base.mcp.constants import MAX_MCP_TOOL_NAME_LENGTH +from langflow.base.mcp.util import get_flow_snake_case, get_unique_name, sanitize_mcp_name +from langflow.helpers.flow import json_schema_from_flow +from langflow.schema.message import Message +from langflow.services.database.models import Flow +from langflow.services.database.models.user.model import User +from langflow.services.deps import get_settings_service, get_storage_service, session_scope +from langflow.services.storage.utils import build_content_type_from_extension + +T = TypeVar("T") +P = ParamSpec("P") + +# Create context variables +current_user_ctx: ContextVar[User] = ContextVar("current_user_ctx") + + +def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + """Decorator to handle MCP endpoint errors consistently.""" + + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + try: + return await func(*args, **kwargs) + except Exception as e: + msg = f"Error in {func.__name__}: {e!s}" + logger.exception(msg) + raise + + return wrapper + + +async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T: + """Execute an operation within a database session context.""" + async with session_scope() as session: + return await operation(session) + + +class MCPConfig: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.enable_progress_notifications = None + return cls._instance + + +def get_mcp_config(): + return MCPConfig() + + +async def handle_list_resources(project_id=None): + """Handle listing resources for MCP. + + Args: + project_id: Optional project ID to filter resources by project + """ + resources = [] + try: + storage_service = get_storage_service() + settings_service = get_settings_service() + + # Build full URL from settings + host = getattr(settings_service.settings, "host", "localhost") + port = getattr(settings_service.settings, "port", 3000) + + base_url = f"http://{host}:{port}".rstrip("/") + + async with session_scope() as session: + # Build query based on whether project_id is provided + flows_query = select(Flow).where(Flow.folder_id == project_id) if project_id else select(Flow) + + flows = (await session.exec(flows_query)).all() + + for flow in flows: + if flow.id: + try: + files = await storage_service.list_files(flow_id=str(flow.id)) + for file_name in files: + # URL encode the filename + safe_filename = quote(file_name) + resource = types.Resource( + uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}", + name=file_name, + description=f"File in flow: {flow.name}", + mimeType=build_content_type_from_extension(file_name), + ) + resources.append(resource) + except FileNotFoundError as e: + msg = f"Error listing files for flow {flow.id}: {e}" + logger.debug(msg) + continue + except Exception as e: + msg = f"Error in listing resources: {e!s}" + logger.exception(msg) + raise + return resources + + +async def handle_read_resource(uri: str) -> bytes: + """Handle resource read requests.""" + try: + # Parse the URI properly + parsed_uri = urlparse(str(uri)) + # Path will be like /api/v1/files/{flow_id}/{filename} + path_parts = parsed_uri.path.split("/") + # Remove empty strings from split + path_parts = [p for p in path_parts if p] + + # The flow_id and filename should be the last two parts + two = 2 + if len(path_parts) < two: + msg = f"Invalid URI format: {uri}" + raise ValueError(msg) + + flow_id = path_parts[-2] + filename = unquote(path_parts[-1]) # URL decode the filename + + storage_service = get_storage_service() + + # Read the file content + content = await storage_service.get_file(flow_id=flow_id, file_name=filename) + if not content: + msg = f"File {filename} not found in flow {flow_id}" + raise ValueError(msg) + + # Ensure content is base64 encoded + if isinstance(content, str): + content = content.encode() + return base64.b64encode(content) + except Exception as e: + msg = f"Error reading resource {uri}: {e!s}" + logger.exception(msg) + raise + + +async def handle_call_tool( + name: str, arguments: dict, server, project_id=None, *, is_action=False +) -> list[types.TextContent]: + """Handle tool execution requests. + + Args: + name: Tool name + arguments: Tool arguments + server: MCP server instance + project_id: Optional project ID to filter flows by project + is_action: Whether to use action name for flow lookup + """ + mcp_config = get_mcp_config() + if mcp_config.enable_progress_notifications is None: + settings_service = get_settings_service() + mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications + + current_user = current_user_ctx.get() + + async def execute_tool(session): + # Get flow id from name + flow = await get_flow_snake_case(name, current_user.id, session, is_action=is_action) + if not flow: + msg = f"Flow with name '{name}' not found" + raise ValueError(msg) + + # If project_id is provided, verify the flow belongs to the project + if project_id and flow.folder_id != project_id: + msg = f"Flow '{name}' not found in project {project_id}" + raise ValueError(msg) + + # Process inputs + processed_inputs = dict(arguments) + + # Initial progress notification + if mcp_config.enable_progress_notifications and (progress_token := server.request_context.meta.progressToken): + await server.request_context.session.send_progress_notification( + progress_token=progress_token, progress=0.0, total=1.0 + ) + + conversation_id = str(uuid4()) + input_request = SimplifiedAPIRequest( + input_value=processed_inputs.get("input_value", ""), session_id=conversation_id + ) + + async def send_progress_updates(progress_token): + try: + progress = 0.0 + while True: + await server.request_context.session.send_progress_notification( + progress_token=progress_token, progress=min(0.9, progress), total=1.0 + ) + progress += 0.1 + await asyncio.sleep(1.0) + except asyncio.CancelledError: + if mcp_config.enable_progress_notifications: + await server.request_context.session.send_progress_notification( + progress_token=progress_token, progress=1.0, total=1.0 + ) + raise + + collected_results = [] + try: + progress_task = None + if mcp_config.enable_progress_notifications and server.request_context.meta.progressToken: + progress_task = asyncio.create_task(send_progress_updates(server.request_context.meta.progressToken)) + + try: + try: + result = await simple_run_flow( + flow=flow, + input_request=input_request, + stream=False, + api_key_user=current_user, + ) + # Process all outputs and messages, ensuring no duplicates + processed_texts = set() + + def add_result(text: str): + if text not in processed_texts: + processed_texts.add(text) + collected_results.append(types.TextContent(type="text", text=text)) + + for run_output in result.outputs: + for component_output in run_output.outputs: + # Handle messages + for msg in component_output.messages or []: + add_result(msg.message) + # Handle results + for value in (component_output.results or {}).values(): + if isinstance(value, Message): + add_result(value.get_text()) + else: + add_result(str(value)) + except Exception as e: # noqa: BLE001 + error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}" + collected_results.append(types.TextContent(type="text", text=error_msg)) + + return collected_results + finally: + if progress_task: + progress_task.cancel() + await asyncio.gather(progress_task, return_exceptions=True) + + except Exception: + if mcp_config.enable_progress_notifications and ( + progress_token := server.request_context.meta.progressToken + ): + await server.request_context.session.send_progress_notification( + progress_token=progress_token, progress=1.0, total=1.0 + ) + raise + + try: + return await with_db_session(execute_tool) + except Exception as e: + msg = f"Error executing tool {name}: {e!s}" + logger.exception(msg) + raise + + +async def handle_list_tools(project_id=None, *, mcp_enabled_only=False): + """Handle listing tools for MCP. + + Args: + project_id: Optional project ID to filter tools by project + mcp_enabled_only: Whether to filter for MCP-enabled flows only + """ + tools = [] + try: + async with session_scope() as session: + # Build query based on parameters + if project_id: + # Filter flows by project and optionally by MCP enabled status + flows_query = select(Flow).where(Flow.folder_id == project_id, Flow.is_component == False) # noqa: E712 + if mcp_enabled_only: + flows_query = flows_query.where(Flow.mcp_enabled == True) # noqa: E712 + else: + # Get all flows + flows_query = select(Flow) + + flows = (await session.exec(flows_query)).all() + + existing_names = set() + for flow in flows: + if flow.user_id is None: + continue + + # For project-specific tools, use action names if available + if project_id: + base_name = ( + sanitize_mcp_name(flow.action_name) if flow.action_name else sanitize_mcp_name(flow.name) + ) + name = get_unique_name(base_name, MAX_MCP_TOOL_NAME_LENGTH, existing_names) + description = flow.action_description or ( + flow.description if flow.description else f"Tool generated from flow: {name}" + ) + else: + # For global tools, use simple sanitized names + base_name = sanitize_mcp_name(flow.name) + name = base_name[:MAX_MCP_TOOL_NAME_LENGTH] + if name in existing_names: + i = 1 + while True: + suffix = f"_{i}" + truncated_base = base_name[: MAX_MCP_TOOL_NAME_LENGTH - len(suffix)] + candidate = f"{truncated_base}{suffix}" + if candidate not in existing_names: + name = candidate + break + i += 1 + description = ( + f"{flow.id}: {flow.description}" if flow.description else f"Tool generated from flow: {name}" + ) + + try: + tool = types.Tool( + name=name, + description=description, + inputSchema=json_schema_from_flow(flow), + ) + tools.append(tool) + existing_names.add(name) + except Exception as e: # noqa: BLE001 + msg = f"Error in listing tools: {e!s} from flow: {base_name}" + logger.warning(msg) + continue + except Exception as e: + msg = f"Error in listing tools: {e!s}" + logger.exception(msg) + raise + return tools