diff --git a/src/backend/base/langflow/api/v1/mcp.py b/src/backend/base/langflow/api/v1/mcp.py index 2ad85b8c8..64d7135a4 100644 --- a/src/backend/base/langflow/api/v1/mcp.py +++ b/src/backend/base/langflow/api/v1/mcp.py @@ -1,6 +1,5 @@ import asyncio import base64 -import json from collections.abc import Awaitable, Callable from contextvars import ContextVar from functools import wraps @@ -17,12 +16,12 @@ from mcp import types from mcp.server import NotificationOptions, Server from mcp.server.sse import SseServerTransport from sqlmodel import select -from starlette.background import BackgroundTasks -from langflow.api.v1.chat import build_flow_and_stream -from langflow.api.v1.schemas import InputValueRequest +from langflow.api.v1.endpoints import simple_run_flow +from langflow.api.v1.schemas import SimplifiedAPIRequest from langflow.base.mcp.util import get_flow_snake_case from langflow.helpers.flow import json_schema_from_flow +from langflow.schema.message import Message from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models import Flow, User from langflow.services.deps import ( @@ -214,7 +213,6 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent settings_service = get_settings_service() mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications - background_tasks = BackgroundTasks() current_user = current_user_ctx.get() async def execute_tool(session): @@ -223,7 +221,6 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent if not flow: msg = f"Flow with name '{name}' not found" raise ValueError(msg) - flow_id = flow.id # Process inputs processed_inputs = dict(arguments) @@ -235,8 +232,8 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent ) conversation_id = str(uuid4()) - input_request = InputValueRequest( - input_value=processed_inputs.get("input_value", ""), components=[], type="chat", session=conversation_id + input_request = SimplifiedAPIRequest( + input_value=processed_inputs.get("input_value", ""), session_id=conversation_id ) async def send_progress_updates(): @@ -260,48 +257,43 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent collected_results = [] try: - progress_task = asyncio.create_task(send_progress_updates()) + progress_task = None + if mcp_config.enable_progress_notifications and server.request_context.meta.progressToken: + progress_task = asyncio.create_task(send_progress_updates()) try: - response = await build_flow_and_stream( - flow_id=flow_id, - inputs=input_request, - background_tasks=background_tasks, - current_user=current_user, - ) - - async for line in response.body_iterator: - if not line: - continue - try: - event_data = json.loads(line) - if event_data.get("event") == "end_vertex": - message = ( - event_data.get("data", {}) - .get("build_data", {}) - .get("data", {}) - .get("results", {}) - .get("message", {}) - .get("text", "") - ) - if message: - collected_results.append(types.TextContent(type="text", text=str(message))) - if event_data.get("event") == "error": - content_blocks = event_data.get("data", {}).get("content_blocks", []) - text = event_data.get("data", {}).get("text", "") - error_msg = f"Error Executing the {flow.name} tool. Error: {text} Details: {content_blocks}" - collected_results.append(types.TextContent(type="text", text=error_msg)) - except json.JSONDecodeError: - msg = f"Failed to parse event data: {line}" - logger.warning(msg) - continue + try: + result = await simple_run_flow( + flow=flow, + input_request=input_request, + stream=False, + api_key_user=current_user, + ) + # Process all outputs and messages + for run_output in result.outputs: + for component_output in run_output.outputs: + # Handle messages + for msg in component_output.messages or []: + text_content = types.TextContent(type="text", text=msg.message) + collected_results.append(text_content) + # Handle results + for value in (component_output.results or {}).values(): + if isinstance(value, Message): + text_content = types.TextContent(type="text", text=value.get_text()) + collected_results.append(text_content) + else: + collected_results.append(types.TextContent(type="text", text=str(value))) + except Exception as e: # noqa: BLE001 + error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}" + collected_results.append(types.TextContent(type="text", text=error_msg)) return collected_results finally: - progress_task.cancel() - await asyncio.wait([progress_task]) - if not progress_task.cancelled() and (exc := progress_task.exception()) is not None: - raise exc + if progress_task: + progress_task.cancel() + await asyncio.wait([progress_task]) + if not progress_task.cancelled() and (exc := progress_task.exception()) is not None: + raise exc except Exception: if mcp_config.enable_progress_notifications and ( diff --git a/src/backend/base/langflow/api/v1/mcp_projects.py b/src/backend/base/langflow/api/v1/mcp_projects.py index 67f037f04..de1ba33a0 100644 --- a/src/backend/base/langflow/api/v1/mcp_projects.py +++ b/src/backend/base/langflow/api/v1/mcp_projects.py @@ -14,7 +14,7 @@ from urllib.parse import quote, unquote, urlparse from uuid import UUID, uuid4 from anyio import BrokenResourceError -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import HTMLResponse from mcp import types from mcp.server import NotificationOptions, Server @@ -22,16 +22,17 @@ from mcp.server.sse import SseServerTransport from sqlalchemy.orm import selectinload from sqlmodel import select -from langflow.api.v1.chat import build_flow_and_stream +from langflow.api.v1.endpoints import simple_run_flow from langflow.api.v1.mcp import ( current_user_ctx, get_mcp_config, handle_mcp_errors, with_db_session, ) -from langflow.api.v1.schemas import InputValueRequest, MCPInstallRequest, MCPSettings +from langflow.api.v1.schemas import MCPInstallRequest, MCPSettings, SimplifiedAPIRequest from langflow.base.mcp.util import get_flow_snake_case from langflow.helpers.flow import json_schema_from_flow +from langflow.schema.message import Message from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models import Flow, Folder, User from langflow.services.deps import get_settings_service, get_storage_service, session_scope @@ -632,7 +633,6 @@ class ProjectMCPServer: settings_service.settings.mcp_server_enable_progress_notifications ) - background_tasks = BackgroundTasks() current_user = current_user_ctx.get() async def execute_tool(session): @@ -641,7 +641,6 @@ class ProjectMCPServer: if not flow: msg = f"Flow with name '{name}' not found" raise ValueError(msg) - flow_id = flow.id # Process inputs processed_inputs = dict(arguments) @@ -655,19 +654,11 @@ class ProjectMCPServer: ) conversation_id = str(uuid4()) - input_request = InputValueRequest( - input_value=processed_inputs.get("input_value", ""), - components=[], - type="chat", - session=conversation_id, + input_request = SimplifiedAPIRequest( + input_value=processed_inputs.get("input_value", ""), session_id=conversation_id ) - async def send_progress_updates(): - if not ( - mcp_config.enable_progress_notifications and self.server.request_context.meta.progressToken - ): - return - + async def send_progress_updates(progress_token): try: progress = 0.0 while True: @@ -685,50 +676,43 @@ class ProjectMCPServer: collected_results = [] try: - progress_task = asyncio.create_task(send_progress_updates()) - - try: - response = await build_flow_and_stream( - flow_id=flow_id, - inputs=input_request, - background_tasks=background_tasks, - current_user=current_user, + progress_task = None + if mcp_config.enable_progress_notifications and self.server.request_context.meta.progressToken: + progress_task = asyncio.create_task( + send_progress_updates(self.server.request_context.meta.progressToken) ) - async for line in response.body_iterator: - if not line: - continue - try: - event_data = json.loads(line) - if event_data.get("event") == "end_vertex": - message = ( - event_data.get("data", {}) - .get("build_data", {}) - .get("data", {}) - .get("results", {}) - .get("message", {}) - .get("text", "") - ) - if message: - collected_results.append(types.TextContent(type="text", text=str(message))) - if event_data.get("event") == "error": - content_blocks = event_data.get("data", {}).get("content_blocks", []) - text = event_data.get("data", {}).get("text", "") - error_msg = ( - f"Error Executing the {flow.name} tool. Error: {text} Details: {content_blocks}" - ) - collected_results.append(types.TextContent(type="text", text=error_msg)) - except json.JSONDecodeError: - msg = f"Failed to parse event data: {line}" - logger.warning(msg) - continue + try: + try: + result = await simple_run_flow( + flow=flow, + input_request=input_request, + stream=False, + api_key_user=current_user, + ) + # Process all outputs and messages + for run_output in result.outputs: + for component_output in run_output.outputs: + # Handle messages + for msg in component_output.messages or []: + text_content = types.TextContent(type="text", text=msg.message) + collected_results.append(text_content) + # Handle results + for value in (component_output.results or {}).values(): + if isinstance(value, Message): + text_content = types.TextContent(type="text", text=value.get_text()) + collected_results.append(text_content) + else: + collected_results.append(types.TextContent(type="text", text=str(value))) + except Exception as e: # noqa: BLE001 + error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}" + collected_results.append(types.TextContent(type="text", text=error_msg)) return collected_results finally: - progress_task.cancel() - await asyncio.wait([progress_task]) - if not progress_task.cancelled() and (exc := progress_task.exception()) is not None: - raise exc + if progress_task: + progress_task.cancel() + await asyncio.gather(progress_task, return_exceptions=True) except Exception: if mcp_config.enable_progress_notifications and (