chore: merge mcp components (#7167)

* take1

* depreacate stdio and sse mcp components

* optionals

* rodrigo fixes

* session management

* update init

* mcp component integration test

* broken

* [autofix.ci] apply automated fixes

* fix url input name

* upated MCP

* Update mcp_component.py

* [autofix.ci] apply automated fixes

* update to the MCP component

* [autofix.ci] apply automated fixes

* mostly working

* [autofix.ci] apply automated fixes

* Update mcp_component.py

* [autofix.ci] apply automated fixes

* update component

* [autofix.ci] apply automated fixes

* Update mcp_component.py

* rename component because Simon

* icon and description for simon

* fix integration test

* fix test

* Update mcp_component.py

* update and basic QoL

* [autofix.ci] apply automated fixes

* refactor clients to util and use flow names not IDs in mcp.py

* integration test

* take out traces

*  (edit-tools.spec.ts): add test for user to be able to edit tools in the frontend application.

* session fix

* fix content output

* ♻️ (util.py): remove redundant constant HTTP_TEMPORARY_REDIRECT and replace its usage with httpx.codes.TEMPORARY_REDIRECT for better code readability and maintainability

* [autofix.ci] apply automated fixes

* 🐛 (utils.ts): fix potential null pointer error when converting words to title case by adding null check before accessing properties

* 🐛 (genericIconComponent/index.tsx): Fix issue with optional chaining in mapping function
🐛 (renderIconComponent/index.tsx): Fix issue with optional chaining in mapping function
🐛 (button.tsx): Fix issue with optional chaining in mapping function
🐛 (utils.ts): Fix issue with optional chaining in mapping functions

* 🐛 (language-select.tsx): Fix potential null pointer error when mapping over allLanguages array

*  (constants.ts): add support for multiple languages in the application by defining an array of language options
♻️ (audio-settings-dialog.tsx, language-select.tsx): refactor to import the array of all languages from constants.ts instead of duplicating it in each file

*  (auto-login-off.spec.ts): add a 2-second delay before continuing the test to ensure proper loading and rendering of elements on the page

* ⬆️ (filterEdge-shard-0.spec.ts): reduce wait time for page interactions to improve test performance
⬆️ (playground.spec.ts): optimize wait times for page interactions to enhance test efficiency

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
Sebastián Estévez 2025-03-21 18:11:01 -04:00 committed by GitHub
commit 59b2ed7765
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1200 additions and 331 deletions

View file

@ -2,11 +2,12 @@ import asyncio
import base64
import json
import logging
import traceback
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Annotated
from functools import wraps
from typing import Annotated, Any, ParamSpec, TypeVar
from urllib.parse import quote, unquote, urlparse
from uuid import UUID, uuid4
from uuid import uuid4
import pydantic
from anyio import BrokenResourceError
@ -20,34 +21,43 @@ from starlette.background import BackgroundTasks
from langflow.api.v1.chat import build_flow_and_stream
from langflow.api.v1.schemas import InputValueRequest
from langflow.base.mcp.util import get_flow
from langflow.helpers.flow import json_schema_from_flow
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models import Flow, User
from langflow.services.deps import (
get_db_service,
get_session,
get_settings_service,
get_storage_service,
session_scope,
)
from langflow.services.storage.utils import build_content_type_from_extension
logger = logging.getLogger(__name__)
if False:
logger.setLevel(logging.DEBUG)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
# Enable debug logging for MCP package
mcp_logger = logging.getLogger("mcp")
mcp_logger.setLevel(logging.DEBUG)
if not mcp_logger.handlers:
mcp_logger.addHandler(handler)
T = TypeVar("T")
P = ParamSpec("P")
logger.debug("MCP module loaded - debug logging enabled")
def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
"""Decorator to handle MCP endpoint errors consistently."""
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return await func(*args, **kwargs)
except Exception as e:
msg = f"Error in {func.__name__}: {e!s}"
logger.exception(msg)
raise
return wrapper
async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T:
"""Execute an operation within a database session context."""
async with session_scope() as session:
return await operation(session)
class MCPConfig:
@ -88,7 +98,7 @@ async def handle_list_prompts():
async def handle_list_resources():
resources = []
try:
session = await anext(get_session())
db_service = get_db_service()
storage_service = get_storage_service()
settings_service = get_settings_service()
@ -98,31 +108,30 @@ async def handle_list_resources():
base_url = f"http://{host}:{port}".rstrip("/")
flows = (await session.exec(select(Flow))).all()
async with db_service.with_session() as session:
flows = (await session.exec(select(Flow))).all()
for flow in flows:
if flow.id:
try:
files = await storage_service.list_files(flow_id=str(flow.id))
for file_name in files:
# URL encode the filename
safe_filename = quote(file_name)
resource = types.Resource(
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
name=file_name,
description=f"File in flow: {flow.name}",
mimeType=build_content_type_from_extension(file_name),
)
resources.append(resource)
except FileNotFoundError as e:
msg = f"Error listing files for flow {flow.id}: {e}"
logger.debug(msg)
continue
for flow in flows:
if flow.id:
try:
files = await storage_service.list_files(flow_id=str(flow.id))
for file_name in files:
# URL encode the filename
safe_filename = quote(file_name)
resource = types.Resource(
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
name=file_name,
description=f"File in flow: {flow.name}",
mimeType=build_content_type_from_extension(file_name),
)
resources.append(resource)
except FileNotFoundError as e:
msg = f"Error listing files for flow {flow.id}: {e}"
logger.debug(msg)
continue
except Exception as e:
msg = f"Error in listing resources: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
return resources
@ -162,8 +171,6 @@ async def handle_read_resource(uri: str) -> bytes:
except Exception as e:
msg = f"Error reading resource {uri}: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
@ -171,47 +178,48 @@ async def handle_read_resource(uri: str) -> bytes:
async def handle_list_tools():
tools = []
try:
session = await anext(get_session())
flows = (await session.exec(select(Flow))).all()
db_service = get_db_service()
async with db_service.with_session() as session:
flows = (await session.exec(select(Flow))).all()
for flow in flows:
if flow.user_id is None:
continue
for flow in flows:
if flow.user_id is None:
continue
tool = types.Tool(
name=str(flow.id), # Use flow.id instead of name
description=f"{flow.name}: {flow.description}"
if flow.description
else f"Tool generated from flow: {flow.name}",
inputSchema=json_schema_from_flow(flow),
)
tools.append(tool)
tool = types.Tool(
name=flow.name,
description=f"{flow.id}: {flow.description}"
if flow.description
else f"Tool generated from flow: {flow.name}",
inputSchema=json_schema_from_flow(flow),
)
tools.append(tool)
except Exception as e:
msg = f"Error in listing tools: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
return tools
@server.call_tool()
@handle_mcp_errors
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
"""Handle tool execution requests."""
mcp_config = get_mcp_config()
if mcp_config.enable_progress_notifications is None:
settings_service = get_settings_service()
mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications
try:
session = await anext(get_session())
background_tasks = BackgroundTasks()
current_user = current_user_ctx.get()
flow = (await session.exec(select(Flow).where(Flow.id == UUID(name)))).first()
background_tasks = BackgroundTasks()
current_user = current_user_ctx.get()
async def execute_tool(session):
# get flow id from name
flow = await get_flow(name, current_user.id, session)
if not flow:
msg = f"Flow with id '{name}' not found"
msg = f"Flow with name '{name}' not found"
raise ValueError(msg)
flow_id = flow.id
# Process inputs
processed_inputs = dict(arguments)
@ -240,70 +248,66 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
progress += 0.1
await asyncio.sleep(1.0)
except asyncio.CancelledError:
# Send final 100% progress
if mcp_config.enable_progress_notifications:
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=1.0, total=1.0
)
raise
db_service = get_db_service()
collected_results = []
async with db_service.with_session():
try:
progress_task = asyncio.create_task(send_progress_updates())
try:
progress_task = asyncio.create_task(send_progress_updates())
response = await build_flow_and_stream(
flow_id=flow_id,
inputs=input_request,
background_tasks=background_tasks,
current_user=current_user,
)
try:
response = await build_flow_and_stream(
flow_id=UUID(name),
inputs=input_request,
background_tasks=background_tasks,
current_user=current_user,
)
async for line in response.body_iterator:
if not line:
continue
try:
event_data = json.loads(line)
if event_data.get("event") == "end_vertex":
message = (
event_data.get("data", {})
.get("build_data", {})
.get("data", {})
.get("results", {})
.get("message", {})
.get("text", "")
)
if message:
collected_results.append(types.TextContent(type="text", text=str(message)))
except json.JSONDecodeError:
msg = f"Failed to parse event data: {line}"
logger.warning(msg)
continue
async for line in response.body_iterator:
if not line:
continue
try:
event_data = json.loads(line)
if event_data.get("event") == "end_vertex":
message = (
event_data.get("data", {})
.get("build_data", {})
.get("data", {})
.get("results", {})
.get("message", {})
.get("text", "")
)
if message:
collected_results.append(types.TextContent(type="text", text=str(message)))
except json.JSONDecodeError:
msg = f"Failed to parse event data: {line}"
logger.warning(msg)
continue
return collected_results
finally:
progress_task.cancel()
await asyncio.wait([progress_task])
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
raise exc
return collected_results
finally:
progress_task.cancel()
await asyncio.wait([progress_task])
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
raise exc
except Exception as e:
msg = f"Error in async session: {e}"
logger.exception(msg)
raise
except Exception:
if mcp_config.enable_progress_notifications and (
progress_token := server.request_context.meta.progressToken
):
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=1.0, total=1.0
)
raise
try:
return await with_db_session(execute_tool)
except Exception as e:
context = server.request_context
# Send error progress if there's an exception
if mcp_config.enable_progress_notifications and (progress_token := context.meta.progressToken):
await server.request_context.session.send_progress_notification(
progress_token=progress_token, progress=1.0, total=1.0
)
msg = f"Error executing tool {name}: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
@ -357,8 +361,6 @@ async def handle_sse(request: Request, current_user: Annotated[User, Depends(get
except Exception as e:
msg = f"Error in MCP: {e!s}"
logger.exception(msg)
trace = traceback.format_exc()
logger.exception(trace)
raise
finally:
current_user_ctx.reset(token)

View file

@ -17,7 +17,7 @@ from langchain_core.tools import BaseTool
from pydantic import BaseModel
from requests.exceptions import RequestException
from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema
from langflow.base.mcp.util import create_input_schema_from_json_schema
from langflow.services.cache.utils import CacheMiss
client_lock = threading.Lock()

View file

@ -1,40 +1,81 @@
import asyncio
import os
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from typing import Any
from uuid import UUID
import httpx
from mcp import ClientSession, StdioServerParameters, stdio_client
from mcp.client.sse import sse_client
from pydantic import Field, create_model
from sqlmodel import select
from langflow.helpers.base_model import BaseModel
from langflow.services.database.models import Flow
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[[dict], Awaitable]:
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., Awaitable]:
async def tool_coroutine(*args, **kwargs):
fields = arg_schema.model_fields.keys()
expected_field_count = len(fields)
if len(args) + len(kwargs) != expected_field_count:
msg = f"{expected_field_count} arguments are required. Received: {args} {kwargs}"
raise ValueError(msg)
arg_dict = dict(zip(fields, args, strict=False))
arg_dict.update(kwargs)
return await session.call_tool(tool_name, arguments=arg_dict)
# Get field names from the model (preserving order)
field_names = list(arg_schema.__fields__.keys())
provided_args = {}
# Map positional arguments to their corresponding field names
for i, arg in enumerate(args):
if i >= len(field_names):
msg = "Too many positional arguments provided"
raise ValueError(msg)
provided_args[field_names[i]] = arg
# Merge in keyword arguments
provided_args.update(kwargs)
# Validate input and fill defaults for missing optional fields
try:
validated = arg_schema.parse_obj(provided_args)
except Exception as e:
msg = f"Invalid input: {e}"
raise ValueError(msg) from e
return await session.call_tool(tool_name, arguments=validated.dict())
return tool_coroutine
def create_tool_func(tool_name: str, session) -> Callable[..., str]:
def tool_func(**kwargs):
if len(kwargs) == 0:
msg = f"at least one named argument is required {kwargs}"
raise ValueError(msg)
def create_tool_func(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., str]:
def tool_func(*args, **kwargs):
field_names = list(arg_schema.__fields__.keys())
provided_args = {}
for i, arg in enumerate(args):
if i >= len(field_names):
msg = "Too many positional arguments provided"
raise ValueError(msg)
provided_args[field_names[i]] = arg
provided_args.update(kwargs)
try:
validated = arg_schema.parse_obj(provided_args)
except Exception as e:
msg = f"Invalid input: {e}"
raise ValueError(msg) from e
loop = asyncio.get_event_loop()
return loop.run_until_complete(session.call_tool(tool_name, arguments=kwargs))
return loop.run_until_complete(session.call_tool(tool_name, arguments=validated.dict()))
return tool_func
async def get_flow(flow_name: str, user_id: str, session) -> Flow | None:
uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
flows = (await session.exec(stmt)).all()
for flow in flows:
if flow.to_data().name == flow_name:
return flow
return None
def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]:
"""Converts a JSON schema into a Pydantic model dynamically.
Fields not listed as required are wrapped in Optional[...] and default to None if not provided.
:param schema: The JSON schema as a dictionary.
:return: A Pydantic model class.
"""
@ -47,9 +88,9 @@ def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseMod
required_fields = set(schema.get("required", []))
for field_name, field_def in properties.items():
# Extract type
field_type_str = field_def.get("type", "str") # Default to string type if not specified
field_type = {
# Determine the base type from the JSON schema type string.
field_type_str = field_def.get("type", "str") # Defaults to string if not specified.
base_type = {
"string": str,
"str": str,
"integer": int,
@ -60,13 +101,77 @@ def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseMod
"object": dict,
}.get(field_type_str, Any)
# Extract description and default if present
field_metadata = {"description": field_def.get("description", "")}
# For non-required fields, wrap the type in Optional[...] and set a default value.
if field_name not in required_fields:
field_metadata["default"] = field_def.get("default", None)
# Create Pydantic field
fields[field_name] = (field_type, Field(**field_metadata))
fields[field_name] = (base_type, Field(**field_metadata))
# Dynamically create the model
return create_model("InputSchema", **fields)
class MCPStdioClient:
def __init__(self):
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def connect_to_server(self, command_str: str):
command = command_str.split(" ")
server_params = StdioServerParameters(
command=command[0],
args=command[1:],
env={"DEBUG": "true", "PATH": os.environ["PATH"]},
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
response = await self.session.list_tools()
return response.tools
class MCPSseClient:
def __init__(self):
self.write = None
self.sse = None
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def pre_check_redirect(self, url: str):
async with httpx.AsyncClient(follow_redirects=False) as client:
response = await client.request("HEAD", url)
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
return response.headers.get("Location")
return url
async def _connect_with_timeout(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
)
self.sse, self.write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
await self.session.initialize()
async def connect_to_server(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
):
if headers is None:
headers = {}
url = await self.pre_check_redirect(url)
try:
await asyncio.wait_for(
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
timeout=timeout_seconds,
)
if self.session is None:
msg = "Session not initialized"
raise ValueError(msg)
response = await self.session.list_tools()
except asyncio.TimeoutError as err:
msg = f"Connection to {url} timed out after {timeout_seconds} seconds"
raise TimeoutError(msg) from err
return response.tools

View file

@ -0,0 +1,61 @@
# from langflow.field_typing import Data
from langchain_core.tools import StructuredTool
from mcp import types
from langflow.base.mcp.util import (
MCPSseClient,
create_input_schema_from_json_schema,
create_tool_coroutine,
create_tool_func,
)
from langflow.custom import Component
from langflow.field_typing import Tool
from langflow.io import MessageTextInput, Output
class MCPSse(Component):
client = MCPSseClient()
tools = types.ListToolsResult
tool_names = [str]
display_name = "MCP Tools (SSE) [DEPRECATED]"
description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent."
documentation: str = "https://docs.langflow.org/components-custom-components"
icon = "code"
name = "MCPSse"
legacy = True
inputs = [
MessageTextInput(
name="url",
display_name="mcp sse url",
info="sse url",
value="http://localhost:7860/api/v1/mcp/sse",
tool_mode=True,
),
]
outputs = [
Output(display_name="Tools", name="tools", method="build_output"),
]
async def build_output(self) -> list[Tool]:
if self.client.session is None:
self.tools = await self.client.connect_to_server(self.url, {})
tool_list = []
for tool in self.tools:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
tool_list.append(
StructuredTool(
name=tool.name, # maybe format this
description=tool.description,
args_schema=args_schema,
func=create_tool_func(tool.name, args_schema, self.client.session),
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
)
)
self.tool_names = [tool.name for tool in self.tools]
return tool_list

View file

@ -1,51 +1,31 @@
# from langflow.field_typing import Data
import os
from contextlib import AsyncExitStack
from langchain_core.tools import StructuredTool
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
from mcp import types
from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func
from langflow.base.mcp.util import (
MCPStdioClient,
create_input_schema_from_json_schema,
create_tool_coroutine,
create_tool_func,
)
from langflow.custom import Component
from langflow.field_typing import Tool
from langflow.io import MessageTextInput, Output
class MCPStdioClient:
def __init__(self):
# Initialize session and client objects
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def connect_to_server(self, command_str: str):
command = command_str.split(" ")
server_params = StdioServerParameters(
command=command[0], args=command[1:], env={"DEBUG": "true", "PATH": os.environ["PATH"]}
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
# List available tools
response = await self.session.list_tools()
return response.tools
class MCPStdio(Component):
client = MCPStdioClient()
tools = types.ListToolsResult
tool_names = [str]
display_name = "MCP Tools (stdio)"
display_name = "MCP Tools (stdio) [DEPRECATED]"
description = (
"Connects to an MCP server over stdio and exposes it's tools as langflow tools to be used by an Agent."
)
documentation: str = "https://docs.langflow.org/components-custom-components"
icon = "code"
name = "MCPStdio"
legacy = True
inputs = [
MessageTextInput(
@ -74,7 +54,7 @@ class MCPStdio(Component):
name=tool.name,
description=tool.description,
args_schema=args_schema,
func=create_tool_func(tool.name, args_schema),
func=create_tool_func(tool.name, args_schema, self.client.session),
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
)
)

View file

@ -13,7 +13,7 @@ from .google_search_api import GoogleSearchAPIComponent
from .google_search_api_core import GoogleSearchAPICore
from .google_serper_api import GoogleSerperAPIComponent
from .google_serper_api_core import GoogleSerperAPICore
from .mcp_stdio import MCPStdio
from .mcp_component import MCPToolsComponent
from .python_code_structured_tool import PythonCodeStructuredTool
from .python_repl import PythonREPLToolComponent
from .python_repl_core import PythonREPLComponent
@ -51,7 +51,7 @@ __all__ = [
"GoogleSearchAPICore",
"GoogleSerperAPIComponent",
"GoogleSerperAPICore",
"MCPStdio",
"MCPToolsComponent",
"PythonCodeStructuredTool",
"PythonREPLComponent",
"PythonREPLToolComponent",

View file

@ -0,0 +1,340 @@
import asyncio
from typing import Any
from langchain_core.tools import StructuredTool
from langflow.base.mcp.util import (
MCPSseClient,
MCPStdioClient,
create_input_schema_from_json_schema,
create_tool_coroutine,
create_tool_func,
)
from langflow.custom import Component
from langflow.inputs import DropdownInput
from langflow.inputs.inputs import InputTypes
from langflow.io import MessageTextInput, Output, TabInput
from langflow.io.schema import schema_to_langflow_inputs
from langflow.logging import logger
from langflow.schema import Message
class MCPToolsComponent(Component):
schema_inputs: list[InputTypes] = []
stdio_client = MCPStdioClient()
sse_client = MCPSseClient()
tools: list = []
tool_names: list[str] = []
_tool_cache: dict = {} # Cache for tool objects
default_keys = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
display_name = "MCP Server"
description = "Connect to an MCP server and expose tools."
icon = "server"
name = "MCPTools"
inputs = [
TabInput(
name="mode",
display_name="Mode",
options=["Stdio", "SSE"],
value="Stdio",
info="Select the connection mode",
real_time_refresh=True,
),
MessageTextInput(
name="command",
display_name="MCP Command",
info="Command for MCP stdio connection",
value="uvx mcp-server-fetch",
show=True,
refresh_button=True,
),
MessageTextInput(
name="sse_url",
display_name="MCP SSE URL",
info="URL for MCP SSE connection",
value="http://localhost:7860/api/v1/mcp/sse",
show=False,
refresh_button=True,
),
DropdownInput(
name="tool",
display_name="Tool",
options=[],
value="",
info="Select the tool to execute",
show=True,
required=True,
real_time_refresh=True,
),
MessageTextInput(
name="tool_placeholder",
display_name="Tool Placeholder",
info="Placeholder for the tool",
value="",
show=False,
tool_mode=True,
),
]
outputs = [
Output(display_name="Response", name="response", method="build_output"),
]
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
"""Validate connection parameters based on mode."""
if mode not in ["Stdio", "SSE"]:
msg = f"Invalid mode: {mode}. Must be either 'Stdio' or 'SSE'"
raise ValueError(msg)
if mode == "Stdio" and not command:
msg = "Command is required for Stdio mode"
raise ValueError(msg)
if mode == "SSE" and not url:
msg = "URL is required for SSE mode"
raise ValueError(msg)
async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:
"""Validate and process schema inputs for a tool."""
try:
if not tool_obj or not hasattr(tool_obj, "inputSchema"):
msg = "Invalid tool object or missing input schema"
raise ValueError(msg)
input_schema = create_input_schema_from_json_schema(tool_obj.inputSchema)
if not input_schema:
msg = f"Empty input schema for tool '{tool_obj.name}'"
raise ValueError(msg)
schema_inputs = schema_to_langflow_inputs(input_schema)
if not schema_inputs:
msg = f"No input parameters defined for tool '{tool_obj.name}'"
logger.warning(msg)
return []
except Exception as e:
msg = f"Error validating schema inputs: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
else:
return schema_inputs
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
"""Toggle the visibility of connection-specific fields based on the selected mode."""
try:
if field_name == "mode":
self.remove_non_default_keys(build_config)
if field_value == "Stdio":
build_config["command"]["show"] = True
build_config["sse_url"]["show"] = False
elif field_value == "SSE":
build_config["command"]["show"] = False
build_config["sse_url"]["show"] = True
if field_name in ("command", "sse_url", "mode"):
try:
await self.update_tools()
if "tool" in build_config:
build_config["tool"]["options"] = self.tool_names
except Exception as e:
build_config["tool"]["options"] = []
msg = f"Failed to update tools: {e!s}"
raise ValueError(msg) from e
elif field_name == "tool":
if len(self.tools) == 0:
await self.update_tools()
if self.tool is None:
return build_config
tool_obj = None
for tool in self.tools:
if tool.name == self.tool:
tool_obj = tool
break
if tool_obj is None:
msg = f"Tool {self.tool} not found in available tools: {self.tools}"
logger.warning(msg)
return build_config
self.remove_non_default_keys(build_config)
await self._update_tool_config(build_config, field_value)
elif field_name == "tool_mode":
build_config["tool"]["show"] = not field_value
for key, value in list(build_config.items()):
if key not in self.default_keys and isinstance(value, dict) and "show" in value:
build_config[key]["show"] = not field_value
except Exception as e:
msg = f"Error in update_build_config: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
else:
return build_config
def get_inputs_for_all_tools(self, tools: list) -> dict:
"""Get input schemas for all tools."""
inputs = {}
for tool in tools:
if not tool or not hasattr(tool, "name"):
continue
try:
input_schema = schema_to_langflow_inputs(create_input_schema_from_json_schema(tool.inputSchema))
inputs[tool.name] = input_schema
except (AttributeError, ValueError, TypeError, KeyError) as e:
msg = f"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}"
logger.exception(msg)
continue
return inputs
def remove_input_schema_from_build_config(
self, build_config: dict, tool_name: str, input_schema: dict[list[InputTypes], Any]
):
"""Remove the input schema for the tool from the build config."""
# Keep only schemas that don't belong to the current tool
input_schema = {k: v for k, v in input_schema.items() if k != tool_name}
# Remove all inputs from other tools
for value in input_schema.values():
for _input in value:
if _input.name in build_config:
build_config.pop(_input.name)
def remove_non_default_keys(self, build_config: dict) -> None:
"""Remove non-default keys from the build config."""
for key in list(build_config.keys()):
if key not in self.default_keys:
build_config.pop(key)
async def _update_tool_config(self, build_config: dict, tool_name: str) -> None:
"""Update tool configuration with proper error handling."""
if not self.tools:
await self.update_tools()
if not tool_name:
return
tool_obj = next((tool for tool in self.tools if tool.name == tool_name), None)
if not tool_obj:
msg = f"Tool {tool_name} not found in available tools: {self.tools}"
logger.warning(msg)
return
try:
# Get all tool inputs and remove old ones
input_schema_for_all_tools = self.get_inputs_for_all_tools(self.tools)
self.remove_input_schema_from_build_config(build_config, tool_name, input_schema_for_all_tools)
# Get and validate new inputs
self.schema_inputs = await self._validate_schema_inputs(tool_obj)
if not self.schema_inputs:
msg = f"No input parameters to configure for tool '{tool_name}'"
logger.info(msg)
return
# Add new inputs to build config
for schema_input in self.schema_inputs:
if not schema_input or not hasattr(schema_input, "name"):
msg = "Invalid schema input detected, skipping"
logger.warning(msg)
continue
try:
name = schema_input.name
input_dict = schema_input.to_dict()
input_dict.setdefault("value", None)
input_dict.setdefault("required", True)
build_config[name] = input_dict
except (AttributeError, KeyError, TypeError) as e:
msg = f"Error processing schema input {schema_input}: {e!s}"
logger.exception(msg)
continue
except ValueError as e:
msg = f"Schema validation error for tool {tool_name}: {e!s}"
logger.exception(msg)
self.schema_inputs = []
return
except (AttributeError, KeyError, TypeError) as e:
msg = f"Error updating tool config: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
async def build_output(self) -> Message:
"""Build output with improved error handling and validation."""
try:
await self.update_tools()
if self.tool != "":
exec_tool = self._tool_cache[self.tool]
tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool]
kwargs = {}
for arg in tool_args:
value = getattr(self, arg.name, None)
if value:
kwargs[arg.name] = value
output = await exec_tool.coroutine(**kwargs)
return Message(text=output.content[len(output.content) - 1].text)
return Message(text="You must select a tool", error=True)
except Exception as e:
msg = f"Error in build_output: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
async def update_tools(self) -> list[StructuredTool]:
"""Connect to the MCP server and update available tools with improved error handling."""
try:
await self._validate_connection_params(self.mode, self.command, self.sse_url)
if self.mode == "Stdio":
if not self.stdio_client.session:
self.tools = await self.stdio_client.connect_to_server(self.command)
elif self.mode == "SSE" and not self.sse_client.session:
self.tools = await self.sse_client.connect_to_server(self.sse_url, {})
if not self.tools:
logger.warning("No tools returned from server")
return []
tool_list = []
for tool in self.tools:
if not tool or not hasattr(tool, "name"):
logger.warning("Invalid tool object detected, skipping")
continue
try:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
if not args_schema:
msg = f"Empty schema for tool '{tool.name}', skipping"
logger.warning(msg)
continue
client = self.stdio_client if self.mode == "Stdio" else self.sse_client
if not client or not client.session:
msg = f"Invalid client session for tool '{tool.name}'"
raise ValueError(msg)
tool_obj = StructuredTool(
name=tool.name,
description=tool.description or "",
args_schema=args_schema,
func=create_tool_func(tool.name, args_schema, client.session),
coroutine=create_tool_coroutine(tool.name, args_schema, client.session),
tags=[tool.name],
)
tool_list.append(tool_obj)
self._tool_cache[tool.name] = tool_obj
except (AttributeError, ValueError, TypeError, KeyError) as e:
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e!s}"
logger.exception(msg)
continue
self.tool_names = [tool.name for tool in self.tools if hasattr(tool, "name")]
except (ValueError, RuntimeError, asyncio.TimeoutError) as e:
msg = f"Error updating tools: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
else:
return tool_list
async def _get_tools(self):
"""Get cached tools or update if necessary."""
if not self.tools:
return await self.update_tools()
return self.tools

View file

@ -1,111 +0,0 @@
# from langflow.field_typing import Data
import asyncio
from contextlib import AsyncExitStack
import httpx
from langchain_core.tools import StructuredTool
from mcp import ClientSession, types
from mcp.client.sse import sse_client
from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func
from langflow.custom import Component
from langflow.field_typing import Tool
from langflow.io import MessageTextInput, Output
# Define constant for status code
HTTP_TEMPORARY_REDIRECT = 307
class MCPSseClient:
def __init__(self):
# Initialize session and client objects
self.write = None
self.sse = None
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()
async def pre_check_redirect(self, url: str):
"""Check if the URL responds with a 307 Redirect."""
async with httpx.AsyncClient(follow_redirects=False) as client:
response = await client.request("HEAD", url)
if response.status_code == HTTP_TEMPORARY_REDIRECT:
return response.headers.get("Location") # Return the redirect URL
return url # Return the original URL if no redirect
async def _connect_with_timeout(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
):
"""Connect to the SSE server with timeout."""
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
)
self.sse, self.write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
await self.session.initialize()
async def connect_to_server(
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
):
if headers is None:
headers = {}
url = await self.pre_check_redirect(url)
try:
await asyncio.wait_for(
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
timeout=timeout_seconds,
)
# List available tools
if self.session is None:
msg = "Session not initialized"
raise ValueError(msg)
response = await self.session.list_tools()
except asyncio.TimeoutError as err:
error_message = f"Connection to {url} timed out after {timeout_seconds} seconds"
raise TimeoutError(error_message) from err
return response.tools
class MCPSse(Component):
client = MCPSseClient()
tools = types.ListToolsResult
tool_names = [str]
display_name = "MCP Tools (SSE)"
description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent."
documentation: str = "https://docs.langflow.org/components-custom-components"
icon = "code"
name = "MCPSse"
inputs = [
MessageTextInput(
name="url",
display_name="mcp sse url",
info="sse url",
value="http://localhost:7860/api/v1/mcp/sse",
tool_mode=True,
),
]
outputs = [
Output(display_name="Tools", name="tools", method="build_output"),
]
async def build_output(self) -> list[Tool]:
if self.client.session is None:
self.tools = await self.client.connect_to_server(self.url, {})
tool_list = []
for tool in self.tools:
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
tool_list.append(
StructuredTool(
name=tool.name, # maybe format this
description=tool.description,
args_schema=args_schema,
func=create_tool_func(tool.name, self.client.session),
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
)
)
self.tool_names = [tool.name for tool in self.tools]
return tool_list

View file

@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, Literal
from typing import Literal, Union, get_args, get_origin
from pydantic import BaseModel, Field, create_model
from langflow.inputs.inputs import FieldTypes
from langflow.inputs.inputs import BoolInput, DictInput, FieldTypes, FloatInput, InputTypes, IntInput, MessageTextInput
from langflow.schema.dotdict import dotdict
_convert_field_type_to_type: dict[FieldTypes, type] = {
@ -20,8 +20,65 @@ _convert_field_type_to_type: dict[FieldTypes, type] = {
FieldTypes.TAB: str,
}
if TYPE_CHECKING:
from langflow.inputs.inputs import InputTypes
_convert_type_to_field_type = {
str: MessageTextInput,
int: IntInput,
float: FloatInput,
bool: BoolInput,
dict: DictInput,
list: MessageTextInput,
}
def schema_to_langflow_inputs(schema: type[BaseModel]) -> list["InputTypes"]:
"""Given a Pydantic schema, convert its fields to Langflow input definitions."""
inputs = []
for field_name, model_field in schema.model_fields.items():
# Start with the field's annotation type
field_type = model_field.annotation
is_list = False
options = None
# If the field is a list, record that and extract its inner type.
if get_origin(field_type) is list:
is_list = True
field_type = get_args(field_type)[0]
# If the field type is a Literal, extract its allowed values.
if get_origin(field_type) is Literal:
options = list(get_args(field_type))
# Optionally, set field_type to the type of the literal values.
if options:
field_type = type(options[0])
# Handle Union types (e.g., Optional fields)
if get_origin(field_type) is Union:
# Get the first non-None type from the Union
field_type = next(t for t in get_args(field_type) if t is not type(None))
# Convert the Python type to the Langflow field type using our reverse mapping.
try:
langflow_field_type = _convert_type_to_field_type[field_type]
except KeyError as e:
msg = f"Unsupported field type: {field_type}"
raise TypeError(msg) from e
# Get metadata from the Pydantic Field.
title = model_field.title or field_name.replace("_", " ").title()
description = model_field.description or ""
required = model_field.is_required()
# Construct the Langflow input.
input_obj = langflow_field_type(
display_name=title,
name=field_name,
info=description,
required=required,
is_list=is_list,
)
inputs.append(input_obj)
return inputs
def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]:

View file

@ -0,0 +1,11 @@
from tests.integration.utils import run_single_component
async def test_mcp_component():
from langflow.components.tools.mcp_component import MCPToolsComponent
inputs = {}
await run_single_component(
MCPToolsComponent,
inputs=inputs, # test default inputs
)

View file

@ -0,0 +1,255 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langflow.components.tools.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
def component_class(self):
"""Return the component class to test."""
return MCPToolsComponent
@pytest.fixture
def default_kwargs(self):
"""Return the default kwargs for the component."""
return {
"mode": "Stdio",
"command": "uvx mcp-server-fetch",
"sse_url": "http://localhost:7860/api/v1/mcp/sse",
"tool": "",
}
@pytest.fixture
def file_names_mapping(self) -> list[VersionComponentMapping]:
"""Return the file names mapping for different versions."""
return []
@pytest.fixture
def mock_tool(self):
"""Create a mock MCP tool."""
tool = MagicMock()
tool.name = "test_tool"
tool.description = "Test tool description"
tool.inputSchema = {
"type": "object",
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
}
return tool
@pytest.fixture
def mock_stdio_client(self, mock_tool):
"""Create a mock stdio client."""
stdio_client = AsyncMock()
stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool])
stdio_client.session = AsyncMock()
return stdio_client
@pytest.fixture
def mock_sse_client(self, mock_tool):
"""Create a mock SSE client."""
sse_client = AsyncMock()
sse_client.connect_to_server = AsyncMock(return_value=[mock_tool])
sse_client.session = AsyncMock()
return sse_client
async def test_validate_connection_params_invalid_mode(self, component_class, default_kwargs):
"""Test validation with invalid mode."""
component = component_class(**default_kwargs)
with pytest.raises(ValueError, match="Invalid mode: invalid. Must be either 'Stdio' or 'SSE'"):
await component._validate_connection_params("invalid")
async def test_validate_connection_params_missing_command(self, component_class, default_kwargs):
"""Test validation with missing command in Stdio mode."""
component = component_class(**default_kwargs)
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
await component._validate_connection_params("Stdio", command=None)
async def test_validate_connection_params_missing_url(self, component_class, default_kwargs):
"""Test validation with missing URL in SSE mode."""
component = component_class(**default_kwargs)
with pytest.raises(ValueError, match="URL is required for SSE mode"):
await component._validate_connection_params("SSE", url=None)
async def test_update_build_config_mode_change(self, component_class, default_kwargs):
"""Test build config updates when mode changes."""
component = component_class(**default_kwargs)
build_config = {
"command": {"show": False},
"sse_url": {"show": True},
"tool": {"options": [], "show": True}, # Add tool field since component uses it
}
# Test switching to Stdio mode
updated_config = await component.update_build_config(build_config, "Stdio", "mode")
assert updated_config["command"]["show"] is True
assert updated_config["sse_url"]["show"] is False
# Test switching to SSE mode
updated_config = await component.update_build_config(build_config, "SSE", "mode")
assert updated_config["command"]["show"] is False
assert updated_config["sse_url"]["show"] is True
# Test tool options are updated
assert "options" in updated_config["tool"]
@patch("langflow.components.tools.mcp_component.create_tool_coroutine")
async def test_build_output(self, mock_create_coroutine, component_class, default_kwargs, mock_tool):
"""Test building output with a tool."""
component = component_class(**default_kwargs)
component.tool = "test_tool"
component.tools = [mock_tool]
# Mock the coroutine response
mock_response = AsyncMock()
mock_response.content = [MagicMock(text="Test response")]
mock_create_coroutine.return_value = AsyncMock(return_value=mock_response)
# Create a mock tool and add it to the cache
mock_structured_tool = MagicMock()
mock_structured_tool.coroutine = mock_create_coroutine.return_value
component._tool_cache = {"test_tool": mock_structured_tool}
# Set the test parameter value
component.test_param = "test value"
# Mock get_inputs_for_all_tools to return our mock input
mock_input = MagicMock()
mock_input.name = "test_param"
with patch.object(component, "get_inputs_for_all_tools") as mock_get_inputs:
mock_get_inputs.return_value = {"test_tool": [mock_input]}
output = await component.build_output()
assert output.text == "Test response"
# Verify the mocks were called correctly
mock_get_inputs.assert_called_once_with(component.tools)
mock_structured_tool.coroutine.assert_called_once_with(test_param="test value")
async def test_get_inputs_for_all_tools(self, component_class, default_kwargs, mock_tool):
"""Test getting input schemas for all tools."""
component = component_class(**default_kwargs)
inputs = component.get_inputs_for_all_tools([mock_tool])
assert "test_tool" in inputs
assert len(inputs["test_tool"]) > 0 # Should have at least one input parameter
async def test_remove_non_default_keys(self, component_class, default_kwargs):
"""Test removing non-default keys from build config."""
component = component_class(**default_kwargs)
build_config = {"code": {}, "mode": {}, "command": {}, "custom_key": {}}
component.remove_non_default_keys(build_config)
assert "custom_key" not in build_config
assert all(key in build_config for key in ["code", "mode", "command"])
class TestMCPStdioClient:
@pytest.fixture
def stdio_client(self):
return MCPStdioClient()
async def test_connect_to_server(self, stdio_client):
"""Test connecting to server via Stdio."""
# Create mock for stdio transport
mock_stdio = AsyncMock()
mock_write = AsyncMock()
mock_stdio_transport = (mock_stdio, mock_write)
mock_stdio_cm = AsyncMock()
mock_stdio_cm.__aenter__.return_value = mock_stdio_transport
# Mock the stdio_client function to return our mock context manager
with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm):
# Mock ClientSession
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools.return_value.tools = [MagicMock()]
# Mock the AsyncExitStack
mock_exit_stack = AsyncMock()
mock_exit_stack.enter_async_context = AsyncMock()
mock_exit_stack.enter_async_context.side_effect = [
mock_stdio_transport, # For stdio_client
mock_session, # For ClientSession
]
stdio_client.exit_stack = mock_exit_stack
tools = await stdio_client.connect_to_server("test_command")
assert len(tools) == 1
assert stdio_client.session is not None
# Verify the exit stack was used correctly
assert mock_exit_stack.enter_async_context.call_count == 2
# Verify the stdio transport was properly set
assert stdio_client.stdio == mock_stdio
assert stdio_client.write == mock_write
class TestMCPSseClient:
@pytest.fixture
def sse_client(self):
return MCPSseClient()
async def test_pre_check_redirect(self, sse_client):
"""Test pre-checking URL for redirects."""
test_url = "http://test.url"
redirect_url = "http://redirect.url"
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 307
mock_response.headers.get.return_value = redirect_url
mock_client.return_value.__aenter__.return_value.request.return_value = mock_response
result = await sse_client.pre_check_redirect(test_url)
assert result == redirect_url
async def test_connect_to_server(self, sse_client):
"""Test connecting to server via SSE."""
# Mock the pre_check_redirect first
with patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"):
# Create mock for sse_client context manager
mock_sse = AsyncMock()
mock_write = AsyncMock()
mock_sse_transport = (mock_sse, mock_write)
mock_sse_cm = AsyncMock()
mock_sse_cm.__aenter__.return_value = mock_sse_transport
# Mock the sse_client function to return our mock context manager
with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm):
# Mock ClientSession
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools.return_value.tools = [MagicMock()]
# Mock the AsyncExitStack
mock_exit_stack = AsyncMock()
mock_exit_stack.enter_async_context = AsyncMock()
mock_exit_stack.enter_async_context.side_effect = [
mock_sse_transport, # For sse_client
mock_session, # For ClientSession
]
sse_client.exit_stack = mock_exit_stack
tools = await sse_client.connect_to_server("http://test.url", {})
assert len(tools) == 1
assert sse_client.session is not None
# Verify the exit stack was used correctly
assert mock_exit_stack.enter_async_context.call_count == 2
# Verify the SSE transport was properly set
assert sse_client.sse == mock_sse
assert sse_client.write == mock_write
async def test_connect_timeout(self, sse_client):
"""Test connection timeout handling."""
with (
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
):
mock_connect.side_effect = asyncio.TimeoutError()
with pytest.raises(TimeoutError, match="Connection to http://test.url timed out after 1 seconds"):
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)

