fix: Refactor MCP API to fix value truncation (#8298)
* refactor: Simplify flow execution and update input handling in MCP API - Replaced `InputValueRequest` with `SimplifiedAPIRequest` for cleaner input management. - Updated flow execution logic to utilize `simple_run_flow`, enhancing clarity and performance. - Removed unnecessary background task handling and streamlined progress updates. - Improved message collection from flow outputs, ensuring robust handling of results. * fix: Add error handling for tool execution in MCP API - Implemented a try-except block around the flow execution to catch and handle exceptions gracefully. - Enhanced message collection logic to ensure that errors during tool execution are communicated back as text content. - Improved robustness of the `handle_call_tool` function by ensuring all potential errors are captured and reported. * fix: Improve error messaging in tool execution for MCP API - Updated error handling in the `handle_call_tool` function to provide more descriptive error messages. - Changed the error message format to include the flow name, enhancing clarity for debugging purposes. - Ensured that all exceptions during tool execution are captured and reported as text content. * refactor: Enhance message and result handling in handle_call_tool - Improved the logic for processing outputs in the `handle_call_tool` function to handle messages and results more comprehensively. - Streamlined the collection of text content from both messages and results, ensuring all relevant outputs are captured. - Enhanced robustness by ensuring that all outputs are processed, regardless of their structure. * refactor: Improve progress notification handling in handle_call_tool - Updated the logic for progress task creation in the `handle_call_tool` function to ensure it only initializes when progress notifications are enabled and a progress token is present. - Enhanced the cancellation and exception handling of the progress task to prevent potential errors when it is not created. - Improved overall robustness of the function by ensuring that progress updates are managed correctly based on the current context. * refactor: Streamline flow execution and message handling in ProjectMCPServer - Replaced `InputValueRequest` with `SimplifiedAPIRequest` for improved input management. - Updated flow execution to utilize `simple_run_flow`, enhancing clarity and performance. - Refined progress notification handling to ensure tasks are only created when necessary. - Improved message collection from flow outputs, ensuring robust handling of both messages and results. - Enhanced error handling during tool execution to provide clearer feedback on failures. * refactor: enhance progress update handling in ProjectMCPServer Updated the send_progress_updates function to accept a progress token as an argument, improving its flexibility. Adjusted the task cancellation logic to use asyncio.gather for better exception handling. This change aims to streamline progress notifications when enabled. * refactor: add group_outputs property to message configurations in starter projects
This commit is contained in:
parent
0896f51983
commit
a276c2de48
2 changed files with 76 additions and 100 deletions
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
|
|
@ -17,12 +16,12 @@ from mcp import types
|
|||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from sqlmodel import select
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from langflow.api.v1.chat import build_flow_and_stream
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
from langflow.api.v1.endpoints import simple_run_flow
|
||||
from langflow.api.v1.schemas import SimplifiedAPIRequest
|
||||
from langflow.base.mcp.util import get_flow_snake_case
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models import Flow, User
|
||||
from langflow.services.deps import (
|
||||
|
|
@ -214,7 +213,6 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
|
|||
settings_service = get_settings_service()
|
||||
mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications
|
||||
|
||||
background_tasks = BackgroundTasks()
|
||||
current_user = current_user_ctx.get()
|
||||
|
||||
async def execute_tool(session):
|
||||
|
|
@ -223,7 +221,6 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
|
|||
if not flow:
|
||||
msg = f"Flow with name '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
flow_id = flow.id
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
|
@ -235,8 +232,8 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
|
|||
)
|
||||
|
||||
conversation_id = str(uuid4())
|
||||
input_request = InputValueRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), components=[], type="chat", session=conversation_id
|
||||
input_request = SimplifiedAPIRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), session_id=conversation_id
|
||||
)
|
||||
|
||||
async def send_progress_updates():
|
||||
|
|
@ -260,48 +257,43 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
|
|||
|
||||
collected_results = []
|
||||
try:
|
||||
progress_task = asyncio.create_task(send_progress_updates())
|
||||
progress_task = None
|
||||
if mcp_config.enable_progress_notifications and server.request_context.meta.progressToken:
|
||||
progress_task = asyncio.create_task(send_progress_updates())
|
||||
|
||||
try:
|
||||
response = await build_flow_and_stream(
|
||||
flow_id=flow_id,
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
if event_data.get("event") == "end_vertex":
|
||||
message = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", {})
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
if message:
|
||||
collected_results.append(types.TextContent(type="text", text=str(message)))
|
||||
if event_data.get("event") == "error":
|
||||
content_blocks = event_data.get("data", {}).get("content_blocks", [])
|
||||
text = event_data.get("data", {}).get("text", "")
|
||||
error_msg = f"Error Executing the {flow.name} tool. Error: {text} Details: {content_blocks}"
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
except json.JSONDecodeError:
|
||||
msg = f"Failed to parse event data: {line}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
try:
|
||||
result = await simple_run_flow(
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
stream=False,
|
||||
api_key_user=current_user,
|
||||
)
|
||||
# Process all outputs and messages
|
||||
for run_output in result.outputs:
|
||||
for component_output in run_output.outputs:
|
||||
# Handle messages
|
||||
for msg in component_output.messages or []:
|
||||
text_content = types.TextContent(type="text", text=msg.message)
|
||||
collected_results.append(text_content)
|
||||
# Handle results
|
||||
for value in (component_output.results or {}).values():
|
||||
if isinstance(value, Message):
|
||||
text_content = types.TextContent(type="text", text=value.get_text())
|
||||
collected_results.append(text_content)
|
||||
else:
|
||||
collected_results.append(types.TextContent(type="text", text=str(value)))
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}"
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
progress_task.cancel()
|
||||
await asyncio.wait([progress_task])
|
||||
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
|
||||
raise exc
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
await asyncio.wait([progress_task])
|
||||
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
|
||||
raise exc
|
||||
|
||||
except Exception:
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from urllib.parse import quote, unquote, urlparse
|
|||
from uuid import UUID, uuid4
|
||||
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import HTMLResponse
|
||||
from mcp import types
|
||||
from mcp.server import NotificationOptions, Server
|
||||
|
|
@ -22,16 +22,17 @@ from mcp.server.sse import SseServerTransport
|
|||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.v1.chat import build_flow_and_stream
|
||||
from langflow.api.v1.endpoints import simple_run_flow
|
||||
from langflow.api.v1.mcp import (
|
||||
current_user_ctx,
|
||||
get_mcp_config,
|
||||
handle_mcp_errors,
|
||||
with_db_session,
|
||||
)
|
||||
from langflow.api.v1.schemas import InputValueRequest, MCPInstallRequest, MCPSettings
|
||||
from langflow.api.v1.schemas import MCPInstallRequest, MCPSettings, SimplifiedAPIRequest
|
||||
from langflow.base.mcp.util import get_flow_snake_case
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models import Flow, Folder, User
|
||||
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
|
||||
|
|
@ -632,7 +633,6 @@ class ProjectMCPServer:
|
|||
settings_service.settings.mcp_server_enable_progress_notifications
|
||||
)
|
||||
|
||||
background_tasks = BackgroundTasks()
|
||||
current_user = current_user_ctx.get()
|
||||
|
||||
async def execute_tool(session):
|
||||
|
|
@ -641,7 +641,6 @@ class ProjectMCPServer:
|
|||
if not flow:
|
||||
msg = f"Flow with name '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
flow_id = flow.id
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
|
@ -655,19 +654,11 @@ class ProjectMCPServer:
|
|||
)
|
||||
|
||||
conversation_id = str(uuid4())
|
||||
input_request = InputValueRequest(
|
||||
input_value=processed_inputs.get("input_value", ""),
|
||||
components=[],
|
||||
type="chat",
|
||||
session=conversation_id,
|
||||
input_request = SimplifiedAPIRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), session_id=conversation_id
|
||||
)
|
||||
|
||||
async def send_progress_updates():
|
||||
if not (
|
||||
mcp_config.enable_progress_notifications and self.server.request_context.meta.progressToken
|
||||
):
|
||||
return
|
||||
|
||||
async def send_progress_updates(progress_token):
|
||||
try:
|
||||
progress = 0.0
|
||||
while True:
|
||||
|
|
@ -685,50 +676,43 @@ class ProjectMCPServer:
|
|||
|
||||
collected_results = []
|
||||
try:
|
||||
progress_task = asyncio.create_task(send_progress_updates())
|
||||
|
||||
try:
|
||||
response = await build_flow_and_stream(
|
||||
flow_id=flow_id,
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
progress_task = None
|
||||
if mcp_config.enable_progress_notifications and self.server.request_context.meta.progressToken:
|
||||
progress_task = asyncio.create_task(
|
||||
send_progress_updates(self.server.request_context.meta.progressToken)
|
||||
)
|
||||
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
if event_data.get("event") == "end_vertex":
|
||||
message = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", {})
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
if message:
|
||||
collected_results.append(types.TextContent(type="text", text=str(message)))
|
||||
if event_data.get("event") == "error":
|
||||
content_blocks = event_data.get("data", {}).get("content_blocks", [])
|
||||
text = event_data.get("data", {}).get("text", "")
|
||||
error_msg = (
|
||||
f"Error Executing the {flow.name} tool. Error: {text} Details: {content_blocks}"
|
||||
)
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
except json.JSONDecodeError:
|
||||
msg = f"Failed to parse event data: {line}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
try:
|
||||
try:
|
||||
result = await simple_run_flow(
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
stream=False,
|
||||
api_key_user=current_user,
|
||||
)
|
||||
# Process all outputs and messages
|
||||
for run_output in result.outputs:
|
||||
for component_output in run_output.outputs:
|
||||
# Handle messages
|
||||
for msg in component_output.messages or []:
|
||||
text_content = types.TextContent(type="text", text=msg.message)
|
||||
collected_results.append(text_content)
|
||||
# Handle results
|
||||
for value in (component_output.results or {}).values():
|
||||
if isinstance(value, Message):
|
||||
text_content = types.TextContent(type="text", text=value.get_text())
|
||||
collected_results.append(text_content)
|
||||
else:
|
||||
collected_results.append(types.TextContent(type="text", text=str(value)))
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_msg = f"Error Executing the {flow.name} tool. Error: {e!s}"
|
||||
collected_results.append(types.TextContent(type="text", text=error_msg))
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
progress_task.cancel()
|
||||
await asyncio.wait([progress_task])
|
||||
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
|
||||
raise exc
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
await asyncio.gather(progress_task, return_exceptions=True)
|
||||
|
||||
except Exception:
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue