chore: merge mcp components (#7167)
* take1 * depreacate stdio and sse mcp components * optionals * rodrigo fixes * session management * update init * mcp component integration test * broken * [autofix.ci] apply automated fixes * fix url input name * upated MCP * Update mcp_component.py * [autofix.ci] apply automated fixes * update to the MCP component * [autofix.ci] apply automated fixes * mostly working * [autofix.ci] apply automated fixes * Update mcp_component.py * [autofix.ci] apply automated fixes * update component * [autofix.ci] apply automated fixes * Update mcp_component.py * rename component because Simon * icon and description for simon * fix integration test * fix test * Update mcp_component.py * update and basic QoL * [autofix.ci] apply automated fixes * refactor clients to util and use flow names not IDs in mcp.py * integration test * take out traces * ✨ (edit-tools.spec.ts): add test for user to be able to edit tools in the frontend application. * session fix * fix content output * ♻️ (util.py): remove redundant constant HTTP_TEMPORARY_REDIRECT and replace its usage with httpx.codes.TEMPORARY_REDIRECT for better code readability and maintainability * [autofix.ci] apply automated fixes * 🐛 (utils.ts): fix potential null pointer error when converting words to title case by adding null check before accessing properties * 🐛 (genericIconComponent/index.tsx): Fix issue with optional chaining in mapping function 🐛 (renderIconComponent/index.tsx): Fix issue with optional chaining in mapping function 🐛 (button.tsx): Fix issue with optional chaining in mapping function 🐛 (utils.ts): Fix issue with optional chaining in mapping functions * 🐛 (language-select.tsx): Fix potential null pointer error when mapping over allLanguages array * ✨ (constants.ts): add support for multiple languages in the application by defining an array of language options ♻️ (audio-settings-dialog.tsx, language-select.tsx): refactor to import the array of all languages from constants.ts instead of duplicating it in each file * ✅ (auto-login-off.spec.ts): add a 2-second delay before continuing the test to ensure proper loading and rendering of elements on the page * ⬆️ (filterEdge-shard-0.spec.ts): reduce wait time for page interactions to improve test performance ⬆️ (playground.spec.ts): optimize wait times for page interactions to enhance test efficiency --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
This commit is contained in:
parent
4527c473be
commit
59b2ed7765
25 changed files with 1200 additions and 331 deletions
|
|
@ -2,11 +2,12 @@ import asyncio
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from typing import Annotated
|
||||
from functools import wraps
|
||||
from typing import Annotated, Any, ParamSpec, TypeVar
|
||||
from urllib.parse import quote, unquote, urlparse
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import pydantic
|
||||
from anyio import BrokenResourceError
|
||||
|
|
@ -20,34 +21,43 @@ from starlette.background import BackgroundTasks
|
|||
|
||||
from langflow.api.v1.chat import build_flow_and_stream
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
from langflow.base.mcp.util import get_flow
|
||||
from langflow.helpers.flow import json_schema_from_flow
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models import Flow, User
|
||||
from langflow.services.deps import (
|
||||
get_db_service,
|
||||
get_session,
|
||||
get_settings_service,
|
||||
get_storage_service,
|
||||
session_scope,
|
||||
)
|
||||
from langflow.services.storage.utils import build_content_type_from_extension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if False:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Enable debug logging for MCP package
|
||||
mcp_logger = logging.getLogger("mcp")
|
||||
mcp_logger.setLevel(logging.DEBUG)
|
||||
if not mcp_logger.handlers:
|
||||
mcp_logger.addHandler(handler)
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
logger.debug("MCP module loaded - debug logging enabled")
|
||||
|
||||
def handle_mcp_errors(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""Decorator to handle MCP endpoint errors consistently."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def with_db_session(operation: Callable[[Any], Awaitable[T]]) -> T:
|
||||
"""Execute an operation within a database session context."""
|
||||
async with session_scope() as session:
|
||||
return await operation(session)
|
||||
|
||||
|
||||
class MCPConfig:
|
||||
|
|
@ -88,7 +98,7 @@ async def handle_list_prompts():
|
|||
async def handle_list_resources():
|
||||
resources = []
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
db_service = get_db_service()
|
||||
storage_service = get_storage_service()
|
||||
settings_service = get_settings_service()
|
||||
|
||||
|
|
@ -98,31 +108,30 @@ async def handle_list_resources():
|
|||
|
||||
base_url = f"http://{host}:{port}".rstrip("/")
|
||||
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
async with db_service.with_session() as session:
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.id:
|
||||
try:
|
||||
files = await storage_service.list_files(flow_id=str(flow.id))
|
||||
for file_name in files:
|
||||
# URL encode the filename
|
||||
safe_filename = quote(file_name)
|
||||
resource = types.Resource(
|
||||
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
|
||||
name=file_name,
|
||||
description=f"File in flow: {flow.name}",
|
||||
mimeType=build_content_type_from_extension(file_name),
|
||||
)
|
||||
resources.append(resource)
|
||||
except FileNotFoundError as e:
|
||||
msg = f"Error listing files for flow {flow.id}: {e}"
|
||||
logger.debug(msg)
|
||||
continue
|
||||
for flow in flows:
|
||||
if flow.id:
|
||||
try:
|
||||
files = await storage_service.list_files(flow_id=str(flow.id))
|
||||
for file_name in files:
|
||||
# URL encode the filename
|
||||
safe_filename = quote(file_name)
|
||||
resource = types.Resource(
|
||||
uri=f"{base_url}/api/v1/files/{flow.id}/{safe_filename}",
|
||||
name=file_name,
|
||||
description=f"File in flow: {flow.name}",
|
||||
mimeType=build_content_type_from_extension(file_name),
|
||||
)
|
||||
resources.append(resource)
|
||||
except FileNotFoundError as e:
|
||||
msg = f"Error listing files for flow {flow.id}: {e}"
|
||||
logger.debug(msg)
|
||||
continue
|
||||
except Exception as e:
|
||||
msg = f"Error in listing resources: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
return resources
|
||||
|
||||
|
|
@ -162,8 +171,6 @@ async def handle_read_resource(uri: str) -> bytes:
|
|||
except Exception as e:
|
||||
msg = f"Error reading resource {uri}: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -171,47 +178,48 @@ async def handle_read_resource(uri: str) -> bytes:
|
|||
async def handle_list_tools():
|
||||
tools = []
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
db_service = get_db_service()
|
||||
async with db_service.with_session() as session:
|
||||
flows = (await session.exec(select(Flow))).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.user_id is None:
|
||||
continue
|
||||
for flow in flows:
|
||||
if flow.user_id is None:
|
||||
continue
|
||||
|
||||
tool = types.Tool(
|
||||
name=str(flow.id), # Use flow.id instead of name
|
||||
description=f"{flow.name}: {flow.description}"
|
||||
if flow.description
|
||||
else f"Tool generated from flow: {flow.name}",
|
||||
inputSchema=json_schema_from_flow(flow),
|
||||
)
|
||||
tools.append(tool)
|
||||
tool = types.Tool(
|
||||
name=flow.name,
|
||||
description=f"{flow.id}: {flow.description}"
|
||||
if flow.description
|
||||
else f"Tool generated from flow: {flow.name}",
|
||||
inputSchema=json_schema_from_flow(flow),
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
msg = f"Error in listing tools: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
return tools
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
@handle_mcp_errors
|
||||
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
"""Handle tool execution requests."""
|
||||
mcp_config = get_mcp_config()
|
||||
if mcp_config.enable_progress_notifications is None:
|
||||
settings_service = get_settings_service()
|
||||
mcp_config.enable_progress_notifications = settings_service.settings.mcp_server_enable_progress_notifications
|
||||
try:
|
||||
session = await anext(get_session())
|
||||
background_tasks = BackgroundTasks()
|
||||
|
||||
current_user = current_user_ctx.get()
|
||||
flow = (await session.exec(select(Flow).where(Flow.id == UUID(name)))).first()
|
||||
background_tasks = BackgroundTasks()
|
||||
current_user = current_user_ctx.get()
|
||||
|
||||
async def execute_tool(session):
|
||||
# get flow id from name
|
||||
flow = await get_flow(name, current_user.id, session)
|
||||
if not flow:
|
||||
msg = f"Flow with id '{name}' not found"
|
||||
msg = f"Flow with name '{name}' not found"
|
||||
raise ValueError(msg)
|
||||
flow_id = flow.id
|
||||
|
||||
# Process inputs
|
||||
processed_inputs = dict(arguments)
|
||||
|
|
@ -240,70 +248,66 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
|
|||
progress += 0.1
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
# Send final 100% progress
|
||||
if mcp_config.enable_progress_notifications:
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
db_service = get_db_service()
|
||||
collected_results = []
|
||||
async with db_service.with_session():
|
||||
try:
|
||||
progress_task = asyncio.create_task(send_progress_updates())
|
||||
|
||||
try:
|
||||
progress_task = asyncio.create_task(send_progress_updates())
|
||||
response = await build_flow_and_stream(
|
||||
flow_id=flow_id,
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await build_flow_and_stream(
|
||||
flow_id=UUID(name),
|
||||
inputs=input_request,
|
||||
background_tasks=background_tasks,
|
||||
current_user=current_user,
|
||||
)
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
if event_data.get("event") == "end_vertex":
|
||||
message = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", {})
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
if message:
|
||||
collected_results.append(types.TextContent(type="text", text=str(message)))
|
||||
except json.JSONDecodeError:
|
||||
msg = f"Failed to parse event data: {line}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
async for line in response.body_iterator:
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
if event_data.get("event") == "end_vertex":
|
||||
message = (
|
||||
event_data.get("data", {})
|
||||
.get("build_data", {})
|
||||
.get("data", {})
|
||||
.get("results", {})
|
||||
.get("message", {})
|
||||
.get("text", "")
|
||||
)
|
||||
if message:
|
||||
collected_results.append(types.TextContent(type="text", text=str(message)))
|
||||
except json.JSONDecodeError:
|
||||
msg = f"Failed to parse event data: {line}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
return collected_results
|
||||
finally:
|
||||
progress_task.cancel()
|
||||
await asyncio.wait([progress_task])
|
||||
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
|
||||
raise exc
|
||||
|
||||
return collected_results
|
||||
finally:
|
||||
progress_task.cancel()
|
||||
await asyncio.wait([progress_task])
|
||||
if not progress_task.cancelled() and (exc := progress_task.exception()) is not None:
|
||||
raise exc
|
||||
except Exception as e:
|
||||
msg = f"Error in async session: {e}"
|
||||
logger.exception(msg)
|
||||
raise
|
||||
except Exception:
|
||||
if mcp_config.enable_progress_notifications and (
|
||||
progress_token := server.request_context.meta.progressToken
|
||||
):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
return await with_db_session(execute_tool)
|
||||
except Exception as e:
|
||||
context = server.request_context
|
||||
# Send error progress if there's an exception
|
||||
if mcp_config.enable_progress_notifications and (progress_token := context.meta.progressToken):
|
||||
await server.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=1.0, total=1.0
|
||||
)
|
||||
msg = f"Error executing tool {name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -357,8 +361,6 @@ async def handle_sse(request: Request, current_user: Annotated[User, Depends(get
|
|||
except Exception as e:
|
||||
msg = f"Error in MCP: {e!s}"
|
||||
logger.exception(msg)
|
||||
trace = traceback.format_exc()
|
||||
logger.exception(trace)
|
||||
raise
|
||||
finally:
|
||||
current_user_ctx.reset(token)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from langchain_core.tools import BaseTool
|
|||
from pydantic import BaseModel
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema
|
||||
from langflow.base.mcp.util import create_input_schema_from_json_schema
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
|
||||
client_lock = threading.Lock()
|
||||
|
|
|
|||
|
|
@ -1,40 +1,81 @@
|
|||
import asyncio
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, StdioServerParameters, stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from pydantic import Field, create_model
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.helpers.base_model import BaseModel
|
||||
from langflow.services.database.models import Flow
|
||||
|
||||
|
||||
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[[dict], Awaitable]:
|
||||
def create_tool_coroutine(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., Awaitable]:
|
||||
async def tool_coroutine(*args, **kwargs):
|
||||
fields = arg_schema.model_fields.keys()
|
||||
expected_field_count = len(fields)
|
||||
if len(args) + len(kwargs) != expected_field_count:
|
||||
msg = f"{expected_field_count} arguments are required. Received: {args} {kwargs}"
|
||||
raise ValueError(msg)
|
||||
arg_dict = dict(zip(fields, args, strict=False))
|
||||
arg_dict.update(kwargs)
|
||||
return await session.call_tool(tool_name, arguments=arg_dict)
|
||||
# Get field names from the model (preserving order)
|
||||
field_names = list(arg_schema.__fields__.keys())
|
||||
provided_args = {}
|
||||
# Map positional arguments to their corresponding field names
|
||||
for i, arg in enumerate(args):
|
||||
if i >= len(field_names):
|
||||
msg = "Too many positional arguments provided"
|
||||
raise ValueError(msg)
|
||||
provided_args[field_names[i]] = arg
|
||||
# Merge in keyword arguments
|
||||
provided_args.update(kwargs)
|
||||
# Validate input and fill defaults for missing optional fields
|
||||
try:
|
||||
validated = arg_schema.parse_obj(provided_args)
|
||||
except Exception as e:
|
||||
msg = f"Invalid input: {e}"
|
||||
raise ValueError(msg) from e
|
||||
return await session.call_tool(tool_name, arguments=validated.dict())
|
||||
|
||||
return tool_coroutine
|
||||
|
||||
|
||||
def create_tool_func(tool_name: str, session) -> Callable[..., str]:
|
||||
def tool_func(**kwargs):
|
||||
if len(kwargs) == 0:
|
||||
msg = f"at least one named argument is required {kwargs}"
|
||||
raise ValueError(msg)
|
||||
def create_tool_func(tool_name: str, arg_schema: type[BaseModel], session) -> Callable[..., str]:
|
||||
def tool_func(*args, **kwargs):
|
||||
field_names = list(arg_schema.__fields__.keys())
|
||||
provided_args = {}
|
||||
for i, arg in enumerate(args):
|
||||
if i >= len(field_names):
|
||||
msg = "Too many positional arguments provided"
|
||||
raise ValueError(msg)
|
||||
provided_args[field_names[i]] = arg
|
||||
provided_args.update(kwargs)
|
||||
try:
|
||||
validated = arg_schema.parse_obj(provided_args)
|
||||
except Exception as e:
|
||||
msg = f"Invalid input: {e}"
|
||||
raise ValueError(msg) from e
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(session.call_tool(tool_name, arguments=kwargs))
|
||||
return loop.run_until_complete(session.call_tool(tool_name, arguments=validated.dict()))
|
||||
|
||||
return tool_func
|
||||
|
||||
|
||||
async def get_flow(flow_name: str, user_id: str, session) -> Flow | None:
|
||||
uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
|
||||
flows = (await session.exec(stmt)).all()
|
||||
|
||||
for flow in flows:
|
||||
if flow.to_data().name == flow_name:
|
||||
return flow
|
||||
return None
|
||||
|
||||
|
||||
def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseModel]:
|
||||
"""Converts a JSON schema into a Pydantic model dynamically.
|
||||
|
||||
Fields not listed as required are wrapped in Optional[...] and default to None if not provided.
|
||||
|
||||
:param schema: The JSON schema as a dictionary.
|
||||
:return: A Pydantic model class.
|
||||
"""
|
||||
|
|
@ -47,9 +88,9 @@ def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseMod
|
|||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for field_name, field_def in properties.items():
|
||||
# Extract type
|
||||
field_type_str = field_def.get("type", "str") # Default to string type if not specified
|
||||
field_type = {
|
||||
# Determine the base type from the JSON schema type string.
|
||||
field_type_str = field_def.get("type", "str") # Defaults to string if not specified.
|
||||
base_type = {
|
||||
"string": str,
|
||||
"str": str,
|
||||
"integer": int,
|
||||
|
|
@ -60,13 +101,77 @@ def create_input_schema_from_json_schema(schema: dict[str, Any]) -> type[BaseMod
|
|||
"object": dict,
|
||||
}.get(field_type_str, Any)
|
||||
|
||||
# Extract description and default if present
|
||||
field_metadata = {"description": field_def.get("description", "")}
|
||||
|
||||
# For non-required fields, wrap the type in Optional[...] and set a default value.
|
||||
if field_name not in required_fields:
|
||||
field_metadata["default"] = field_def.get("default", None)
|
||||
|
||||
# Create Pydantic field
|
||||
fields[field_name] = (field_type, Field(**field_metadata))
|
||||
fields[field_name] = (base_type, Field(**field_metadata))
|
||||
|
||||
# Dynamically create the model
|
||||
return create_model("InputSchema", **fields)
|
||||
|
||||
|
||||
class MCPStdioClient:
|
||||
def __init__(self):
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_to_server(self, command_str: str):
|
||||
command = command_str.split(" ")
|
||||
server_params = StdioServerParameters(
|
||||
command=command[0],
|
||||
args=command[1:],
|
||||
env={"DEBUG": "true", "PATH": os.environ["PATH"]},
|
||||
)
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
await self.session.initialize()
|
||||
response = await self.session.list_tools()
|
||||
return response.tools
|
||||
|
||||
|
||||
class MCPSseClient:
|
||||
def __init__(self):
|
||||
self.write = None
|
||||
self.sse = None
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def pre_check_redirect(self, url: str):
|
||||
async with httpx.AsyncClient(follow_redirects=False) as client:
|
||||
response = await client.request("HEAD", url)
|
||||
if response.status_code == httpx.codes.TEMPORARY_REDIRECT:
|
||||
return response.headers.get("Location")
|
||||
return url
|
||||
|
||||
async def _connect_with_timeout(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
|
||||
):
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
|
||||
)
|
||||
self.sse, self.write = sse_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
|
||||
await self.session.initialize()
|
||||
|
||||
async def connect_to_server(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
|
||||
):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
url = await self.pre_check_redirect(url)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if self.session is None:
|
||||
msg = "Session not initialized"
|
||||
raise ValueError(msg)
|
||||
response = await self.session.list_tools()
|
||||
except asyncio.TimeoutError as err:
|
||||
msg = f"Connection to {url} timed out after {timeout_seconds} seconds"
|
||||
raise TimeoutError(msg) from err
|
||||
return response.tools
|
||||
|
|
|
|||
61
src/backend/base/langflow/components/deactivated/mcp_sse.py
Normal file
61
src/backend/base/langflow/components/deactivated/mcp_sse.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# from langflow.field_typing import Data
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import types
|
||||
|
||||
from langflow.base.mcp.util import (
|
||||
MCPSseClient,
|
||||
create_input_schema_from_json_schema,
|
||||
create_tool_coroutine,
|
||||
create_tool_func,
|
||||
)
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.io import MessageTextInput, Output
|
||||
|
||||
|
||||
class MCPSse(Component):
|
||||
client = MCPSseClient()
|
||||
tools = types.ListToolsResult
|
||||
tool_names = [str]
|
||||
display_name = "MCP Tools (SSE) [DEPRECATED]"
|
||||
description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent."
|
||||
documentation: str = "https://docs.langflow.org/components-custom-components"
|
||||
icon = "code"
|
||||
name = "MCPSse"
|
||||
legacy = True
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="url",
|
||||
display_name="mcp sse url",
|
||||
info="sse url",
|
||||
value="http://localhost:7860/api/v1/mcp/sse",
|
||||
tool_mode=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Tools", name="tools", method="build_output"),
|
||||
]
|
||||
|
||||
async def build_output(self) -> list[Tool]:
|
||||
if self.client.session is None:
|
||||
self.tools = await self.client.connect_to_server(self.url, {})
|
||||
|
||||
tool_list = []
|
||||
|
||||
for tool in self.tools:
|
||||
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
|
||||
tool_list.append(
|
||||
StructuredTool(
|
||||
name=tool.name, # maybe format this
|
||||
description=tool.description,
|
||||
args_schema=args_schema,
|
||||
func=create_tool_func(tool.name, args_schema, self.client.session),
|
||||
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
|
||||
)
|
||||
)
|
||||
|
||||
self.tool_names = [tool.name for tool in self.tools]
|
||||
return tool_list
|
||||
|
|
@ -1,51 +1,31 @@
|
|||
# from langflow.field_typing import Data
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp import types
|
||||
|
||||
from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func
|
||||
from langflow.base.mcp.util import (
|
||||
MCPStdioClient,
|
||||
create_input_schema_from_json_schema,
|
||||
create_tool_coroutine,
|
||||
create_tool_func,
|
||||
)
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.io import MessageTextInput, Output
|
||||
|
||||
|
||||
class MCPStdioClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_to_server(self, command_str: str):
|
||||
command = command_str.split(" ")
|
||||
server_params = StdioServerParameters(
|
||||
command=command[0], args=command[1:], env={"DEBUG": "true", "PATH": os.environ["PATH"]}
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await self.session.list_tools()
|
||||
return response.tools
|
||||
|
||||
|
||||
class MCPStdio(Component):
|
||||
client = MCPStdioClient()
|
||||
tools = types.ListToolsResult
|
||||
tool_names = [str]
|
||||
display_name = "MCP Tools (stdio)"
|
||||
display_name = "MCP Tools (stdio) [DEPRECATED]"
|
||||
description = (
|
||||
"Connects to an MCP server over stdio and exposes it's tools as langflow tools to be used by an Agent."
|
||||
)
|
||||
documentation: str = "https://docs.langflow.org/components-custom-components"
|
||||
icon = "code"
|
||||
name = "MCPStdio"
|
||||
legacy = True
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
|
|
@ -74,7 +54,7 @@ class MCPStdio(Component):
|
|||
name=tool.name,
|
||||
description=tool.description,
|
||||
args_schema=args_schema,
|
||||
func=create_tool_func(tool.name, args_schema),
|
||||
func=create_tool_func(tool.name, args_schema, self.client.session),
|
||||
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
|
||||
)
|
||||
)
|
||||
|
|
@ -13,7 +13,7 @@ from .google_search_api import GoogleSearchAPIComponent
|
|||
from .google_search_api_core import GoogleSearchAPICore
|
||||
from .google_serper_api import GoogleSerperAPIComponent
|
||||
from .google_serper_api_core import GoogleSerperAPICore
|
||||
from .mcp_stdio import MCPStdio
|
||||
from .mcp_component import MCPToolsComponent
|
||||
from .python_code_structured_tool import PythonCodeStructuredTool
|
||||
from .python_repl import PythonREPLToolComponent
|
||||
from .python_repl_core import PythonREPLComponent
|
||||
|
|
@ -51,7 +51,7 @@ __all__ = [
|
|||
"GoogleSearchAPICore",
|
||||
"GoogleSerperAPIComponent",
|
||||
"GoogleSerperAPICore",
|
||||
"MCPStdio",
|
||||
"MCPToolsComponent",
|
||||
"PythonCodeStructuredTool",
|
||||
"PythonREPLComponent",
|
||||
"PythonREPLToolComponent",
|
||||
|
|
|
|||
340
src/backend/base/langflow/components/tools/mcp_component.py
Normal file
340
src/backend/base/langflow/components/tools/mcp_component.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from langflow.base.mcp.util import (
|
||||
MCPSseClient,
|
||||
MCPStdioClient,
|
||||
create_input_schema_from_json_schema,
|
||||
create_tool_coroutine,
|
||||
create_tool_func,
|
||||
)
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs import DropdownInput
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.io import MessageTextInput, Output, TabInput
|
||||
from langflow.io.schema import schema_to_langflow_inputs
|
||||
from langflow.logging import logger
|
||||
from langflow.schema import Message
|
||||
|
||||
|
||||
class MCPToolsComponent(Component):
|
||||
schema_inputs: list[InputTypes] = []
|
||||
stdio_client = MCPStdioClient()
|
||||
sse_client = MCPSseClient()
|
||||
tools: list = []
|
||||
tool_names: list[str] = []
|
||||
_tool_cache: dict = {} # Cache for tool objects
|
||||
default_keys = ["code", "_type", "mode", "command", "sse_url", "tool_placeholder", "tool_mode", "tool"]
|
||||
|
||||
display_name = "MCP Server"
|
||||
description = "Connect to an MCP server and expose tools."
|
||||
icon = "server"
|
||||
name = "MCPTools"
|
||||
|
||||
inputs = [
|
||||
TabInput(
|
||||
name="mode",
|
||||
display_name="Mode",
|
||||
options=["Stdio", "SSE"],
|
||||
value="Stdio",
|
||||
info="Select the connection mode",
|
||||
real_time_refresh=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="command",
|
||||
display_name="MCP Command",
|
||||
info="Command for MCP stdio connection",
|
||||
value="uvx mcp-server-fetch",
|
||||
show=True,
|
||||
refresh_button=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="sse_url",
|
||||
display_name="MCP SSE URL",
|
||||
info="URL for MCP SSE connection",
|
||||
value="http://localhost:7860/api/v1/mcp/sse",
|
||||
show=False,
|
||||
refresh_button=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="tool",
|
||||
display_name="Tool",
|
||||
options=[],
|
||||
value="",
|
||||
info="Select the tool to execute",
|
||||
show=True,
|
||||
required=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="tool_placeholder",
|
||||
display_name="Tool Placeholder",
|
||||
info="Placeholder for the tool",
|
||||
value="",
|
||||
show=False,
|
||||
tool_mode=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Response", name="response", method="build_output"),
|
||||
]
|
||||
|
||||
async def _validate_connection_params(self, mode: str, command: str | None = None, url: str | None = None) -> None:
|
||||
"""Validate connection parameters based on mode."""
|
||||
if mode not in ["Stdio", "SSE"]:
|
||||
msg = f"Invalid mode: {mode}. Must be either 'Stdio' or 'SSE'"
|
||||
raise ValueError(msg)
|
||||
|
||||
if mode == "Stdio" and not command:
|
||||
msg = "Command is required for Stdio mode"
|
||||
raise ValueError(msg)
|
||||
if mode == "SSE" and not url:
|
||||
msg = "URL is required for SSE mode"
|
||||
raise ValueError(msg)
|
||||
|
||||
async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:
|
||||
"""Validate and process schema inputs for a tool."""
|
||||
try:
|
||||
if not tool_obj or not hasattr(tool_obj, "inputSchema"):
|
||||
msg = "Invalid tool object or missing input schema"
|
||||
raise ValueError(msg)
|
||||
|
||||
input_schema = create_input_schema_from_json_schema(tool_obj.inputSchema)
|
||||
if not input_schema:
|
||||
msg = f"Empty input schema for tool '{tool_obj.name}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
schema_inputs = schema_to_langflow_inputs(input_schema)
|
||||
if not schema_inputs:
|
||||
msg = f"No input parameters defined for tool '{tool_obj.name}'"
|
||||
logger.warning(msg)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error validating schema inputs: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
return schema_inputs
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
|
||||
"""Toggle the visibility of connection-specific fields based on the selected mode."""
|
||||
try:
|
||||
if field_name == "mode":
|
||||
self.remove_non_default_keys(build_config)
|
||||
if field_value == "Stdio":
|
||||
build_config["command"]["show"] = True
|
||||
build_config["sse_url"]["show"] = False
|
||||
elif field_value == "SSE":
|
||||
build_config["command"]["show"] = False
|
||||
build_config["sse_url"]["show"] = True
|
||||
if field_name in ("command", "sse_url", "mode"):
|
||||
try:
|
||||
await self.update_tools()
|
||||
if "tool" in build_config:
|
||||
build_config["tool"]["options"] = self.tool_names
|
||||
except Exception as e:
|
||||
build_config["tool"]["options"] = []
|
||||
msg = f"Failed to update tools: {e!s}"
|
||||
raise ValueError(msg) from e
|
||||
elif field_name == "tool":
|
||||
if len(self.tools) == 0:
|
||||
await self.update_tools()
|
||||
if self.tool is None:
|
||||
return build_config
|
||||
tool_obj = None
|
||||
for tool in self.tools:
|
||||
if tool.name == self.tool:
|
||||
tool_obj = tool
|
||||
break
|
||||
if tool_obj is None:
|
||||
msg = f"Tool {self.tool} not found in available tools: {self.tools}"
|
||||
logger.warning(msg)
|
||||
return build_config
|
||||
self.remove_non_default_keys(build_config)
|
||||
await self._update_tool_config(build_config, field_value)
|
||||
elif field_name == "tool_mode":
|
||||
build_config["tool"]["show"] = not field_value
|
||||
for key, value in list(build_config.items()):
|
||||
if key not in self.default_keys and isinstance(value, dict) and "show" in value:
|
||||
build_config[key]["show"] = not field_value
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error in update_build_config: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
return build_config
|
||||
|
||||
def get_inputs_for_all_tools(self, tools: list) -> dict:
|
||||
"""Get input schemas for all tools."""
|
||||
inputs = {}
|
||||
for tool in tools:
|
||||
if not tool or not hasattr(tool, "name"):
|
||||
continue
|
||||
try:
|
||||
input_schema = schema_to_langflow_inputs(create_input_schema_from_json_schema(tool.inputSchema))
|
||||
inputs[tool.name] = input_schema
|
||||
except (AttributeError, ValueError, TypeError, KeyError) as e:
|
||||
msg = f"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}"
|
||||
logger.exception(msg)
|
||||
continue
|
||||
return inputs
|
||||
|
||||
def remove_input_schema_from_build_config(
|
||||
self, build_config: dict, tool_name: str, input_schema: dict[list[InputTypes], Any]
|
||||
):
|
||||
"""Remove the input schema for the tool from the build config."""
|
||||
# Keep only schemas that don't belong to the current tool
|
||||
input_schema = {k: v for k, v in input_schema.items() if k != tool_name}
|
||||
# Remove all inputs from other tools
|
||||
for value in input_schema.values():
|
||||
for _input in value:
|
||||
if _input.name in build_config:
|
||||
build_config.pop(_input.name)
|
||||
|
||||
def remove_non_default_keys(self, build_config: dict) -> None:
|
||||
"""Remove non-default keys from the build config."""
|
||||
for key in list(build_config.keys()):
|
||||
if key not in self.default_keys:
|
||||
build_config.pop(key)
|
||||
|
||||
async def _update_tool_config(self, build_config: dict, tool_name: str) -> None:
|
||||
"""Update tool configuration with proper error handling."""
|
||||
if not self.tools:
|
||||
await self.update_tools()
|
||||
|
||||
if not tool_name:
|
||||
return
|
||||
|
||||
tool_obj = next((tool for tool in self.tools if tool.name == tool_name), None)
|
||||
if not tool_obj:
|
||||
msg = f"Tool {tool_name} not found in available tools: {self.tools}"
|
||||
logger.warning(msg)
|
||||
return
|
||||
|
||||
try:
|
||||
# Get all tool inputs and remove old ones
|
||||
input_schema_for_all_tools = self.get_inputs_for_all_tools(self.tools)
|
||||
self.remove_input_schema_from_build_config(build_config, tool_name, input_schema_for_all_tools)
|
||||
|
||||
# Get and validate new inputs
|
||||
self.schema_inputs = await self._validate_schema_inputs(tool_obj)
|
||||
if not self.schema_inputs:
|
||||
msg = f"No input parameters to configure for tool '{tool_name}'"
|
||||
logger.info(msg)
|
||||
return
|
||||
|
||||
# Add new inputs to build config
|
||||
for schema_input in self.schema_inputs:
|
||||
if not schema_input or not hasattr(schema_input, "name"):
|
||||
msg = "Invalid schema input detected, skipping"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
name = schema_input.name
|
||||
input_dict = schema_input.to_dict()
|
||||
input_dict.setdefault("value", None)
|
||||
input_dict.setdefault("required", True)
|
||||
build_config[name] = input_dict
|
||||
except (AttributeError, KeyError, TypeError) as e:
|
||||
msg = f"Error processing schema input {schema_input}: {e!s}"
|
||||
logger.exception(msg)
|
||||
continue
|
||||
|
||||
except ValueError as e:
|
||||
msg = f"Schema validation error for tool {tool_name}: {e!s}"
|
||||
logger.exception(msg)
|
||||
self.schema_inputs = []
|
||||
return
|
||||
except (AttributeError, KeyError, TypeError) as e:
|
||||
msg = f"Error updating tool config: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
async def build_output(self) -> Message:
|
||||
"""Build output with improved error handling and validation."""
|
||||
try:
|
||||
await self.update_tools()
|
||||
if self.tool != "":
|
||||
exec_tool = self._tool_cache[self.tool]
|
||||
tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool]
|
||||
kwargs = {}
|
||||
for arg in tool_args:
|
||||
value = getattr(self, arg.name, None)
|
||||
if value:
|
||||
kwargs[arg.name] = value
|
||||
output = await exec_tool.coroutine(**kwargs)
|
||||
return Message(text=output.content[len(output.content) - 1].text)
|
||||
return Message(text="You must select a tool", error=True)
|
||||
except Exception as e:
|
||||
msg = f"Error in build_output: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
async def update_tools(self) -> list[StructuredTool]:
|
||||
"""Connect to the MCP server and update available tools with improved error handling."""
|
||||
try:
|
||||
await self._validate_connection_params(self.mode, self.command, self.sse_url)
|
||||
|
||||
if self.mode == "Stdio":
|
||||
if not self.stdio_client.session:
|
||||
self.tools = await self.stdio_client.connect_to_server(self.command)
|
||||
elif self.mode == "SSE" and not self.sse_client.session:
|
||||
self.tools = await self.sse_client.connect_to_server(self.sse_url, {})
|
||||
|
||||
if not self.tools:
|
||||
logger.warning("No tools returned from server")
|
||||
return []
|
||||
|
||||
tool_list = []
|
||||
for tool in self.tools:
|
||||
if not tool or not hasattr(tool, "name"):
|
||||
logger.warning("Invalid tool object detected, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
|
||||
if not args_schema:
|
||||
msg = f"Empty schema for tool '{tool.name}', skipping"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
client = self.stdio_client if self.mode == "Stdio" else self.sse_client
|
||||
if not client or not client.session:
|
||||
msg = f"Invalid client session for tool '{tool.name}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
tool_obj = StructuredTool(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
args_schema=args_schema,
|
||||
func=create_tool_func(tool.name, args_schema, client.session),
|
||||
coroutine=create_tool_coroutine(tool.name, args_schema, client.session),
|
||||
tags=[tool.name],
|
||||
)
|
||||
tool_list.append(tool_obj)
|
||||
self._tool_cache[tool.name] = tool_obj
|
||||
except (AttributeError, ValueError, TypeError, KeyError) as e:
|
||||
msg = f"Error creating tool {getattr(tool, 'name', 'unknown')}: {e!s}"
|
||||
logger.exception(msg)
|
||||
continue
|
||||
|
||||
self.tool_names = [tool.name for tool in self.tools if hasattr(tool, "name")]
|
||||
|
||||
except (ValueError, RuntimeError, asyncio.TimeoutError) as e:
|
||||
msg = f"Error updating tools: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
return tool_list
|
||||
|
||||
async def _get_tools(self):
|
||||
"""Get cached tools or update if necessary."""
|
||||
if not self.tools:
|
||||
return await self.update_tools()
|
||||
return self.tools
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
# from langflow.field_typing import Data
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession, types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from langflow.base.mcp.util import create_input_schema_from_json_schema, create_tool_coroutine, create_tool_func
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.io import MessageTextInput, Output
|
||||
|
||||
# Define constant for status code
|
||||
HTTP_TEMPORARY_REDIRECT = 307
|
||||
|
||||
|
||||
class MCPSseClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.write = None
|
||||
self.sse = None
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def pre_check_redirect(self, url: str):
|
||||
"""Check if the URL responds with a 307 Redirect."""
|
||||
async with httpx.AsyncClient(follow_redirects=False) as client:
|
||||
response = await client.request("HEAD", url)
|
||||
if response.status_code == HTTP_TEMPORARY_REDIRECT:
|
||||
return response.headers.get("Location") # Return the redirect URL
|
||||
return url # Return the original URL if no redirect
|
||||
|
||||
async def _connect_with_timeout(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int, sse_read_timeout_seconds: int
|
||||
):
|
||||
"""Connect to the SSE server with timeout."""
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(url, headers, timeout_seconds, sse_read_timeout_seconds)
|
||||
)
|
||||
self.sse, self.write = sse_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.sse, self.write))
|
||||
await self.session.initialize()
|
||||
|
||||
async def connect_to_server(
|
||||
self, url: str, headers: dict[str, str] | None, timeout_seconds: int = 500, sse_read_timeout_seconds: int = 500
|
||||
):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
url = await self.pre_check_redirect(url)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._connect_with_timeout(url, headers, timeout_seconds, sse_read_timeout_seconds),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
# List available tools
|
||||
if self.session is None:
|
||||
msg = "Session not initialized"
|
||||
raise ValueError(msg)
|
||||
response = await self.session.list_tools()
|
||||
except asyncio.TimeoutError as err:
|
||||
error_message = f"Connection to {url} timed out after {timeout_seconds} seconds"
|
||||
raise TimeoutError(error_message) from err
|
||||
return response.tools
|
||||
|
||||
|
||||
class MCPSse(Component):
|
||||
client = MCPSseClient()
|
||||
tools = types.ListToolsResult
|
||||
tool_names = [str]
|
||||
display_name = "MCP Tools (SSE)"
|
||||
description = "Connects to an MCP server over SSE and exposes it's tools as langflow tools to be used by an Agent."
|
||||
documentation: str = "https://docs.langflow.org/components-custom-components"
|
||||
icon = "code"
|
||||
name = "MCPSse"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="url",
|
||||
display_name="mcp sse url",
|
||||
info="sse url",
|
||||
value="http://localhost:7860/api/v1/mcp/sse",
|
||||
tool_mode=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Tools", name="tools", method="build_output"),
|
||||
]
|
||||
|
||||
async def build_output(self) -> list[Tool]:
|
||||
if self.client.session is None:
|
||||
self.tools = await self.client.connect_to_server(self.url, {})
|
||||
|
||||
tool_list = []
|
||||
|
||||
for tool in self.tools:
|
||||
args_schema = create_input_schema_from_json_schema(tool.inputSchema)
|
||||
tool_list.append(
|
||||
StructuredTool(
|
||||
name=tool.name, # maybe format this
|
||||
description=tool.description,
|
||||
args_schema=args_schema,
|
||||
func=create_tool_func(tool.name, self.client.session),
|
||||
coroutine=create_tool_coroutine(tool.name, args_schema, self.client.session),
|
||||
)
|
||||
)
|
||||
|
||||
self.tool_names = [tool.name for tool in self.tools]
|
||||
return tool_list
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import Literal, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langflow.inputs.inputs import FieldTypes
|
||||
from langflow.inputs.inputs import BoolInput, DictInput, FieldTypes, FloatInput, InputTypes, IntInput, MessageTextInput
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
_convert_field_type_to_type: dict[FieldTypes, type] = {
|
||||
|
|
@ -20,8 +20,65 @@ _convert_field_type_to_type: dict[FieldTypes, type] = {
|
|||
FieldTypes.TAB: str,
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
|
||||
_convert_type_to_field_type = {
|
||||
str: MessageTextInput,
|
||||
int: IntInput,
|
||||
float: FloatInput,
|
||||
bool: BoolInput,
|
||||
dict: DictInput,
|
||||
list: MessageTextInput,
|
||||
}
|
||||
|
||||
|
||||
def schema_to_langflow_inputs(schema: type[BaseModel]) -> list["InputTypes"]:
|
||||
"""Given a Pydantic schema, convert its fields to Langflow input definitions."""
|
||||
inputs = []
|
||||
for field_name, model_field in schema.model_fields.items():
|
||||
# Start with the field's annotation type
|
||||
field_type = model_field.annotation
|
||||
is_list = False
|
||||
options = None
|
||||
|
||||
# If the field is a list, record that and extract its inner type.
|
||||
if get_origin(field_type) is list:
|
||||
is_list = True
|
||||
field_type = get_args(field_type)[0]
|
||||
|
||||
# If the field type is a Literal, extract its allowed values.
|
||||
if get_origin(field_type) is Literal:
|
||||
options = list(get_args(field_type))
|
||||
# Optionally, set field_type to the type of the literal values.
|
||||
if options:
|
||||
field_type = type(options[0])
|
||||
|
||||
# Handle Union types (e.g., Optional fields)
|
||||
if get_origin(field_type) is Union:
|
||||
# Get the first non-None type from the Union
|
||||
field_type = next(t for t in get_args(field_type) if t is not type(None))
|
||||
|
||||
# Convert the Python type to the Langflow field type using our reverse mapping.
|
||||
try:
|
||||
langflow_field_type = _convert_type_to_field_type[field_type]
|
||||
except KeyError as e:
|
||||
msg = f"Unsupported field type: {field_type}"
|
||||
raise TypeError(msg) from e
|
||||
|
||||
# Get metadata from the Pydantic Field.
|
||||
title = model_field.title or field_name.replace("_", " ").title()
|
||||
description = model_field.description or ""
|
||||
required = model_field.is_required()
|
||||
|
||||
# Construct the Langflow input.
|
||||
input_obj = langflow_field_type(
|
||||
display_name=title,
|
||||
name=field_name,
|
||||
info=description,
|
||||
required=required,
|
||||
is_list=is_list,
|
||||
)
|
||||
inputs.append(input_obj)
|
||||
return inputs
|
||||
|
||||
|
||||
def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]:
|
||||
|
|
|
|||
0
src/backend/tests/integration/components/mcp/__init__.py
Normal file
0
src/backend/tests/integration/components/mcp/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
async def test_mcp_component():
|
||||
from langflow.components.tools.mcp_component import MCPToolsComponent
|
||||
|
||||
inputs = {}
|
||||
await run_single_component(
|
||||
MCPToolsComponent,
|
||||
inputs=inputs, # test default inputs
|
||||
)
|
||||
255
src/backend/tests/unit/components/tools/test_mcp_component.py
Normal file
255
src/backend/tests/unit/components/tools/test_mcp_component.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.components.tools.mcp_component import MCPSseClient, MCPStdioClient, MCPToolsComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient, VersionComponentMapping
|
||||
|
||||
|
||||
class TestMCPToolsComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
"""Return the component class to test."""
|
||||
return MCPToolsComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
"""Return the default kwargs for the component."""
|
||||
return {
|
||||
"mode": "Stdio",
|
||||
"command": "uvx mcp-server-fetch",
|
||||
"sse_url": "http://localhost:7860/api/v1/mcp/sse",
|
||||
"tool": "",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self) -> list[VersionComponentMapping]:
|
||||
"""Return the file names mapping for different versions."""
|
||||
return []
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self):
|
||||
"""Create a mock MCP tool."""
|
||||
tool = MagicMock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
tool.inputSchema = {
|
||||
"type": "object",
|
||||
"properties": {"test_param": {"type": "string", "description": "Test parameter"}},
|
||||
}
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stdio_client(self, mock_tool):
|
||||
"""Create a mock stdio client."""
|
||||
stdio_client = AsyncMock()
|
||||
stdio_client.connect_to_server = AsyncMock(return_value=[mock_tool])
|
||||
stdio_client.session = AsyncMock()
|
||||
return stdio_client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sse_client(self, mock_tool):
|
||||
"""Create a mock SSE client."""
|
||||
sse_client = AsyncMock()
|
||||
sse_client.connect_to_server = AsyncMock(return_value=[mock_tool])
|
||||
sse_client.session = AsyncMock()
|
||||
return sse_client
|
||||
|
||||
async def test_validate_connection_params_invalid_mode(self, component_class, default_kwargs):
|
||||
"""Test validation with invalid mode."""
|
||||
component = component_class(**default_kwargs)
|
||||
with pytest.raises(ValueError, match="Invalid mode: invalid. Must be either 'Stdio' or 'SSE'"):
|
||||
await component._validate_connection_params("invalid")
|
||||
|
||||
async def test_validate_connection_params_missing_command(self, component_class, default_kwargs):
|
||||
"""Test validation with missing command in Stdio mode."""
|
||||
component = component_class(**default_kwargs)
|
||||
with pytest.raises(ValueError, match="Command is required for Stdio mode"):
|
||||
await component._validate_connection_params("Stdio", command=None)
|
||||
|
||||
async def test_validate_connection_params_missing_url(self, component_class, default_kwargs):
|
||||
"""Test validation with missing URL in SSE mode."""
|
||||
component = component_class(**default_kwargs)
|
||||
with pytest.raises(ValueError, match="URL is required for SSE mode"):
|
||||
await component._validate_connection_params("SSE", url=None)
|
||||
|
||||
async def test_update_build_config_mode_change(self, component_class, default_kwargs):
|
||||
"""Test build config updates when mode changes."""
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {
|
||||
"command": {"show": False},
|
||||
"sse_url": {"show": True},
|
||||
"tool": {"options": [], "show": True}, # Add tool field since component uses it
|
||||
}
|
||||
|
||||
# Test switching to Stdio mode
|
||||
updated_config = await component.update_build_config(build_config, "Stdio", "mode")
|
||||
assert updated_config["command"]["show"] is True
|
||||
assert updated_config["sse_url"]["show"] is False
|
||||
|
||||
# Test switching to SSE mode
|
||||
updated_config = await component.update_build_config(build_config, "SSE", "mode")
|
||||
assert updated_config["command"]["show"] is False
|
||||
assert updated_config["sse_url"]["show"] is True
|
||||
|
||||
# Test tool options are updated
|
||||
assert "options" in updated_config["tool"]
|
||||
|
||||
@patch("langflow.components.tools.mcp_component.create_tool_coroutine")
|
||||
async def test_build_output(self, mock_create_coroutine, component_class, default_kwargs, mock_tool):
|
||||
"""Test building output with a tool."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.tool = "test_tool"
|
||||
component.tools = [mock_tool]
|
||||
|
||||
# Mock the coroutine response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.content = [MagicMock(text="Test response")]
|
||||
mock_create_coroutine.return_value = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Create a mock tool and add it to the cache
|
||||
mock_structured_tool = MagicMock()
|
||||
mock_structured_tool.coroutine = mock_create_coroutine.return_value
|
||||
component._tool_cache = {"test_tool": mock_structured_tool}
|
||||
|
||||
# Set the test parameter value
|
||||
component.test_param = "test value"
|
||||
|
||||
# Mock get_inputs_for_all_tools to return our mock input
|
||||
mock_input = MagicMock()
|
||||
mock_input.name = "test_param"
|
||||
with patch.object(component, "get_inputs_for_all_tools") as mock_get_inputs:
|
||||
mock_get_inputs.return_value = {"test_tool": [mock_input]}
|
||||
output = await component.build_output()
|
||||
|
||||
assert output.text == "Test response"
|
||||
# Verify the mocks were called correctly
|
||||
mock_get_inputs.assert_called_once_with(component.tools)
|
||||
mock_structured_tool.coroutine.assert_called_once_with(test_param="test value")
|
||||
|
||||
async def test_get_inputs_for_all_tools(self, component_class, default_kwargs, mock_tool):
|
||||
"""Test getting input schemas for all tools."""
|
||||
component = component_class(**default_kwargs)
|
||||
inputs = component.get_inputs_for_all_tools([mock_tool])
|
||||
|
||||
assert "test_tool" in inputs
|
||||
assert len(inputs["test_tool"]) > 0 # Should have at least one input parameter
|
||||
|
||||
async def test_remove_non_default_keys(self, component_class, default_kwargs):
|
||||
"""Test removing non-default keys from build config."""
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {"code": {}, "mode": {}, "command": {}, "custom_key": {}}
|
||||
|
||||
component.remove_non_default_keys(build_config)
|
||||
assert "custom_key" not in build_config
|
||||
assert all(key in build_config for key in ["code", "mode", "command"])
|
||||
|
||||
|
||||
class TestMCPStdioClient:
|
||||
@pytest.fixture
|
||||
def stdio_client(self):
|
||||
return MCPStdioClient()
|
||||
|
||||
async def test_connect_to_server(self, stdio_client):
|
||||
"""Test connecting to server via Stdio."""
|
||||
# Create mock for stdio transport
|
||||
mock_stdio = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_stdio_transport = (mock_stdio, mock_write)
|
||||
mock_stdio_cm = AsyncMock()
|
||||
mock_stdio_cm.__aenter__.return_value = mock_stdio_transport
|
||||
|
||||
# Mock the stdio_client function to return our mock context manager
|
||||
with patch("mcp.client.stdio.stdio_client", return_value=mock_stdio_cm):
|
||||
# Mock ClientSession
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools.return_value.tools = [MagicMock()]
|
||||
|
||||
# Mock the AsyncExitStack
|
||||
mock_exit_stack = AsyncMock()
|
||||
mock_exit_stack.enter_async_context = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
mock_stdio_transport, # For stdio_client
|
||||
mock_session, # For ClientSession
|
||||
]
|
||||
stdio_client.exit_stack = mock_exit_stack
|
||||
|
||||
tools = await stdio_client.connect_to_server("test_command")
|
||||
|
||||
assert len(tools) == 1
|
||||
assert stdio_client.session is not None
|
||||
# Verify the exit stack was used correctly
|
||||
assert mock_exit_stack.enter_async_context.call_count == 2
|
||||
# Verify the stdio transport was properly set
|
||||
assert stdio_client.stdio == mock_stdio
|
||||
assert stdio_client.write == mock_write
|
||||
|
||||
|
||||
class TestMCPSseClient:
|
||||
@pytest.fixture
|
||||
def sse_client(self):
|
||||
return MCPSseClient()
|
||||
|
||||
async def test_pre_check_redirect(self, sse_client):
|
||||
"""Test pre-checking URL for redirects."""
|
||||
test_url = "http://test.url"
|
||||
redirect_url = "http://redirect.url"
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 307
|
||||
mock_response.headers.get.return_value = redirect_url
|
||||
mock_client.return_value.__aenter__.return_value.request.return_value = mock_response
|
||||
|
||||
result = await sse_client.pre_check_redirect(test_url)
|
||||
assert result == redirect_url
|
||||
|
||||
async def test_connect_to_server(self, sse_client):
|
||||
"""Test connecting to server via SSE."""
|
||||
# Mock the pre_check_redirect first
|
||||
with patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"):
|
||||
# Create mock for sse_client context manager
|
||||
mock_sse = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_sse_transport = (mock_sse, mock_write)
|
||||
mock_sse_cm = AsyncMock()
|
||||
mock_sse_cm.__aenter__.return_value = mock_sse_transport
|
||||
|
||||
# Mock the sse_client function to return our mock context manager
|
||||
with patch("mcp.client.sse.sse_client", return_value=mock_sse_cm):
|
||||
# Mock ClientSession
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools.return_value.tools = [MagicMock()]
|
||||
|
||||
# Mock the AsyncExitStack
|
||||
mock_exit_stack = AsyncMock()
|
||||
mock_exit_stack.enter_async_context = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
mock_sse_transport, # For sse_client
|
||||
mock_session, # For ClientSession
|
||||
]
|
||||
sse_client.exit_stack = mock_exit_stack
|
||||
|
||||
tools = await sse_client.connect_to_server("http://test.url", {})
|
||||
|
||||
assert len(tools) == 1
|
||||
assert sse_client.session is not None
|
||||
# Verify the exit stack was used correctly
|
||||
assert mock_exit_stack.enter_async_context.call_count == 2
|
||||
# Verify the SSE transport was properly set
|
||||
assert sse_client.sse == mock_sse
|
||||
assert sse_client.write == mock_write
|
||||
|
||||
async def test_connect_timeout(self, sse_client):
|
||||
"""Test connection timeout handling."""
|
||||
with (
|
||||
patch.object(sse_client, "pre_check_redirect", return_value="http://test.url"),
|
||||
patch.object(sse_client, "_connect_with_timeout") as mock_connect,
|
||||
):
|
||||
mock_connect.side_effect = asyncio.TimeoutError()
|
||||
|
||||
with pytest.raises(TimeoutError, match="Connection to http://test.url timed out after 1 seconds"):
|
||||
await sse_client.connect_to_server("http://test.url", {}, timeout_seconds=1)
|
||||
|
|
@ -3,11 +3,13 @@ from types import NoneType
|
|||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from langflow.inputs.inputs import BoolInput, DictInput, FloatInput, InputTypes, IntInput, MessageTextInput
|
||||
from langflow.io.schema import schema_to_langflow_inputs
|
||||
from langflow.schema.data import Data
|
||||
from langflow.template import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
class TestInput:
|
||||
|
|
@ -178,3 +180,65 @@ class TestPostProcessType:
|
|||
pass
|
||||
|
||||
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007
|
||||
|
||||
|
||||
def test_schema_to_langflow_inputs():
|
||||
# Define a test Pydantic model with various field types
|
||||
class TestSchema(BaseModel):
|
||||
text_field: str = Field(title="Custom Text Title", description="A text field")
|
||||
number_field: int = Field(description="A number field")
|
||||
bool_field: bool = Field(description="A boolean field")
|
||||
dict_field: dict = Field(description="A dictionary field")
|
||||
list_field: list[str] = Field(description="A list of strings")
|
||||
|
||||
# Convert schema to Langflow inputs
|
||||
inputs = schema_to_langflow_inputs(TestSchema)
|
||||
|
||||
# Verify the number of inputs matches the schema fields
|
||||
assert len(inputs) == 5
|
||||
|
||||
# Helper function to find input by name
|
||||
def find_input(name: str) -> InputTypes | None:
|
||||
for _input in inputs:
|
||||
if _input.name == name:
|
||||
return _input
|
||||
return None
|
||||
|
||||
# Test text field
|
||||
text_input = find_input("text_field")
|
||||
assert text_input.display_name == "Custom Text Title"
|
||||
assert text_input.info == "A text field"
|
||||
assert isinstance(text_input, MessageTextInput) # Check the instance type instead of field_type
|
||||
|
||||
# Test number field
|
||||
number_input = find_input("number_field")
|
||||
assert number_input.display_name == "Number Field"
|
||||
assert number_input.info == "A number field"
|
||||
assert isinstance(number_input, IntInput | FloatInput)
|
||||
|
||||
# Test boolean field
|
||||
bool_input = find_input("bool_field")
|
||||
assert isinstance(bool_input, BoolInput)
|
||||
|
||||
# Test dictionary field
|
||||
dict_input = find_input("dict_field")
|
||||
assert isinstance(dict_input, DictInput)
|
||||
|
||||
# Test list field
|
||||
list_input = find_input("list_field")
|
||||
assert list_input.is_list is True
|
||||
assert isinstance(list_input, MessageTextInput)
|
||||
|
||||
|
||||
def test_schema_to_langflow_inputs_invalid_type():
|
||||
# Define a schema with an unsupported type
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
class InvalidSchema(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True} # Add this line
|
||||
invalid_field: CustomType
|
||||
|
||||
# Test that attempting to convert an unsupported type raises TypeError
|
||||
with pytest.raises(TypeError, match="Unsupported field type:"):
|
||||
schema_to_langflow_inputs(InvalidSchema)
|
||||
|
|
|
|||
1
src/frontend/package-lock.json
generated
1
src/frontend/package-lock.json
generated
|
|
@ -706,6 +706,7 @@
|
|||
},
|
||||
"node_modules/@clack/prompts/node_modules/is-unicode-supported": {
|
||||
"version": "1.3.0",
|
||||
"extraneous": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ export const ForwardedIconComponent = memo(
|
|||
nodeIconsLucide[
|
||||
name
|
||||
?.split("-")
|
||||
?.map((x) => String(x[0]).toUpperCase() + String(x).slice(1))
|
||||
?.map((x) => String(x[0])?.toUpperCase() + String(x).slice(1))
|
||||
?.join("")
|
||||
];
|
||||
if (!TargetIcon) {
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ export default function RenderKey({
|
|||
className={cn(tableRender ? "h-4 w-4" : "h-3 w-3")}
|
||||
/>
|
||||
) : (
|
||||
<span>{value.toUpperCase()}</span>
|
||||
<span>{value?.toUpperCase()}</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -61,9 +61,11 @@ export interface ButtonProps
|
|||
|
||||
function toTitleCase(text: string) {
|
||||
return text
|
||||
.split(" ")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
|
||||
.join(" ");
|
||||
?.split(" ")
|
||||
?.map(
|
||||
(word) => word?.charAt(0)?.toUpperCase() + word?.slice(1)?.toLowerCase(),
|
||||
)
|
||||
?.join(" ");
|
||||
}
|
||||
|
||||
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
|
||||
|
|
|
|||
|
|
@ -1022,3 +1022,18 @@ export const IS_AUTO_LOGIN =
|
|||
|
||||
export const AUTO_LOGIN_RETRY_DELAY = 2000;
|
||||
export const AUTO_LOGIN_MAX_RETRY_DELAY = 60000;
|
||||
|
||||
export const ALL_LANGUAGES = [
|
||||
{ value: "en-US", name: "English (US)" },
|
||||
{ value: "en-GB", name: "English (UK)" },
|
||||
{ value: "it-IT", name: "Italian" },
|
||||
{ value: "fr-FR", name: "French" },
|
||||
{ value: "es-ES", name: "Spanish" },
|
||||
{ value: "de-DE", name: "German" },
|
||||
{ value: "ja-JP", name: "Japanese" },
|
||||
{ value: "pt-BR", name: "Portuguese (Brazil)" },
|
||||
{ value: "zh-CN", name: "Chinese (Simplified)" },
|
||||
{ value: "ru-RU", name: "Russian" },
|
||||
{ value: "ar-SA", name: "Arabic" },
|
||||
{ value: "hi-IN", name: "Hindi" },
|
||||
];
|
||||
|
|
|
|||
|
|
@ -21,21 +21,6 @@ import LanguageSelect from "./components/language-select";
|
|||
import MicrophoneSelect from "./components/microphone-select";
|
||||
import VoiceSelect from "./components/voice-select";
|
||||
|
||||
const ALL_LANGUAGES = [
|
||||
{ value: "en-US", name: "English (US)" },
|
||||
{ value: "en-GB", name: "English (UK)" },
|
||||
{ value: "it-IT", name: "Italian" },
|
||||
{ value: "fr-FR", name: "French" },
|
||||
{ value: "es-ES", name: "Spanish" },
|
||||
{ value: "de-DE", name: "German" },
|
||||
{ value: "ja-JP", name: "Japanese" },
|
||||
{ value: "pt-BR", name: "Portuguese (Brazil)" },
|
||||
{ value: "zh-CN", name: "Chinese (Simplified)" },
|
||||
{ value: "ru-RU", name: "Russian" },
|
||||
{ value: "ar-SA", name: "Arabic" },
|
||||
{ value: "hi-IN", name: "Hindi" },
|
||||
];
|
||||
|
||||
interface SettingsVoiceModalProps {
|
||||
children?: React.ReactNode;
|
||||
userOpenaiApiKey?: string;
|
||||
|
|
@ -415,7 +400,6 @@ const SettingsVoiceModal = ({
|
|||
<LanguageSelect
|
||||
language={currentLanguage}
|
||||
handleSetLanguage={handleSetLanguage}
|
||||
allLanguages={ALL_LANGUAGES}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import { ALL_LANGUAGES } from "@/constants/constants";
|
||||
import IconComponent from "../../../../../../../../../../components/common/genericIconComponent";
|
||||
import ShadTooltip from "../../../../../../../../../../components/common/shadTooltipComponent";
|
||||
import {
|
||||
|
|
@ -12,13 +13,11 @@ import {
|
|||
interface LanguageSelectProps {
|
||||
language: string;
|
||||
handleSetLanguage: (value: string) => void;
|
||||
allLanguages: { value: string; name: string }[];
|
||||
}
|
||||
|
||||
const LanguageSelect = ({
|
||||
language,
|
||||
handleSetLanguage,
|
||||
allLanguages,
|
||||
}: LanguageSelectProps) => {
|
||||
return (
|
||||
<div className="grid w-full items-center gap-2">
|
||||
|
|
@ -41,7 +40,7 @@ const LanguageSelect = ({
|
|||
</SelectTrigger>
|
||||
<SelectContent className="max-h-[200px]">
|
||||
<SelectGroup>
|
||||
{allLanguages.map((lang) => (
|
||||
{ALL_LANGUAGES?.map((lang) => (
|
||||
<SelectItem key={lang?.value} value={lang?.value}>
|
||||
<div className="max-w-[220px] truncate text-left">
|
||||
{lang?.name}
|
||||
|
|
|
|||
|
|
@ -47,9 +47,9 @@ export function toNormalCase(str: string): string {
|
|||
.split("_")
|
||||
.map((word, index) => {
|
||||
if (index === 0) {
|
||||
return word[0].toUpperCase() + word.slice(1).toLowerCase();
|
||||
return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
|
||||
}
|
||||
return word.toLowerCase();
|
||||
return word?.toLowerCase();
|
||||
})
|
||||
.join(" ");
|
||||
|
||||
|
|
@ -57,9 +57,9 @@ export function toNormalCase(str: string): string {
|
|||
.split("-")
|
||||
.map((word, index) => {
|
||||
if (index === 0) {
|
||||
return word[0].toUpperCase() + word.slice(1).toLowerCase();
|
||||
return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
|
||||
}
|
||||
return word.toLowerCase();
|
||||
return word?.toLowerCase();
|
||||
})
|
||||
.join(" ");
|
||||
}
|
||||
|
|
@ -69,11 +69,11 @@ export function normalCaseToSnakeCase(str: string): string {
|
|||
.split(" ")
|
||||
.map((word, index) => {
|
||||
if (index === 0) {
|
||||
return word[0].toUpperCase() + word.slice(1).toLowerCase();
|
||||
return word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
|
||||
}
|
||||
return word.toLowerCase();
|
||||
return word?.toLowerCase();
|
||||
})
|
||||
.join("_");
|
||||
?.join("_");
|
||||
}
|
||||
|
||||
export function toTitleCase(
|
||||
|
|
@ -82,41 +82,41 @@ export function toTitleCase(
|
|||
): string {
|
||||
if (!str) return "";
|
||||
let result = str
|
||||
.split("_")
|
||||
.map((word, index) => {
|
||||
?.split("_")
|
||||
?.map((word, index) => {
|
||||
if (isNodeField) return word;
|
||||
if (index === 0) {
|
||||
return checkUpperWords(
|
||||
word[0].toUpperCase() + word.slice(1).toLowerCase(),
|
||||
word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(),
|
||||
);
|
||||
}
|
||||
return checkUpperWords(word.toLowerCase());
|
||||
return checkUpperWords(word?.toLowerCase());
|
||||
})
|
||||
.join(" ");
|
||||
|
||||
return result
|
||||
.split("-")
|
||||
.map((word, index) => {
|
||||
?.split("-")
|
||||
?.map((word, index) => {
|
||||
if (isNodeField) return word;
|
||||
if (index === 0) {
|
||||
return checkUpperWords(
|
||||
word[0].toUpperCase() + word.slice(1).toLowerCase(),
|
||||
word[0]?.toUpperCase() + word.slice(1)?.toLowerCase(),
|
||||
);
|
||||
}
|
||||
return checkUpperWords(word.toLowerCase());
|
||||
return checkUpperWords(word?.toLowerCase());
|
||||
})
|
||||
.join(" ");
|
||||
?.join(" ");
|
||||
}
|
||||
|
||||
export const upperCaseWords: string[] = ["llm", "uri"];
|
||||
export function checkUpperWords(str: string): string {
|
||||
const words = str.split(" ").map((word) => {
|
||||
return upperCaseWords.includes(word.toLowerCase())
|
||||
? word.toUpperCase()
|
||||
: word[0].toUpperCase() + word.slice(1).toLowerCase();
|
||||
const words = str?.split(" ")?.map((word) => {
|
||||
return upperCaseWords.includes(word?.toLowerCase())
|
||||
? word?.toUpperCase()
|
||||
: word[0]?.toUpperCase() + word.slice(1)?.toLowerCase();
|
||||
});
|
||||
|
||||
return words.join(" ");
|
||||
return words?.join(" ");
|
||||
}
|
||||
|
||||
export function buildInputs(): string {
|
||||
|
|
|
|||
|
|
@ -215,6 +215,8 @@ test(
|
|||
).isVisible(),
|
||||
);
|
||||
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
await awaitBootstrapTest(page, { skipGoto: true });
|
||||
|
||||
await page.getByTestId("side_nav_options_all-templates").click();
|
||||
|
|
|
|||
|
|
@ -73,9 +73,11 @@ test(
|
|||
}
|
||||
}
|
||||
|
||||
await page.waitForTimeout(1000);
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await visibleElementHandle.hover().then(async () => {
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
await expect(
|
||||
page.getByText("Drag to connect compatible outputs").first(),
|
||||
).toBeVisible();
|
||||
|
|
@ -105,7 +107,11 @@ test(
|
|||
}
|
||||
}
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await visibleElementHandle.hover().then(async () => {
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
await expect(
|
||||
page.getByText("Drag to connect compatible outputs").first(),
|
||||
).toBeVisible();
|
||||
|
|
|
|||
|
|
@ -164,13 +164,21 @@ test(
|
|||
await page.getByTestId("chat-message-User-session_after_delete").click();
|
||||
await expect(page.getByTestId("session-selector")).toBeVisible();
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// check helpful button
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await page.getByTestId("helpful-button").click();
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("helpful-button").click();
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({
|
||||
|
|
@ -178,26 +186,38 @@ test(
|
|||
visible: false,
|
||||
});
|
||||
// check not helpful button
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await page.getByTestId("not-helpful-button").click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await page.getByTestId("not-helpful-button").click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({
|
||||
timeout: 10000,
|
||||
visible: false,
|
||||
});
|
||||
// check switch feedback
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await page.getByTestId("helpful-button").click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await expect(page.getByTestId("icon-ThumbUpIconCustom")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await page.getByTestId("not-helpful-button").click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("chat-message-AI-session_after_delete").hover();
|
||||
await expect(page.getByTestId("icon-ThumbDownIconCustom")).toBeVisible({
|
||||
timeout: 10000,
|
||||
|
|
|
|||
76
src/frontend/tests/extended/features/edit-tools.spec.ts
Normal file
76
src/frontend/tests/extended/features/edit-tools.spec.ts
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import { expect, test } from "@playwright/test";
|
||||
import { awaitBootstrapTest } from "../../utils/await-bootstrap-test";
|
||||
test(
|
||||
"user should be able to edit tools",
|
||||
{ tag: ["@release"] },
|
||||
async ({ page }) => {
|
||||
await awaitBootstrapTest(page);
|
||||
|
||||
await page.getByTestId("blank-flow").click();
|
||||
|
||||
await page.getByTestId("sidebar-search-input").click();
|
||||
await page.getByTestId("sidebar-search-input").fill("api request");
|
||||
|
||||
await page.waitForSelector('[data-testid="dataAPI Request"]', {
|
||||
timeout: 3000,
|
||||
});
|
||||
|
||||
await page
|
||||
.getByTestId("dataAPI Request")
|
||||
.hover()
|
||||
.then(async () => {
|
||||
await page.getByTestId("add-component-button-api-request").click();
|
||||
});
|
||||
|
||||
await page.waitForSelector(
|
||||
'[data-testid="generic-node-title-arrangement"]',
|
||||
{
|
||||
timeout: 3000,
|
||||
},
|
||||
);
|
||||
|
||||
await page.getByTestId("generic-node-title-arrangement").click();
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("tool-mode-button").click();
|
||||
|
||||
await page.locator('[data-testid="icon-Hammer"]').nth(1).waitFor({
|
||||
timeout: 3000,
|
||||
state: "visible",
|
||||
});
|
||||
|
||||
await page.getByTestId("icon-Hammer").nth(1).click();
|
||||
|
||||
await page.waitForSelector("text=edit tools", { timeout: 30000 });
|
||||
|
||||
const rowsCount = await page.getByRole("gridcell").count();
|
||||
|
||||
expect(rowsCount).toBeGreaterThan(3);
|
||||
|
||||
expect(await page.getByRole("switch").nth(0).isChecked()).toBe(true);
|
||||
|
||||
await page.getByRole("switch").nth(0).click();
|
||||
|
||||
expect(await page.getByRole("switch").nth(0).isChecked()).toBe(false);
|
||||
|
||||
await page.getByText("Save").last().click();
|
||||
|
||||
await page.waitForSelector(
|
||||
'[data-testid="generic-node-title-arrangement"]',
|
||||
{
|
||||
timeout: 3000,
|
||||
},
|
||||
);
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByTestId("icon-Hammer").nth(1).click();
|
||||
|
||||
await page.waitForSelector("text=edit tools", { timeout: 30000 });
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
expect(await page.getByRole("switch").nth(0).isChecked()).toBe(false);
|
||||
},
|
||||
);
|
||||
Loading…
Add table
Add a link
Reference in a new issue