View file

@ -3,11 +3,13 @@ from types import NoneType
from typing import Union
import pytest
from langflow.inputs.inputs import BoolInput, DictInput, FloatInput, InputTypes, IntInput, MessageTextInput
from langflow.io.schema import schema_to_langflow_inputs
from langflow.schema.data import Data
from langflow.template import Input, Output
from langflow.template.field.base import UNDEFINED
from langflow.type_extraction.type_extraction import post_process_type
from pydantic import ValidationError
from pydantic import BaseModel, Field, ValidationError
class TestInput:
@ -178,3 +180,65 @@ class TestPostProcessType:
pass
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007
def test_schema_to_langflow_inputs():
# Define a test Pydantic model with various field types
class TestSchema(BaseModel):
text_field: str = Field(title="Custom Text Title", description="A text field")
number_field: int = Field(description="A number field")
bool_field: bool = Field(description="A boolean field")
dict_field: dict = Field(description="A dictionary field")
list_field: list[str] = Field(description="A list of strings")
# Convert schema to Langflow inputs
inputs = schema_to_langflow_inputs(TestSchema)
# Verify the number of inputs matches the schema fields
assert len(inputs) == 5
# Helper function to find input by name
def find_input(name: str) -> InputTypes | None:
for _input in inputs:
if _input.name == name:
return _input
return None
# Test text field
text_input = find_input("text_field")
assert text_input.display_name == "Custom Text Title"
assert text_input.info == "A text field"
assert isinstance(text_input, MessageTextInput) # Check the instance type instead of field_type
# Test number field
number_input = find_input("number_field")
assert number_input.display_name == "Number Field"
assert number_input.info == "A number field"
assert isinstance(number_input, IntInput | FloatInput)
# Test boolean field
bool_input = find_input("bool_field")
assert isinstance(bool_input, BoolInput)
# Test dictionary field
dict_input = find_input("dict_field")
assert isinstance(dict_input, DictInput)
# Test list field
list_input = find_input("list_field")
assert list_input.is_list is True
assert isinstance(list_input, MessageTextInput)
def test_schema_to_langflow_inputs_invalid_type():
# Define a schema with an unsupported type
class CustomType:
pass
class InvalidSchema(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Add this line
invalid_field: CustomType
# Test that attempting to convert an unsupported type raises TypeError
with pytest.raises(TypeError, match="Unsupported field type:"):
schema_to_langflow_inputs(InvalidSchema)

View file

@ -706,6 +706,7 @@
},
"node_modules/@clack/prompts/node_modules/is-unicode-supported": {
"version": "1.3.0",
"extraneous": true,
"inBundle": true,
"license": "MIT",
"engines": {

View file

@ -37,7 +37,7 @@ export const ForwardedIconComponent = memo(
nodeIconsLucide[
name
?.split("-")
?.map((x) => String(x[0]).toUpperCase() + String(x).slice(1))
?.map((x) => String(x[0])?.toUpperCase() + String(x).slice(1))
?.join("")
];
if (!TargetIcon) {

View file

@ -32,7 +32,7 @@ export default function RenderKey({
className={cn(tableRender ? "h-4 w-4" : "h-3 w-3")}
/>
) : (
<span>{value.toUpperCase()}</span>
<span>{value?.toUpperCase()}</span>
)}
</div>
);

View file

@ -61,9 +61,11 @@ export interface ButtonProps
function toTitleCase(text: string) {
return text
.split(" ")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join(" ");
?.split(" ")
?.map(
(word) => word?.charAt(0)?.toUpperCase() + word?.slice(1)?.toLowerCase(),
)
?.join(" ");
}
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(

View file

@ -1022,3 +1022,18 @@ export const IS_AUTO_LOGIN =
export const AUTO_LOGIN_RETRY_DELAY = 2000;
export const AUTO_LOGIN_MAX_RETRY_DELAY = 60000;
export const ALL_LANGUAGES = [
{ value: "en-US", name: "English (US)" },
{ value: "en-GB", name: "English (UK)" },
{ value: "it-IT", name: "Italian" },
{ value: "fr-FR", name: "French" },
{ value: "es-ES", name: "Spanish" },
{ value: "de-DE", name: "German" },
{ value: "ja-JP", name: "Japanese" },
{ value: "pt-BR", name: "Portuguese (Brazil)" },
{ value: "zh-CN", name: "Chinese (Simplified)" },
{ value: "ru-RU", name: "Russian" },
{ value: "ar-SA", name: "Arabic" },
{ value: "hi-IN", name: "Hindi" },
];

View file

@ -21,21 +21,6 @@ import LanguageSelect from "./components/language-select";
import MicrophoneSelect from "./components/microphone-select";
import VoiceSelect from "./components/voice-select";
const ALL_LANGUAGES = [
{ value: "en-US", name: "English (US)" },
{ value: "en-GB", name: "English (UK)" },
{ value: "it-IT", name: "Italian" },
{ value: "fr-FR", name: "French" },
{ value: "es-ES", name: "Spanish" },
{ value: "de-DE", name: "German" },
{ value: "ja-JP", name: "Japanese" },
{ value: "pt-BR", name: "Portuguese (Brazil)" },
{ value: "zh-CN", name: "Chinese (Simplified)" },
{ value: "ru-RU", name: "Russian" },
{ value: "ar-SA", name: "Arabic" },
{ value: "hi-IN", name: "Hindi" },
];
interface SettingsVoiceModalProps {
children?: React.ReactNode;
userOpenaiApiKey?: string;
@ -415,7 +400,6 @@ const SettingsVoiceModal = ({
<LanguageSelect
language={currentLanguage}
handleSetLanguage={handleSetLanguage}
allLanguages={ALL_LANGUAGES}
/>
</>
)}

View file

@ -1,3 +1,4 @@
import { ALL_LANGUAGES } from "@/constants/constants";
import IconComponent from "../../../../../../../../../../components/common/genericIconComponent";
import ShadTooltip from "../../../../../../../../../../components/common/shadTooltipComponent";
import {
@ -12,13 +13,11 @@ import {
interface LanguageSelectProps {
language: string;
handleSetLanguage: (value: string) => void;
allLanguages: { value: string; name: string }[];
}
const LanguageSelect = ({
language,
handleSetLanguage,
allLanguages,
}: LanguageSelectProps) => {
return (
<div className="grid w-full items-center gap-2">
@ -41,7 +40,7 @@ const LanguageSelect = ({
</SelectTrigger>
<SelectContent className="max-h-[200px]">
<SelectGroup>
{allLanguages.map((lang) => (
{ALL_LANGUAGES?.map((lang) => (
<SelectItem key={lang?.value} value={lang?.value}>
<div className="max-w-[220px] truncate text-left">
{lang?.name}

View file

@ -47,9 +47,9 @@ export function toNormalCase(str: string): string {
.split("_")
.map((word, index) => {
if (index === 0) {
return word[0].toUpperCase() + word.slice(1).toLowerCase();
return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
}
return word.toLowerCase();
return word?.toLowerCase();
})
.join(" ");
@ -57,9 +57,9 @@ export function toNormalCase(str: string): string {
.split("-")
.map((word, index) => {
if (index === 0) {
return word[0].toUpperCase() + word.slice(1).toLowerCase();
return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
}
return word.toLowerCase();
return word?.toLowerCase();
})
.join(" ");
}
@ -69,11 +69,11 @@ export function normalCaseToSnakeCase(str: string): string {
.split(" ")
.map((word, index) => {
if (index === 0) {
return word[0].toUpperCase() + word.slice(1).toLowerCase();
return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
}
return word.toLowerCase();
return word?.toLowerCase();
})
.join("_");
?.join("_");
}
export function toTitleCase(
@ -82,41 +82,41 @@ export function toTitleCase(
): string {
if (!str) return "";
let result = str
.split("_")
.map((word, index) => {
?.split("_")
?.map((word, index) => {
if (isNodeField) return word;
if (index === 0) {
return checkUpperWords(
word[0].toUpperCase() + word.slice(1).toLowerCase(),
word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(),
);
}
return checkUpperWords(word.toLowerCase());
return checkUpperWords(word?.toLowerCase());
})
.join(" ");
return result
.split("-")
.map((word, index) => {
?.split("-")
?.map((word, index) => {
if (isNodeField) return word;
if (index === 0) {
return checkUpperWords(
word[0].toUpperCase() + word.slice(1).toLowerCase(),
word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(),
);
}
return checkUpperWords(word.toLowerCase());
return checkUpperWords(word?.toLowerCase());
})
.join(" ");
?.join(" ");
}
export const upperCaseWords: string[] = ["llm", "uri"];
export function checkUpperWords(str: string): string {
const words = str.split(" ").map((word) => {
return upperCaseWords.includes(word.toLowerCase())
? word.toUpperCase()
: word[0].toUpperCase() + word.slice(1).toLowerCase();
const words = str?.split(" ")?.map((word) => {
return upperCaseWords.includes(word?.toLowerCase())
? word?.toUpperCase()
: word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
});
return words.join(" ");
return words?.join(" ");
}
export function buildInputs(): string {

View file

@ -215,6 +215,8 @@ test(
).isVisible(),
);
await page.waitForTimeout(2000);
await awaitBootstrapTest(page, { skipGoto: true });
await page.getByTestId("side_nav_options_all-templates").click();

View file

@ -73,9 +73,11 @@ test(
}
}
await page.waitForTimeout(1000);
await page.waitForTimeout(500);
await visibleElementHandle.hover().then(async () => {
await page.waitForTimeout(1000);
await expect(
page.getByText("Drag to connect compatible outputs").first(),
).toBeVisible();
@ -105,7 +107,11 @@ test(
}
}
await page.waitForTimeout(500);
await visibleElementHandle.hover().then(async () => {
await page.waitForTimeout(1000);
await expect(
page.getByText("Drag to connect compatible outputs").first(),
).toBeVisible();

View file

@ -164,13 +164,21 @@ test(
await page.getByTestId("chat-message-User-session_after_delete").click();
await expect(page.getByTestId("session-selector")).toBeVisible();
await page.waitForTimeout(500);
// check helpful button
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await page.getByTestId("helpful-button").click();
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({
timeout: 10000,
});
await page.waitForTimeout(500);
await page.getByTestId("helpful-button").click();
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({
@ -178,26 +186,38 @@ test(
visible: false,
});
// check not helpful button
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await page.getByTestId("not-helpful-button").click();
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({
timeout: 10000,
});
await page.getByTestId("not-helpful-button").click();
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({
timeout: 10000,
visible: false,
});
// check switch feedback
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await page.getByTestId("helpful-button").click();
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({
timeout: 10000,
});
await page.getByTestId("not-helpful-button").click();
await page.waitForTimeout(500);
await page.getByTestId("chat-message-AI-session_after_delete").hover();
await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({
timeout: 10000,

View file

@ -0,0 +1,76 @@
import { expect, test } from "@playwright/test";
import { awaitBootstrapTest } from "../../utils/await-bootstrap-test";
test(
"user should be able to edit tools",
{ tag: ["@release"] },
async ({ page }) => {
await awaitBootstrapTest(page);
await page.getByTestId("blank-flow").click();
await page.getByTestId("sidebar-search-input").click();
await page.getByTestId("sidebar-search-input").fill("api request");
await page.waitForSelector('[data-testid="dataAPI Request"]', {
timeout: 3000,
});
await page
.getByTestId("dataAPI Request")
.hover()
.then(async () => {
await page.getByTestId("add-component-button-api-request").click();
});
await page.waitForSelector(
'[data-testid="generic-node-title-arrangement"]',
{
timeout: 3000,
},
);
await page.getByTestId("generic-node-title-arrangement").click();
await page.waitForTimeout(500);
await page.getByTestId("tool-mode-button").click();
await page.locator('[data-testid="icon-Hammer"]').nth(1).waitFor({
timeout: 3000,
state: "visible",
});
await page.getByTestId("icon-Hammer").nth(1).click();
await page.waitForSelector("text=edit tools", { timeout: 30000 });
const rowsCount = await page.getByRole("gridcell").count();
expect(rowsCount).toBeGreaterThan(3);
expect(await page.getByRole("switch").nth(0).isChecked()).toBe(true);
await page.getByRole("switch").nth(0).click();
expect(await page.getByRole("switch").nth(0).isChecked()).toBe(false);
await page.getByText("Save").last().click();
await page.waitForSelector(
'[data-testid="generic-node-title-arrangement"]',
{
timeout: 3000,
},
);
await page.waitForTimeout(500);
await page.getByTestId("icon-Hammer").nth(1).click();
await page.waitForSelector("text=edit tools", { timeout: 30000 });
await page.waitForTimeout(500);
expect(await page.getByRole("switch").nth(0).isChecked()).toBe(false);
},
);