feat: add anthropic mcp endpoint (#5148)
* mcp WIP * [autofix.ci] apply automated fixes * logging and flow user check * mcp stdio client component * handle disconnect better * initialization * session fix and type fix * [autofix.ci] apply automated fixes * defensive against mcp server bugs * [autofix.ci] apply automated fixes * notifications and sse component * enabled flags and resource support * remove unneeded print * extract json schema util * [autofix.ci] apply automated fixes * ruff * fix tools [] bug and db asysnc session api change * Tool instead of StructuredTool * ruff fixes * ruff * validation optimization * fix frontend test * another playwright fix * Update src/frontend/tests/extended/features/notifications.spec.ts Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * mcp component descriptions * mypy fixes * fix setup_database_url test --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
ee19feaf39
commit
63d649b0f4
14 changed files with 1229 additions and 525 deletions
|
|
@ -113,8 +113,10 @@ dependencies = [
|
|||
"aiofile>=3.9.0,<4.0.0",
|
||||
"sseclient-py==1.8.0",
|
||||
"arize-phoenix-otel>=0.6.1",
|
||||
"openinference-instrumentation-langchain==0.1.29",
|
||||
"openinference-instrumentation-langchain>=0.1.29",
|
||||
"crewai~=0.86.0",
|
||||
"mcp>=0.9.1",
|
||||
"uv>=0.5.7",
|
||||
"ag2",
|
||||
"pydantic-ai>=0.0.12",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langflow.api.v1 import (
|
|||
flows_router,
|
||||
folders_router,
|
||||
login_router,
|
||||
mcp_router,
|
||||
monitor_router,
|
||||
starter_projects_router,
|
||||
store_router,
|
||||
|
|
@ -16,6 +17,7 @@ from langflow.api.v1 import (
|
|||
validate_router,
|
||||
variables_router,
|
||||
)
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1",
|
||||
|
|
@ -33,3 +35,6 @@ router.include_router(files_router)
|
|||
router.include_router(monitor_router)
|
||||
router.include_router(folders_router)
|
||||
router.include_router(starter_projects_router)
|
||||
|
||||
if get_settings_service().settings.mcp_server_enabled:
|
||||
router.include_router(mcp_router)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from langflow.api.v1.files import router as files_router
|
|||
from langflow.api.v1.flows import router as flows_router
|
||||
from langflow.api.v1.folders import router as folders_router
|
||||
from langflow.api.v1.login import router as login_router
|
||||
from langflow.api.v1.mcp import router as mcp_router
|
||||
from langflow.api.v1.monitor import router as monitor_router
|
||||
from langflow.api.v1.starter_projects import router as starter_projects_router
|
||||
from langflow.api.v1.store import router as store_router
|
||||
|
|
@ -20,6 +21,7 @@ __all__ = [
|
|||
"flows_router",
|
||||
"folders_router",
|
||||
"login_router",
|
||||
"mcp_router",
|
||||
"monitor_router",
|
||||
"starter_projects_router",
|
||||
"store_router",
|
||||
|
|
|
|||
343
src/backend/base/langflow/api/v1/mcp.py
Normal file
343
src/backend/base/langflow/api/v1/mcp.py
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from contextvars import ContextVar
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote, unquote, urlparse
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from mcp import types
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from sqlmodel import select
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from langflow.api.v1.chat import build_flow
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models import Flow, User
|
||||
from langflow.services.deps import get_db_service, get_session, get_settings_service, get_storage_service
|
||||
from langflow.services.storage.utils import build_content_type_from_extension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if False:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Enable debug logging for MCP package
|
||||
mcp_logger = logging.getLogger("mcp")
|
||||
mcp_logger.setLevel(logging.DEBUG)
|
||||
if not mcp_logger.handlers:
|
||||
mcp_logger.addHandler(handler)
|
||||
|
||||
logger.debug("MCP module loaded - debug logging enabled")
|
||||
|
||||
enable_progress_notifications = get_settings_service().settings.mcp_server_enable_progress_notifications
|
||||
|
||||
router = APIRouter(prefix="/mcp", tags=["mcp"])
|
||||
|
||||
server = Server("langflow-mcp-server")
|
||||
|
||||
# Create a context variable to store the current user
|
||||
current_user_ctx: ContextVar[User] = ContextVar("current_user_ctx")
|
||||
|
||||
# Define constants
|
||||
MAX_RETRIES = 2
|
||||
|
||||
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts():
|
||||
return []
|
||||
|
||||
|
||||
@server.list_resources()
|
||||
async def handle_list_resources():
|
||||
resources = []
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
storage_service = get_storage_service()
|
||||
settings_service = get_settings_service()
|
||||
|
||||
# Build full URL from settings
|
||||
host = getattr(settings_service.settings, "holst", "localhost")
|
||||
port = getattr(settings_service.settings, "port", 3000)
|
||||
|
||||
base_url = f"http://{host}:{port}".rstrip("/")
|
||||
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.id:
|
||||
try:
|
||||
files = await storage_service.list_files(flow_id=str(flow.id))
|
||||
for file_name in files:
|
||||
# URL encode the filename
|
||||
safe_filename = quote(file_name)
|
||||
resource = types.Resource(
|
||||
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
|
||||
name=file_name,
|
||||
description=f"File in flow: {flow.name}",
|
||||
mimeType=build_content_type_from_extension(file_name),
|
||||
)
|
||||
resources.append(resource)
|
||||
except FileNotFoundError as e:
|
||||
msg = f"Error listing files for flow {flow.id}: {e}"
|
||||
logger.debug(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing resources: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
return resources
|
||||
|
||||
|
||||
@server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> bytes:
|
||||
"""Handle resource read requests."""
|
||||
try:
|
||||
# Parse the URI properly
|
||||
parsed_uri = urlparse(str(uri))
|
||||
# Path will be like /api/v1/files/{flow_id}/{filename}
|
||||
path_parts = parsed_uri.path.split("/")
|
||||
# Remove empty strings from split
|
||||
path_parts = [p for p in path_parts if p]
|
||||
|
||||
# The flow_id and filename should be the last two parts
|
||||
two = 2
|
||||
if len(path_parts) < two:
|
||||
msg = f"Invalid URI format: {uri}"
|
||||
raise ValueError(msg)
|
||||
|
||||
flow_id = path_parts[-2]
|
||||
filename = unquote(path_parts[-1]) # URL decode the filename
|
||||
|
||||
storage_service = get_storage_service()
|
||||
|
||||
# Read the file content
|
||||
content = await storage_service.get_file(flow_id=flow_id, file_name=filename)
|
||||
if not content:
|
||||
msg = f"File {filename} not found in flow {flow_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Ensure content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode()
|
||||
return base64.b64encode(content)
|
||||
except Exception as e:
|
||||
msg = f"Error reading resource {uri}: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools():
|
||||
tools = []
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.user_id is None:
|
||||
continue
|
||||
|
||||
tool = types.Tool(
|
||||
name=str(flow.id), # Use flow.id instead of name
|
||||
description=f"{flow.name}: {flow.description}"
|
||||
if flow.description
|
||||
else f"Tool generated from flow: {flow.name}",
|
||||
inputSchema=json_schema_from_flow(flow),
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
msg = f"Error in listing tools: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
return tools
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests."""
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
background_tasks = BackgroundTasks()
|
||||
|
||||
current_user = current_user_ctx.get()
|
||||
flow = (await session.exec(select(Flow).where(Flow.id == UUID(name)))).first()
|
||||
|
||||
if not flow:
|
||||
msg = f"Flow with id '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
||||
# Initial progress notification
|
||||
if enable_progress_notifications and (progress_token := server.request_context.meta.progressToken):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=0.0, total=1.0
|
||||
)
|
||||
|
||||
conversation_id = str(uuid4())
|
||||
input_request = InputValueRequest(
|
||||
input_value=processed_inputs.get("input_value", ""), components=[], type="chat", session=conversation_id
|
||||
)
|
||||
|
||||
async def send_progress_updates():
|
||||
if not (enable_progress_notifications and server.request_context.meta.progressToken):
|
||||
return
|
||||
|
||||
try:
|
||||
progress = 0.0
|
||||
while True:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=min(0.9, progress), total=1.0
|
||||
)
|
||||
progress += 0.1
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
# Send final 100% progress
|
||||
if enable_progress_notifications:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
db_service = get_db_service()
|
||||
collected_results = []
|
||||
async with db_service.with_session() as async_session:
|
||||
try:
|
||||
progress_task = asyncio.create_task(send_progress_updates())
|
||||
|
||||
try:
|
||||
response = await build_flow(
|
||||
flow_id=UUID(name),
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
session=async_session,
|
||||
)
|
||||
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
if event_data.get("event") == "end_vertex":
|
||||
message = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", {})
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
if message:
|
||||
collected_results.append(types.TextContent(type="text", text=str(message)))
|
||||
except json.JSONDecodeError:
|
||||
msg = f"Failed to parse event data: {line}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
progress_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await progress_task
|
||||
except Exception as e:
|
||||
msg = f"Error in async session: {e}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
context = server.request_context
|
||||
# Send error progress if there's an exception
|
||||
if enable_progress_notifications and (progress_token := context.meta.progressToken):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
msg = f"Error executing tool {name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
|
||||
|
||||
sse = SseServerTransport("/api/v1/mcp/")
|
||||
|
||||
|
||||
def find_validation_error(exc):
|
||||
"""Searches for a pydantic.ValidationError in the exception chain."""
|
||||
while exc:
|
||||
if isinstance(exc, pydantic.ValidationError):
|
||||
return exc
|
||||
exc = getattr(exc, "__cause__", None) or getattr(exc, "__context__", None)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/sse", response_class=StreamingResponse)
|
||||
async def handle_sse(request: Request, current_user: Annotated[User, Depends(get_current_active_user)]):
|
||||
token = current_user_ctx.set(current_user)
|
||||
try:
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
try:
|
||||
msg = "Starting SSE connection"
|
||||
logger.debug(msg)
|
||||
msg = f"Stream types: read={type(streams[0])}, write={type(streams[1])}"
|
||||
logger.debug(msg)
|
||||
|
||||
notification_options = NotificationOptions(
|
||||
prompts_changed=True, resources_changed=True, tools_changed=True
|
||||
)
|
||||
init_options = server.create_initialization_options(notification_options)
|
||||
msg = f"Initialization options: {init_options}"
|
||||
logger.debug(msg)
|
||||
|
||||
try:
|
||||
await server.run(streams[0], streams[1], init_options)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
validation_error = find_validation_error(exc)
|
||||
if validation_error:
|
||||
msg = "Validation error in MCP:" + str(validation_error)
|
||||
logger.debug(msg)
|
||||
else:
|
||||
msg = f"Error in MCP: {exc!s}"
|
||||
logger.debug(msg)
|
||||
return
|
||||
except BrokenResourceError:
|
||||
# Handle gracefully when client disconnects
|
||||
logger.info("Client disconnected from SSE connection")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("SSE connection was cancelled")
|
||||
except Exception as e:
|
||||
msg = f"Error in MCP: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
finally:
|
||||
current_user_ctx.reset(token)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def handle_messages(request: Request):
|
||||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
||||
0
src/backend/base/langflow/base/mcp/__init__.py
Normal file
0
src/backend/base/langflow/base/mcp/__init__.py
Normal file
26
src/backend/base/langflow/base/mcp/util.py
Normal file
26
src/backend/base/langflow/base/mcp/util.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langflow.helpers.base_model import BaseModel
|
||||
|
||||
|
||||
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[[dict], Awaitable]:
|
||||
async def tool_coroutine(*args):
|
||||
if len(args) == 0:
|
||||
msg = f"at least one positional argument is required {args}"
|
||||
raise ValueError(msg)
|
||||
arg_dict = dict(zip(arg_schema.model_fields.keys(), args, strict=False))
|
||||
return await session.call_tool(tool_name, arguments=arg_dict)
|
||||
|
||||
return tool_coroutine
|
||||
|
||||
|
||||
def create_tool_func(tool_name: str, session) -> Callable[..., str]:
|
||||
def tool_func(**kwargs):
|
||||
if len(kwargs) == 0:
|
||||
msg = f"at least one named argument is required {kwargs}"
|
||||
raise ValueError(msg)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(session.call_tool(tool_name, arguments=kwargs))
|
||||
|
||||
return tool_func
|
||||
|
|
@ -9,6 +9,7 @@ from .exa_search import ExaSearchToolkit
|
|||
from .glean_search_api import GleanSearchAPIComponent
|
||||
from .google_search_api import GoogleSearchAPIComponent
|
||||
from .google_serper_api import GoogleSerperAPIComponent
|
||||
from .mcp_stdio import MCPStdio
|
||||
from .python_code_structured_tool import PythonCodeStructuredTool
|
||||
from .python_repl import PythonREPLToolComponent
|
||||
from .search_api import SearchAPIComponent
|
||||
|
|
@ -36,6 +37,7 @@ __all__ = [
|
|||
"GleanSearchAPIComponent",
|
||||
"GoogleSearchAPIComponent",
|
||||
"GoogleSerperAPIComponent",
|
||||
"MCPStdio",
|
||||
"PythonCodeStructuredTool",
|
||||
"PythonREPLToolComponent",
|
||||
"SearXNGToolComponent",
|
||||
|
|
|
|||
98
src/backend/base/langflow/components/tools/mcp_sse.py
Normal file
98
src/backend/base/langflow/components/tools/mcp_sse.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
# from langflow.field_typing import Data
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from langflow.base.mcp.util import create_tool_coroutine, create_tool_func
|
||||
from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.io import MessageTextInput, Output
|
||||
|
||||
# Define constant for status code
|
||||
HTTP_TEMPORARY_REDIRECT = 307
|
||||
|
||||
|
||||
class MCPSseClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.write = None
|
||||
self.sse = None
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def pre_check_redirect(self, url: str):
|
||||
"""Check if the URL responds with a 307 Redirect."""
|
||||
async with httpx.AsyncClient(follow_redirects=False) as client:
|
||||
response = await client.request("HEAD", url)
|
||||
if response.status_code == HTTP_TEMPORARY_REDIRECT:
|
||||
return response.headers.get("Location") # Return the redirect URL
|
||||
return url # Return the original URL if no redirect
|
||||
|
||||
async def connect_to_server(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
|
||||
):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
url = await self.pre_check_redirect(url)
|
||||
|
||||
async with asyncio.timeout(timeout_seconds):
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
|
||||
)
|
||||
self.sse, self.write = sse_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await self.session.list_tools()
|
||||
return response.tools
|
||||
|
||||
|
||||
class MCPSse(Component):
|
||||
client = MCPSseClient()
|
||||
tools = types.ListToolsResult
|
||||
tool_names = [str]
|
||||
display_name = "MCP Tools (SSE)"
|
||||
description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent."
|
||||
documentation: str = "http://docs.langflow.org/components/custom"
|
||||
icon = "code"
|
||||
name = "MCPSse"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="url",
|
||||
display_name="mcp sse url",
|
||||
info="sse url",
|
||||
value="http://localhost:7860/api/v1/mcp/sse",
|
||||
tool_mode=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Tools", name="tools", method="build_output"),
|
||||
]
|
||||
|
||||
async def build_output(self) -> list[Tool]:
|
||||
if self.client.session is None:
|
||||
self.tools = await self.client.connect_to_server(self.url, {})
|
||||
|
||||
tool_list = []
|
||||
|
||||
for tool in self.tools:
|
||||
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
|
||||
tool_list.append(
|
||||
Tool(
|
||||
name=tool.name, # maybe format this
|
||||
description=tool.description,
|
||||
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
|
||||
func=create_tool_func(tool.name, self.client.session),
|
||||
)
|
||||
)
|
||||
|
||||
self.tool_names = [tool.name for tool in self.tools]
|
||||
return tool_list
|
||||
122
src/backend/base/langflow/components/tools/mcp_stdio.py
Normal file
122
src/backend/base/langflow/components/tools/mcp_stdio.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# from langflow.field_typing import Data
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp.client.stdio import stdio_client
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langflow.base.mcp.util import create_tool_coroutine, create_tool_func
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.io import MessageTextInput, Output
|
||||
|
||||
|
||||
class MCPStdioClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_to_server(self, command_str: str):
|
||||
command = command_str.split(" ")
|
||||
server_params = StdioServerParameters(
|
||||
command=command[0], args=command[1:], env={"DEBUG": "true", "PATH": os.environ["PATH"]}
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await self.session.list_tools()
|
||||
return response.tools
|
||||
|
||||
|
||||
def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]:
|
||||
"""Converts a JSON schema into a Pydantic model dynamically.
|
||||
|
||||
:param schema: The JSON schema as a dictionary.
|
||||
:return: A Pydantic model class.
|
||||
"""
|
||||
if schema.get("type") != "object":
|
||||
msg = "JSON schema must be of type 'object' at the root level."
|
||||
raise ValueError(msg)
|
||||
|
||||
fields = {}
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for field_name, field_def in properties.items():
|
||||
# Extract type
|
||||
field_type_str = field_def.get("type", "str") # Default to string type if not specified
|
||||
field_type = {
|
||||
"string": str,
|
||||
"str": str,
|
||||
"integer": int,
|
||||
"int": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}.get(field_type_str, Any)
|
||||
|
||||
# Extract description and default if present
|
||||
field_metadata = {"description": field_def.get("description", "")}
|
||||
if field_name not in required_fields:
|
||||
field_metadata["default"] = field_def.get("default", None)
|
||||
|
||||
# Create Pydantic field
|
||||
fields[field_name] = (field_type, Field(**field_metadata))
|
||||
|
||||
# Dynamically create the model
|
||||
return create_model("InputSchema", **fields)
|
||||
|
||||
|
||||
class MCPStdio(Component):
|
||||
client = MCPStdioClient()
|
||||
tools = types.ListToolsResult
|
||||
tool_names = [str]
|
||||
display_name = "MCP Tools (stdio)"
|
||||
description = (
|
||||
"Connects to an MCP server over stdio and exposes it's tools as langflow tools to be used by an Agent."
|
||||
)
|
||||
documentation: str = "http://docs.langflow.org/components/custom"
|
||||
icon = "code"
|
||||
name = "MCPStdio"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="command",
|
||||
display_name="mcp command",
|
||||
info="mcp command",
|
||||
value="uvx mcp-sse-shim@latest",
|
||||
tool_mode=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Tools", name="tools", method="build_output"),
|
||||
]
|
||||
|
||||
async def build_output(self) -> list[Tool]:
|
||||
if self.client.session is None:
|
||||
self.tools = await self.client.connect_to_server(self.command)
|
||||
|
||||
tool_list = []
|
||||
|
||||
for tool in self.tools:
|
||||
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
|
||||
tool_list.append(
|
||||
Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
|
||||
func=create_tool_func(tool.name, args_schema),
|
||||
)
|
||||
)
|
||||
self.tool_names = [tool.name for tool in self.tools]
|
||||
return tool_list
|
||||
|
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, cast
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from pydantic.v1 import BaseModel, Field, create_model
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
@ -314,3 +315,46 @@ async def generate_unique_flow_name(flow_name, user_id, session):
|
|||
# If a flow with the name already exists, append (n) to the name and increment n
|
||||
flow_name = f"{original_name} ({n})"
|
||||
n += 1
|
||||
|
||||
|
||||
def json_schema_from_flow(flow: Flow) -> dict:
|
||||
"""Generate JSON schema from flow input nodes."""
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
||||
# Get the flow's data which contains the nodes and their configurations
|
||||
flow_data = flow.data if flow.data else {}
|
||||
|
||||
graph = Graph.from_payload(flow_data)
|
||||
input_nodes = [vertex for vertex in graph.vertices if vertex.is_input]
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
for node in input_nodes:
|
||||
node_data = node.data["node"]
|
||||
template = node_data["template"]
|
||||
|
||||
for field_name, field_data in template.items():
|
||||
if field_data != "Component" and field_data.get("show", False) and not field_data.get("advanced", False):
|
||||
field_type = field_data.get("type", "string")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_data.get("info", f"Input for {field_name}"),
|
||||
}
|
||||
# Update field_type in properties after determining the JSON Schema type
|
||||
if field_type == "str":
|
||||
field_type = "string"
|
||||
elif field_type == "int":
|
||||
field_type = "integer"
|
||||
elif field_type == "float":
|
||||
field_type = "number"
|
||||
elif field_type == "bool":
|
||||
field_type = "boolean"
|
||||
else:
|
||||
logger.warning(f"Unknown field type: {field_type} defaulting to string")
|
||||
field_type = "string"
|
||||
properties[field_name]["type"] = field_type
|
||||
|
||||
if field_data.get("required", False):
|
||||
required.append(field_name)
|
||||
|
||||
return {"type": "object", "properties": properties, "required": required}
|
||||
|
|
|
|||
|
|
@ -181,6 +181,12 @@ class Settings(BaseSettings):
|
|||
max_vertex_builds_to_keep: int = 3000
|
||||
"""The maximum number of vertex builds to keep in the database."""
|
||||
|
||||
# MCP Server
|
||||
mcp_server_enabled: bool = True
|
||||
"""If set to False, Langflow will not enable the MCP server."""
|
||||
mcp_server_enable_progress_notifications: bool = False
|
||||
"""If set to False, Langflow will not send progress notifications in the MCP server."""
|
||||
|
||||
@field_validator("dev")
|
||||
@classmethod
|
||||
def set_dev(cls, value):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
|
@ -5,14 +7,18 @@ from langflow.services.deps import get_settings_service
|
|||
@pytest.fixture(autouse=True)
|
||||
def setup_database_url(tmp_path, monkeypatch):
|
||||
"""Setup a temporary database URL for testing."""
|
||||
settings_service = get_settings_service()
|
||||
db_path = tmp_path / "test_performance.db"
|
||||
original_value = monkeypatch.delenv("LANGFLOW_DATABASE_URL", raising=False)
|
||||
original_value = os.getenv("LANGFLOW_DATABASE_URL")
|
||||
monkeypatch.delenv("LANGFLOW_DATABASE_URL", raising=False)
|
||||
test_db_url = f"sqlite:///{db_path}"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", test_db_url)
|
||||
settings_service.set("database_url", test_db_url)
|
||||
yield
|
||||
# Restore original value if it existed
|
||||
if original_value is not None:
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", original_value)
|
||||
settings_service.set("database_url", original_value)
|
||||
else:
|
||||
monkeypatch.delenv("LANGFLOW_DATABASE_URL", raising=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ test(
|
|||
}
|
||||
|
||||
await page.locator(".react-flow__pane").click();
|
||||
|
||||
await adjustScreenView(page, { numberOfZoomOut: 1 });
|
||||
await visibleElementHandle.hover();
|
||||
await page.mouse.down();
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue