diff --git a/src/backend/base/langflow/api/v1/mcp.py b/src/backend/base/langflow/api/v1/mcp.py index 0c358037d..a609863b6 100644 --- a/src/backend/base/langflow/api/v1/mcp.py +++ b/src/backend/base/langflow/api/v1/mcp.py @@ -2,11 +2,12 @@ import asyncio import base64 import json import logging -import traceback +from collections.abc import Awaitable, Callable from contextvars import ContextVar -from typing import Annotated +from functools import wraps +from typing import Annotated, Any, ParamSpec, TypeVar from urllib.parse import quote, unquote, urlparse -from uuid import UUID, uuid4 +from uuid import uuid4 import pydantic from anyio import BrokenResourceError @@ -20,34 +21,43 @@ from starlette.background import BackgroundTasks from langflow.api.v1.chat import build_flow_and_stream from langflow.api.v1.schemas import InputValueRequest +from langflow.base.mcp.util import get_flow from langflow.helpers.flow import json_schema_from_flow from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models import Flow, User from langflow.services.deps import ( get_db_service, - get_session, get_settings_service, get_storage_service, + session_scope, ) from langflow.services.storage.utils import build_content_type_from_extension logger = logging.getLogger(__name__) -if False: - logger.setLevel(logging.DEBUG) - if not logger.handlers: - handler = logging.StreamHandler() - handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - # Enable debug logging for MCP package - mcp_logger = logging.getLogger("mcp") - mcp_logger.setLevel(logging.DEBUG) - if not mcp_logger.handlers: - mcp_logger.addHandler(handler) +T = TypeVar("T") +P = ParamSpec("P") - logger.debug("MCP module loaded - debug logging enabled") + +def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + """Decorator to handle MCP endpoint errors consistently.""" + + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + try: + return await func(*args, **kwargs) + except Exception as e: + msg = f"Error in {func.__name__}: {e!s}" + logger.exception(msg) + raise + + return wrapper + + +async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T: + """Execute an operation within a database session context.""" + async with session_scope() as session: + return await operation(session) class MCPConfig: @@ -88,7 +98,7 @@ async def handle_list_prompts(): async def handle_list_resources(): resources = [] try: - session = await anext(get_session()) + db_service = get_db_service() storage_service = get_storage_service() settings_service = get_settings_service() @@ -98,31 +108,30 @@ async def handle_list_resources(): base_url = f"http://{host}:{port}".rstrip("/") - flows = (await session.exec(select(Flow))).all() + async with db_service.with_session() as session: + flows = (await session.exec(select(Flow))).all() - for flow in flows: - if flow.id: - try: - files = await storage_service.list_files(flow_id=str(flow.id)) - for file_name in files: - # URL encode the filename - safe_filename = quote(file_name) - resource = types.Resource( - uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}", - name=file_name, - description=f"File in flow: {flow.name}", - mimeType=build_content_type_from_extension(file_name), - ) - resources.append(resource) - except FileNotFoundError as e: - msg = f"Error listing files for flow {flow.id}: {e}" - logger.debug(msg) - continue + for flow in flows: + if flow.id: + try: + files = await storage_service.list_files(flow_id=str(flow.id)) + for file_name in files: + # URL encode the filename + safe_filename = quote(file_name) + resource = types.Resource( + uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}", + name=file_name, + description=f"File in flow: {flow.name}", + mimeType=build_content_type_from_extension(file_name), + ) + resources.append(resource) + except FileNotFoundError as e: + msg = f"Error listing files for flow {flow.id}: {e}" + logger.debug(msg) + continue except Exception as e: msg = f"Error in listing resources: {e!s}" logger.exception(msg) - trace = traceback.format_exc() - logger.exception(trace) raise return resources @@ -162,8 +171,6 @@ async def handle_read_resource(uri: str) -> bytes: except Exception as e: msg = f"Error reading resource {uri}: {e!s}" logger.exception(msg) - trace = traceback.format_exc() - logger.exception(trace) raise @@ -171,47 +178,48 @@ async def handle_read_resource(uri: str) -> bytes: async def handle_list_tools(): tools = [] try: - session = await anext(get_session()) - flows = (await session.exec(select(Flow))).all() + db_service = get_db_service() + async with db_service.with_session() as session: + flows = (await session.exec(select(Flow))).all() - for flow in flows: - if flow.user_id is None: - continue + for flow in flows: + if flow.user_id is None: + continue - tool = types.Tool( - name=str(flow.id), # Use flow.id instead of name - description=f"{flow.name}: {flow.description}" - if flow.description - else f"Tool generated from flow: {flow.name}", - inputSchema=json_schema_from_flow(flow), - ) - tools.append(tool) + tool = types.Tool( + name=flow.name, + description=f"{flow.id}: {flow.description}" + if flow.description + else f"Tool generated from flow: {flow.name}", + inputSchema=json_schema_from_flow(flow), + ) + tools.append(tool) except Exception as e: msg = f"Error in listing tools: {e!s}" logger.exception(msg) - trace = traceback.format_exc() - logger.exception(trace) raise return tools @server.call_tool() +@handle_mcp_errors async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]: """Handle tool execution requests.""" mcp_config = get_mcp_config() if mcp_config.enable_progress_notifications is None: settings_service = get_settings_service() mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications - try: - session = await anext(get_session()) - background_tasks = BackgroundTasks() - current_user = current_user_ctx.get() - flow = (await session.exec(select(Flow).where(Flow.id == UUID(name)))).first() + background_tasks = BackgroundTasks() + current_user = current_user_ctx.get() + async def execute_tool(session): + # get flow id from name + flow = await get_flow(name, current_user.id, session) if not flow: - msg = f"Flow with id '{name}' not found" + msg = f"Flow with name '{name}' not found" raise ValueError(msg) + flow_id = flow.id # Process inputs processed_inputs = dict(arguments) @@ -240,70 +248,66 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent progress += 0.1 await asyncio.sleep(1.0) except asyncio.CancelledError: - # Send final 100% progress if mcp_config.enable_progress_notifications: await server.request_context.session.send_progress_notification( progress_token=progress_token, progress=1.0, total=1.0 ) raise - db_service = get_db_service() collected_results = [] - async with db_service.with_session(): + try: + progress_task = asyncio.create_task(send_progress_updates()) + try: - progress_task = asyncio.create_task(send_progress_updates()) + response = await build_flow_and_stream( + flow_id=flow_id, + inputs=input_request, + background_tasks=background_tasks, + current_user=current_user, + ) - try: - response = await build_flow_and_stream( - flow_id=UUID(name), - inputs=input_request, - background_tasks=background_tasks, - current_user=current_user, - ) + async for line in response.body_iterator: + if not line: + continue + try: + event_data = json.loads(line) + if event_data.get("event") == "end_vertex": + message = ( + event_data.get("data", {}) + .get("build_data", {}) + .get("data", {}) + .get("results", {}) + .get("message", {}) + .get("text", "") + ) + if message: + collected_results.append(types.TextContent(type="text", text=str(message))) + except json.JSONDecodeError: + msg = f"Failed to parse event data: {line}" + logger.warning(msg) + continue - async for line in response.body_iterator: - if not line: - continue - try: - event_data = json.loads(line) - if event_data.get("event") == "end_vertex": - message = ( - event_data.get("data", {}) - .get("build_data", {}) - .get("data", {}) - .get("results", {}) - .get("message", {}) - .get("text", "") - ) - if message: - collected_results.append(types.TextContent(type="text", text=str(message))) - except json.JSONDecodeError: - msg = f"Failed to parse event data: {line}" - logger.warning(msg) - continue + return collected_results + finally: + progress_task.cancel() + await asyncio.wait([progress_task]) + if not progress_task.cancelled() and (exc := progress_task.exception()) is not None: + raise exc - return collected_results - finally: - progress_task.cancel() - await asyncio.wait([progress_task]) - if not progress_task.cancelled() and (exc := progress_task.exception()) is not None: - raise exc - except Exception as e: - msg = f"Error in async session: {e}" - logger.exception(msg) - raise + except Exception: + if mcp_config.enable_progress_notifications and ( + progress_token := server.request_context.meta.progressToken + ): + await server.request_context.session.send_progress_notification( + progress_token=progress_token, progress=1.0, total=1.0 + ) + raise + try: + return await with_db_session(execute_tool) except Exception as e: - context = server.request_context - # Send error progress if there's an exception - if mcp_config.enable_progress_notifications and (progress_token := context.meta.progressToken): - await server.request_context.session.send_progress_notification( - progress_token=progress_token, progress=1.0, total=1.0 - ) msg = f"Error executing tool {name}: {e!s}" logger.exception(msg) - trace = traceback.format_exc() - logger.exception(trace) raise @@ -357,8 +361,6 @@ async def handle_sse(request: Request, current_user: Annotated[User, Depends(get except Exception as e: msg = f"Error in MCP: {e!s}" logger.exception(msg) - trace = traceback.format_exc() - logger.exception(trace) raise finally: current_user_ctx.reset(token) diff --git a/src/backend/base/langflow/base/astra_assistants/util.py b/src/backend/base/langflow/base/astra_assistants/util.py index 2cde2d63b..6bb588f46 100644 --- a/src/backend/base/langflow/base/astra_assistants/util.py +++ b/src/backend/base/langflow/base/astra_assistants/util.py @@ -17,7 +17,7 @@ from langchain_core.tools import BaseTool from pydantic import BaseModel from requests.exceptions import RequestException -from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema +from langflow.base.mcp.util import create_input_schema_from_json_schema from langflow.services.cache.utils import CacheMiss client_lock = threading.Lock() diff --git a/src/backend/base/langflow/base/mcp/util.py b/src/backend/base/langflow/base/mcp/util.py index 5bcfb06cd..bb90a2492 100644 --- a/src/backend/base/langflow/base/mcp/util.py +++ b/src/backend/base/langflow/base/mcp/util.py @@ -1,40 +1,81 @@ import asyncio +import os from collections.abc import Awaitable, Callable +from contextlib import AsyncExitStack from typing import Any +from uuid import UUID +import httpx +from mcp import ClientSession, StdioServerParameters, stdio_client +from mcp.client.sse import sse_client from pydantic import Field, create_model +from sqlmodel import select from langflow.helpers.base_model import BaseModel +from langflow.services.database.models import Flow -def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[[dict], Awaitable]: +def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., Awaitable]: async def tool_coroutine(*args, **kwargs): - fields = arg_schema.model_fields.keys() - expected_field_count = len(fields) - if len(args) + len(kwargs) != expected_field_count: - msg = f"{expected_field_count} arguments are required. Received: {args} {kwargs}" - raise ValueError(msg) - arg_dict = dict(zip(fields, args, strict=False)) - arg_dict.update(kwargs) - return await session.call_tool(tool_name, arguments=arg_dict) + # Get field names from the model (preserving order) + field_names = list(arg_schema.__fields__.keys()) + provided_args = {} + # Map positional arguments to their corresponding field names + for i, arg in enumerate(args): + if i >= len(field_names): + msg = "Too many positional arguments provided" + raise ValueError(msg) + provided_args[field_names[i]] = arg + # Merge in keyword arguments + provided_args.update(kwargs) + # Validate input and fill defaults for missing optional fields + try: + validated = arg_schema.parse_obj(provided_args) + except Exception as e: + msg = f"Invalid input: {e}" + raise ValueError(msg) from e + return await session.call_tool(tool_name, arguments=validated.dict()) return tool_coroutine -def create_tool_func(tool_name: str, session) -> Callable[..., str]: - def tool_func(**kwargs): - if len(kwargs) == 0: - msg = f"at least one named argument is required {kwargs}" - raise ValueError(msg) +def create_tool_func(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., str]: + def tool_func(*args, **kwargs): + field_names = list(arg_schema.__fields__.keys()) + provided_args = {} + for i, arg in enumerate(args): + if i >= len(field_names): + msg = "Too many positional arguments provided" + raise ValueError(msg) + provided_args[field_names[i]] = arg + provided_args.update(kwargs) + try: + validated = arg_schema.parse_obj(provided_args) + except Exception as e: + msg = f"Invalid input: {e}" + raise ValueError(msg) from e loop = asyncio.get_event_loop() - return loop.run_until_complete(session.call_tool(tool_name, arguments=kwargs)) + return loop.run_until_complete(session.call_tool(tool_name, arguments=validated.dict())) return tool_func +async def get_flow(flow_name: str, user_id: str, session) -> Flow | None: + uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id + stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712 + flows = (await session.exec(stmt)).all() + + for flow in flows: + if flow.to_data().name == flow_name: + return flow + return None + + def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]: """Converts a JSON schema into a Pydantic model dynamically. + Fields not listed as required are wrapped in Optional[...] and default to None if not provided. + :param schema: The JSON schema as a dictionary. :return: A Pydantic model class. """ @@ -47,9 +88,9 @@ def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseMod required_fields = set(schema.get("required", [])) for field_name, field_def in properties.items(): - # Extract type - field_type_str = field_def.get("type", "str") # Default to string type if not specified - field_type = { + # Determine the base type from the JSON schema type string. + field_type_str = field_def.get("type", "str") # Defaults to string if not specified. + base_type = { "string": str, "str": str, "integer": int, @@ -60,13 +101,77 @@ def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseMod "object": dict, }.get(field_type_str, Any) - # Extract description and default if present field_metadata = {"description": field_def.get("description", "")} + + # For non-required fields, wrap the type in Optional[...] and set a default value. if field_name not in required_fields: field_metadata["default"] = field_def.get("default", None) - # Create Pydantic field - fields[field_name] = (field_type, Field(**field_metadata)) + fields[field_name] = (base_type, Field(**field_metadata)) - # Dynamically create the model return create_model("InputSchema", **fields) + + +class MCPStdioClient: + def __init__(self): + self.session: ClientSession | None = None + self.exit_stack = AsyncExitStack() + + async def connect_to_server(self, command_str: str): + command = command_str.split(" ") + server_params = StdioServerParameters( + command=command[0], + args=command[1:], + env={"DEBUG": "true", "PATH": os.environ["PATH"]}, + ) + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + self.stdio, self.write = stdio_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + await self.session.initialize() + response = await self.session.list_tools() + return response.tools + + +class MCPSseClient: + def __init__(self): + self.write = None + self.sse = None + self.session: ClientSession | None = None + self.exit_stack = AsyncExitStack() + + async def pre_check_redirect(self, url: str): + async with httpx.AsyncClient(follow_redirects=False) as client: + response = await client.request("HEAD", url) + if response.status_code == httpx.codes.TEMPORARY_REDIRECT: + return response.headers.get("Location") + return url + + async def _connect_with_timeout( + self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int + ): + sse_transport = await self.exit_stack.enter_async_context( + sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds) + ) + self.sse, self.write = sse_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write)) + await self.session.initialize() + + async def connect_to_server( + self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500 + ): + if headers is None: + headers = {} + url = await self.pre_check_redirect(url) + try: + await asyncio.wait_for( + self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds), + timeout=timeout_seconds, + ) + if self.session is None: + msg = "Session not initialized" + raise ValueError(msg) + response = await self.session.list_tools() + except asyncio.TimeoutError as err: + msg = f"Connection to {url} timed out after {timeout_seconds} seconds" + raise TimeoutError(msg) from err + return response.tools diff --git a/src/backend/base/langflow/components/deactivated/mcp_sse.py b/src/backend/base/langflow/components/deactivated/mcp_sse.py new file mode 100644 index 000000000..e3588b2f1 --- /dev/null +++ b/src/backend/base/langflow/components/deactivated/mcp_sse.py @@ -0,0 +1,61 @@ +# from langflow.field_typing import Data + +from langchain_core.tools import StructuredTool +from mcp import types + +from langflow.base.mcp.util import ( + MCPSseClient, + create_input_schema_from_json_schema, + create_tool_coroutine, + create_tool_func, +) +from langflow.custom import Component +from langflow.field_typing import Tool +from langflow.io import MessageTextInput, Output + + +class MCPSse(Component): + client = MCPSseClient() + tools = types.ListToolsResult + tool_names = [str] + display_name = "MCP Tools (SSE) [DEPRECATED]" + description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent." + documentation: str = "https://docs.langflow.org/components-custom-components" + icon = "code" + name = "MCPSse" + legacy = True + + inputs = [ + MessageTextInput( + name="url", + display_name="mcp sse url", + info="sse url", + value="http://localhost:7860/api/v1/mcp/sse", + tool_mode=True, + ), + ] + + outputs = [ + Output(display_name="Tools", name="tools", method="build_output"), + ] + + async def build_output(self) -> list[Tool]: + if self.client.session is None: + self.tools = await self.client.connect_to_server(self.url, {}) + + tool_list = [] + + for tool in self.tools: + args_schema = create_input_schema_from_json_schema(tool.inputSchema) + tool_list.append( + StructuredTool( + name=tool.name, # maybe format this + description=tool.description, + args_schema=args_schema, + func=create_tool_func(tool.name, args_schema, self.client.session), + coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), + ) + ) + + self.tool_names = [tool.name for tool in self.tools] + return tool_list diff --git a/src/backend/base/langflow/components/tools/mcp_stdio.py b/src/backend/base/langflow/components/deactivated/mcp_stdio.py similarity index 59% rename from src/backend/base/langflow/components/tools/mcp_stdio.py rename to src/backend/base/langflow/components/deactivated/mcp_stdio.py index 48535a00d..e68d47f0d 100644 --- a/src/backend/base/langflow/components/tools/mcp_stdio.py +++ b/src/backend/base/langflow/components/deactivated/mcp_stdio.py @@ -1,51 +1,31 @@ # from langflow.field_typing import Data -import os -from contextlib import AsyncExitStack from langchain_core.tools import StructuredTool -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.stdio import stdio_client +from mcp import types -from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func +from langflow.base.mcp.util import ( + MCPStdioClient, + create_input_schema_from_json_schema, + create_tool_coroutine, + create_tool_func, +) from langflow.custom import Component from langflow.field_typing import Tool from langflow.io import MessageTextInput, Output -class MCPStdioClient: - def __init__(self): - # Initialize session and client objects - self.session: ClientSession | None = None - self.exit_stack = AsyncExitStack() - - async def connect_to_server(self, command_str: str): - command = command_str.split(" ") - server_params = StdioServerParameters( - command=command[0], args=command[1:], env={"DEBUG": "true", "PATH": os.environ["PATH"]} - ) - - stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) - self.stdio, self.write = stdio_transport - self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) - - await self.session.initialize() - - # List available tools - response = await self.session.list_tools() - return response.tools - - class MCPStdio(Component): client = MCPStdioClient() tools = types.ListToolsResult tool_names = [str] - display_name = "MCP Tools (stdio)" + display_name = "MCP Tools (stdio) [DEPRECATED]" description = ( "Connects to an MCP server over stdio and exposes it's tools as langflow tools to be used by an Agent." ) documentation: str = "https://docs.langflow.org/components-custom-components" icon = "code" name = "MCPStdio" + legacy = True inputs = [ MessageTextInput( @@ -74,7 +54,7 @@ class MCPStdio(Component): name=tool.name, description=tool.description, args_schema=args_schema, - func=create_tool_func(tool.name, args_schema), + func=create_tool_func(tool.name, args_schema, self.client.session), coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), ) ) diff --git a/src/backend/base/langflow/components/tools/__init__.py b/src/backend/base/langflow/components/tools/__init__.py index 78820ed0f..e7b67cfbc 100644 --- a/src/backend/base/langflow/components/tools/__init__.py +++ b/src/backend/base/langflow/components/tools/__init__.py @@ -13,7 +13,7 @@ from .google_search_api import GoogleSearchAPIComponent from .google_search_api_core import GoogleSearchAPICore from .google_serper_api import GoogleSerperAPIComponent from .google_serper_api_core import GoogleSerperAPICore -from .mcp_stdio import MCPStdio +from .mcp_component import MCPToolsComponent from .python_code_structured_tool import PythonCodeStructuredTool from .python_repl import PythonREPLToolComponent from .python_repl_core import PythonREPLComponent @@ -51,7 +51,7 @@ __all__ = [ "GoogleSearchAPICore", "GoogleSerperAPIComponent", "GoogleSerperAPICore", - "MCPStdio", + "MCPToolsComponent", "PythonCodeStructuredTool", "PythonREPLComponent", "PythonREPLToolComponent", diff --git a/src/backend/base/langflow/components/tools/mcp_component.py b/src/backend/base/langflow/components/tools/mcp_component.py new file mode 100644 index 000000000..ef32db397 --- /dev/null +++ b/src/backend/base/langflow/components/tools/mcp_component.py @@ -0,0 +1,340 @@ +import asyncio +from typing import Any + +from langchain_core.tools import StructuredTool + +from langflow.base.mcp.util import ( + MCPSseClient, + MCPStdioClient, + create_input_schema_from_json_schema, + create_tool_coroutine, + create_tool_func, +) +from langflow.custom import Component +from langflow.inputs import DropdownInput +from langflow.inputs.inputs import InputTypes +from langflow.io import MessageTextInput, Output, TabInput +from langflow.io.schema import schema_to_langflow_inputs +from langflow.logging import logger +from langflow.schema import Message + + +class MCPToolsComponent(Component): + schema_inputs: list[InputTypes] = [] + stdio_client = MCPStdioClient() + sse_client = MCPSseClient() + tools: list = [] + tool_names: list[str] = [] + _tool_cache: dict = {} # Cache for tool objects + default_keys = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"] + + display_name = "MCP Server" + description = "Connect to an MCP server and expose tools." + icon = "server" + name = "MCPTools" + + inputs = [ + TabInput( + name="mode", + display_name="Mode", + options=["Stdio", "SSE"], + value="Stdio", + info="Select the connection mode", + real_time_refresh=True, + ), + MessageTextInput( + name="command", + display_name="MCP Command", + info="Command for MCP stdio connection", + value="uvx mcp-server-fetch", + show=True, + refresh_button=True, + ), + MessageTextInput( + name="sse_url", + display_name="MCP SSE URL", + info="URL for MCP SSE connection", + value="http://localhost:7860/api/v1/mcp/sse", + show=False, + refresh_button=True, + ), + DropdownInput( + name="tool", + display_name="Tool", + options=[], + value="", + info="Select the tool to execute", + show=True, + required=True, + real_time_refresh=True, + ), + MessageTextInput( + name="tool_placeholder", + display_name="Tool Placeholder", + info="Placeholder for the tool", + value="", + show=False, + tool_mode=True, + ), + ] + + outputs = [ + Output(display_name="Response", name="response", method="build_output"), + ] + + async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None: + """Validate connection parameters based on mode.""" + if mode not in ["Stdio", "SSE"]: + msg = f"Invalid mode: {mode}. Must be either 'Stdio' or 'SSE'" + raise ValueError(msg) + + if mode == "Stdio" and not command: + msg = "Command is required for Stdio mode" + raise ValueError(msg) + if mode == "SSE" and not url: + msg = "URL is required for SSE mode" + raise ValueError(msg) + + async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]: + """Validate and process schema inputs for a tool.""" + try: + if not tool_obj or not hasattr(tool_obj, "inputSchema"): + msg = "Invalid tool object or missing input schema" + raise ValueError(msg) + + input_schema = create_input_schema_from_json_schema(tool_obj.inputSchema) + if not input_schema: + msg = f"Empty input schema for tool '{tool_obj.name}'" + raise ValueError(msg) + + schema_inputs = schema_to_langflow_inputs(input_schema) + if not schema_inputs: + msg = f"No input parameters defined for tool '{tool_obj.name}'" + logger.warning(msg) + return [] + + except Exception as e: + msg = f"Error validating schema inputs: {e!s}" + logger.exception(msg) + raise ValueError(msg) from e + else: + return schema_inputs + + async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict: + """Toggle the visibility of connection-specific fields based on the selected mode.""" + try: + if field_name == "mode": + self.remove_non_default_keys(build_config) + if field_value == "Stdio": + build_config["command"]["show"] = True + build_config["sse_url"]["show"] = False + elif field_value == "SSE": + build_config["command"]["show"] = False + build_config["sse_url"]["show"] = True + if field_name in ("command", "sse_url", "mode"): + try: + await self.update_tools() + if "tool" in build_config: + build_config["tool"]["options"] = self.tool_names + except Exception as e: + build_config["tool"]["options"] = [] + msg = f"Failed to update tools: {e!s}" + raise ValueError(msg) from e + elif field_name == "tool": + if len(self.tools) == 0: + await self.update_tools() + if self.tool is None: + return build_config + tool_obj = None + for tool in self.tools: + if tool.name == self.tool: + tool_obj = tool + break + if tool_obj is None: + msg = f"Tool {self.tool} not found in available tools: {self.tools}" + logger.warning(msg) + return build_config + self.remove_non_default_keys(build_config) + await self._update_tool_config(build_config, field_value) + elif field_name == "tool_mode": + build_config["tool"]["show"] = not field_value + for key, value in list(build_config.items()): + if key not in self.default_keys and isinstance(value, dict) and "show" in value: + build_config[key]["show"] = not field_value + + except Exception as e: + msg = f"Error in update_build_config: {e!s}" + logger.exception(msg) + raise ValueError(msg) from e + else: + return build_config + + def get_inputs_for_all_tools(self, tools: list) -> dict: + """Get input schemas for all tools.""" + inputs = {} + for tool in tools: + if not tool or not hasattr(tool, "name"): + continue + try: + input_schema = schema_to_langflow_inputs(create_input_schema_from_json_schema(tool.inputSchema)) + inputs[tool.name] = input_schema + except (AttributeError, ValueError, TypeError, KeyError) as e: + msg = f"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}" + logger.exception(msg) + continue + return inputs + + def remove_input_schema_from_build_config( + self, build_config: dict, tool_name: str, input_schema: dict[list[InputTypes], Any] + ): + """Remove the input schema for the tool from the build config.""" + # Keep only schemas that don't belong to the current tool + input_schema = {k: v for k, v in input_schema.items() if k != tool_name} + # Remove all inputs from other tools + for value in input_schema.values(): + for _input in value: + if _input.name in build_config: + build_config.pop(_input.name) + + def remove_non_default_keys(self, build_config: dict) -> None: + """Remove non-default keys from the build config.""" + for key in list(build_config.keys()): + if key not in self.default_keys: + build_config.pop(key) + + async def _update_tool_config(self, build_config: dict, tool_name: str) -> None: + """Update tool configuration with proper error handling.""" + if not self.tools: + await self.update_tools() + + if not tool_name: + return + + tool_obj = next((tool for tool in self.tools if tool.name == tool_name), None) + if not tool_obj: + msg = f"Tool {tool_name} not found in available tools: {self.tools}" + logger.warning(msg) + return + + try: + # Get all tool inputs and remove old ones + input_schema_for_all_tools = self.get_inputs_for_all_tools(self.tools) + self.remove_input_schema_from_build_config(build_config, tool_name, input_schema_for_all_tools) + + # Get and validate new inputs + self.schema_inputs = await self._validate_schema_inputs(tool_obj) + if not self.schema_inputs: + msg = f"No input parameters to configure for tool '{tool_name}'" + logger.info(msg) + return + + # Add new inputs to build config + for schema_input in self.schema_inputs: + if not schema_input or not hasattr(schema_input, "name"): + msg = "Invalid schema input detected, skipping" + logger.warning(msg) + continue + + try: + name = schema_input.name + input_dict = schema_input.to_dict() + input_dict.setdefault("value", None) + input_dict.setdefault("required", True) + build_config[name] = input_dict + except (AttributeError, KeyError, TypeError) as e: + msg = f"Error processing schema input {schema_input}: {e!s}" + logger.exception(msg) + continue + + except ValueError as e: + msg = f"Schema validation error for tool {tool_name}: {e!s}" + logger.exception(msg) + self.schema_inputs = [] + return + except (AttributeError, KeyError, TypeError) as e: + msg = f"Error updating tool config: {e!s}" + logger.exception(msg) + raise ValueError(msg) from e + + async def build_output(self) -> Message: + """Build output with improved error handling and validation.""" + try: + await self.update_tools() + if self.tool != "": + exec_tool = self._tool_cache[self.tool] + tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool] + kwargs = {} + for arg in tool_args: + value = getattr(self, arg.name, None) + if value: + kwargs[arg.name] = value + output = await exec_tool.coroutine(**kwargs) + return Message(text=output.content[len(output.content) - 1].text) + return Message(text="You must select a tool", error=True) + except Exception as e: + msg = f"Error in build_output: {e!s}" + logger.exception(msg) + raise ValueError(msg) from e + + async def update_tools(self) -> list[StructuredTool]: + """Connect to the MCP server and update available tools with improved error handling.""" + try: + await self._validate_connection_params(self.mode, self.command, self.sse_url) + + if self.mode == "Stdio": + if not self.stdio_client.session: + self.tools = await self.stdio_client.connect_to_server(self.command) + elif self.mode == "SSE" and not self.sse_client.session: + self.tools = await self.sse_client.connect_to_server(self.sse_url, {}) + + if not self.tools: + logger.warning("No tools returned from server") + return [] + + tool_list = [] + for tool in self.tools: + if not tool or not hasattr(tool, "name"): + logger.warning("Invalid tool object detected, skipping") + continue + + try: + args_schema = create_input_schema_from_json_schema(tool.inputSchema) + if not args_schema: + msg = f"Empty schema for tool '{tool.name}', skipping" + logger.warning(msg) + continue + + client = self.stdio_client if self.mode == "Stdio" else self.sse_client + if not client or not client.session: + msg = f"Invalid client session for tool '{tool.name}'" + raise ValueError(msg) + + tool_obj = StructuredTool( + name=tool.name, + description=tool.description or "", + args_schema=args_schema, + func=create_tool_func(tool.name, args_schema, client.session), + coroutine=create_tool_coroutine(tool.name, args_schema, client.session), + tags=[tool.name], + ) + tool_list.append(tool_obj) + self._tool_cache[tool.name] = tool_obj + except (AttributeError, ValueError, TypeError, KeyError) as e: + msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e!s}" + logger.exception(msg) + continue + + self.tool_names = [tool.name for tool in self.tools if hasattr(tool, "name")] + + except (ValueError, RuntimeError, asyncio.TimeoutError) as e: + msg = f"Error updating tools: {e!s}" + logger.exception(msg) + raise ValueError(msg) from e + else: + return tool_list + + async def _get_tools(self): + """Get cached tools or update if necessary.""" + if not self.tools: + return await self.update_tools() + return self.tools diff --git a/src/backend/base/langflow/components/tools/mcp_sse.py b/src/backend/base/langflow/components/tools/mcp_sse.py deleted file mode 100644 index 5a184fe7f..000000000 --- a/src/backend/base/langflow/components/tools/mcp_sse.py +++ /dev/null @@ -1,111 +0,0 @@ -# from langflow.field_typing import Data -import asyncio -from contextlib import AsyncExitStack - -import httpx -from langchain_core.tools import StructuredTool -from mcp import ClientSession, types -from mcp.client.sse import sse_client - -from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func -from langflow.custom import Component -from langflow.field_typing import Tool -from langflow.io import MessageTextInput, Output - -# Define constant for status code -HTTP_TEMPORARY_REDIRECT = 307 - - -class MCPSseClient: - def __init__(self): - # Initialize session and client objects - self.write = None - self.sse = None - self.session: ClientSession | None = None - self.exit_stack = AsyncExitStack() - - async def pre_check_redirect(self, url: str): - """Check if the URL responds with a 307 Redirect.""" - async with httpx.AsyncClient(follow_redirects=False) as client: - response = await client.request("HEAD", url) - if response.status_code == HTTP_TEMPORARY_REDIRECT: - return response.headers.get("Location") # Return the redirect URL - return url # Return the original URL if no redirect - - async def _connect_with_timeout( - self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int - ): - """Connect to the SSE server with timeout.""" - sse_transport = await self.exit_stack.enter_async_context( - sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds) - ) - self.sse, self.write = sse_transport - self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write)) - await self.session.initialize() - - async def connect_to_server( - self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500 - ): - if headers is None: - headers = {} - url = await self.pre_check_redirect(url) - try: - await asyncio.wait_for( - self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds), - timeout=timeout_seconds, - ) - # List available tools - if self.session is None: - msg = "Session not initialized" - raise ValueError(msg) - response = await self.session.list_tools() - except asyncio.TimeoutError as err: - error_message = f"Connection to {url} timed out after {timeout_seconds} seconds" - raise TimeoutError(error_message) from err - return response.tools - - -class MCPSse(Component): - client = MCPSseClient() - tools = types.ListToolsResult - tool_names = [str] - display_name = "MCP Tools (SSE)" - description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent." - documentation: str = "https://docs.langflow.org/components-custom-components" - icon = "code" - name = "MCPSse" - - inputs = [ - MessageTextInput( - name="url", - display_name="mcp sse url", - info="sse url", - value="http://localhost:7860/api/v1/mcp/sse", - tool_mode=True, - ), - ] - - outputs = [ - Output(display_name="Tools", name="tools", method="build_output"), - ] - - async def build_output(self) -> list[Tool]: - if self.client.session is None: - self.tools = await self.client.connect_to_server(self.url, {}) - - tool_list = [] - - for tool in self.tools: - args_schema = create_input_schema_from_json_schema(tool.inputSchema) - tool_list.append( - StructuredTool( - name=tool.name, # maybe format this - description=tool.description, - args_schema=args_schema, - func=create_tool_func(tool.name, self.client.session), - coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session), - ) - ) - - self.tool_names = [tool.name for tool in self.tools] - return tool_list diff --git a/src/backend/base/langflow/io/schema.py b/src/backend/base/langflow/io/schema.py index 61f6a33a8..1915b9230 100644 --- a/src/backend/base/langflow/io/schema.py +++ b/src/backend/base/langflow/io/schema.py @@ -1,8 +1,8 @@ -from typing import TYPE_CHECKING, Literal +from typing import Literal, Union, get_args, get_origin from pydantic import BaseModel, Field, create_model -from langflow.inputs.inputs import FieldTypes +from langflow.inputs.inputs import BoolInput, DictInput, FieldTypes, FloatInput, InputTypes, IntInput, MessageTextInput from langflow.schema.dotdict import dotdict _convert_field_type_to_type: dict[FieldTypes, type] = { @@ -20,8 +20,65 @@ _convert_field_type_to_type: dict[FieldTypes, type] = { FieldTypes.TAB: str, } -if TYPE_CHECKING: - from langflow.inputs.inputs import InputTypes + +_convert_type_to_field_type = { + str: MessageTextInput, + int: IntInput, + float: FloatInput, + bool: BoolInput, + dict: DictInput, + list: MessageTextInput, +} + + +def schema_to_langflow_inputs(schema: type[BaseModel]) -> list["InputTypes"]: + """Given a Pydantic schema, convert its fields to Langflow input definitions.""" + inputs = [] + for field_name, model_field in schema.model_fields.items(): + # Start with the field's annotation type + field_type = model_field.annotation + is_list = False + options = None + + # If the field is a list, record that and extract its inner type. + if get_origin(field_type) is list: + is_list = True + field_type = get_args(field_type)[0] + + # If the field type is a Literal, extract its allowed values. + if get_origin(field_type) is Literal: + options = list(get_args(field_type)) + # Optionally, set field_type to the type of the literal values. + if options: + field_type = type(options[0]) + + # Handle Union types (e.g., Optional fields) + if get_origin(field_type) is Union: + # Get the first non-None type from the Union + field_type = next(t for t in get_args(field_type) if t is not type(None)) + + # Convert the Python type to the Langflow field type using our reverse mapping. + try: + langflow_field_type = _convert_type_to_field_type[field_type] + except KeyError as e: + msg = f"Unsupported field type: {field_type}" + raise TypeError(msg) from e + + # Get metadata from the Pydantic Field. + title = model_field.title or field_name.replace("_", " ").title() + description = model_field.description or "" + required = model_field.is_required() + + # Construct the Langflow input. + input_obj = langflow_field_type( + display_name=title, + name=field_name, + info=description, + required=required, + is_list=is_list, + ) + inputs.append(input_obj) + return inputs def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]: diff --git a/src/backend/tests/integration/components/mcp/__init__.py b/src/backend/tests/integration/components/mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/tests/integration/components/mcp/test_mcp_component.py b/src/backend/tests/integration/components/mcp/test_mcp_component.py new file mode 100644 index 000000000..0e2338897 --- /dev/null +++ b/src/backend/tests/integration/components/mcp/test_mcp_component.py @@ -0,0 +1,11 @@ +from tests.integration.utils import run_single_component + + +async def test_mcp_component(): + from langflow.components.tools.mcp_component import MCPToolsComponent + + inputs = {} + await run_single_component( + MCPToolsComponent, + inputs=inputs, # test default inputs + ) diff --git a/src/backend/tests/unit/components/tools/test_mcp_component.py b/src/backend/tests/unit/components/tools/test_mcp_component.py new file mode 100644 index 000000000..3ec0c0f37 --- /dev/null +++ b/src/backend/tests/unit/components/tools/test_mcp_component.py @@ -0,0 +1,255 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langflow.components.tools.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent + +from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping + + +class TestMCPToolsComponent(ComponentTestBaseWithoutClient): + @pytest.fixture + def component_class(self): + """Return the component class to test.""" + return MCPToolsComponent + + @pytest.fixture + def default_kwargs(self): + """Return the default kwargs for the component.""" + return { + "mode": "Stdio", + "command": "uvx mcp-server-fetch", + "sse_url": "http://localhost:7860/api/v1/mcp/sse", + "tool": "", + } + + @pytest.fixture + def file_names_mapping(self) -> list[VersionComponentMapping]: + """Return the file names mapping for different versions.""" + return [] + + @pytest.fixture + def mock_tool(self): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + tool.description = "Test tool description" + tool.inputSchema = { + "type": "object", + "properties": {"test_param": {"type": "string", "description": "Test parameter"}}, + } + return tool + + @pytest.fixture + def mock_stdio_client(self, mock_tool): + """Create a mock stdio client.""" + stdio_client = AsyncMock() + stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool]) + stdio_client.session = AsyncMock() + return stdio_client + + @pytest.fixture + def mock_sse_client(self, mock_tool): + """Create a mock SSE client.""" + sse_client = AsyncMock() + sse_client.connect_to_server = AsyncMock(return_value=[mock_tool]) + sse_client.session = AsyncMock() + return sse_client + + async def test_validate_connection_params_invalid_mode(self, component_class, default_kwargs): + """Test validation with invalid mode.""" + component = component_class(**default_kwargs) + with pytest.raises(ValueError, match="Invalid mode: invalid. Must be either 'Stdio' or 'SSE'"): + await component._validate_connection_params("invalid") + + async def test_validate_connection_params_missing_command(self, component_class, default_kwargs): + """Test validation with missing command in Stdio mode.""" + component = component_class(**default_kwargs) + with pytest.raises(ValueError, match="Command is required for Stdio mode"): + await component._validate_connection_params("Stdio", command=None) + + async def test_validate_connection_params_missing_url(self, component_class, default_kwargs): + """Test validation with missing URL in SSE mode.""" + component = component_class(**default_kwargs) + with pytest.raises(ValueError, match="URL is required for SSE mode"): + await component._validate_connection_params("SSE", url=None) + + async def test_update_build_config_mode_change(self, component_class, default_kwargs): + """Test build config updates when mode changes.""" + component = component_class(**default_kwargs) + build_config = { + "command": {"show": False}, + "sse_url": {"show": True}, + "tool": {"options": [], "show": True}, # Add tool field since component uses it + } + + # Test switching to Stdio mode + updated_config = await component.update_build_config(build_config, "Stdio", "mode") + assert updated_config["command"]["show"] is True + assert updated_config["sse_url"]["show"] is False + + # Test switching to SSE mode + updated_config = await component.update_build_config(build_config, "SSE", "mode") + assert updated_config["command"]["show"] is False + assert updated_config["sse_url"]["show"] is True + + # Test tool options are updated + assert "options" in updated_config["tool"] + + @patch("langflow.components.tools.mcp_component.create_tool_coroutine") + async def test_build_output(self, mock_create_coroutine, component_class, default_kwargs, mock_tool): + """Test building output with a tool.""" + component = component_class(**default_kwargs) + component.tool = "test_tool" + component.tools = [mock_tool] + + # Mock the coroutine response + mock_response = AsyncMock() + mock_response.content = [MagicMock(text="Test response")] + mock_create_coroutine.return_value = AsyncMock(return_value=mock_response) + + # Create a mock tool and add it to the cache + mock_structured_tool = MagicMock() + mock_structured_tool.coroutine = mock_create_coroutine.return_value + component._tool_cache = {"test_tool": mock_structured_tool} + + # Set the test parameter value + component.test_param = "test value" + + # Mock get_inputs_for_all_tools to return our mock input + mock_input = MagicMock() + mock_input.name = "test_param" + with patch.object(component, "get_inputs_for_all_tools") as mock_get_inputs: + mock_get_inputs.return_value = {"test_tool": [mock_input]} + output = await component.build_output() + + assert output.text == "Test response" + # Verify the mocks were called correctly + mock_get_inputs.assert_called_once_with(component.tools) + mock_structured_tool.coroutine.assert_called_once_with(test_param="test value") + + async def test_get_inputs_for_all_tools(self, component_class, default_kwargs, mock_tool): + """Test getting input schemas for all tools.""" + component = component_class(**default_kwargs) + inputs = component.get_inputs_for_all_tools([mock_tool]) + + assert "test_tool" in inputs + assert len(inputs["test_tool"]) > 0 # Should have at least one input parameter + + async def test_remove_non_default_keys(self, component_class, default_kwargs): + """Test removing non-default keys from build config.""" + component = component_class(**default_kwargs) + build_config = {"code": {}, "mode": {}, "command": {}, "custom_key": {}} + + component.remove_non_default_keys(build_config) + assert "custom_key" not in build_config + assert all(key in build_config for key in ["code", "mode", "command"]) + + +class TestMCPStdioClient: + @pytest.fixture + def stdio_client(self): + return MCPStdioClient() + + async def test_connect_to_server(self, stdio_client): + """Test connecting to server via Stdio.""" + # Create mock for stdio transport + mock_stdio = AsyncMock() + mock_write = AsyncMock() + mock_stdio_transport = (mock_stdio, mock_write) + mock_stdio_cm = AsyncMock() + mock_stdio_cm.__aenter__.return_value = mock_stdio_transport + + # Mock the stdio_client function to return our mock context manager + with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm): + # Mock ClientSession + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools.return_value.tools = [MagicMock()] + + # Mock the AsyncExitStack + mock_exit_stack = AsyncMock() + mock_exit_stack.enter_async_context = AsyncMock() + mock_exit_stack.enter_async_context.side_effect = [ + mock_stdio_transport, # For stdio_client + mock_session, # For ClientSession + ] + stdio_client.exit_stack = mock_exit_stack + + tools = await stdio_client.connect_to_server("test_command") + + assert len(tools) == 1 + assert stdio_client.session is not None + # Verify the exit stack was used correctly + assert mock_exit_stack.enter_async_context.call_count == 2 + # Verify the stdio transport was properly set + assert stdio_client.stdio == mock_stdio + assert stdio_client.write == mock_write + + +class TestMCPSseClient: + @pytest.fixture + def sse_client(self): + return MCPSseClient() + + async def test_pre_check_redirect(self, sse_client): + """Test pre-checking URL for redirects.""" + test_url = "http://test.url" + redirect_url = "http://redirect.url" + + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 307 + mock_response.headers.get.return_value = redirect_url + mock_client.return_value.__aenter__.return_value.request.return_value = mock_response + + result = await sse_client.pre_check_redirect(test_url) + assert result == redirect_url + + async def test_connect_to_server(self, sse_client): + """Test connecting to server via SSE.""" + # Mock the pre_check_redirect first + with patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"): + # Create mock for sse_client context manager + mock_sse = AsyncMock() + mock_write = AsyncMock() + mock_sse_transport = (mock_sse, mock_write) + mock_sse_cm = AsyncMock() + mock_sse_cm.__aenter__.return_value = mock_sse_transport + + # Mock the sse_client function to return our mock context manager + with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm): + # Mock ClientSession + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools.return_value.tools = [MagicMock()] + + # Mock the AsyncExitStack + mock_exit_stack = AsyncMock() + mock_exit_stack.enter_async_context = AsyncMock() + mock_exit_stack.enter_async_context.side_effect = [ + mock_sse_transport, # For sse_client + mock_session, # For ClientSession + ] + sse_client.exit_stack = mock_exit_stack + + tools = await sse_client.connect_to_server("http://test.url", {}) + + assert len(tools) == 1 + assert sse_client.session is not None + # Verify the exit stack was used correctly + assert mock_exit_stack.enter_async_context.call_count == 2 + # Verify the SSE transport was properly set + assert sse_client.sse == mock_sse + assert sse_client.write == mock_write + + async def test_connect_timeout(self, sse_client): + """Test connection timeout handling.""" + with ( + patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"), + patch.object(sse_client, "_connect_with_timeout") as mock_connect, + ): + mock_connect.side_effect = asyncio.TimeoutError() + + with pytest.raises(TimeoutError, match="Connection to http://test.url timed out after 1 seconds"): + await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1) diff --git a/src/backend/tests/unit/test_schema.py b/src/backend/tests/unit/test_schema.py index 97fd930e3..72c8b4571 100644 --- a/src/backend/tests/unit/test_schema.py +++ b/src/backend/tests/unit/test_schema.py @@ -3,11 +3,13 @@ from types import NoneType from typing import Union import pytest +from langflow.inputs.inputs import BoolInput, DictInput, FloatInput, InputTypes, IntInput, MessageTextInput +from langflow.io.schema import schema_to_langflow_inputs from langflow.schema.data import Data from langflow.template import Input, Output from langflow.template.field.base import UNDEFINED from langflow.type_extraction.type_extraction import post_process_type -from pydantic import ValidationError +from pydantic import BaseModel, Field, ValidationError class TestInput: @@ -178,3 +180,65 @@ class TestPostProcessType: pass assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007 + + +def test_schema_to_langflow_inputs(): + # Define a test Pydantic model with various field types + class TestSchema(BaseModel): + text_field: str = Field(title="Custom Text Title", description="A text field") + number_field: int = Field(description="A number field") + bool_field: bool = Field(description="A boolean field") + dict_field: dict = Field(description="A dictionary field") + list_field: list[str] = Field(description="A list of strings") + + # Convert schema to Langflow inputs + inputs = schema_to_langflow_inputs(TestSchema) + + # Verify the number of inputs matches the schema fields + assert len(inputs) == 5 + + # Helper function to find input by name + def find_input(name: str) -> InputTypes | None: + for _input in inputs: + if _input.name == name: + return _input + return None + + # Test text field + text_input = find_input("text_field") + assert text_input.display_name == "Custom Text Title" + assert text_input.info == "A text field" + assert isinstance(text_input, MessageTextInput) # Check the instance type instead of field_type + + # Test number field + number_input = find_input("number_field") + assert number_input.display_name == "Number Field" + assert number_input.info == "A number field" + assert isinstance(number_input, IntInput | FloatInput) + + # Test boolean field + bool_input = find_input("bool_field") + assert isinstance(bool_input, BoolInput) + + # Test dictionary field + dict_input = find_input("dict_field") + assert isinstance(dict_input, DictInput) + + # Test list field + list_input = find_input("list_field") + assert list_input.is_list is True + assert isinstance(list_input, MessageTextInput) + + +def test_schema_to_langflow_inputs_invalid_type(): + # Define a schema with an unsupported type + class CustomType: + pass + + class InvalidSchema(BaseModel): + model_config = {"arbitrary_types_allowed": True} # Add this line + invalid_field: CustomType + + # Test that attempting to convert an unsupported type raises TypeError + with pytest.raises(TypeError, match="Unsupported field type:"): + schema_to_langflow_inputs(InvalidSchema) diff --git a/src/frontend/package-lock.json b/src/frontend/package-lock.json index 1ac06ef0a..f3233d206 100644 --- a/src/frontend/package-lock.json +++ b/src/frontend/package-lock.json @@ -706,6 +706,7 @@ }, "node_modules/@clack/prompts/node_modules/is-unicode-supported": { "version": "1.3.0", + "extraneous": true, "inBundle": true, "license": "MIT", "engines": { diff --git a/src/frontend/src/components/common/genericIconComponent/index.tsx b/src/frontend/src/components/common/genericIconComponent/index.tsx index 9b40cb6ff..5f55df1a7 100644 --- a/src/frontend/src/components/common/genericIconComponent/index.tsx +++ b/src/frontend/src/components/common/genericIconComponent/index.tsx @@ -37,7 +37,7 @@ export const ForwardedIconComponent = memo( nodeIconsLucide[ name ?.split("-") - ?.map((x) => String(x[0]).toUpperCase() + String(x).slice(1)) + ?.map((x) => String(x[0])?.toUpperCase() + String(x).slice(1)) ?.join("") ]; if (!TargetIcon) { diff --git a/src/frontend/src/components/common/renderIconComponent/components/renderKey/index.tsx b/src/frontend/src/components/common/renderIconComponent/components/renderKey/index.tsx index 4793596c2..9f5f27719 100644 --- a/src/frontend/src/components/common/renderIconComponent/components/renderKey/index.tsx +++ b/src/frontend/src/components/common/renderIconComponent/components/renderKey/index.tsx @@ -32,7 +32,7 @@ export default function RenderKey({ className={cn(tableRender ? "h-4 w-4" : "h-3 w-3")} /> ) : ( - {value.toUpperCase()} + {value?.toUpperCase()} )} ); diff --git a/src/frontend/src/components/ui/button.tsx b/src/frontend/src/components/ui/button.tsx index 7c51e5a7e..0ae72f3c0 100644 --- a/src/frontend/src/components/ui/button.tsx +++ b/src/frontend/src/components/ui/button.tsx @@ -61,9 +61,11 @@ export interface ButtonProps function toTitleCase(text: string) { return text - .split(" ") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase()) - .join(" "); + ?.split(" ") + ?.map( + (word) => word?.charAt(0)?.toUpperCase() + word?.slice(1)?.toLowerCase(), + ) + ?.join(" "); } const Button = React.forwardRef( diff --git a/src/frontend/src/constants/constants.ts b/src/frontend/src/constants/constants.ts index d5c708504..04b91654d 100644 --- a/src/frontend/src/constants/constants.ts +++ b/src/frontend/src/constants/constants.ts @@ -1022,3 +1022,18 @@ export const IS_AUTO_LOGIN = export const AUTO_LOGIN_RETRY_DELAY = 2000; export const AUTO_LOGIN_MAX_RETRY_DELAY = 60000; + +export const ALL_LANGUAGES = [ + { value: "en-US", name: "English (US)" }, + { value: "en-GB", name: "English (UK)" }, + { value: "it-IT", name: "Italian" }, + { value: "fr-FR", name: "French" }, + { value: "es-ES", name: "Spanish" }, + { value: "de-DE", name: "German" }, + { value: "ja-JP", name: "Japanese" }, + { value: "pt-BR", name: "Portuguese (Brazil)" }, + { value: "zh-CN", name: "Chinese (Simplified)" }, + { value: "ru-RU", name: "Russian" }, + { value: "ar-SA", name: "Arabic" }, + { value: "hi-IN", name: "Hindi" }, +]; diff --git a/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/audio-settings-dialog.tsx b/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/audio-settings-dialog.tsx index a1dbc7ab7..11b3b4914 100644 --- a/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/audio-settings-dialog.tsx +++ b/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/audio-settings-dialog.tsx @@ -21,21 +21,6 @@ import LanguageSelect from "./components/language-select"; import MicrophoneSelect from "./components/microphone-select"; import VoiceSelect from "./components/voice-select"; -const ALL_LANGUAGES = [ - { value: "en-US", name: "English (US)" }, - { value: "en-GB", name: "English (UK)" }, - { value: "it-IT", name: "Italian" }, - { value: "fr-FR", name: "French" }, - { value: "es-ES", name: "Spanish" }, - { value: "de-DE", name: "German" }, - { value: "ja-JP", name: "Japanese" }, - { value: "pt-BR", name: "Portuguese (Brazil)" }, - { value: "zh-CN", name: "Chinese (Simplified)" }, - { value: "ru-RU", name: "Russian" }, - { value: "ar-SA", name: "Arabic" }, - { value: "hi-IN", name: "Hindi" }, -]; - interface SettingsVoiceModalProps { children?: React.ReactNode; userOpenaiApiKey?: string; @@ -415,7 +400,6 @@ const SettingsVoiceModal = ({ )} diff --git a/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/components/language-select.tsx b/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/components/language-select.tsx index c5ce1ec22..8658004cf 100644 --- a/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/components/language-select.tsx +++ b/src/frontend/src/modals/IOModal/components/chatView/chatInput/components/voice-assistant/components/audio-settings/components/language-select.tsx @@ -1,3 +1,4 @@ +import { ALL_LANGUAGES } from "@/constants/constants"; import IconComponent from "../../../../../../../../../../components/common/genericIconComponent"; import ShadTooltip from "../../../../../../../../../../components/common/shadTooltipComponent"; import { @@ -12,13 +13,11 @@ import { interface LanguageSelectProps { language: string; handleSetLanguage: (value: string) => void; - allLanguages: { value: string; name: string }[]; } const LanguageSelect = ({ language, handleSetLanguage, - allLanguages, }: LanguageSelectProps) => { return (
@@ -41,7 +40,7 @@ const LanguageSelect = ({ - {allLanguages.map((lang) => ( + {ALL_LANGUAGES?.map((lang) => (
{lang?.name} diff --git a/src/frontend/src/utils/utils.ts b/src/frontend/src/utils/utils.ts index 95b1897ca..3f90be599 100644 --- a/src/frontend/src/utils/utils.ts +++ b/src/frontend/src/utils/utils.ts @@ -47,9 +47,9 @@ export function toNormalCase(str: string): string { .split("_") .map((word, index) => { if (index === 0) { - return word[0].toUpperCase() + word.slice(1).toLowerCase(); + return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(); } - return word.toLowerCase(); + return word?.toLowerCase(); }) .join(" "); @@ -57,9 +57,9 @@ export function toNormalCase(str: string): string { .split("-") .map((word, index) => { if (index === 0) { - return word[0].toUpperCase() + word.slice(1).toLowerCase(); + return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(); } - return word.toLowerCase(); + return word?.toLowerCase(); }) .join(" "); } @@ -69,11 +69,11 @@ export function normalCaseToSnakeCase(str: string): string { .split(" ") .map((word, index) => { if (index === 0) { - return word[0].toUpperCase() + word.slice(1).toLowerCase(); + return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(); } - return word.toLowerCase(); + return word?.toLowerCase(); }) - .join("_"); + ?.join("_"); } export function toTitleCase( @@ -82,41 +82,41 @@ export function toTitleCase( ): string { if (!str) return ""; let result = str - .split("_") - .map((word, index) => { + ?.split("_") + ?.map((word, index) => { if (isNodeField) return word; if (index === 0) { return checkUpperWords( - word[0].toUpperCase() + word.slice(1).toLowerCase(), + word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(), ); } - return checkUpperWords(word.toLowerCase()); + return checkUpperWords(word?.toLowerCase()); }) .join(" "); return result - .split("-") - .map((word, index) => { + ?.split("-") + ?.map((word, index) => { if (isNodeField) return word; if (index === 0) { return checkUpperWords( - word[0].toUpperCase() + word.slice(1).toLowerCase(), + word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(), ); } - return checkUpperWords(word.toLowerCase()); + return checkUpperWords(word?.toLowerCase()); }) - .join(" "); + ?.join(" "); } export const upperCaseWords: string[] = ["llm", "uri"]; export function checkUpperWords(str: string): string { - const words = str.split(" ").map((word) => { - return upperCaseWords.includes(word.toLowerCase()) - ? word.toUpperCase() - : word[0].toUpperCase() + word.slice(1).toLowerCase(); + const words = str?.split(" ")?.map((word) => { + return upperCaseWords.includes(word?.toLowerCase()) + ? word?.toUpperCase() + : word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(); }); - return words.join(" "); + return words?.join(" "); } export function buildInputs(): string { diff --git a/src/frontend/tests/core/features/auto-login-off.spec.ts b/src/frontend/tests/core/features/auto-login-off.spec.ts index cc6b94203..bedf0c12f 100644 --- a/src/frontend/tests/core/features/auto-login-off.spec.ts +++ b/src/frontend/tests/core/features/auto-login-off.spec.ts @@ -215,6 +215,8 @@ test( ).isVisible(), ); + await page.waitForTimeout(2000); + await awaitBootstrapTest(page, { skipGoto: true }); await page.getByTestId("side_nav_options_all-templates").click(); diff --git a/src/frontend/tests/core/features/filterEdge-shard-0.spec.ts b/src/frontend/tests/core/features/filterEdge-shard-0.spec.ts index 27acb1338..94c578dff 100644 --- a/src/frontend/tests/core/features/filterEdge-shard-0.spec.ts +++ b/src/frontend/tests/core/features/filterEdge-shard-0.spec.ts @@ -73,9 +73,11 @@ test( } } - await page.waitForTimeout(1000); + await page.waitForTimeout(500); await visibleElementHandle.hover().then(async () => { + await page.waitForTimeout(1000); + await expect( page.getByText("Drag to connect compatible outputs").first(), ).toBeVisible(); @@ -105,7 +107,11 @@ test( } } + await page.waitForTimeout(500); + await visibleElementHandle.hover().then(async () => { + await page.waitForTimeout(1000); + await expect( page.getByText("Drag to connect compatible outputs").first(), ).toBeVisible(); diff --git a/src/frontend/tests/core/features/playground.spec.ts b/src/frontend/tests/core/features/playground.spec.ts index db723bd3d..c1a9e1012 100644 --- a/src/frontend/tests/core/features/playground.spec.ts +++ b/src/frontend/tests/core/features/playground.spec.ts @@ -164,13 +164,21 @@ test( await page.getByTestId("chat-message-User-session_after_delete").click(); await expect(page.getByTestId("session-selector")).toBeVisible(); + await page.waitForTimeout(500); + // check helpful button await page.getByTestId("chat-message-AI-session_after_delete").hover(); await page.getByTestId("helpful-button").click(); + + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({ timeout: 10000, }); + + await page.waitForTimeout(500); + await page.getByTestId("helpful-button").click(); await page.getByTestId("chat-message-AI-session_after_delete").hover(); await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({ @@ -178,26 +186,38 @@ test( visible: false, }); // check not helpful button + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await page.getByTestId("not-helpful-button").click(); + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({ timeout: 10000, }); await page.getByTestId("not-helpful-button").click(); + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({ timeout: 10000, visible: false, }); // check switch feedback + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await page.getByTestId("helpful-button").click(); + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({ timeout: 10000, }); await page.getByTestId("not-helpful-button").click(); + await page.waitForTimeout(500); + await page.getByTestId("chat-message-AI-session_after_delete").hover(); await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({ timeout: 10000, diff --git a/src/frontend/tests/extended/features/edit-tools.spec.ts b/src/frontend/tests/extended/features/edit-tools.spec.ts new file mode 100644 index 000000000..d137bcceb --- /dev/null +++ b/src/frontend/tests/extended/features/edit-tools.spec.ts @@ -0,0 +1,76 @@ +import { expect, test } from "@playwright/test"; +import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; +test( + "user should be able to edit tools", + { tag: ["@release"] }, + async ({ page }) => { + await awaitBootstrapTest(page); + + await page.getByTestId("blank-flow").click(); + + await page.getByTestId("sidebar-search-input").click(); + await page.getByTestId("sidebar-search-input").fill("api request"); + + await page.waitForSelector('[data-testid="dataAPI Request"]', { + timeout: 3000, + }); + + await page + .getByTestId("dataAPI Request") + .hover() + .then(async () => { + await page.getByTestId("add-component-button-api-request").click(); + }); + + await page.waitForSelector( + '[data-testid="generic-node-title-arrangement"]', + { + timeout: 3000, + }, + ); + + await page.getByTestId("generic-node-title-arrangement").click(); + + await page.waitForTimeout(500); + + await page.getByTestId("tool-mode-button").click(); + + await page.locator('[data-testid="icon-Hammer"]').nth(1).waitFor({ + timeout: 3000, + state: "visible", + }); + + await page.getByTestId("icon-Hammer").nth(1).click(); + + await page.waitForSelector("text=edit tools", { timeout: 30000 }); + + const rowsCount = await page.getByRole("gridcell").count(); + + expect(rowsCount).toBeGreaterThan(3); + + expect(await page.getByRole("switch").nth(0).isChecked()).toBe(true); + + await page.getByRole("switch").nth(0).click(); + + expect(await page.getByRole("switch").nth(0).isChecked()).toBe(false); + + await page.getByText("Save").last().click(); + + await page.waitForSelector( + '[data-testid="generic-node-title-arrangement"]', + { + timeout: 3000, + }, + ); + + await page.waitForTimeout(500); + + await page.getByTestId("icon-Hammer").nth(1).click(); + + await page.waitForSelector("text=edit tools", { timeout: 30000 }); + + await page.waitForTimeout(500); + + expect(await page.getByRole("switch").nth(0).isChecked()).toBe(false); + }, +);