refactor: update imports and move functions out of MCPToolsComponent (#8976)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-07-10 10:53:56 -03:00 committed by GitHub
commit 8033f2c9d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 67 additions and 67 deletions

View file

@ -1,3 +1,4 @@
import re
from collections.abc import Callable, Sequence
from typing import Any
@ -14,7 +15,10 @@ from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from langflow.logging import logger
from langflow.schema.data import Data
from langflow.services.cache.base import CacheService
from langflow.services.cache.utils import CacheMiss
from .default_prompts import XML_AGENT_PROMPT
@ -141,3 +145,61 @@ AGENTS: dict[str, AgentSpec] = {
def get_agents_list():
return list(AGENTS.keys())
def safe_cache_get(cache: CacheService, key, default=None):
"""Safely get a value from cache, handling CacheMiss objects."""
try:
value = cache.get(key)
if isinstance(value, CacheMiss):
return default
except (AttributeError, KeyError, TypeError):
return default
else:
return value
def safe_cache_set(cache: CacheService, key, value):
"""Safely set a value in cache, handling potential errors."""
try:
cache.set(key, value)
except (AttributeError, TypeError) as e:
logger.warning(f"Failed to set cache key '{key}': {e}")
def maybe_unflatten_dict(flat: dict[str, Any]) -> dict[str, Any]:
"""If any key looks nested (contains a dot or "[index]"), rebuild the.
full nested structure; otherwise return flat as is.
"""
# Quick check: do we have any nested keys?
if not any(re.search(r"\.|\[\d+\]", key) for key in flat):
return flat
# Otherwise, unflatten into dicts/lists
nested: dict[str, Any] = {}
array_re = re.compile(r"^(.+)\[(\d+)\]$")
for key, val in flat.items():
parts = key.split(".")
cur = nested
for i, part in enumerate(parts):
m = array_re.match(part)
# Array segment?
if m:
name, idx = m.group(1), int(m.group(2))
lst = cur.setdefault(name, [])
# Ensure list is big enough
while len(lst) <= idx:
lst.append({})
if i == len(parts) - 1:
lst[idx] = val
else:
cur = lst[idx]
# Normal object key
elif i == len(parts) - 1:
cur[part] = val
else:
cur = cur.setdefault(part, {})
return nested

View file

@ -1,11 +1,13 @@
from __future__ import annotations
import asyncio
import re
import uuid
from typing import TYPE_CHECKING, Any
from typing import Any
from langchain_core.tools import StructuredTool # noqa: TC002
from langflow.api.v2.mcp import get_server
from langflow.base.agents.utils import maybe_unflatten_dict, safe_cache_get, safe_cache_set
from langflow.base.mcp.util import (
MCPSseClient,
MCPStdioClient,
@ -13,82 +15,18 @@ from langflow.base.mcp.util import (
update_tools,
)
from langflow.custom.custom_component.component_with_cache import ComponentWithCache
from langflow.inputs.inputs import InputTypes # noqa: TC001
from langflow.io import DropdownInput, McpInput, MessageTextInput, Output
from langflow.io.schema import flatten_schema, schema_to_langflow_inputs
from langflow.logging import logger
from langflow.schema.dataframe import DataFrame
from langflow.schema.message import Message
from langflow.services.auth.utils import create_user_longterm_token
from langflow.services.cache.utils import CacheMiss
# Import get_server from the backend API
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import get_session, get_settings_service, get_storage_service
if TYPE_CHECKING:
from langchain_core.tools import StructuredTool
from langflow.inputs.inputs import InputTypes
from langflow.services.cache.base import CacheService
def maybe_unflatten_dict(flat: dict[str, Any]) -> dict[str, Any]:
"""If any key looks nested (contains a dot or "[index]"), rebuild the.
full nested structure; otherwise return flat as is.
"""
# Quick check: do we have any nested keys?
if not any(re.search(r"\.|\[\d+\]", key) for key in flat):
return flat
# Otherwise, unflatten into dicts/lists
nested: dict[str, Any] = {}
array_re = re.compile(r"^(.+)\[(\d+)\]$")
for key, val in flat.items():
parts = key.split(".")
cur = nested
for i, part in enumerate(parts):
m = array_re.match(part)
# Array segment?
if m:
name, idx = m.group(1), int(m.group(2))
lst = cur.setdefault(name, [])
# Ensure list is big enough
while len(lst) <= idx:
lst.append({})
if i == len(parts) - 1:
lst[idx] = val
else:
cur = lst[idx]
# Normal object key
elif i == len(parts) - 1:
cur[part] = val
else:
cur = cur.setdefault(part, {})
return nested
def safe_cache_get(cache: CacheService, key, default=None):
"""Safely get a value from cache, handling CacheMiss objects."""
try:
value = cache.get(key)
if isinstance(value, CacheMiss):
return default
except (AttributeError, KeyError, TypeError):
return default
else:
return value
def safe_cache_set(cache: CacheService, key, value):
"""Safely set a value in cache, handling potential errors."""
try:
cache.set(key, value)
except (AttributeError, TypeError) as e:
logger.warning(f"Failed to set cache key '{key}': {e}")
class MCPToolsComponent(ComponentWithCache):
schema_inputs: list = []