From 59937ee9e7604965aed5a56b70ad70a82ff6c9ee Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Thu, 21 Aug 2025 16:30:54 -0700 Subject: [PATCH] feat: Make knowledge bases user-stored and support global vars (#9458) * feat: Make knowledge bases user-stored * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Fix ruff error * [autofix.ci] apply automated fixes * Reuse code * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Don't show options by default * [autofix.ci] apply automated fixes * Pass in the Langflow API key if set * [autofix.ci] apply automated fixes * Update files.py * [autofix.ci] apply automated fixes * Properly handle secret retrieval * [autofix.ci] apply automated fixes * Update src/backend/base/langflow/base/data/kb_utils.py Co-authored-by: Gabriel Luiz Freitas Almeida * Update src/backend/base/langflow/base/data/kb_utils.py Co-authored-by: Gabriel Luiz Freitas Almeida * Update src/backend/base/langflow/components/data/kb_ingest.py Co-authored-by: Gabriel Luiz Freitas Almeida * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Feedback from review * [autofix.ci] apply automated fixes * Fix other uses of incorrect user * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Feedback from review 2 * [autofix.ci] apply automated fixes * Update kb_ingest.py * [autofix.ci] apply automated fixes * Update tests * [autofix.ci] apply automated fixes * Update kb_ingest.py * [autofix.ci] apply automated fixes * Fix mypy issues * [autofix.ci] apply automated fixes * Update kb_utils.py * Update test_kb_ingest.py * Fix tests * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida --- .../base/langflow/api/v1/knowledge_bases.py | 25 ++- src/backend/base/langflow/api/v2/files.py | 4 +- .../base/langflow/base/data/kb_utils.py | 33 ++++ .../components/agents/mcp_component.py | 87 ++++------ .../langflow/components/data/kb_ingest.py | 159 +++++++++++++----- .../langflow/components/data/kb_retrieval.py | 50 +++--- .../components/processing/save_file.py | 38 +---- .../starter_projects/Knowledge Ingestion.json | 4 +- .../starter_projects/Knowledge Retrieval.json | 4 +- .../starter_projects/News Aggregator.json | 21 +-- .../starter_projects/Nvidia Remix.json | 21 +-- .../unit/components/data/test_kb_ingest.py | 90 ++++++---- .../unit/components/data/test_kb_retrieval.py | 124 ++++++++------ 13 files changed, 375 insertions(+), 285 deletions(-) diff --git a/src/backend/base/langflow/api/v1/knowledge_bases.py b/src/backend/base/langflow/api/v1/knowledge_bases.py index 138fda815..41c5ec255 100644 --- a/src/backend/base/langflow/api/v1/knowledge_bases.py +++ b/src/backend/base/langflow/api/v1/knowledge_bases.py @@ -9,6 +9,7 @@ from langchain_chroma import Chroma from loguru import logger from pydantic import BaseModel +from langflow.api.utils import CurrentActiveUser from langflow.services.deps import get_settings_service router = APIRouter(tags=["Knowledge Bases"], prefix="/knowledge_bases") @@ -290,17 +291,19 @@ def get_kb_metadata(kb_path: Path) -> dict: @router.get("", status_code=HTTPStatus.OK) @router.get("/", status_code=HTTPStatus.OK) -async def list_knowledge_bases() -> list[KnowledgeBaseInfo]: +async def list_knowledge_bases(current_user: CurrentActiveUser) -> list[KnowledgeBaseInfo]: """List all available knowledge bases.""" try: kb_root_path = get_kb_root_path() + kb_user = current_user.username + kb_path = kb_root_path / kb_user - if not kb_root_path.exists(): + if not kb_path.exists(): return [] knowledge_bases = [] - for kb_dir in kb_root_path.iterdir(): + for kb_dir in kb_path.iterdir(): if not kb_dir.is_dir() or kb_dir.name.startswith("."): continue @@ -340,11 +343,12 @@ async def list_knowledge_bases() -> list[KnowledgeBaseInfo]: @router.get("/{kb_name}", status_code=HTTPStatus.OK) -async def get_knowledge_base(kb_name: str) -> KnowledgeBaseInfo: +async def get_knowledge_base(kb_name: str, current_user: CurrentActiveUser) -> KnowledgeBaseInfo: """Get detailed information about a specific knowledge base.""" try: kb_root_path = get_kb_root_path() - kb_path = kb_root_path / kb_name + kb_user = current_user.username + kb_path = kb_root_path / kb_user / kb_name if not kb_path.exists() or not kb_path.is_dir(): raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") @@ -374,11 +378,12 @@ async def get_knowledge_base(kb_name: str) -> KnowledgeBaseInfo: @router.delete("/{kb_name}", status_code=HTTPStatus.OK) -async def delete_knowledge_base(kb_name: str) -> dict[str, str]: +async def delete_knowledge_base(kb_name: str, current_user: CurrentActiveUser) -> dict[str, str]: """Delete a specific knowledge base.""" try: kb_root_path = get_kb_root_path() - kb_path = kb_root_path / kb_name + kb_user = current_user.username + kb_path = kb_root_path / kb_user / kb_name if not kb_path.exists() or not kb_path.is_dir(): raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") @@ -396,15 +401,17 @@ async def delete_knowledge_base(kb_name: str) -> dict[str, str]: @router.delete("", status_code=HTTPStatus.OK) @router.delete("/", status_code=HTTPStatus.OK) -async def delete_knowledge_bases_bulk(request: BulkDeleteRequest) -> dict[str, object]: +async def delete_knowledge_bases_bulk(request: BulkDeleteRequest, current_user: CurrentActiveUser) -> dict[str, object]: """Delete multiple knowledge bases.""" try: kb_root_path = get_kb_root_path() + kb_user = current_user.username + kb_user_path = kb_root_path / kb_user deleted_count = 0 not_found_kbs = [] for kb_name in request.kb_names: - kb_path = kb_root_path / kb_name + kb_path = kb_user_path / kb_name if not kb_path.exists() or not kb_path.is_dir(): not_found_kbs.append(kb_name) diff --git a/src/backend/base/langflow/api/v2/files.py b/src/backend/base/langflow/api/v2/files.py index 3777a5da6..1f0c3d483 100644 --- a/src/backend/base/langflow/api/v2/files.py +++ b/src/backend/base/langflow/api/v2/files.py @@ -123,7 +123,9 @@ async def upload_user_file( unique_filename = new_filename else: # For normal files, ensure unique name by appending a count if necessary - stmt = select(UserFile).where(col(UserFile.name).like(f"{root_filename}%")) + stmt = select(UserFile).where( + col(UserFile.name).like(f"{root_filename}%"), UserFile.user_id == current_user.id + ) existing_files = await session.exec(stmt) files = existing_files.all() # Fetch all matching records diff --git a/src/backend/base/langflow/base/data/kb_utils.py b/src/backend/base/langflow/base/data/kb_utils.py index f453eef6f..f58ede2df 100644 --- a/src/backend/base/langflow/base/data/kb_utils.py +++ b/src/backend/base/langflow/base/data/kb_utils.py @@ -1,5 +1,10 @@ import math from collections import Counter +from pathlib import Path +from uuid import UUID + +from langflow.services.database.models.user.crud import get_user_by_id +from langflow.services.deps import session_scope def compute_tfidf(documents: list[str], query_terms: list[str]) -> list[float]: @@ -102,3 +107,31 @@ def compute_bm25(documents: list[str], query_terms: list[str], k1: float = 1.2, scores.append(doc_score) return scores + + +async def get_knowledge_bases(kb_root: Path, user_id: UUID | str) -> list[str]: + """Retrieve a list of available knowledge bases. + + Returns: + A list of knowledge base names. + """ + if not kb_root.exists(): + return [] + + # Get the current user + async with session_scope() as db: + if not user_id: + msg = "User ID is required for fetching knowledge bases." + raise ValueError(msg) + user_id = UUID(user_id) if isinstance(user_id, str) else user_id + current_user = await get_user_by_id(db, user_id) + if not current_user: + msg = f"User with ID {user_id} not found." + raise ValueError(msg) + kb_user = current_user.username + kb_path = kb_root / kb_user + + if not kb_path.exists(): + return [] + + return [str(d.name) for d in kb_path.iterdir() if not d.name.startswith(".") and d.is_dir()] diff --git a/src/backend/base/langflow/components/agents/mcp_component.py b/src/backend/base/langflow/components/agents/mcp_component.py index 99a9bf6fc..a9847e4fc 100644 --- a/src/backend/base/langflow/components/agents/mcp_component.py +++ b/src/backend/base/langflow/components/agents/mcp_component.py @@ -16,16 +16,15 @@ from langflow.base.mcp.util import ( ) from langflow.custom.custom_component.component_with_cache import ComponentWithCache from langflow.inputs.inputs import InputTypes # noqa: TC001 -from langflow.io import DropdownInput, McpInput, MessageTextInput, Output, SecretStrInput +from langflow.io import DropdownInput, McpInput, MessageTextInput, Output from langflow.io.schema import flatten_schema, schema_to_langflow_inputs from langflow.logging import logger from langflow.schema.dataframe import DataFrame from langflow.schema.message import Message # Import get_server from the backend API -from langflow.services.auth.utils import create_user_longterm_token, get_current_user from langflow.services.database.models.user.crud import get_user_by_id -from langflow.services.deps import get_session, get_settings_service, get_storage_service +from langflow.services.deps import get_settings_service, get_storage_service, session_scope class MCPToolsComponent(ComponentWithCache): @@ -96,13 +95,6 @@ class MCPToolsComponent(ComponentWithCache): show=False, tool_mode=False, ), - SecretStrInput( - name="api_key", - display_name="Langflow API Key", - info="Langflow API key for authentication when fetching MCP servers and tools.", - required=False, - advanced=True, - ), ] outputs = [ @@ -161,19 +153,11 @@ class MCPToolsComponent(ComponentWithCache): return self.tools, {"name": server_name, "config": server_config_from_value} try: - async for db in get_session(): - # TODO: In 1.6, this may need to be removed or adjusted - # Try to get the super user token, if possible - if self.api_key: - current_user = await get_current_user( - token=None, - query_param=self.api_key, - header_param=None, - db=db, - ) - else: - user_id, _ = await create_user_longterm_token(db) - current_user = await get_user_by_id(db, user_id) + async with session_scope() as db: + if not self.user_id: + msg = "User ID is required for fetching MCP tools." + raise ValueError(msg) + current_user = await get_user_by_id(db, self.user_id) # Try to get server config from DB/API server_config = await get_server( @@ -184,39 +168,38 @@ class MCPToolsComponent(ComponentWithCache): settings_service=get_settings_service(), ) - # If get_server returns empty but we have a config, use it - if not server_config and server_config_from_value: - server_config = server_config_from_value + # If get_server returns empty but we have a config, use it + if not server_config and server_config_from_value: + server_config = server_config_from_value - if not server_config: - self.tools = [] - return [], {"name": server_name, "config": server_config} + if not server_config: + self.tools = [] + return [], {"name": server_name, "config": server_config} - _, tool_list, tool_cache = await update_tools( - server_name=server_name, - server_config=server_config, - mcp_stdio_client=self.stdio_client, - mcp_sse_client=self.sse_client, - ) + _, tool_list, tool_cache = await update_tools( + server_name=server_name, + server_config=server_config, + mcp_stdio_client=self.stdio_client, + mcp_sse_client=self.sse_client, + ) - self.tool_names = [tool.name for tool in tool_list if hasattr(tool, "name")] - self._tool_cache = tool_cache - self.tools = tool_list - # Cache the result using shared cache - cache_data = { - "tools": tool_list, - "tool_names": self.tool_names, - "tool_cache": tool_cache, - "config": server_config, - } + self.tool_names = [tool.name for tool in tool_list if hasattr(tool, "name")] + self._tool_cache = tool_cache + self.tools = tool_list + # Cache the result using shared cache + cache_data = { + "tools": tool_list, + "tool_names": self.tool_names, + "tool_cache": tool_cache, + "config": server_config, + } - # Safely update the servers cache - current_servers_cache = safe_cache_get(self._shared_component_cache, "servers", {}) - if isinstance(current_servers_cache, dict): - current_servers_cache[server_name] = cache_data - safe_cache_set(self._shared_component_cache, "servers", current_servers_cache) + # Safely update the servers cache + current_servers_cache = safe_cache_get(self._shared_component_cache, "servers", {}) + if isinstance(current_servers_cache, dict): + current_servers_cache[server_name] = cache_data + safe_cache_set(self._shared_component_cache, "servers", current_servers_cache) - return tool_list, {"name": server_name, "config": server_config} except (TimeoutError, asyncio.TimeoutError) as e: msg = f"Timeout updating tool list: {e!s}" logger.exception(msg) @@ -225,6 +208,8 @@ class MCPToolsComponent(ComponentWithCache): msg = f"Error updating tool list: {e!s}" logger.exception(msg) raise ValueError(msg) from e + else: + return tool_list, {"name": server_name, "config": server_config} async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict: """Toggle the visibility of connection-specific fields based on the selected mode.""" diff --git a/src/backend/base/langflow/components/data/kb_ingest.py b/src/backend/base/langflow/components/data/kb_ingest.py index 1dd63088d..4a734ba81 100644 --- a/src/backend/base/langflow/components/data/kb_ingest.py +++ b/src/backend/base/langflow/components/data/kb_ingest.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import contextlib import hashlib import json import re @@ -14,6 +16,7 @@ from cryptography.fernet import InvalidToken from langchain_chroma import Chroma from loguru import logger +from langflow.base.data.kb_utils import get_knowledge_bases from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES from langflow.custom import Component from langflow.io import BoolInput, DataFrameInput, DropdownInput, IntInput, Output, SecretStrInput, StrInput, TableInput @@ -21,7 +24,8 @@ from langflow.schema.data import Data from langflow.schema.dotdict import dotdict # noqa: TC001 from langflow.schema.table import EditMode from langflow.services.auth.utils import decrypt_api_key, encrypt_api_key -from langflow.services.deps import get_settings_service +from langflow.services.database.models.user.crud import get_user_by_id +from langflow.services.deps import get_settings_service, get_variable_service, session_scope HUGGINGFACE_MODEL_NAMES = ["sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"] COHERE_MODEL_NAMES = ["embed-english-v3.0", "embed-multilingual-v3.0"] @@ -43,6 +47,10 @@ class KBIngestionComponent(Component): icon = "database" name = "KBIngestion" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._cached_kb_path: Path | None = None + @dataclass class NewKnowledgeBaseInput: functionality: str = "create" @@ -76,7 +84,7 @@ class KBIngestionComponent(Component): display_name="API Key", info="Provider API key for embedding model", required=True, - load_from_db=True, + load_from_db=False, ), }, }, @@ -91,11 +99,7 @@ class KBIngestionComponent(Component): display_name="Knowledge", info="Select the knowledge to load data from.", required=True, - options=[ - str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir() - ] - if KNOWLEDGE_BASES_ROOT_PATH.exists() - else [], + options=[], refresh_button=True, dialog_inputs=asdict(NewKnowledgeBaseInput()), ), @@ -329,22 +333,23 @@ class KBIngestionComponent(Component): return metadata - def _create_vector_store( + async def _create_vector_store( self, df_source: pd.DataFrame, config_list: list[dict[str, Any]], embedding_model: str, api_key: str ) -> None: """Create vector store following Local DB component pattern.""" try: # Set up vector store directory - base_dir = self._get_kb_root() - - vector_store_dir = base_dir / self.knowledge_base + vector_store_dir = await self._kb_path() + if not vector_store_dir: + msg = "Knowledge base path is not set. Please create a new knowledge base first." + raise ValueError(msg) vector_store_dir.mkdir(parents=True, exist_ok=True) # Create embeddings model embedding_function = self._build_embeddings(embedding_model, api_key) # Convert DataFrame to Data objects (following Local DB pattern) - data_objects = self._convert_df_to_data_objects(df_source, config_list) + data_objects = await self._convert_df_to_data_objects(df_source, config_list) # Create vector store chroma = Chroma( @@ -367,16 +372,18 @@ class KBIngestionComponent(Component): except (OSError, ValueError, RuntimeError) as e: self.log(f"Error creating vector store: {e}") - def _convert_df_to_data_objects(self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]) -> list[Data]: + async def _convert_df_to_data_objects( + self, df_source: pd.DataFrame, config_list: list[dict[str, Any]] + ) -> list[Data]: """Convert DataFrame to Data objects for vector store.""" data_objects: list[Data] = [] # Set up vector store directory - base_dir = self._get_kb_root() + kb_path = await self._kb_path() # If we don't allow duplicates, we need to get the existing hashes chroma = Chroma( - persist_directory=str(base_dir / self.knowledge_base), + persist_directory=str(kb_path), collection_name=self.knowledge_base, ) @@ -466,10 +473,34 @@ class KBIngestionComponent(Component): # Check allowed characters (condition 3) return re.match(r"^[a-zA-Z0-9_-]+$", name) is not None + async def _kb_path(self) -> Path | None: + # Check if we already have the path cached + cached_path = getattr(self, "_cached_kb_path", None) + if cached_path is not None: + return cached_path + + # If not cached, compute it + async with session_scope() as db: + if not self.user_id: + msg = "User ID is required for fetching knowledge base path." + raise ValueError(msg) + current_user = await get_user_by_id(db, self.user_id) + if not current_user: + msg = f"User with ID {self.user_id} not found." + raise ValueError(msg) + kb_user = current_user.username + + kb_root = self._get_kb_root() + + # Cache the result + self._cached_kb_path = kb_root / kb_user / self.knowledge_base + + return self._cached_kb_path + # --------------------------------------------------------------------- # OUTPUT METHODS # --------------------------------------------------------------------- - def build_kb_info(self) -> Data: + async def build_kb_info(self) -> Data: """Main ingestion routine → returns a dict with KB metadata.""" try: # Get source DataFrame @@ -479,11 +510,11 @@ class KBIngestionComponent(Component): config_list = self._validate_column_config(df_source) column_metadata = self._build_column_metadata(config_list, df_source) - # Prepare KB folder (using File Component patterns) - kb_root = self._get_kb_root() - kb_path = kb_root / self.knowledge_base - # Read the embedding info from the knowledge base folder + kb_path = await self._kb_path() + if not kb_path: + msg = "Knowledge base path is not set. Please create a new knowledge base first." + raise ValueError(msg) metadata_path = kb_path / "embedding_metadata.json" # If the API key is not provided, try to read it from the metadata file @@ -506,7 +537,7 @@ class KBIngestionComponent(Component): ) # Create vector store following Local DB component pattern - self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key) + await self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key) # Save KB files (using File Component storage patterns) self._save_kb_files(kb_path, config_list) @@ -532,40 +563,77 @@ class KBIngestionComponent(Component): self.status = f"❌ KB ingestion failed: {e}" return Data(data={"error": str(e), "kb_name": self.knowledge_base}) - def _get_knowledge_bases(self) -> list[str]: - """Retrieve a list of available knowledge bases. + async def _get_api_key_variable(self, field_value: dict[str, Any]): + async with session_scope() as db: + if not self.user_id: + msg = "User ID is required for fetching global variables." + raise ValueError(msg) + current_user = await get_user_by_id(db, self.user_id) + if not current_user: + msg = f"User with ID {self.user_id} not found." + raise ValueError(msg) + variable_service = get_variable_service() - Returns: - A list of knowledge base names. - """ - # Return the list of directories in the knowledge base root path - kb_root_path = self._get_kb_root() + # Process the api_key field variable + return await variable_service.get_variable( + user_id=current_user.id, + name=field_value["03_api_key"], + field="", + session=db, + ) - if not kb_root_path.exists(): - return [] - - return [str(d.name) for d in kb_root_path.iterdir() if not d.name.startswith(".") and d.is_dir()] - - def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict: + async def update_build_config( + self, + build_config: dotdict, + field_value: Any, + field_name: str | None = None, + ) -> dotdict: """Update build configuration based on provider selection.""" # Create a new knowledge base if field_name == "knowledge_base": + async with session_scope() as db: + if not self.user_id: + msg = "User ID is required for fetching knowledge base list." + raise ValueError(msg) + current_user = await get_user_by_id(db, self.user_id) + if not current_user: + msg = f"User with ID {self.user_id} not found." + raise ValueError(msg) + kb_user = current_user.username if isinstance(field_value, dict) and "01_new_kb_name" in field_value: # Validate the knowledge base name - Make sure it follows these rules: if not self.is_valid_collection_name(field_value["01_new_kb_name"]): msg = f"Invalid knowledge base name: {field_value['01_new_kb_name']}" raise ValueError(msg) - # We need to test the API Key one time against the embedding model - embed_model = self._build_embeddings( - embedding_model=field_value["02_embedding_model"], api_key=field_value["03_api_key"] - ) + api_key = field_value.get("03_api_key", None) + with contextlib.suppress(Exception): + # If the API key is a variable, resolve it + api_key = await self._get_api_key_variable(field_value) - # Try to generate a dummy embedding to validate the API key - embed_model.embed_query("test") + # Make sure api_key is a string + if not isinstance(api_key, str): + msg = "API key must be a string." + raise ValueError(msg) + + # We need to test the API Key one time against the embedding model + embed_model = self._build_embeddings(embedding_model=field_value["02_embedding_model"], api_key=api_key) + + # Try to generate a dummy embedding to validate the API key without blocking the event loop + try: + await asyncio.wait_for( + asyncio.to_thread(embed_model.embed_query, "test"), + timeout=10, + ) + except TimeoutError as e: + msg = "Embedding validation timed out. Please verify network connectivity and key." + raise ValueError(msg) from e + except Exception as e: + msg = f"Embedding validation failed: {e!s}" + raise ValueError(msg) from e # Create the new knowledge base directory - kb_path = KNOWLEDGE_BASES_ROOT_PATH / field_value["01_new_kb_name"] + kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / field_value["01_new_kb_name"] kb_path.mkdir(parents=True, exist_ok=True) # Save the embedding metadata @@ -573,11 +641,16 @@ class KBIngestionComponent(Component): self._save_embedding_metadata( kb_path=kb_path, embedding_model=field_value["02_embedding_model"], - api_key=field_value["03_api_key"], + api_key=api_key, ) # Update the knowledge base options dynamically - build_config["knowledge_base"]["options"] = self._get_knowledge_bases() + build_config["knowledge_base"]["options"] = await get_knowledge_bases( + KNOWLEDGE_BASES_ROOT_PATH, + user_id=self.user_id, + ) + + # If the selected knowledge base is not available, reset it if build_config["knowledge_base"]["value"] not in build_config["knowledge_base"]["options"]: build_config["knowledge_base"]["value"] = None diff --git a/src/backend/base/langflow/components/data/kb_retrieval.py b/src/backend/base/langflow/components/data/kb_retrieval.py index 24d5559e8..d633c62b0 100644 --- a/src/backend/base/langflow/components/data/kb_retrieval.py +++ b/src/backend/base/langflow/components/data/kb_retrieval.py @@ -5,13 +5,16 @@ from typing import Any from cryptography.fernet import InvalidToken from langchain_chroma import Chroma from loguru import logger +from pydantic import SecretStr +from langflow.base.data.kb_utils import get_knowledge_bases from langflow.custom import Component from langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, Output, SecretStrInput from langflow.schema.data import Data from langflow.schema.dataframe import DataFrame from langflow.services.auth.utils import decrypt_api_key -from langflow.services.deps import get_settings_service +from langflow.services.database.models.user.crud import get_user_by_id +from langflow.services.deps import get_settings_service, session_scope settings = get_settings_service().settings knowledge_directory = settings.knowledge_bases_dir @@ -33,11 +36,7 @@ class KBRetrievalComponent(Component): display_name="Knowledge", info="Select the knowledge to load data from.", required=True, - options=[ - str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir() - ] - if KNOWLEDGE_BASES_ROOT_PATH.exists() - else [], + options=[], refresh_button=True, real_time_refresh=True, ), @@ -79,21 +78,13 @@ class KBRetrievalComponent(Component): ), ] - def _get_knowledge_bases(self) -> list[str]: - """Retrieve a list of available knowledge bases. - - Returns: - A list of knowledge base names. - """ - if not KNOWLEDGE_BASES_ROOT_PATH.exists(): - return [] - - return [str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()] - - def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002 + async def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002 if field_name == "knowledge_base": # Update the knowledge base options dynamically - build_config["knowledge_base"]["options"] = self._get_knowledge_bases() + build_config["knowledge_base"]["options"] = await get_knowledge_bases( + KNOWLEDGE_BASES_ROOT_PATH, + user_id=self.user_id, # Use the user_id from the component context + ) # If the selected knowledge base is not available, reset it if build_config["knowledge_base"]["value"] not in build_config["knowledge_base"]["options"]: @@ -129,15 +120,12 @@ class KBRetrievalComponent(Component): def _build_embeddings(self, metadata: dict): """Build embedding model from metadata.""" + runtime_api_key = self.api_key.get_secret_value() if isinstance(self.api_key, SecretStr) else self.api_key provider = metadata.get("embedding_provider") model = metadata.get("embedding_model") - api_key = metadata.get("api_key") + api_key = runtime_api_key or metadata.get("api_key") chunk_size = metadata.get("chunk_size") - # If user provided a key in the input, it overrides the stored one. - if self.api_key and self.api_key.get_secret_value(): - api_key = self.api_key.get_secret_value() - # Handle various providers if provider == "OpenAI": from langchain_openai import OpenAIEmbeddings @@ -174,13 +162,23 @@ class KBRetrievalComponent(Component): msg = f"Embedding provider '{provider}' is not supported for retrieval." raise NotImplementedError(msg) - def get_chroma_kb_data(self) -> DataFrame: + async def get_chroma_kb_data(self) -> DataFrame: """Retrieve data from the selected knowledge base by reading the Chroma collection. Returns: A DataFrame containing the data rows from the knowledge base. """ - kb_path = KNOWLEDGE_BASES_ROOT_PATH / self.knowledge_base + # Get the current user + async with session_scope() as db: + if not self.user_id: + msg = "User ID is required for fetching Knowledge Base data." + raise ValueError(msg) + current_user = await get_user_by_id(db, self.user_id) + if not current_user: + msg = f"User with ID {self.user_id} not found." + raise ValueError(msg) + kb_user = current_user.username + kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / self.knowledge_base metadata = self._get_kb_metadata(kb_path) if not metadata: diff --git a/src/backend/base/langflow/components/processing/save_file.py b/src/backend/base/langflow/components/processing/save_file.py index 16ff06712..e0af7f546 100644 --- a/src/backend/base/langflow/components/processing/save_file.py +++ b/src/backend/base/langflow/components/processing/save_file.py @@ -1,7 +1,6 @@ import json from collections.abc import AsyncIterator, Iterator from pathlib import Path -from typing import TYPE_CHECKING import orjson import pandas as pd @@ -10,16 +9,12 @@ from fastapi.encoders import jsonable_encoder from langflow.api.v2.files import upload_user_file from langflow.custom import Component -from langflow.io import DropdownInput, HandleInput, SecretStrInput, StrInput +from langflow.io import DropdownInput, HandleInput, StrInput from langflow.schema import Data, DataFrame, Message -from langflow.services.auth.utils import create_user_longterm_token, get_current_user from langflow.services.database.models.user.crud import get_user_by_id -from langflow.services.deps import get_session, get_settings_service, get_storage_service +from langflow.services.deps import get_settings_service, get_storage_service, session_scope from langflow.template.field.base import Output -if TYPE_CHECKING: - from langflow.services.database.models.user.model import User - class SaveToFileComponent(Component): display_name = "Save File" @@ -55,13 +50,6 @@ class SaveToFileComponent(Component): value="", advanced=True, ), - SecretStrInput( - name="api_key", - display_name="Langflow API Key", - info="Langflow API key for authentication when saving the file.", - required=False, - advanced=True, - ), ] outputs = [Output(display_name="File Path", name="message", method="save_to_file")] @@ -148,25 +136,11 @@ class SaveToFileComponent(Component): raise FileNotFoundError(msg) with file_path.open("rb") as f: - async for db in get_session(): - # TODO: In 1.6, this may need to be removed or adjusted - # Try to get the super user token, if possible - current_user: User | None = None - if self.api_key: - current_user = await get_current_user( - token="", - query_param=self.api_key, - header_param="", - db=db, - ) - else: - user_id, _ = await create_user_longterm_token(db) - current_user = await get_user_by_id(db, user_id) - - # Fail if the user is not found - if not current_user: - msg = "User not found. Please provide a valid API key or ensure the user exists." + async with session_scope() as db: + if not self.user_id: + msg = "User ID is required for file saving." raise ValueError(msg) + current_user = await get_user_by_id(db, self.user_id) await upload_user_file( file=UploadFile(filename=file_path.name, file=f, size=file_path.stat().st_size), diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Ingestion.json b/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Ingestion.json index 0371c75ec..84023b12e 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Ingestion.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Ingestion.json @@ -702,7 +702,7 @@ "last_updated": "2025-08-13T19:45:49.122Z", "legacy": false, "metadata": { - "code_hash": "e1ebcd66ecbc", + "code_hash": "6c62063f2c09", "module": "langflow.components.data.kb_ingest.KBIngestionComponent" }, "minimized": false, @@ -795,7 +795,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from __future__ import annotations\n\nimport hashlib\nimport json\nimport re\nimport uuid\nfrom dataclasses import asdict, dataclass, field\nfrom datetime import datetime, timezone\nfrom pathlib import Path\nfrom typing import Any\n\nimport pandas as pd\nfrom cryptography.fernet import InvalidToken\nfrom langchain_chroma import Chroma\nfrom loguru import logger\n\nfrom langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES\nfrom langflow.custom import Component\nfrom langflow.io import BoolInput, DataFrameInput, DropdownInput, IntInput, Output, SecretStrInput, StrInput, TableInput\nfrom langflow.schema.data import Data\nfrom langflow.schema.dotdict import dotdict # noqa: TC001\nfrom langflow.schema.table import EditMode\nfrom langflow.services.auth.utils import decrypt_api_key, encrypt_api_key\nfrom langflow.services.deps import get_settings_service\n\nHUGGINGFACE_MODEL_NAMES = [\"sentence-transformers/all-MiniLM-L6-v2\", \"sentence-transformers/all-mpnet-base-v2\"]\nCOHERE_MODEL_NAMES = [\"embed-english-v3.0\", \"embed-multilingual-v3.0\"]\n\nsettings = get_settings_service().settings\nknowledge_directory = settings.knowledge_bases_dir\nif not knowledge_directory:\n msg = \"Knowledge bases directory is not set in the settings.\"\n raise ValueError(msg)\nKNOWLEDGE_BASES_ROOT_PATH = Path(knowledge_directory).expanduser()\n\n\nclass KBIngestionComponent(Component):\n \"\"\"Create or append to Langflow Knowledge from a DataFrame.\"\"\"\n\n # ------ UI metadata ---------------------------------------------------\n display_name = \"Knowledge Ingestion\"\n description = \"Create or update knowledge in Langflow.\"\n icon = \"database\"\n name = \"KBIngestion\"\n\n @dataclass\n class NewKnowledgeBaseInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_knowledge_base\",\n \"description\": \"Create new knowledge in Langflow.\",\n \"display_name\": \"Create new knowledge\",\n \"field_order\": [\"01_new_kb_name\", \"02_embedding_model\", \"03_api_key\"],\n \"template\": {\n \"01_new_kb_name\": StrInput(\n name=\"new_kb_name\",\n display_name=\"Knowledge Name\",\n info=\"Name of the new knowledge to create.\",\n required=True,\n ),\n \"02_embedding_model\": DropdownInput(\n name=\"embedding_model\",\n display_name=\"Model Name\",\n info=\"Select the embedding model to use for this knowledge base.\",\n required=True,\n options=OPENAI_EMBEDDING_MODEL_NAMES + HUGGINGFACE_MODEL_NAMES + COHERE_MODEL_NAMES,\n options_metadata=[{\"icon\": \"OpenAI\"} for _ in OPENAI_EMBEDDING_MODEL_NAMES]\n + [{\"icon\": \"HuggingFace\"} for _ in HUGGINGFACE_MODEL_NAMES]\n + [{\"icon\": \"Cohere\"} for _ in COHERE_MODEL_NAMES],\n ),\n \"03_api_key\": SecretStrInput(\n name=\"api_key\",\n display_name=\"API Key\",\n info=\"Provider API key for embedding model\",\n required=True,\n load_from_db=True,\n ),\n },\n },\n }\n }\n )\n\n # ------ Inputs --------------------------------------------------------\n inputs = [\n DropdownInput(\n name=\"knowledge_base\",\n display_name=\"Knowledge\",\n info=\"Select the knowledge to load data from.\",\n required=True,\n options=[\n str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(\".\") and d.is_dir()\n ]\n if KNOWLEDGE_BASES_ROOT_PATH.exists()\n else [],\n refresh_button=True,\n dialog_inputs=asdict(NewKnowledgeBaseInput()),\n ),\n DataFrameInput(\n name=\"input_df\",\n display_name=\"Data\",\n info=\"Table with all original columns (already chunked / processed).\",\n required=True,\n ),\n TableInput(\n name=\"column_config\",\n display_name=\"Column Configuration\",\n info=\"Configure column behavior for the knowledge base.\",\n required=True,\n table_schema=[\n {\n \"name\": \"column_name\",\n \"display_name\": \"Column Name\",\n \"type\": \"str\",\n \"description\": \"Name of the column in the source DataFrame\",\n \"edit_mode\": EditMode.INLINE,\n },\n {\n \"name\": \"vectorize\",\n \"display_name\": \"Vectorize\",\n \"type\": \"boolean\",\n \"description\": \"Create embeddings for this column\",\n \"default\": False,\n \"edit_mode\": EditMode.INLINE,\n },\n {\n \"name\": \"identifier\",\n \"display_name\": \"Identifier\",\n \"type\": \"boolean\",\n \"description\": \"Use this column as unique identifier\",\n \"default\": False,\n \"edit_mode\": EditMode.INLINE,\n },\n ],\n value=[\n {\n \"column_name\": \"text\",\n \"vectorize\": True,\n \"identifier\": True,\n },\n ],\n ),\n IntInput(\n name=\"chunk_size\",\n display_name=\"Chunk Size\",\n info=\"Batch size for processing embeddings\",\n advanced=True,\n value=1000,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Embedding Provider API Key\",\n info=\"API key for the embedding provider to generate embeddings.\",\n advanced=True,\n required=False,\n ),\n BoolInput(\n name=\"allow_duplicates\",\n display_name=\"Allow Duplicates\",\n info=\"Allow duplicate rows in the knowledge base\",\n advanced=True,\n value=False,\n ),\n ]\n\n # ------ Outputs -------------------------------------------------------\n outputs = [Output(display_name=\"DataFrame\", name=\"dataframe\", method=\"build_kb_info\")]\n\n # ------ Internal helpers ---------------------------------------------\n def _get_kb_root(self) -> Path:\n \"\"\"Return the root directory for knowledge bases.\"\"\"\n return KNOWLEDGE_BASES_ROOT_PATH\n\n def _validate_column_config(self, df_source: pd.DataFrame) -> list[dict[str, Any]]:\n \"\"\"Validate column configuration using Structured Output patterns.\"\"\"\n if not self.column_config:\n msg = \"Column configuration cannot be empty\"\n raise ValueError(msg)\n\n # Convert table input to list of dicts (similar to Structured Output)\n config_list = self.column_config if isinstance(self.column_config, list) else []\n\n # Validate column names exist in DataFrame\n df_columns = set(df_source.columns)\n for config in config_list:\n col_name = config.get(\"column_name\")\n if col_name not in df_columns:\n msg = f\"Column '{col_name}' not found in DataFrame. Available columns: {sorted(df_columns)}\"\n raise ValueError(msg)\n\n return config_list\n\n def _get_embedding_provider(self, embedding_model: str) -> str:\n \"\"\"Get embedding provider by matching model name to lists.\"\"\"\n if embedding_model in OPENAI_EMBEDDING_MODEL_NAMES:\n return \"OpenAI\"\n if embedding_model in HUGGINGFACE_MODEL_NAMES:\n return \"HuggingFace\"\n if embedding_model in COHERE_MODEL_NAMES:\n return \"Cohere\"\n return \"Custom\"\n\n def _build_embeddings(self, embedding_model: str, api_key: str):\n \"\"\"Build embedding model using provider patterns.\"\"\"\n # Get provider by matching model name to lists\n provider = self._get_embedding_provider(embedding_model)\n\n # Validate provider and model\n if provider == \"OpenAI\":\n from langchain_openai import OpenAIEmbeddings\n\n if not api_key:\n msg = \"OpenAI API key is required when using OpenAI provider\"\n raise ValueError(msg)\n return OpenAIEmbeddings(\n model=embedding_model,\n api_key=api_key,\n chunk_size=self.chunk_size,\n )\n if provider == \"HuggingFace\":\n from langchain_huggingface import HuggingFaceEmbeddings\n\n return HuggingFaceEmbeddings(\n model=embedding_model,\n )\n if provider == \"Cohere\":\n from langchain_cohere import CohereEmbeddings\n\n if not api_key:\n msg = \"Cohere API key is required when using Cohere provider\"\n raise ValueError(msg)\n return CohereEmbeddings(\n model=embedding_model,\n cohere_api_key=api_key,\n )\n if provider == \"Custom\":\n # For custom embedding models, we would need additional configuration\n msg = \"Custom embedding models not yet supported\"\n raise NotImplementedError(msg)\n msg = f\"Unknown provider: {provider}\"\n raise ValueError(msg)\n\n def _build_embedding_metadata(self, embedding_model, api_key) -> dict[str, Any]:\n \"\"\"Build embedding model metadata.\"\"\"\n # Get provider by matching model name to lists\n embedding_provider = self._get_embedding_provider(embedding_model)\n\n api_key_to_save = None\n if api_key and hasattr(api_key, \"get_secret_value\"):\n api_key_to_save = api_key.get_secret_value()\n elif isinstance(api_key, str):\n api_key_to_save = api_key\n\n encrypted_api_key = None\n if api_key_to_save:\n settings_service = get_settings_service()\n try:\n encrypted_api_key = encrypt_api_key(api_key_to_save, settings_service=settings_service)\n except (TypeError, ValueError) as e:\n self.log(f\"Could not encrypt API key: {e}\")\n logger.error(f\"Could not encrypt API key: {e}\")\n\n return {\n \"embedding_provider\": embedding_provider,\n \"embedding_model\": embedding_model,\n \"api_key\": encrypted_api_key,\n \"api_key_used\": bool(api_key),\n \"chunk_size\": self.chunk_size,\n \"created_at\": datetime.now(timezone.utc).isoformat(),\n }\n\n def _save_embedding_metadata(self, kb_path: Path, embedding_model: str, api_key: str) -> None:\n \"\"\"Save embedding model metadata.\"\"\"\n embedding_metadata = self._build_embedding_metadata(embedding_model, api_key)\n metadata_path = kb_path / \"embedding_metadata.json\"\n metadata_path.write_text(json.dumps(embedding_metadata, indent=2))\n\n def _save_kb_files(\n self,\n kb_path: Path,\n config_list: list[dict[str, Any]],\n ) -> None:\n \"\"\"Save KB files using File Component storage patterns.\"\"\"\n try:\n # Create directory (following File Component patterns)\n kb_path.mkdir(parents=True, exist_ok=True)\n\n # Save column configuration\n # Only do this if the file doesn't exist already\n cfg_path = kb_path / \"schema.json\"\n if not cfg_path.exists():\n cfg_path.write_text(json.dumps(config_list, indent=2))\n\n except (OSError, TypeError, ValueError) as e:\n self.log(f\"Error saving KB files: {e}\")\n\n def _build_column_metadata(self, config_list: list[dict[str, Any]], df_source: pd.DataFrame) -> dict[str, Any]:\n \"\"\"Build detailed column metadata.\"\"\"\n metadata: dict[str, Any] = {\n \"total_columns\": len(df_source.columns),\n \"mapped_columns\": len(config_list),\n \"unmapped_columns\": len(df_source.columns) - len(config_list),\n \"columns\": [],\n \"summary\": {\"vectorized_columns\": [], \"identifier_columns\": []},\n }\n\n for config in config_list:\n col_name = config.get(\"column_name\")\n vectorize = config.get(\"vectorize\") == \"True\" or config.get(\"vectorize\") is True\n identifier = config.get(\"identifier\") == \"True\" or config.get(\"identifier\") is True\n\n # Add to columns list\n metadata[\"columns\"].append(\n {\n \"name\": col_name,\n \"vectorize\": vectorize,\n \"identifier\": identifier,\n }\n )\n\n # Update summary\n if vectorize:\n metadata[\"summary\"][\"vectorized_columns\"].append(col_name)\n if identifier:\n metadata[\"summary\"][\"identifier_columns\"].append(col_name)\n\n return metadata\n\n def _create_vector_store(\n self, df_source: pd.DataFrame, config_list: list[dict[str, Any]], embedding_model: str, api_key: str\n ) -> None:\n \"\"\"Create vector store following Local DB component pattern.\"\"\"\n try:\n # Set up vector store directory\n base_dir = self._get_kb_root()\n\n vector_store_dir = base_dir / self.knowledge_base\n vector_store_dir.mkdir(parents=True, exist_ok=True)\n\n # Create embeddings model\n embedding_function = self._build_embeddings(embedding_model, api_key)\n\n # Convert DataFrame to Data objects (following Local DB pattern)\n data_objects = self._convert_df_to_data_objects(df_source, config_list)\n\n # Create vector store\n chroma = Chroma(\n persist_directory=str(vector_store_dir),\n embedding_function=embedding_function,\n collection_name=self.knowledge_base,\n )\n\n # Convert Data objects to LangChain Documents\n documents = []\n for data_obj in data_objects:\n doc = data_obj.to_lc_document()\n documents.append(doc)\n\n # Add documents to vector store\n if documents:\n chroma.add_documents(documents)\n self.log(f\"Added {len(documents)} documents to vector store '{self.knowledge_base}'\")\n\n except (OSError, ValueError, RuntimeError) as e:\n self.log(f\"Error creating vector store: {e}\")\n\n def _convert_df_to_data_objects(self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]) -> list[Data]:\n \"\"\"Convert DataFrame to Data objects for vector store.\"\"\"\n data_objects: list[Data] = []\n\n # Set up vector store directory\n base_dir = self._get_kb_root()\n\n # If we don't allow duplicates, we need to get the existing hashes\n chroma = Chroma(\n persist_directory=str(base_dir / self.knowledge_base),\n collection_name=self.knowledge_base,\n )\n\n # Get all documents and their metadata\n all_docs = chroma.get()\n\n # Extract all _id values from metadata\n id_list = [metadata.get(\"_id\") for metadata in all_docs[\"metadatas\"] if metadata.get(\"_id\")]\n\n # Get column roles\n content_cols = []\n identifier_cols = []\n\n for config in config_list:\n col_name = config.get(\"column_name\")\n vectorize = config.get(\"vectorize\") == \"True\" or config.get(\"vectorize\") is True\n identifier = config.get(\"identifier\") == \"True\" or config.get(\"identifier\") is True\n\n if vectorize:\n content_cols.append(col_name)\n elif identifier:\n identifier_cols.append(col_name)\n\n # Convert each row to a Data object\n for _, row in df_source.iterrows():\n # Build content text from identifier columns using list comprehension\n identifier_parts = [str(row[col]) for col in content_cols if col in row and pd.notna(row[col])]\n\n # Join all parts into a single string\n page_content = \" \".join(identifier_parts)\n\n # Build metadata from NON-vectorized columns only (simple key-value pairs)\n data_dict = {\n \"text\": page_content, # Main content for vectorization\n }\n\n # Add identifier columns if they exist\n if identifier_cols:\n identifier_parts = [str(row[col]) for col in identifier_cols if col in row and pd.notna(row[col])]\n page_content = \" \".join(identifier_parts)\n\n # Add metadata columns as simple key-value pairs\n for col in df_source.columns:\n if col not in content_cols and col in row and pd.notna(row[col]):\n # Convert to simple types for Chroma metadata\n value = row[col]\n data_dict[col] = str(value) # Convert complex types to string\n\n # Hash the page_content for unique ID\n page_content_hash = hashlib.sha256(page_content.encode()).hexdigest()\n data_dict[\"_id\"] = page_content_hash\n\n # If duplicates are disallowed, and hash exists, prevent adding this row\n if not self.allow_duplicates and page_content_hash in id_list:\n self.log(f\"Skipping duplicate row with hash {page_content_hash}\")\n continue\n\n # Create Data object - everything except \"text\" becomes metadata\n data_obj = Data(data=data_dict)\n data_objects.append(data_obj)\n\n return data_objects\n\n def is_valid_collection_name(self, name, min_length: int = 3, max_length: int = 63) -> bool:\n \"\"\"Validates collection name against conditions 1-3.\n\n 1. Contains 3-63 characters\n 2. Starts and ends with alphanumeric character\n 3. Contains only alphanumeric characters, underscores, or hyphens.\n\n Args:\n name (str): Collection name to validate\n min_length (int): Minimum length of the name\n max_length (int): Maximum length of the name\n\n Returns:\n bool: True if valid, False otherwise\n \"\"\"\n # Check length (condition 1)\n if not (min_length <= len(name) <= max_length):\n return False\n\n # Check start/end with alphanumeric (condition 2)\n if not (name[0].isalnum() and name[-1].isalnum()):\n return False\n\n # Check allowed characters (condition 3)\n return re.match(r\"^[a-zA-Z0-9_-]+$\", name) is not None\n\n # ---------------------------------------------------------------------\n # OUTPUT METHODS\n # ---------------------------------------------------------------------\n def build_kb_info(self) -> Data:\n \"\"\"Main ingestion routine → returns a dict with KB metadata.\"\"\"\n try:\n # Get source DataFrame\n df_source: pd.DataFrame = self.input_df\n\n # Validate column configuration (using Structured Output patterns)\n config_list = self._validate_column_config(df_source)\n column_metadata = self._build_column_metadata(config_list, df_source)\n\n # Prepare KB folder (using File Component patterns)\n kb_root = self._get_kb_root()\n kb_path = kb_root / self.knowledge_base\n\n # Read the embedding info from the knowledge base folder\n metadata_path = kb_path / \"embedding_metadata.json\"\n\n # If the API key is not provided, try to read it from the metadata file\n if metadata_path.exists():\n settings_service = get_settings_service()\n metadata = json.loads(metadata_path.read_text())\n embedding_model = metadata.get(\"embedding_model\")\n try:\n api_key = decrypt_api_key(metadata[\"api_key\"], settings_service)\n except (InvalidToken, TypeError, ValueError) as e:\n logger.error(f\"Could not decrypt API key. Please provide it manually. Error: {e}\")\n\n # Check if a custom API key was provided, update metadata if so\n if self.api_key:\n api_key = self.api_key\n self._save_embedding_metadata(\n kb_path=kb_path,\n embedding_model=embedding_model,\n api_key=api_key,\n )\n\n # Create vector store following Local DB component pattern\n self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key)\n\n # Save KB files (using File Component storage patterns)\n self._save_kb_files(kb_path, config_list)\n\n # Build metadata response\n meta: dict[str, Any] = {\n \"kb_id\": str(uuid.uuid4()),\n \"kb_name\": self.knowledge_base,\n \"rows\": len(df_source),\n \"column_metadata\": column_metadata,\n \"path\": str(kb_path),\n \"config_columns\": len(config_list),\n \"timestamp\": datetime.now(tz=timezone.utc).isoformat(),\n }\n\n # Set status message\n self.status = f\"✅ KB **{self.knowledge_base}** saved · {len(df_source)} chunks.\"\n\n return Data(data=meta)\n\n except (OSError, ValueError, RuntimeError, KeyError) as e:\n self.log(f\"Error in KB ingestion: {e}\")\n self.status = f\"❌ KB ingestion failed: {e}\"\n return Data(data={\"error\": str(e), \"kb_name\": self.knowledge_base})\n\n def _get_knowledge_bases(self) -> list[str]:\n \"\"\"Retrieve a list of available knowledge bases.\n\n Returns:\n A list of knowledge base names.\n \"\"\"\n # Return the list of directories in the knowledge base root path\n kb_root_path = self._get_kb_root()\n\n if not kb_root_path.exists():\n return []\n\n return [str(d.name) for d in kb_root_path.iterdir() if not d.name.startswith(\".\") and d.is_dir()]\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:\n \"\"\"Update build configuration based on provider selection.\"\"\"\n # Create a new knowledge base\n if field_name == \"knowledge_base\":\n if isinstance(field_value, dict) and \"01_new_kb_name\" in field_value:\n # Validate the knowledge base name - Make sure it follows these rules:\n if not self.is_valid_collection_name(field_value[\"01_new_kb_name\"]):\n msg = f\"Invalid knowledge base name: {field_value['01_new_kb_name']}\"\n raise ValueError(msg)\n\n # We need to test the API Key one time against the embedding model\n embed_model = self._build_embeddings(\n embedding_model=field_value[\"02_embedding_model\"], api_key=field_value[\"03_api_key\"]\n )\n\n # Try to generate a dummy embedding to validate the API key\n embed_model.embed_query(\"test\")\n\n # Create the new knowledge base directory\n kb_path = KNOWLEDGE_BASES_ROOT_PATH / field_value[\"01_new_kb_name\"]\n kb_path.mkdir(parents=True, exist_ok=True)\n\n # Save the embedding metadata\n build_config[\"knowledge_base\"][\"value\"] = field_value[\"01_new_kb_name\"]\n self._save_embedding_metadata(\n kb_path=kb_path,\n embedding_model=field_value[\"02_embedding_model\"],\n api_key=field_value[\"03_api_key\"],\n )\n\n # Update the knowledge base options dynamically\n build_config[\"knowledge_base\"][\"options\"] = self._get_knowledge_bases()\n if build_config[\"knowledge_base\"][\"value\"] not in build_config[\"knowledge_base\"][\"options\"]:\n build_config[\"knowledge_base\"][\"value\"] = None\n\n return build_config\n" + "value": "from __future__ import annotations\n\nimport asyncio\nimport contextlib\nimport hashlib\nimport json\nimport re\nimport uuid\nfrom dataclasses import asdict, dataclass, field\nfrom datetime import datetime, timezone\nfrom pathlib import Path\nfrom typing import Any\n\nimport pandas as pd\nfrom cryptography.fernet import InvalidToken\nfrom langchain_chroma import Chroma\nfrom loguru import logger\n\nfrom langflow.base.data.kb_utils import get_knowledge_bases\nfrom langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES\nfrom langflow.custom import Component\nfrom langflow.io import BoolInput, DataFrameInput, DropdownInput, IntInput, Output, SecretStrInput, StrInput, TableInput\nfrom langflow.schema.data import Data\nfrom langflow.schema.dotdict import dotdict # noqa: TC001\nfrom langflow.schema.table import EditMode\nfrom langflow.services.auth.utils import decrypt_api_key, encrypt_api_key\nfrom langflow.services.database.models.user.crud import get_user_by_id\nfrom langflow.services.deps import get_settings_service, get_variable_service, session_scope\n\nHUGGINGFACE_MODEL_NAMES = [\"sentence-transformers/all-MiniLM-L6-v2\", \"sentence-transformers/all-mpnet-base-v2\"]\nCOHERE_MODEL_NAMES = [\"embed-english-v3.0\", \"embed-multilingual-v3.0\"]\n\nsettings = get_settings_service().settings\nknowledge_directory = settings.knowledge_bases_dir\nif not knowledge_directory:\n msg = \"Knowledge bases directory is not set in the settings.\"\n raise ValueError(msg)\nKNOWLEDGE_BASES_ROOT_PATH = Path(knowledge_directory).expanduser()\n\n\nclass KBIngestionComponent(Component):\n \"\"\"Create or append to Langflow Knowledge from a DataFrame.\"\"\"\n\n # ------ UI metadata ---------------------------------------------------\n display_name = \"Knowledge Ingestion\"\n description = \"Create or update knowledge in Langflow.\"\n icon = \"database\"\n name = \"KBIngestion\"\n\n def __init__(self, *args, **kwargs) -> None:\n super().__init__(*args, **kwargs)\n self._cached_kb_path: Path | None = None\n\n @dataclass\n class NewKnowledgeBaseInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_knowledge_base\",\n \"description\": \"Create new knowledge in Langflow.\",\n \"display_name\": \"Create new knowledge\",\n \"field_order\": [\"01_new_kb_name\", \"02_embedding_model\", \"03_api_key\"],\n \"template\": {\n \"01_new_kb_name\": StrInput(\n name=\"new_kb_name\",\n display_name=\"Knowledge Name\",\n info=\"Name of the new knowledge to create.\",\n required=True,\n ),\n \"02_embedding_model\": DropdownInput(\n name=\"embedding_model\",\n display_name=\"Model Name\",\n info=\"Select the embedding model to use for this knowledge base.\",\n required=True,\n options=OPENAI_EMBEDDING_MODEL_NAMES + HUGGINGFACE_MODEL_NAMES + COHERE_MODEL_NAMES,\n options_metadata=[{\"icon\": \"OpenAI\"} for _ in OPENAI_EMBEDDING_MODEL_NAMES]\n + [{\"icon\": \"HuggingFace\"} for _ in HUGGINGFACE_MODEL_NAMES]\n + [{\"icon\": \"Cohere\"} for _ in COHERE_MODEL_NAMES],\n ),\n \"03_api_key\": SecretStrInput(\n name=\"api_key\",\n display_name=\"API Key\",\n info=\"Provider API key for embedding model\",\n required=True,\n load_from_db=False,\n ),\n },\n },\n }\n }\n )\n\n # ------ Inputs --------------------------------------------------------\n inputs = [\n DropdownInput(\n name=\"knowledge_base\",\n display_name=\"Knowledge\",\n info=\"Select the knowledge to load data from.\",\n required=True,\n options=[],\n refresh_button=True,\n dialog_inputs=asdict(NewKnowledgeBaseInput()),\n ),\n DataFrameInput(\n name=\"input_df\",\n display_name=\"Data\",\n info=\"Table with all original columns (already chunked / processed).\",\n required=True,\n ),\n TableInput(\n name=\"column_config\",\n display_name=\"Column Configuration\",\n info=\"Configure column behavior for the knowledge base.\",\n required=True,\n table_schema=[\n {\n \"name\": \"column_name\",\n \"display_name\": \"Column Name\",\n \"type\": \"str\",\n \"description\": \"Name of the column in the source DataFrame\",\n \"edit_mode\": EditMode.INLINE,\n },\n {\n \"name\": \"vectorize\",\n \"display_name\": \"Vectorize\",\n \"type\": \"boolean\",\n \"description\": \"Create embeddings for this column\",\n \"default\": False,\n \"edit_mode\": EditMode.INLINE,\n },\n {\n \"name\": \"identifier\",\n \"display_name\": \"Identifier\",\n \"type\": \"boolean\",\n \"description\": \"Use this column as unique identifier\",\n \"default\": False,\n \"edit_mode\": EditMode.INLINE,\n },\n ],\n value=[\n {\n \"column_name\": \"text\",\n \"vectorize\": True,\n \"identifier\": True,\n },\n ],\n ),\n IntInput(\n name=\"chunk_size\",\n display_name=\"Chunk Size\",\n info=\"Batch size for processing embeddings\",\n advanced=True,\n value=1000,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Embedding Provider API Key\",\n info=\"API key for the embedding provider to generate embeddings.\",\n advanced=True,\n required=False,\n ),\n BoolInput(\n name=\"allow_duplicates\",\n display_name=\"Allow Duplicates\",\n info=\"Allow duplicate rows in the knowledge base\",\n advanced=True,\n value=False,\n ),\n ]\n\n # ------ Outputs -------------------------------------------------------\n outputs = [Output(display_name=\"DataFrame\", name=\"dataframe\", method=\"build_kb_info\")]\n\n # ------ Internal helpers ---------------------------------------------\n def _get_kb_root(self) -> Path:\n \"\"\"Return the root directory for knowledge bases.\"\"\"\n return KNOWLEDGE_BASES_ROOT_PATH\n\n def _validate_column_config(self, df_source: pd.DataFrame) -> list[dict[str, Any]]:\n \"\"\"Validate column configuration using Structured Output patterns.\"\"\"\n if not self.column_config:\n msg = \"Column configuration cannot be empty\"\n raise ValueError(msg)\n\n # Convert table input to list of dicts (similar to Structured Output)\n config_list = self.column_config if isinstance(self.column_config, list) else []\n\n # Validate column names exist in DataFrame\n df_columns = set(df_source.columns)\n for config in config_list:\n col_name = config.get(\"column_name\")\n if col_name not in df_columns:\n msg = f\"Column '{col_name}' not found in DataFrame. Available columns: {sorted(df_columns)}\"\n raise ValueError(msg)\n\n return config_list\n\n def _get_embedding_provider(self, embedding_model: str) -> str:\n \"\"\"Get embedding provider by matching model name to lists.\"\"\"\n if embedding_model in OPENAI_EMBEDDING_MODEL_NAMES:\n return \"OpenAI\"\n if embedding_model in HUGGINGFACE_MODEL_NAMES:\n return \"HuggingFace\"\n if embedding_model in COHERE_MODEL_NAMES:\n return \"Cohere\"\n return \"Custom\"\n\n def _build_embeddings(self, embedding_model: str, api_key: str):\n \"\"\"Build embedding model using provider patterns.\"\"\"\n # Get provider by matching model name to lists\n provider = self._get_embedding_provider(embedding_model)\n\n # Validate provider and model\n if provider == \"OpenAI\":\n from langchain_openai import OpenAIEmbeddings\n\n if not api_key:\n msg = \"OpenAI API key is required when using OpenAI provider\"\n raise ValueError(msg)\n return OpenAIEmbeddings(\n model=embedding_model,\n api_key=api_key,\n chunk_size=self.chunk_size,\n )\n if provider == \"HuggingFace\":\n from langchain_huggingface import HuggingFaceEmbeddings\n\n return HuggingFaceEmbeddings(\n model=embedding_model,\n )\n if provider == \"Cohere\":\n from langchain_cohere import CohereEmbeddings\n\n if not api_key:\n msg = \"Cohere API key is required when using Cohere provider\"\n raise ValueError(msg)\n return CohereEmbeddings(\n model=embedding_model,\n cohere_api_key=api_key,\n )\n if provider == \"Custom\":\n # For custom embedding models, we would need additional configuration\n msg = \"Custom embedding models not yet supported\"\n raise NotImplementedError(msg)\n msg = f\"Unknown provider: {provider}\"\n raise ValueError(msg)\n\n def _build_embedding_metadata(self, embedding_model, api_key) -> dict[str, Any]:\n \"\"\"Build embedding model metadata.\"\"\"\n # Get provider by matching model name to lists\n embedding_provider = self._get_embedding_provider(embedding_model)\n\n api_key_to_save = None\n if api_key and hasattr(api_key, \"get_secret_value\"):\n api_key_to_save = api_key.get_secret_value()\n elif isinstance(api_key, str):\n api_key_to_save = api_key\n\n encrypted_api_key = None\n if api_key_to_save:\n settings_service = get_settings_service()\n try:\n encrypted_api_key = encrypt_api_key(api_key_to_save, settings_service=settings_service)\n except (TypeError, ValueError) as e:\n self.log(f\"Could not encrypt API key: {e}\")\n logger.error(f\"Could not encrypt API key: {e}\")\n\n return {\n \"embedding_provider\": embedding_provider,\n \"embedding_model\": embedding_model,\n \"api_key\": encrypted_api_key,\n \"api_key_used\": bool(api_key),\n \"chunk_size\": self.chunk_size,\n \"created_at\": datetime.now(timezone.utc).isoformat(),\n }\n\n def _save_embedding_metadata(self, kb_path: Path, embedding_model: str, api_key: str) -> None:\n \"\"\"Save embedding model metadata.\"\"\"\n embedding_metadata = self._build_embedding_metadata(embedding_model, api_key)\n metadata_path = kb_path / \"embedding_metadata.json\"\n metadata_path.write_text(json.dumps(embedding_metadata, indent=2))\n\n def _save_kb_files(\n self,\n kb_path: Path,\n config_list: list[dict[str, Any]],\n ) -> None:\n \"\"\"Save KB files using File Component storage patterns.\"\"\"\n try:\n # Create directory (following File Component patterns)\n kb_path.mkdir(parents=True, exist_ok=True)\n\n # Save column configuration\n # Only do this if the file doesn't exist already\n cfg_path = kb_path / \"schema.json\"\n if not cfg_path.exists():\n cfg_path.write_text(json.dumps(config_list, indent=2))\n\n except (OSError, TypeError, ValueError) as e:\n self.log(f\"Error saving KB files: {e}\")\n\n def _build_column_metadata(self, config_list: list[dict[str, Any]], df_source: pd.DataFrame) -> dict[str, Any]:\n \"\"\"Build detailed column metadata.\"\"\"\n metadata: dict[str, Any] = {\n \"total_columns\": len(df_source.columns),\n \"mapped_columns\": len(config_list),\n \"unmapped_columns\": len(df_source.columns) - len(config_list),\n \"columns\": [],\n \"summary\": {\"vectorized_columns\": [], \"identifier_columns\": []},\n }\n\n for config in config_list:\n col_name = config.get(\"column_name\")\n vectorize = config.get(\"vectorize\") == \"True\" or config.get(\"vectorize\") is True\n identifier = config.get(\"identifier\") == \"True\" or config.get(\"identifier\") is True\n\n # Add to columns list\n metadata[\"columns\"].append(\n {\n \"name\": col_name,\n \"vectorize\": vectorize,\n \"identifier\": identifier,\n }\n )\n\n # Update summary\n if vectorize:\n metadata[\"summary\"][\"vectorized_columns\"].append(col_name)\n if identifier:\n metadata[\"summary\"][\"identifier_columns\"].append(col_name)\n\n return metadata\n\n async def _create_vector_store(\n self, df_source: pd.DataFrame, config_list: list[dict[str, Any]], embedding_model: str, api_key: str\n ) -> None:\n \"\"\"Create vector store following Local DB component pattern.\"\"\"\n try:\n # Set up vector store directory\n vector_store_dir = await self._kb_path()\n if not vector_store_dir:\n msg = \"Knowledge base path is not set. Please create a new knowledge base first.\"\n raise ValueError(msg)\n vector_store_dir.mkdir(parents=True, exist_ok=True)\n\n # Create embeddings model\n embedding_function = self._build_embeddings(embedding_model, api_key)\n\n # Convert DataFrame to Data objects (following Local DB pattern)\n data_objects = await self._convert_df_to_data_objects(df_source, config_list)\n\n # Create vector store\n chroma = Chroma(\n persist_directory=str(vector_store_dir),\n embedding_function=embedding_function,\n collection_name=self.knowledge_base,\n )\n\n # Convert Data objects to LangChain Documents\n documents = []\n for data_obj in data_objects:\n doc = data_obj.to_lc_document()\n documents.append(doc)\n\n # Add documents to vector store\n if documents:\n chroma.add_documents(documents)\n self.log(f\"Added {len(documents)} documents to vector store '{self.knowledge_base}'\")\n\n except (OSError, ValueError, RuntimeError) as e:\n self.log(f\"Error creating vector store: {e}\")\n\n async def _convert_df_to_data_objects(\n self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]\n ) -> list[Data]:\n \"\"\"Convert DataFrame to Data objects for vector store.\"\"\"\n data_objects: list[Data] = []\n\n # Set up vector store directory\n kb_path = await self._kb_path()\n\n # If we don't allow duplicates, we need to get the existing hashes\n chroma = Chroma(\n persist_directory=str(kb_path),\n collection_name=self.knowledge_base,\n )\n\n # Get all documents and their metadata\n all_docs = chroma.get()\n\n # Extract all _id values from metadata\n id_list = [metadata.get(\"_id\") for metadata in all_docs[\"metadatas\"] if metadata.get(\"_id\")]\n\n # Get column roles\n content_cols = []\n identifier_cols = []\n\n for config in config_list:\n col_name = config.get(\"column_name\")\n vectorize = config.get(\"vectorize\") == \"True\" or config.get(\"vectorize\") is True\n identifier = config.get(\"identifier\") == \"True\" or config.get(\"identifier\") is True\n\n if vectorize:\n content_cols.append(col_name)\n elif identifier:\n identifier_cols.append(col_name)\n\n # Convert each row to a Data object\n for _, row in df_source.iterrows():\n # Build content text from identifier columns using list comprehension\n identifier_parts = [str(row[col]) for col in content_cols if col in row and pd.notna(row[col])]\n\n # Join all parts into a single string\n page_content = \" \".join(identifier_parts)\n\n # Build metadata from NON-vectorized columns only (simple key-value pairs)\n data_dict = {\n \"text\": page_content, # Main content for vectorization\n }\n\n # Add identifier columns if they exist\n if identifier_cols:\n identifier_parts = [str(row[col]) for col in identifier_cols if col in row and pd.notna(row[col])]\n page_content = \" \".join(identifier_parts)\n\n # Add metadata columns as simple key-value pairs\n for col in df_source.columns:\n if col not in content_cols and col in row and pd.notna(row[col]):\n # Convert to simple types for Chroma metadata\n value = row[col]\n data_dict[col] = str(value) # Convert complex types to string\n\n # Hash the page_content for unique ID\n page_content_hash = hashlib.sha256(page_content.encode()).hexdigest()\n data_dict[\"_id\"] = page_content_hash\n\n # If duplicates are disallowed, and hash exists, prevent adding this row\n if not self.allow_duplicates and page_content_hash in id_list:\n self.log(f\"Skipping duplicate row with hash {page_content_hash}\")\n continue\n\n # Create Data object - everything except \"text\" becomes metadata\n data_obj = Data(data=data_dict)\n data_objects.append(data_obj)\n\n return data_objects\n\n def is_valid_collection_name(self, name, min_length: int = 3, max_length: int = 63) -> bool:\n \"\"\"Validates collection name against conditions 1-3.\n\n 1. Contains 3-63 characters\n 2. Starts and ends with alphanumeric character\n 3. Contains only alphanumeric characters, underscores, or hyphens.\n\n Args:\n name (str): Collection name to validate\n min_length (int): Minimum length of the name\n max_length (int): Maximum length of the name\n\n Returns:\n bool: True if valid, False otherwise\n \"\"\"\n # Check length (condition 1)\n if not (min_length <= len(name) <= max_length):\n return False\n\n # Check start/end with alphanumeric (condition 2)\n if not (name[0].isalnum() and name[-1].isalnum()):\n return False\n\n # Check allowed characters (condition 3)\n return re.match(r\"^[a-zA-Z0-9_-]+$\", name) is not None\n\n async def _kb_path(self) -> Path | None:\n # Check if we already have the path cached\n cached_path = getattr(self, \"_cached_kb_path\", None)\n if cached_path is not None:\n return cached_path\n\n # If not cached, compute it\n async with session_scope() as db:\n if not self.user_id:\n msg = \"User ID is required for fetching knowledge base path.\"\n raise ValueError(msg)\n current_user = await get_user_by_id(db, self.user_id)\n if not current_user:\n msg = f\"User with ID {self.user_id} not found.\"\n raise ValueError(msg)\n kb_user = current_user.username\n\n kb_root = self._get_kb_root()\n\n # Cache the result\n self._cached_kb_path = kb_root / kb_user / self.knowledge_base\n\n return self._cached_kb_path\n\n # ---------------------------------------------------------------------\n # OUTPUT METHODS\n # ---------------------------------------------------------------------\n async def build_kb_info(self) -> Data:\n \"\"\"Main ingestion routine → returns a dict with KB metadata.\"\"\"\n try:\n # Get source DataFrame\n df_source: pd.DataFrame = self.input_df\n\n # Validate column configuration (using Structured Output patterns)\n config_list = self._validate_column_config(df_source)\n column_metadata = self._build_column_metadata(config_list, df_source)\n\n # Read the embedding info from the knowledge base folder\n kb_path = await self._kb_path()\n if not kb_path:\n msg = \"Knowledge base path is not set. Please create a new knowledge base first.\"\n raise ValueError(msg)\n metadata_path = kb_path / \"embedding_metadata.json\"\n\n # If the API key is not provided, try to read it from the metadata file\n if metadata_path.exists():\n settings_service = get_settings_service()\n metadata = json.loads(metadata_path.read_text())\n embedding_model = metadata.get(\"embedding_model\")\n try:\n api_key = decrypt_api_key(metadata[\"api_key\"], settings_service)\n except (InvalidToken, TypeError, ValueError) as e:\n logger.error(f\"Could not decrypt API key. Please provide it manually. Error: {e}\")\n\n # Check if a custom API key was provided, update metadata if so\n if self.api_key:\n api_key = self.api_key\n self._save_embedding_metadata(\n kb_path=kb_path,\n embedding_model=embedding_model,\n api_key=api_key,\n )\n\n # Create vector store following Local DB component pattern\n await self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key)\n\n # Save KB files (using File Component storage patterns)\n self._save_kb_files(kb_path, config_list)\n\n # Build metadata response\n meta: dict[str, Any] = {\n \"kb_id\": str(uuid.uuid4()),\n \"kb_name\": self.knowledge_base,\n \"rows\": len(df_source),\n \"column_metadata\": column_metadata,\n \"path\": str(kb_path),\n \"config_columns\": len(config_list),\n \"timestamp\": datetime.now(tz=timezone.utc).isoformat(),\n }\n\n # Set status message\n self.status = f\"✅ KB **{self.knowledge_base}** saved · {len(df_source)} chunks.\"\n\n return Data(data=meta)\n\n except (OSError, ValueError, RuntimeError, KeyError) as e:\n self.log(f\"Error in KB ingestion: {e}\")\n self.status = f\"❌ KB ingestion failed: {e}\"\n return Data(data={\"error\": str(e), \"kb_name\": self.knowledge_base})\n\n async def _get_api_key_variable(self, field_value: dict[str, Any]):\n async with session_scope() as db:\n if not self.user_id:\n msg = \"User ID is required for fetching global variables.\"\n raise ValueError(msg)\n current_user = await get_user_by_id(db, self.user_id)\n if not current_user:\n msg = f\"User with ID {self.user_id} not found.\"\n raise ValueError(msg)\n variable_service = get_variable_service()\n\n # Process the api_key field variable\n return await variable_service.get_variable(\n user_id=current_user.id,\n name=field_value[\"03_api_key\"],\n field=\"\",\n session=db,\n )\n\n async def update_build_config(\n self,\n build_config: dotdict,\n field_value: Any,\n field_name: str | None = None,\n ) -> dotdict:\n \"\"\"Update build configuration based on provider selection.\"\"\"\n # Create a new knowledge base\n if field_name == \"knowledge_base\":\n async with session_scope() as db:\n if not self.user_id:\n msg = \"User ID is required for fetching knowledge base list.\"\n raise ValueError(msg)\n current_user = await get_user_by_id(db, self.user_id)\n if not current_user:\n msg = f\"User with ID {self.user_id} not found.\"\n raise ValueError(msg)\n kb_user = current_user.username\n if isinstance(field_value, dict) and \"01_new_kb_name\" in field_value:\n # Validate the knowledge base name - Make sure it follows these rules:\n if not self.is_valid_collection_name(field_value[\"01_new_kb_name\"]):\n msg = f\"Invalid knowledge base name: {field_value['01_new_kb_name']}\"\n raise ValueError(msg)\n\n api_key = field_value.get(\"03_api_key\", None)\n with contextlib.suppress(Exception):\n # If the API key is a variable, resolve it\n api_key = await self._get_api_key_variable(field_value)\n\n # Make sure api_key is a string\n if not isinstance(api_key, str):\n msg = \"API key must be a string.\"\n raise ValueError(msg)\n\n # We need to test the API Key one time against the embedding model\n embed_model = self._build_embeddings(embedding_model=field_value[\"02_embedding_model\"], api_key=api_key)\n\n # Try to generate a dummy embedding to validate the API key without blocking the event loop\n try:\n await asyncio.wait_for(\n asyncio.to_thread(embed_model.embed_query, \"test\"),\n timeout=10,\n )\n except TimeoutError as e:\n msg = \"Embedding validation timed out. Please verify network connectivity and key.\"\n raise ValueError(msg) from e\n except Exception as e:\n msg = f\"Embedding validation failed: {e!s}\"\n raise ValueError(msg) from e\n\n # Create the new knowledge base directory\n kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / field_value[\"01_new_kb_name\"]\n kb_path.mkdir(parents=True, exist_ok=True)\n\n # Save the embedding metadata\n build_config[\"knowledge_base\"][\"value\"] = field_value[\"01_new_kb_name\"]\n self._save_embedding_metadata(\n kb_path=kb_path,\n embedding_model=field_value[\"02_embedding_model\"],\n api_key=api_key,\n )\n\n # Update the knowledge base options dynamically\n build_config[\"knowledge_base\"][\"options\"] = await get_knowledge_bases(\n KNOWLEDGE_BASES_ROOT_PATH,\n user_id=self.user_id,\n )\n\n # If the selected knowledge base is not available, reset it\n if build_config[\"knowledge_base\"][\"value\"] not in build_config[\"knowledge_base\"][\"options\"]:\n build_config[\"knowledge_base\"][\"value\"] = None\n\n return build_config\n" }, "column_config": { "_input_type": "TableInput", diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Retrieval.json b/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Retrieval.json index ccd8793bc..c83495376 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Retrieval.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Knowledge Retrieval.json @@ -532,7 +532,7 @@ "last_updated": "2025-08-14T17:19:22.182Z", "legacy": false, "metadata": { - "code_hash": "ee2b66958f09", + "code_hash": "6fcf86be1aca", "module": "langflow.components.data.kb_retrieval.KBRetrievalComponent" }, "minimized": false, @@ -589,7 +589,7 @@ "show": true, "title_case": false, "type": "code", - "value": "import json\nfrom pathlib import Path\nfrom typing import Any\n\nfrom cryptography.fernet import InvalidToken\nfrom langchain_chroma import Chroma\nfrom loguru import logger\n\nfrom langflow.custom import Component\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, Output, SecretStrInput\nfrom langflow.schema.data import Data\nfrom langflow.schema.dataframe import DataFrame\nfrom langflow.services.auth.utils import decrypt_api_key\nfrom langflow.services.deps import get_settings_service\n\nsettings = get_settings_service().settings\nknowledge_directory = settings.knowledge_bases_dir\nif not knowledge_directory:\n msg = \"Knowledge bases directory is not set in the settings.\"\n raise ValueError(msg)\nKNOWLEDGE_BASES_ROOT_PATH = Path(knowledge_directory).expanduser()\n\n\nclass KBRetrievalComponent(Component):\n display_name = \"Knowledge Retrieval\"\n description = \"Search and retrieve data from knowledge.\"\n icon = \"database\"\n name = \"KBRetrieval\"\n\n inputs = [\n DropdownInput(\n name=\"knowledge_base\",\n display_name=\"Knowledge\",\n info=\"Select the knowledge to load data from.\",\n required=True,\n options=[\n str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(\".\") and d.is_dir()\n ]\n if KNOWLEDGE_BASES_ROOT_PATH.exists()\n else [],\n refresh_button=True,\n real_time_refresh=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Embedding Provider API Key\",\n info=\"API key for the embedding provider to generate embeddings.\",\n advanced=True,\n required=False,\n ),\n MessageTextInput(\n name=\"search_query\",\n display_name=\"Search Query\",\n info=\"Optional search query to filter knowledge base data.\",\n ),\n IntInput(\n name=\"top_k\",\n display_name=\"Top K Results\",\n info=\"Number of top results to return from the knowledge base.\",\n value=5,\n advanced=True,\n required=False,\n ),\n BoolInput(\n name=\"include_metadata\",\n display_name=\"Include Metadata\",\n info=\"Whether to include all metadata and embeddings in the output. If false, only content is returned.\",\n value=True,\n advanced=False,\n ),\n ]\n\n outputs = [\n Output(\n name=\"chroma_kb_data\",\n display_name=\"Results\",\n method=\"get_chroma_kb_data\",\n info=\"Returns the data from the selected knowledge base.\",\n ),\n ]\n\n def _get_knowledge_bases(self) -> list[str]:\n \"\"\"Retrieve a list of available knowledge bases.\n\n Returns:\n A list of knowledge base names.\n \"\"\"\n if not KNOWLEDGE_BASES_ROOT_PATH.exists():\n return []\n\n return [str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(\".\") and d.is_dir()]\n\n def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002\n if field_name == \"knowledge_base\":\n # Update the knowledge base options dynamically\n build_config[\"knowledge_base\"][\"options\"] = self._get_knowledge_bases()\n\n # If the selected knowledge base is not available, reset it\n if build_config[\"knowledge_base\"][\"value\"] not in build_config[\"knowledge_base\"][\"options\"]:\n build_config[\"knowledge_base\"][\"value\"] = None\n\n return build_config\n\n def _get_kb_metadata(self, kb_path: Path) -> dict:\n \"\"\"Load and process knowledge base metadata.\"\"\"\n metadata: dict[str, Any] = {}\n metadata_file = kb_path / \"embedding_metadata.json\"\n if not metadata_file.exists():\n logger.warning(f\"Embedding metadata file not found at {metadata_file}\")\n return metadata\n\n try:\n with metadata_file.open(\"r\", encoding=\"utf-8\") as f:\n metadata = json.load(f)\n except json.JSONDecodeError:\n logger.error(f\"Error decoding JSON from {metadata_file}\")\n return {}\n\n # Decrypt API key if it exists\n if \"api_key\" in metadata and metadata.get(\"api_key\"):\n settings_service = get_settings_service()\n try:\n decrypted_key = decrypt_api_key(metadata[\"api_key\"], settings_service)\n metadata[\"api_key\"] = decrypted_key\n except (InvalidToken, TypeError, ValueError) as e:\n logger.error(f\"Could not decrypt API key. Please provide it manually. Error: {e}\")\n metadata[\"api_key\"] = None\n return metadata\n\n def _build_embeddings(self, metadata: dict):\n \"\"\"Build embedding model from metadata.\"\"\"\n provider = metadata.get(\"embedding_provider\")\n model = metadata.get(\"embedding_model\")\n api_key = metadata.get(\"api_key\")\n chunk_size = metadata.get(\"chunk_size\")\n\n # If user provided a key in the input, it overrides the stored one.\n if self.api_key and self.api_key.get_secret_value():\n api_key = self.api_key.get_secret_value()\n\n # Handle various providers\n if provider == \"OpenAI\":\n from langchain_openai import OpenAIEmbeddings\n\n if not api_key:\n msg = \"OpenAI API key is required. Provide it in the component's advanced settings.\"\n raise ValueError(msg)\n return OpenAIEmbeddings(\n model=model,\n api_key=api_key,\n chunk_size=chunk_size,\n )\n if provider == \"HuggingFace\":\n from langchain_huggingface import HuggingFaceEmbeddings\n\n return HuggingFaceEmbeddings(\n model=model,\n )\n if provider == \"Cohere\":\n from langchain_cohere import CohereEmbeddings\n\n if not api_key:\n msg = \"Cohere API key is required when using Cohere provider\"\n raise ValueError(msg)\n return CohereEmbeddings(\n model=model,\n cohere_api_key=api_key,\n )\n if provider == \"Custom\":\n # For custom embedding models, we would need additional configuration\n msg = \"Custom embedding models not yet supported\"\n raise NotImplementedError(msg)\n # Add other providers here if they become supported in ingest\n msg = f\"Embedding provider '{provider}' is not supported for retrieval.\"\n raise NotImplementedError(msg)\n\n def get_chroma_kb_data(self) -> DataFrame:\n \"\"\"Retrieve data from the selected knowledge base by reading the Chroma collection.\n\n Returns:\n A DataFrame containing the data rows from the knowledge base.\n \"\"\"\n kb_path = KNOWLEDGE_BASES_ROOT_PATH / self.knowledge_base\n\n metadata = self._get_kb_metadata(kb_path)\n if not metadata:\n msg = f\"Metadata not found for knowledge base: {self.knowledge_base}. Ensure it has been indexed.\"\n raise ValueError(msg)\n\n # Build the embedder for the knowledge base\n embedding_function = self._build_embeddings(metadata)\n\n # Load vector store\n chroma = Chroma(\n persist_directory=str(kb_path),\n embedding_function=embedding_function,\n collection_name=self.knowledge_base,\n )\n\n # If a search query is provided, perform a similarity search\n if self.search_query:\n # Use the search query to perform a similarity search\n logger.info(f\"Performing similarity search with query: {self.search_query}\")\n results = chroma.similarity_search_with_score(\n query=self.search_query or \"\",\n k=self.top_k,\n )\n else:\n results = chroma.similarity_search(\n query=self.search_query or \"\",\n k=self.top_k,\n )\n\n # For each result, make it a tuple to match the expected output format\n results = [(doc, 0) for doc in results] # Assign a dummy score of 0\n\n # If metadata is enabled, get embeddings for the results\n id_to_embedding = {}\n if self.include_metadata and results:\n doc_ids = [doc[0].metadata.get(\"_id\") for doc in results if doc[0].metadata.get(\"_id\")]\n\n # Only proceed if we have valid document IDs\n if doc_ids:\n # Access underlying client to get embeddings\n collection = chroma._client.get_collection(name=self.knowledge_base)\n embeddings_result = collection.get(where={\"_id\": {\"$in\": doc_ids}}, include=[\"embeddings\", \"metadatas\"])\n\n # Create a mapping from document ID to embedding\n for i, metadata in enumerate(embeddings_result.get(\"metadatas\", [])):\n if metadata and \"_id\" in metadata:\n id_to_embedding[metadata[\"_id\"]] = embeddings_result[\"embeddings\"][i]\n\n # Build output data based on include_metadata setting\n data_list = []\n for doc in results:\n if self.include_metadata:\n # Include all metadata, embeddings, and content\n kwargs = {\n \"content\": doc[0].page_content,\n **doc[0].metadata,\n }\n if self.search_query:\n kwargs[\"_score\"] = -1 * doc[1]\n kwargs[\"_embeddings\"] = id_to_embedding.get(doc[0].metadata.get(\"_id\"))\n else:\n # Only include content\n kwargs = {\n \"content\": doc[0].page_content,\n }\n\n data_list.append(Data(**kwargs))\n\n # Return the DataFrame containing the data\n return DataFrame(data=data_list)\n" + "value": "import json\nfrom pathlib import Path\nfrom typing import Any\n\nfrom cryptography.fernet import InvalidToken\nfrom langchain_chroma import Chroma\nfrom loguru import logger\nfrom pydantic import SecretStr\n\nfrom langflow.base.data.kb_utils import get_knowledge_bases\nfrom langflow.custom import Component\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, Output, SecretStrInput\nfrom langflow.schema.data import Data\nfrom langflow.schema.dataframe import DataFrame\nfrom langflow.services.auth.utils import decrypt_api_key\nfrom langflow.services.database.models.user.crud import get_user_by_id\nfrom langflow.services.deps import get_settings_service, session_scope\n\nsettings = get_settings_service().settings\nknowledge_directory = settings.knowledge_bases_dir\nif not knowledge_directory:\n msg = \"Knowledge bases directory is not set in the settings.\"\n raise ValueError(msg)\nKNOWLEDGE_BASES_ROOT_PATH = Path(knowledge_directory).expanduser()\n\n\nclass KBRetrievalComponent(Component):\n display_name = \"Knowledge Retrieval\"\n description = \"Search and retrieve data from knowledge.\"\n icon = \"database\"\n name = \"KBRetrieval\"\n\n inputs = [\n DropdownInput(\n name=\"knowledge_base\",\n display_name=\"Knowledge\",\n info=\"Select the knowledge to load data from.\",\n required=True,\n options=[],\n refresh_button=True,\n real_time_refresh=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Embedding Provider API Key\",\n info=\"API key for the embedding provider to generate embeddings.\",\n advanced=True,\n required=False,\n ),\n MessageTextInput(\n name=\"search_query\",\n display_name=\"Search Query\",\n info=\"Optional search query to filter knowledge base data.\",\n ),\n IntInput(\n name=\"top_k\",\n display_name=\"Top K Results\",\n info=\"Number of top results to return from the knowledge base.\",\n value=5,\n advanced=True,\n required=False,\n ),\n BoolInput(\n name=\"include_metadata\",\n display_name=\"Include Metadata\",\n info=\"Whether to include all metadata and embeddings in the output. If false, only content is returned.\",\n value=True,\n advanced=False,\n ),\n ]\n\n outputs = [\n Output(\n name=\"chroma_kb_data\",\n display_name=\"Results\",\n method=\"get_chroma_kb_data\",\n info=\"Returns the data from the selected knowledge base.\",\n ),\n ]\n\n async def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002\n if field_name == \"knowledge_base\":\n # Update the knowledge base options dynamically\n build_config[\"knowledge_base\"][\"options\"] = await get_knowledge_bases(\n KNOWLEDGE_BASES_ROOT_PATH,\n user_id=self.user_id, # Use the user_id from the component context\n )\n\n # If the selected knowledge base is not available, reset it\n if build_config[\"knowledge_base\"][\"value\"] not in build_config[\"knowledge_base\"][\"options\"]:\n build_config[\"knowledge_base\"][\"value\"] = None\n\n return build_config\n\n def _get_kb_metadata(self, kb_path: Path) -> dict:\n \"\"\"Load and process knowledge base metadata.\"\"\"\n metadata: dict[str, Any] = {}\n metadata_file = kb_path / \"embedding_metadata.json\"\n if not metadata_file.exists():\n logger.warning(f\"Embedding metadata file not found at {metadata_file}\")\n return metadata\n\n try:\n with metadata_file.open(\"r\", encoding=\"utf-8\") as f:\n metadata = json.load(f)\n except json.JSONDecodeError:\n logger.error(f\"Error decoding JSON from {metadata_file}\")\n return {}\n\n # Decrypt API key if it exists\n if \"api_key\" in metadata and metadata.get(\"api_key\"):\n settings_service = get_settings_service()\n try:\n decrypted_key = decrypt_api_key(metadata[\"api_key\"], settings_service)\n metadata[\"api_key\"] = decrypted_key\n except (InvalidToken, TypeError, ValueError) as e:\n logger.error(f\"Could not decrypt API key. Please provide it manually. Error: {e}\")\n metadata[\"api_key\"] = None\n return metadata\n\n def _build_embeddings(self, metadata: dict):\n \"\"\"Build embedding model from metadata.\"\"\"\n runtime_api_key = self.api_key.get_secret_value() if isinstance(self.api_key, SecretStr) else self.api_key\n provider = metadata.get(\"embedding_provider\")\n model = metadata.get(\"embedding_model\")\n api_key = runtime_api_key or metadata.get(\"api_key\")\n chunk_size = metadata.get(\"chunk_size\")\n\n # Handle various providers\n if provider == \"OpenAI\":\n from langchain_openai import OpenAIEmbeddings\n\n if not api_key:\n msg = \"OpenAI API key is required. Provide it in the component's advanced settings.\"\n raise ValueError(msg)\n return OpenAIEmbeddings(\n model=model,\n api_key=api_key,\n chunk_size=chunk_size,\n )\n if provider == \"HuggingFace\":\n from langchain_huggingface import HuggingFaceEmbeddings\n\n return HuggingFaceEmbeddings(\n model=model,\n )\n if provider == \"Cohere\":\n from langchain_cohere import CohereEmbeddings\n\n if not api_key:\n msg = \"Cohere API key is required when using Cohere provider\"\n raise ValueError(msg)\n return CohereEmbeddings(\n model=model,\n cohere_api_key=api_key,\n )\n if provider == \"Custom\":\n # For custom embedding models, we would need additional configuration\n msg = \"Custom embedding models not yet supported\"\n raise NotImplementedError(msg)\n # Add other providers here if they become supported in ingest\n msg = f\"Embedding provider '{provider}' is not supported for retrieval.\"\n raise NotImplementedError(msg)\n\n async def get_chroma_kb_data(self) -> DataFrame:\n \"\"\"Retrieve data from the selected knowledge base by reading the Chroma collection.\n\n Returns:\n A DataFrame containing the data rows from the knowledge base.\n \"\"\"\n # Get the current user\n async with session_scope() as db:\n if not self.user_id:\n msg = \"User ID is required for fetching Knowledge Base data.\"\n raise ValueError(msg)\n current_user = await get_user_by_id(db, self.user_id)\n if not current_user:\n msg = f\"User with ID {self.user_id} not found.\"\n raise ValueError(msg)\n kb_user = current_user.username\n kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / self.knowledge_base\n\n metadata = self._get_kb_metadata(kb_path)\n if not metadata:\n msg = f\"Metadata not found for knowledge base: {self.knowledge_base}. Ensure it has been indexed.\"\n raise ValueError(msg)\n\n # Build the embedder for the knowledge base\n embedding_function = self._build_embeddings(metadata)\n\n # Load vector store\n chroma = Chroma(\n persist_directory=str(kb_path),\n embedding_function=embedding_function,\n collection_name=self.knowledge_base,\n )\n\n # If a search query is provided, perform a similarity search\n if self.search_query:\n # Use the search query to perform a similarity search\n logger.info(f\"Performing similarity search with query: {self.search_query}\")\n results = chroma.similarity_search_with_score(\n query=self.search_query or \"\",\n k=self.top_k,\n )\n else:\n results = chroma.similarity_search(\n query=self.search_query or \"\",\n k=self.top_k,\n )\n\n # For each result, make it a tuple to match the expected output format\n results = [(doc, 0) for doc in results] # Assign a dummy score of 0\n\n # If metadata is enabled, get embeddings for the results\n id_to_embedding = {}\n if self.include_metadata and results:\n doc_ids = [doc[0].metadata.get(\"_id\") for doc in results if doc[0].metadata.get(\"_id\")]\n\n # Only proceed if we have valid document IDs\n if doc_ids:\n # Access underlying client to get embeddings\n collection = chroma._client.get_collection(name=self.knowledge_base)\n embeddings_result = collection.get(where={\"_id\": {\"$in\": doc_ids}}, include=[\"embeddings\", \"metadatas\"])\n\n # Create a mapping from document ID to embedding\n for i, metadata in enumerate(embeddings_result.get(\"metadatas\", [])):\n if metadata and \"_id\" in metadata:\n id_to_embedding[metadata[\"_id\"]] = embeddings_result[\"embeddings\"][i]\n\n # Build output data based on include_metadata setting\n data_list = []\n for doc in results:\n if self.include_metadata:\n # Include all metadata, embeddings, and content\n kwargs = {\n \"content\": doc[0].page_content,\n **doc[0].metadata,\n }\n if self.search_query:\n kwargs[\"_score\"] = -1 * doc[1]\n kwargs[\"_embeddings\"] = id_to_embedding.get(doc[0].metadata.get(\"_id\"))\n else:\n # Only include content\n kwargs = {\n \"content\": doc[0].page_content,\n }\n\n data_list.append(Data(**kwargs))\n\n # Return the DataFrame containing the data\n return DataFrame(data=data_list)\n" }, "include_metadata": { "_input_type": "BoolInput", diff --git a/src/backend/base/langflow/initial_setup/starter_projects/News Aggregator.json b/src/backend/base/langflow/initial_setup/starter_projects/News Aggregator.json index 764d98adb..a2674c80e 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/News Aggregator.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/News Aggregator.json @@ -1208,7 +1208,7 @@ "legacy": false, "lf_version": "1.4.3", "metadata": { - "code_hash": "d9af728ce02a", + "code_hash": "1bcc6faaaa62", "module": "langflow.components.processing.save_file.SaveToFileComponent" }, "minimized": false, @@ -1232,23 +1232,6 @@ "pinned": false, "template": { "_type": "Component", - "api_key": { - "_input_type": "SecretStrInput", - "advanced": true, - "display_name": "Langflow API Key", - "dynamic": false, - "info": "Langflow API key for authentication when saving the file.", - "input_types": [], - "load_from_db": true, - "name": "api_key", - "password": true, - "placeholder": "", - "required": false, - "show": true, - "title_case": false, - "type": "str", - "value": "" - }, "code": { "advanced": true, "dynamic": true, @@ -1265,7 +1248,7 @@ "show": true, "title_case": false, "type": "code", - "value": "import json\nfrom collections.abc import AsyncIterator, Iterator\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nimport orjson\nimport pandas as pd\nfrom fastapi import UploadFile\nfrom fastapi.encoders import jsonable_encoder\n\nfrom langflow.api.v2.files import upload_user_file\nfrom langflow.custom import Component\nfrom langflow.io import DropdownInput, HandleInput, SecretStrInput, StrInput\nfrom langflow.schema import Data, DataFrame, Message\nfrom langflow.services.auth.utils import create_user_longterm_token, get_current_user\nfrom langflow.services.database.models.user.crud import get_user_by_id\nfrom langflow.services.deps import get_session, get_settings_service, get_storage_service\nfrom langflow.template.field.base import Output\n\nif TYPE_CHECKING:\n from langflow.services.database.models.user.model import User\n\n\nclass SaveToFileComponent(Component):\n display_name = \"Save File\"\n description = \"Save data to a local file in the selected format.\"\n documentation: str = \"https://docs.langflow.org/components-processing#save-file\"\n icon = \"save\"\n name = \"SaveToFile\"\n\n # File format options for different types\n DATA_FORMAT_CHOICES = [\"csv\", \"excel\", \"json\", \"markdown\"]\n MESSAGE_FORMAT_CHOICES = [\"txt\", \"json\", \"markdown\"]\n\n inputs = [\n HandleInput(\n name=\"input\",\n display_name=\"Input\",\n info=\"The input to save.\",\n dynamic=True,\n input_types=[\"Data\", \"DataFrame\", \"Message\"],\n required=True,\n ),\n StrInput(\n name=\"file_name\",\n display_name=\"File Name\",\n info=\"Name file will be saved as (without extension).\",\n required=True,\n ),\n DropdownInput(\n name=\"file_format\",\n display_name=\"File Format\",\n options=list(dict.fromkeys(DATA_FORMAT_CHOICES + MESSAGE_FORMAT_CHOICES)),\n info=\"Select the file format to save the input. If not provided, the default format will be used.\",\n value=\"\",\n advanced=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Langflow API Key\",\n info=\"Langflow API key for authentication when saving the file.\",\n required=False,\n advanced=True,\n ),\n ]\n\n outputs = [Output(display_name=\"File Path\", name=\"message\", method=\"save_to_file\")]\n\n async def save_to_file(self) -> Message:\n \"\"\"Save the input to a file and upload it, returning a confirmation message.\"\"\"\n # Validate inputs\n if not self.file_name:\n msg = \"File name must be provided.\"\n raise ValueError(msg)\n if not self._get_input_type():\n msg = \"Input type is not set.\"\n raise ValueError(msg)\n\n # Validate file format based on input type\n file_format = self.file_format or self._get_default_format()\n allowed_formats = (\n self.MESSAGE_FORMAT_CHOICES if self._get_input_type() == \"Message\" else self.DATA_FORMAT_CHOICES\n )\n if file_format not in allowed_formats:\n msg = f\"Invalid file format '{file_format}' for {self._get_input_type()}. Allowed: {allowed_formats}\"\n raise ValueError(msg)\n\n # Prepare file path\n file_path = Path(self.file_name).expanduser()\n if not file_path.parent.exists():\n file_path.parent.mkdir(parents=True, exist_ok=True)\n file_path = self._adjust_file_path_with_format(file_path, file_format)\n\n # Save the input to file based on type\n if self._get_input_type() == \"DataFrame\":\n confirmation = self._save_dataframe(self.input, file_path, file_format)\n elif self._get_input_type() == \"Data\":\n confirmation = self._save_data(self.input, file_path, file_format)\n elif self._get_input_type() == \"Message\":\n confirmation = await self._save_message(self.input, file_path, file_format)\n else:\n msg = f\"Unsupported input type: {self._get_input_type()}\"\n raise ValueError(msg)\n\n # Upload the saved file\n await self._upload_file(file_path)\n\n # Return the final file path and confirmation message\n final_path = Path.cwd() / file_path if not file_path.is_absolute() else file_path\n\n return Message(text=f\"{confirmation} at {final_path}\")\n\n def _get_input_type(self) -> str:\n \"\"\"Determine the input type based on the provided input.\"\"\"\n # Use exact type checking (type() is) instead of isinstance() to avoid inheritance issues.\n # Since Message inherits from Data, isinstance(message, Data) would return True for Message objects,\n # causing Message inputs to be incorrectly identified as Data type.\n if type(self.input) is DataFrame:\n return \"DataFrame\"\n if type(self.input) is Message:\n return \"Message\"\n if type(self.input) is Data:\n return \"Data\"\n msg = f\"Unsupported input type: {type(self.input)}\"\n raise ValueError(msg)\n\n def _get_default_format(self) -> str:\n \"\"\"Return the default file format based on input type.\"\"\"\n if self._get_input_type() == \"DataFrame\":\n return \"csv\"\n if self._get_input_type() == \"Data\":\n return \"json\"\n if self._get_input_type() == \"Message\":\n return \"json\"\n return \"json\" # Fallback\n\n def _adjust_file_path_with_format(self, path: Path, fmt: str) -> Path:\n \"\"\"Adjust the file path to include the correct extension.\"\"\"\n file_extension = path.suffix.lower().lstrip(\".\")\n if fmt == \"excel\":\n return Path(f\"{path}.xlsx\").expanduser() if file_extension not in [\"xlsx\", \"xls\"] else path\n return Path(f\"{path}.{fmt}\").expanduser() if file_extension != fmt else path\n\n async def _upload_file(self, file_path: Path) -> None:\n \"\"\"Upload the saved file using the upload_user_file service.\"\"\"\n if not file_path.exists():\n msg = f\"File not found: {file_path}\"\n raise FileNotFoundError(msg)\n\n with file_path.open(\"rb\") as f:\n async for db in get_session():\n # TODO: In 1.6, this may need to be removed or adjusted\n # Try to get the super user token, if possible\n current_user: User | None = None\n if self.api_key:\n current_user = await get_current_user(\n token=\"\",\n query_param=self.api_key,\n header_param=\"\",\n db=db,\n )\n else:\n user_id, _ = await create_user_longterm_token(db)\n current_user = await get_user_by_id(db, user_id)\n\n # Fail if the user is not found\n if not current_user:\n msg = \"User not found. Please provide a valid API key or ensure the user exists.\"\n raise ValueError(msg)\n\n await upload_user_file(\n file=UploadFile(filename=file_path.name, file=f, size=file_path.stat().st_size),\n session=db,\n current_user=current_user,\n storage_service=get_storage_service(),\n settings_service=get_settings_service(),\n )\n\n def _save_dataframe(self, dataframe: DataFrame, path: Path, fmt: str) -> str:\n \"\"\"Save a DataFrame to the specified file format.\"\"\"\n if fmt == \"csv\":\n dataframe.to_csv(path, index=False)\n elif fmt == \"excel\":\n dataframe.to_excel(path, index=False, engine=\"openpyxl\")\n elif fmt == \"json\":\n dataframe.to_json(path, orient=\"records\", indent=2)\n elif fmt == \"markdown\":\n path.write_text(dataframe.to_markdown(index=False), encoding=\"utf-8\")\n else:\n msg = f\"Unsupported DataFrame format: {fmt}\"\n raise ValueError(msg)\n return f\"DataFrame saved successfully as '{path}'\"\n\n def _save_data(self, data: Data, path: Path, fmt: str) -> str:\n \"\"\"Save a Data object to the specified file format.\"\"\"\n if fmt == \"csv\":\n pd.DataFrame(data.data).to_csv(path, index=False)\n elif fmt == \"excel\":\n pd.DataFrame(data.data).to_excel(path, index=False, engine=\"openpyxl\")\n elif fmt == \"json\":\n path.write_text(\n orjson.dumps(jsonable_encoder(data.data), option=orjson.OPT_INDENT_2).decode(\"utf-8\"), encoding=\"utf-8\"\n )\n elif fmt == \"markdown\":\n path.write_text(pd.DataFrame(data.data).to_markdown(index=False), encoding=\"utf-8\")\n else:\n msg = f\"Unsupported Data format: {fmt}\"\n raise ValueError(msg)\n return f\"Data saved successfully as '{path}'\"\n\n async def _save_message(self, message: Message, path: Path, fmt: str) -> str:\n \"\"\"Save a Message to the specified file format, handling async iterators.\"\"\"\n content = \"\"\n if message.text is None:\n content = \"\"\n elif isinstance(message.text, AsyncIterator):\n async for item in message.text:\n content += str(item) + \" \"\n content = content.strip()\n elif isinstance(message.text, Iterator):\n content = \" \".join(str(item) for item in message.text)\n else:\n content = str(message.text)\n\n if fmt == \"txt\":\n path.write_text(content, encoding=\"utf-8\")\n elif fmt == \"json\":\n path.write_text(json.dumps({\"message\": content}, indent=2), encoding=\"utf-8\")\n elif fmt == \"markdown\":\n path.write_text(f\"**Message:**\\n\\n{content}\", encoding=\"utf-8\")\n else:\n msg = f\"Unsupported Message format: {fmt}\"\n raise ValueError(msg)\n return f\"Message saved successfully as '{path}'\"\n" + "value": "import json\nfrom collections.abc import AsyncIterator, Iterator\nfrom pathlib import Path\n\nimport orjson\nimport pandas as pd\nfrom fastapi import UploadFile\nfrom fastapi.encoders import jsonable_encoder\n\nfrom langflow.api.v2.files import upload_user_file\nfrom langflow.custom import Component\nfrom langflow.io import DropdownInput, HandleInput, StrInput\nfrom langflow.schema import Data, DataFrame, Message\nfrom langflow.services.database.models.user.crud import get_user_by_id\nfrom langflow.services.deps import get_settings_service, get_storage_service, session_scope\nfrom langflow.template.field.base import Output\n\n\nclass SaveToFileComponent(Component):\n display_name = \"Save File\"\n description = \"Save data to a local file in the selected format.\"\n documentation: str = \"https://docs.langflow.org/components-processing#save-file\"\n icon = \"save\"\n name = \"SaveToFile\"\n\n # File format options for different types\n DATA_FORMAT_CHOICES = [\"csv\", \"excel\", \"json\", \"markdown\"]\n MESSAGE_FORMAT_CHOICES = [\"txt\", \"json\", \"markdown\"]\n\n inputs = [\n HandleInput(\n name=\"input\",\n display_name=\"Input\",\n info=\"The input to save.\",\n dynamic=True,\n input_types=[\"Data\", \"DataFrame\", \"Message\"],\n required=True,\n ),\n StrInput(\n name=\"file_name\",\n display_name=\"File Name\",\n info=\"Name file will be saved as (without extension).\",\n required=True,\n ),\n DropdownInput(\n name=\"file_format\",\n display_name=\"File Format\",\n options=list(dict.fromkeys(DATA_FORMAT_CHOICES + MESSAGE_FORMAT_CHOICES)),\n info=\"Select the file format to save the input. If not provided, the default format will be used.\",\n value=\"\",\n advanced=True,\n ),\n ]\n\n outputs = [Output(display_name=\"File Path\", name=\"message\", method=\"save_to_file\")]\n\n async def save_to_file(self) -> Message:\n \"\"\"Save the input to a file and upload it, returning a confirmation message.\"\"\"\n # Validate inputs\n if not self.file_name:\n msg = \"File name must be provided.\"\n raise ValueError(msg)\n if not self._get_input_type():\n msg = \"Input type is not set.\"\n raise ValueError(msg)\n\n # Validate file format based on input type\n file_format = self.file_format or self._get_default_format()\n allowed_formats = (\n self.MESSAGE_FORMAT_CHOICES if self._get_input_type() == \"Message\" else self.DATA_FORMAT_CHOICES\n )\n if file_format not in allowed_formats:\n msg = f\"Invalid file format '{file_format}' for {self._get_input_type()}. Allowed: {allowed_formats}\"\n raise ValueError(msg)\n\n # Prepare file path\n file_path = Path(self.file_name).expanduser()\n if not file_path.parent.exists():\n file_path.parent.mkdir(parents=True, exist_ok=True)\n file_path = self._adjust_file_path_with_format(file_path, file_format)\n\n # Save the input to file based on type\n if self._get_input_type() == \"DataFrame\":\n confirmation = self._save_dataframe(self.input, file_path, file_format)\n elif self._get_input_type() == \"Data\":\n confirmation = self._save_data(self.input, file_path, file_format)\n elif self._get_input_type() == \"Message\":\n confirmation = await self._save_message(self.input, file_path, file_format)\n else:\n msg = f\"Unsupported input type: {self._get_input_type()}\"\n raise ValueError(msg)\n\n # Upload the saved file\n await self._upload_file(file_path)\n\n # Return the final file path and confirmation message\n final_path = Path.cwd() / file_path if not file_path.is_absolute() else file_path\n\n return Message(text=f\"{confirmation} at {final_path}\")\n\n def _get_input_type(self) -> str:\n \"\"\"Determine the input type based on the provided input.\"\"\"\n # Use exact type checking (type() is) instead of isinstance() to avoid inheritance issues.\n # Since Message inherits from Data, isinstance(message, Data) would return True for Message objects,\n # causing Message inputs to be incorrectly identified as Data type.\n if type(self.input) is DataFrame:\n return \"DataFrame\"\n if type(self.input) is Message:\n return \"Message\"\n if type(self.input) is Data:\n return \"Data\"\n msg = f\"Unsupported input type: {type(self.input)}\"\n raise ValueError(msg)\n\n def _get_default_format(self) -> str:\n \"\"\"Return the default file format based on input type.\"\"\"\n if self._get_input_type() == \"DataFrame\":\n return \"csv\"\n if self._get_input_type() == \"Data\":\n return \"json\"\n if self._get_input_type() == \"Message\":\n return \"json\"\n return \"json\" # Fallback\n\n def _adjust_file_path_with_format(self, path: Path, fmt: str) -> Path:\n \"\"\"Adjust the file path to include the correct extension.\"\"\"\n file_extension = path.suffix.lower().lstrip(\".\")\n if fmt == \"excel\":\n return Path(f\"{path}.xlsx\").expanduser() if file_extension not in [\"xlsx\", \"xls\"] else path\n return Path(f\"{path}.{fmt}\").expanduser() if file_extension != fmt else path\n\n async def _upload_file(self, file_path: Path) -> None:\n \"\"\"Upload the saved file using the upload_user_file service.\"\"\"\n if not file_path.exists():\n msg = f\"File not found: {file_path}\"\n raise FileNotFoundError(msg)\n\n with file_path.open(\"rb\") as f:\n async with session_scope() as db:\n if not self.user_id:\n msg = \"User ID is required for file saving.\"\n raise ValueError(msg)\n current_user = await get_user_by_id(db, self.user_id)\n\n await upload_user_file(\n file=UploadFile(filename=file_path.name, file=f, size=file_path.stat().st_size),\n session=db,\n current_user=current_user,\n storage_service=get_storage_service(),\n settings_service=get_settings_service(),\n )\n\n def _save_dataframe(self, dataframe: DataFrame, path: Path, fmt: str) -> str:\n \"\"\"Save a DataFrame to the specified file format.\"\"\"\n if fmt == \"csv\":\n dataframe.to_csv(path, index=False)\n elif fmt == \"excel\":\n dataframe.to_excel(path, index=False, engine=\"openpyxl\")\n elif fmt == \"json\":\n dataframe.to_json(path, orient=\"records\", indent=2)\n elif fmt == \"markdown\":\n path.write_text(dataframe.to_markdown(index=False), encoding=\"utf-8\")\n else:\n msg = f\"Unsupported DataFrame format: {fmt}\"\n raise ValueError(msg)\n return f\"DataFrame saved successfully as '{path}'\"\n\n def _save_data(self, data: Data, path: Path, fmt: str) -> str:\n \"\"\"Save a Data object to the specified file format.\"\"\"\n if fmt == \"csv\":\n pd.DataFrame(data.data).to_csv(path, index=False)\n elif fmt == \"excel\":\n pd.DataFrame(data.data).to_excel(path, index=False, engine=\"openpyxl\")\n elif fmt == \"json\":\n path.write_text(\n orjson.dumps(jsonable_encoder(data.data), option=orjson.OPT_INDENT_2).decode(\"utf-8\"), encoding=\"utf-8\"\n )\n elif fmt == \"markdown\":\n path.write_text(pd.DataFrame(data.data).to_markdown(index=False), encoding=\"utf-8\")\n else:\n msg = f\"Unsupported Data format: {fmt}\"\n raise ValueError(msg)\n return f\"Data saved successfully as '{path}'\"\n\n async def _save_message(self, message: Message, path: Path, fmt: str) -> str:\n \"\"\"Save a Message to the specified file format, handling async iterators.\"\"\"\n content = \"\"\n if message.text is None:\n content = \"\"\n elif isinstance(message.text, AsyncIterator):\n async for item in message.text:\n content += str(item) + \" \"\n content = content.strip()\n elif isinstance(message.text, Iterator):\n content = \" \".join(str(item) for item in message.text)\n else:\n content = str(message.text)\n\n if fmt == \"txt\":\n path.write_text(content, encoding=\"utf-8\")\n elif fmt == \"json\":\n path.write_text(json.dumps({\"message\": content}, indent=2), encoding=\"utf-8\")\n elif fmt == \"markdown\":\n path.write_text(f\"**Message:**\\n\\n{content}\", encoding=\"utf-8\")\n else:\n msg = f\"Unsupported Message format: {fmt}\"\n raise ValueError(msg)\n return f\"Message saved successfully as '{path}'\"\n" }, "file_format": { "_input_type": "DropdownInput", diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Nvidia Remix.json b/src/backend/base/langflow/initial_setup/starter_projects/Nvidia Remix.json index 558965faf..7f5f195ec 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Nvidia Remix.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Nvidia Remix.json @@ -2518,7 +2518,7 @@ "legacy": false, "lf_version": "1.4.2", "metadata": { - "code_hash": "6839fa3cae99", + "code_hash": "b0a921d4ce11", "module": "langflow.components.agents.mcp_component.MCPToolsComponent" }, "minimized": false, @@ -2545,23 +2545,6 @@ "score": 0.003932426697386162, "template": { "_type": "Component", - "api_key": { - "_input_type": "SecretStrInput", - "advanced": true, - "display_name": "Langflow API Key", - "dynamic": false, - "info": "Langflow API key for authentication when fetching MCP servers and tools.", - "input_types": [], - "load_from_db": true, - "name": "api_key", - "password": true, - "placeholder": "", - "required": false, - "show": true, - "title_case": false, - "type": "str", - "value": "" - }, "code": { "advanced": true, "dynamic": true, @@ -2578,7 +2561,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from __future__ import annotations\n\nimport asyncio\nimport uuid\nfrom typing import Any\n\nfrom langchain_core.tools import StructuredTool # noqa: TC002\n\nfrom langflow.api.v2.mcp import get_server\nfrom langflow.base.agents.utils import maybe_unflatten_dict, safe_cache_get, safe_cache_set\nfrom langflow.base.mcp.util import (\n MCPSseClient,\n MCPStdioClient,\n create_input_schema_from_json_schema,\n update_tools,\n)\nfrom langflow.custom.custom_component.component_with_cache import ComponentWithCache\nfrom langflow.inputs.inputs import InputTypes # noqa: TC001\nfrom langflow.io import DropdownInput, McpInput, MessageTextInput, Output, SecretStrInput\nfrom langflow.io.schema import flatten_schema, schema_to_langflow_inputs\nfrom langflow.logging import logger\nfrom langflow.schema.dataframe import DataFrame\nfrom langflow.schema.message import Message\n\n# Import get_server from the backend API\nfrom langflow.services.auth.utils import create_user_longterm_token, get_current_user\nfrom langflow.services.database.models.user.crud import get_user_by_id\nfrom langflow.services.deps import get_session, get_settings_service, get_storage_service\n\n\nclass MCPToolsComponent(ComponentWithCache):\n schema_inputs: list = []\n tools: list[StructuredTool] = []\n _not_load_actions: bool = False\n _tool_cache: dict = {}\n _last_selected_server: str | None = None # Cache for the last selected server\n\n def __init__(self, **data) -> None:\n super().__init__(**data)\n # Initialize cache keys to avoid CacheMiss when accessing them\n self._ensure_cache_structure()\n\n # Initialize clients with access to the component cache\n self.stdio_client: MCPStdioClient = MCPStdioClient(component_cache=self._shared_component_cache)\n self.sse_client: MCPSseClient = MCPSseClient(component_cache=self._shared_component_cache)\n\n def _ensure_cache_structure(self):\n \"\"\"Ensure the cache has the required structure.\"\"\"\n # Check if servers key exists and is not CacheMiss\n servers_value = safe_cache_get(self._shared_component_cache, \"servers\")\n if servers_value is None:\n safe_cache_set(self._shared_component_cache, \"servers\", {})\n\n # Check if last_selected_server key exists and is not CacheMiss\n last_server_value = safe_cache_get(self._shared_component_cache, \"last_selected_server\")\n if last_server_value is None:\n safe_cache_set(self._shared_component_cache, \"last_selected_server\", \"\")\n\n default_keys: list[str] = [\n \"code\",\n \"_type\",\n \"tool_mode\",\n \"tool_placeholder\",\n \"mcp_server\",\n \"tool\",\n ]\n\n display_name = \"MCP Tools\"\n description = \"Connect to an MCP server to use its tools.\"\n documentation: str = \"https://docs.langflow.org/mcp-client\"\n icon = \"Mcp\"\n name = \"MCPTools\"\n\n inputs = [\n McpInput(\n name=\"mcp_server\",\n display_name=\"MCP Server\",\n info=\"Select the MCP Server that will be used by this component\",\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"tool\",\n display_name=\"Tool\",\n options=[],\n value=\"\",\n info=\"Select the tool to execute\",\n show=False,\n required=True,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"tool_placeholder\",\n display_name=\"Tool Placeholder\",\n info=\"Placeholder for the tool\",\n value=\"\",\n show=False,\n tool_mode=False,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Langflow API Key\",\n info=\"Langflow API key for authentication when fetching MCP servers and tools.\",\n required=False,\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Response\", name=\"response\", method=\"build_output\"),\n ]\n\n async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:\n \"\"\"Validate and process schema inputs for a tool.\"\"\"\n try:\n if not tool_obj or not hasattr(tool_obj, \"args_schema\"):\n msg = \"Invalid tool object or missing input schema\"\n raise ValueError(msg)\n\n flat_schema = flatten_schema(tool_obj.args_schema.schema())\n input_schema = create_input_schema_from_json_schema(flat_schema)\n if not input_schema:\n msg = f\"Empty input schema for tool '{tool_obj.name}'\"\n raise ValueError(msg)\n\n schema_inputs = schema_to_langflow_inputs(input_schema)\n if not schema_inputs:\n msg = f\"No input parameters defined for tool '{tool_obj.name}'\"\n logger.warning(msg)\n return []\n\n except Exception as e:\n msg = f\"Error validating schema inputs: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n else:\n return schema_inputs\n\n async def update_tool_list(self, mcp_server_value=None):\n # Accepts mcp_server_value as dict {name, config} or uses self.mcp_server\n mcp_server = mcp_server_value if mcp_server_value is not None else getattr(self, \"mcp_server\", None)\n server_name = None\n server_config_from_value = None\n if isinstance(mcp_server, dict):\n server_name = mcp_server.get(\"name\")\n server_config_from_value = mcp_server.get(\"config\")\n else:\n server_name = mcp_server\n if not server_name:\n self.tools = []\n return [], {\"name\": server_name, \"config\": server_config_from_value}\n\n # Use shared cache if available\n servers_cache = safe_cache_get(self._shared_component_cache, \"servers\", {})\n cached = servers_cache.get(server_name) if isinstance(servers_cache, dict) else None\n\n if cached is not None:\n self.tools = cached[\"tools\"]\n self.tool_names = cached[\"tool_names\"]\n self._tool_cache = cached[\"tool_cache\"]\n server_config_from_value = cached[\"config\"]\n return self.tools, {\"name\": server_name, \"config\": server_config_from_value}\n\n try:\n async for db in get_session():\n # TODO: In 1.6, this may need to be removed or adjusted\n # Try to get the super user token, if possible\n if self.api_key:\n current_user = await get_current_user(\n token=None,\n query_param=self.api_key,\n header_param=None,\n db=db,\n )\n else:\n user_id, _ = await create_user_longterm_token(db)\n current_user = await get_user_by_id(db, user_id)\n\n # Try to get server config from DB/API\n server_config = await get_server(\n server_name,\n current_user,\n db,\n storage_service=get_storage_service(),\n settings_service=get_settings_service(),\n )\n\n # If get_server returns empty but we have a config, use it\n if not server_config and server_config_from_value:\n server_config = server_config_from_value\n\n if not server_config:\n self.tools = []\n return [], {\"name\": server_name, \"config\": server_config}\n\n _, tool_list, tool_cache = await update_tools(\n server_name=server_name,\n server_config=server_config,\n mcp_stdio_client=self.stdio_client,\n mcp_sse_client=self.sse_client,\n )\n\n self.tool_names = [tool.name for tool in tool_list if hasattr(tool, \"name\")]\n self._tool_cache = tool_cache\n self.tools = tool_list\n # Cache the result using shared cache\n cache_data = {\n \"tools\": tool_list,\n \"tool_names\": self.tool_names,\n \"tool_cache\": tool_cache,\n \"config\": server_config,\n }\n\n # Safely update the servers cache\n current_servers_cache = safe_cache_get(self._shared_component_cache, \"servers\", {})\n if isinstance(current_servers_cache, dict):\n current_servers_cache[server_name] = cache_data\n safe_cache_set(self._shared_component_cache, \"servers\", current_servers_cache)\n\n return tool_list, {\"name\": server_name, \"config\": server_config}\n except (TimeoutError, asyncio.TimeoutError) as e:\n msg = f\"Timeout updating tool list: {e!s}\"\n logger.exception(msg)\n raise TimeoutError(msg) from e\n except Exception as e:\n msg = f\"Error updating tool list: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n\n async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:\n \"\"\"Toggle the visibility of connection-specific fields based on the selected mode.\"\"\"\n try:\n if field_name == \"tool\":\n try:\n if len(self.tools) == 0:\n try:\n self.tools, build_config[\"mcp_server\"][\"value\"] = await self.update_tool_list()\n build_config[\"tool\"][\"options\"] = [tool.name for tool in self.tools]\n build_config[\"tool\"][\"placeholder\"] = \"Select a tool\"\n except (TimeoutError, asyncio.TimeoutError) as e:\n msg = f\"Timeout updating tool list: {e!s}\"\n logger.exception(msg)\n if not build_config[\"tools_metadata\"][\"show\"]:\n build_config[\"tool\"][\"show\"] = True\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = \"\"\n build_config[\"tool\"][\"placeholder\"] = \"Timeout on MCP server\"\n else:\n build_config[\"tool\"][\"show\"] = False\n except ValueError:\n if not build_config[\"tools_metadata\"][\"show\"]:\n build_config[\"tool\"][\"show\"] = True\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = \"\"\n build_config[\"tool\"][\"placeholder\"] = \"Error on MCP Server\"\n else:\n build_config[\"tool\"][\"show\"] = False\n\n if field_value == \"\":\n return build_config\n tool_obj = None\n for tool in self.tools:\n if tool.name == field_value:\n tool_obj = tool\n break\n if tool_obj is None:\n msg = f\"Tool {field_value} not found in available tools: {self.tools}\"\n logger.warning(msg)\n return build_config\n await self._update_tool_config(build_config, field_value)\n except Exception as e:\n build_config[\"tool\"][\"options\"] = []\n msg = f\"Failed to update tools: {e!s}\"\n raise ValueError(msg) from e\n else:\n return build_config\n elif field_name == \"mcp_server\":\n if not field_value:\n build_config[\"tool\"][\"show\"] = False\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = \"\"\n build_config[\"tool\"][\"placeholder\"] = \"\"\n build_config[\"tool_placeholder\"][\"tool_mode\"] = False\n self.remove_non_default_keys(build_config)\n return build_config\n\n build_config[\"tool_placeholder\"][\"tool_mode\"] = True\n\n current_server_name = field_value.get(\"name\") if isinstance(field_value, dict) else field_value\n _last_selected_server = safe_cache_get(self._shared_component_cache, \"last_selected_server\", \"\")\n\n # To avoid unnecessary updates, only proceed if the server has actually changed\n if (_last_selected_server in (current_server_name, \"\")) and build_config[\"tool\"][\"show\"]:\n return build_config\n\n # Determine if \"Tool Mode\" is active by checking if the tool dropdown is hidden.\n is_in_tool_mode = build_config[\"tools_metadata\"][\"show\"]\n safe_cache_set(self._shared_component_cache, \"last_selected_server\", current_server_name)\n\n # Check if tools are already cached for this server before clearing\n cached_tools = None\n if current_server_name:\n servers_cache = safe_cache_get(self._shared_component_cache, \"servers\", {})\n if isinstance(servers_cache, dict):\n cached = servers_cache.get(current_server_name)\n if cached is not None:\n cached_tools = cached[\"tools\"]\n self.tools = cached_tools\n self.tool_names = cached[\"tool_names\"]\n self._tool_cache = cached[\"tool_cache\"]\n\n # Only clear tools if we don't have cached tools for the current server\n if not cached_tools:\n self.tools = [] # Clear previous tools only if no cache\n\n self.remove_non_default_keys(build_config) # Clear previous tool inputs\n\n # Only show the tool dropdown if not in tool_mode\n if not is_in_tool_mode:\n build_config[\"tool\"][\"show\"] = True\n if cached_tools:\n # Use cached tools to populate options immediately\n build_config[\"tool\"][\"options\"] = [tool.name for tool in cached_tools]\n build_config[\"tool\"][\"placeholder\"] = \"Select a tool\"\n else:\n # Show loading state only when we need to fetch tools\n build_config[\"tool\"][\"placeholder\"] = \"Loading tools...\"\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = uuid.uuid4()\n else:\n # Keep the tool dropdown hidden if in tool_mode\n self._not_load_actions = True\n build_config[\"tool\"][\"show\"] = False\n\n elif field_name == \"tool_mode\":\n build_config[\"tool\"][\"placeholder\"] = \"\"\n build_config[\"tool\"][\"show\"] = not bool(field_value) and bool(build_config[\"mcp_server\"])\n self.remove_non_default_keys(build_config)\n self.tool = build_config[\"tool\"][\"value\"]\n if field_value:\n self._not_load_actions = True\n else:\n build_config[\"tool\"][\"value\"] = uuid.uuid4()\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"show\"] = True\n build_config[\"tool\"][\"placeholder\"] = \"Loading tools...\"\n elif field_name == \"tools_metadata\":\n self._not_load_actions = False\n\n except Exception as e:\n msg = f\"Error in update_build_config: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n else:\n return build_config\n\n def get_inputs_for_all_tools(self, tools: list) -> dict:\n \"\"\"Get input schemas for all tools.\"\"\"\n inputs = {}\n for tool in tools:\n if not tool or not hasattr(tool, \"name\"):\n continue\n try:\n flat_schema = flatten_schema(tool.args_schema.schema())\n input_schema = create_input_schema_from_json_schema(flat_schema)\n langflow_inputs = schema_to_langflow_inputs(input_schema)\n inputs[tool.name] = langflow_inputs\n except (AttributeError, ValueError, TypeError, KeyError) as e:\n msg = f\"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}\"\n logger.exception(msg)\n continue\n return inputs\n\n def remove_input_schema_from_build_config(\n self, build_config: dict, tool_name: str, input_schema: dict[list[InputTypes], Any]\n ):\n \"\"\"Remove the input schema for the tool from the build config.\"\"\"\n # Keep only schemas that don't belong to the current tool\n input_schema = {k: v for k, v in input_schema.items() if k != tool_name}\n # Remove all inputs from other tools\n for value in input_schema.values():\n for _input in value:\n if _input.name in build_config:\n build_config.pop(_input.name)\n\n def remove_non_default_keys(self, build_config: dict) -> None:\n \"\"\"Remove non-default keys from the build config.\"\"\"\n for key in list(build_config.keys()):\n if key not in self.default_keys:\n build_config.pop(key)\n\n async def _update_tool_config(self, build_config: dict, tool_name: str) -> None:\n \"\"\"Update tool configuration with proper error handling.\"\"\"\n if not self.tools:\n self.tools, build_config[\"mcp_server\"][\"value\"] = await self.update_tool_list()\n\n if not tool_name:\n return\n\n tool_obj = next((tool for tool in self.tools if tool.name == tool_name), None)\n if not tool_obj:\n msg = f\"Tool {tool_name} not found in available tools: {self.tools}\"\n self.remove_non_default_keys(build_config)\n build_config[\"tool\"][\"value\"] = \"\"\n logger.warning(msg)\n return\n\n try:\n # Store current values before removing inputs\n current_values = {}\n for key, value in build_config.items():\n if key not in self.default_keys and isinstance(value, dict) and \"value\" in value:\n current_values[key] = value[\"value\"]\n\n # Get all tool inputs and remove old ones\n input_schema_for_all_tools = self.get_inputs_for_all_tools(self.tools)\n self.remove_input_schema_from_build_config(build_config, tool_name, input_schema_for_all_tools)\n\n # Get and validate new inputs\n self.schema_inputs = await self._validate_schema_inputs(tool_obj)\n if not self.schema_inputs:\n msg = f\"No input parameters to configure for tool '{tool_name}'\"\n logger.info(msg)\n return\n\n # Add new inputs to build config\n for schema_input in self.schema_inputs:\n if not schema_input or not hasattr(schema_input, \"name\"):\n msg = \"Invalid schema input detected, skipping\"\n logger.warning(msg)\n continue\n\n try:\n name = schema_input.name\n input_dict = schema_input.to_dict()\n input_dict.setdefault(\"value\", None)\n input_dict.setdefault(\"required\", True)\n\n build_config[name] = input_dict\n\n # Preserve existing value if the parameter name exists in current_values\n if name in current_values:\n build_config[name][\"value\"] = current_values[name]\n\n except (AttributeError, KeyError, TypeError) as e:\n msg = f\"Error processing schema input {schema_input}: {e!s}\"\n logger.exception(msg)\n continue\n except ValueError as e:\n msg = f\"Schema validation error for tool {tool_name}: {e!s}\"\n logger.exception(msg)\n self.schema_inputs = []\n return\n except (AttributeError, KeyError, TypeError) as e:\n msg = f\"Error updating tool config: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n\n async def build_output(self) -> DataFrame:\n \"\"\"Build output with improved error handling and validation.\"\"\"\n try:\n self.tools, _ = await self.update_tool_list()\n if self.tool != \"\":\n # Set session context for persistent MCP sessions using Langflow session ID\n session_context = self._get_session_context()\n if session_context:\n self.stdio_client.set_session_context(session_context)\n self.sse_client.set_session_context(session_context)\n\n exec_tool = self._tool_cache[self.tool]\n tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool]\n kwargs = {}\n for arg in tool_args:\n value = getattr(self, arg.name, None)\n if value is not None:\n if isinstance(value, Message):\n kwargs[arg.name] = value.text\n else:\n kwargs[arg.name] = value\n\n unflattened_kwargs = maybe_unflatten_dict(kwargs)\n\n output = await exec_tool.coroutine(**unflattened_kwargs)\n\n tool_content = []\n for item in output.content:\n item_dict = item.model_dump()\n tool_content.append(item_dict)\n return DataFrame(data=tool_content)\n return DataFrame(data=[{\"error\": \"You must select a tool\"}])\n except Exception as e:\n msg = f\"Error in build_output: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n\n def _get_session_context(self) -> str | None:\n \"\"\"Get the Langflow session ID for MCP session caching.\"\"\"\n # Try to get session ID from the component's execution context\n if hasattr(self, \"graph\") and hasattr(self.graph, \"session_id\"):\n session_id = self.graph.session_id\n # Include server name to ensure different servers get different sessions\n server_name = \"\"\n mcp_server = getattr(self, \"mcp_server\", None)\n if isinstance(mcp_server, dict):\n server_name = mcp_server.get(\"name\", \"\")\n elif mcp_server:\n server_name = str(mcp_server)\n return f\"{session_id}_{server_name}\" if session_id else None\n return None\n\n async def _get_tools(self):\n \"\"\"Get cached tools or update if necessary.\"\"\"\n mcp_server = getattr(self, \"mcp_server\", None)\n if not self._not_load_actions:\n tools, _ = await self.update_tool_list(mcp_server)\n return tools\n return []\n" + "value": "from __future__ import annotations\n\nimport asyncio\nimport uuid\nfrom typing import Any\n\nfrom langchain_core.tools import StructuredTool # noqa: TC002\n\nfrom langflow.api.v2.mcp import get_server\nfrom langflow.base.agents.utils import maybe_unflatten_dict, safe_cache_get, safe_cache_set\nfrom langflow.base.mcp.util import (\n MCPSseClient,\n MCPStdioClient,\n create_input_schema_from_json_schema,\n update_tools,\n)\nfrom langflow.custom.custom_component.component_with_cache import ComponentWithCache\nfrom langflow.inputs.inputs import InputTypes # noqa: TC001\nfrom langflow.io import DropdownInput, McpInput, MessageTextInput, Output\nfrom langflow.io.schema import flatten_schema, schema_to_langflow_inputs\nfrom langflow.logging import logger\nfrom langflow.schema.dataframe import DataFrame\nfrom langflow.schema.message import Message\n\n# Import get_server from the backend API\nfrom langflow.services.database.models.user.crud import get_user_by_id\nfrom langflow.services.deps import get_settings_service, get_storage_service, session_scope\n\n\nclass MCPToolsComponent(ComponentWithCache):\n schema_inputs: list = []\n tools: list[StructuredTool] = []\n _not_load_actions: bool = False\n _tool_cache: dict = {}\n _last_selected_server: str | None = None # Cache for the last selected server\n\n def __init__(self, **data) -> None:\n super().__init__(**data)\n # Initialize cache keys to avoid CacheMiss when accessing them\n self._ensure_cache_structure()\n\n # Initialize clients with access to the component cache\n self.stdio_client: MCPStdioClient = MCPStdioClient(component_cache=self._shared_component_cache)\n self.sse_client: MCPSseClient = MCPSseClient(component_cache=self._shared_component_cache)\n\n def _ensure_cache_structure(self):\n \"\"\"Ensure the cache has the required structure.\"\"\"\n # Check if servers key exists and is not CacheMiss\n servers_value = safe_cache_get(self._shared_component_cache, \"servers\")\n if servers_value is None:\n safe_cache_set(self._shared_component_cache, \"servers\", {})\n\n # Check if last_selected_server key exists and is not CacheMiss\n last_server_value = safe_cache_get(self._shared_component_cache, \"last_selected_server\")\n if last_server_value is None:\n safe_cache_set(self._shared_component_cache, \"last_selected_server\", \"\")\n\n default_keys: list[str] = [\n \"code\",\n \"_type\",\n \"tool_mode\",\n \"tool_placeholder\",\n \"mcp_server\",\n \"tool\",\n ]\n\n display_name = \"MCP Tools\"\n description = \"Connect to an MCP server to use its tools.\"\n documentation: str = \"https://docs.langflow.org/mcp-client\"\n icon = \"Mcp\"\n name = \"MCPTools\"\n\n inputs = [\n McpInput(\n name=\"mcp_server\",\n display_name=\"MCP Server\",\n info=\"Select the MCP Server that will be used by this component\",\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"tool\",\n display_name=\"Tool\",\n options=[],\n value=\"\",\n info=\"Select the tool to execute\",\n show=False,\n required=True,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"tool_placeholder\",\n display_name=\"Tool Placeholder\",\n info=\"Placeholder for the tool\",\n value=\"\",\n show=False,\n tool_mode=False,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Response\", name=\"response\", method=\"build_output\"),\n ]\n\n async def _validate_schema_inputs(self, tool_obj) -> list[InputTypes]:\n \"\"\"Validate and process schema inputs for a tool.\"\"\"\n try:\n if not tool_obj or not hasattr(tool_obj, \"args_schema\"):\n msg = \"Invalid tool object or missing input schema\"\n raise ValueError(msg)\n\n flat_schema = flatten_schema(tool_obj.args_schema.schema())\n input_schema = create_input_schema_from_json_schema(flat_schema)\n if not input_schema:\n msg = f\"Empty input schema for tool '{tool_obj.name}'\"\n raise ValueError(msg)\n\n schema_inputs = schema_to_langflow_inputs(input_schema)\n if not schema_inputs:\n msg = f\"No input parameters defined for tool '{tool_obj.name}'\"\n logger.warning(msg)\n return []\n\n except Exception as e:\n msg = f\"Error validating schema inputs: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n else:\n return schema_inputs\n\n async def update_tool_list(self, mcp_server_value=None):\n # Accepts mcp_server_value as dict {name, config} or uses self.mcp_server\n mcp_server = mcp_server_value if mcp_server_value is not None else getattr(self, \"mcp_server\", None)\n server_name = None\n server_config_from_value = None\n if isinstance(mcp_server, dict):\n server_name = mcp_server.get(\"name\")\n server_config_from_value = mcp_server.get(\"config\")\n else:\n server_name = mcp_server\n if not server_name:\n self.tools = []\n return [], {\"name\": server_name, \"config\": server_config_from_value}\n\n # Use shared cache if available\n servers_cache = safe_cache_get(self._shared_component_cache, \"servers\", {})\n cached = servers_cache.get(server_name) if isinstance(servers_cache, dict) else None\n\n if cached is not None:\n self.tools = cached[\"tools\"]\n self.tool_names = cached[\"tool_names\"]\n self._tool_cache = cached[\"tool_cache\"]\n server_config_from_value = cached[\"config\"]\n return self.tools, {\"name\": server_name, \"config\": server_config_from_value}\n\n try:\n async with session_scope() as db:\n if not self.user_id:\n msg = \"User ID is required for fetching MCP tools.\"\n raise ValueError(msg)\n current_user = await get_user_by_id(db, self.user_id)\n\n # Try to get server config from DB/API\n server_config = await get_server(\n server_name,\n current_user,\n db,\n storage_service=get_storage_service(),\n settings_service=get_settings_service(),\n )\n\n # If get_server returns empty but we have a config, use it\n if not server_config and server_config_from_value:\n server_config = server_config_from_value\n\n if not server_config:\n self.tools = []\n return [], {\"name\": server_name, \"config\": server_config}\n\n _, tool_list, tool_cache = await update_tools(\n server_name=server_name,\n server_config=server_config,\n mcp_stdio_client=self.stdio_client,\n mcp_sse_client=self.sse_client,\n )\n\n self.tool_names = [tool.name for tool in tool_list if hasattr(tool, \"name\")]\n self._tool_cache = tool_cache\n self.tools = tool_list\n # Cache the result using shared cache\n cache_data = {\n \"tools\": tool_list,\n \"tool_names\": self.tool_names,\n \"tool_cache\": tool_cache,\n \"config\": server_config,\n }\n\n # Safely update the servers cache\n current_servers_cache = safe_cache_get(self._shared_component_cache, \"servers\", {})\n if isinstance(current_servers_cache, dict):\n current_servers_cache[server_name] = cache_data\n safe_cache_set(self._shared_component_cache, \"servers\", current_servers_cache)\n\n except (TimeoutError, asyncio.TimeoutError) as e:\n msg = f\"Timeout updating tool list: {e!s}\"\n logger.exception(msg)\n raise TimeoutError(msg) from e\n except Exception as e:\n msg = f\"Error updating tool list: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n else:\n return tool_list, {\"name\": server_name, \"config\": server_config}\n\n async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:\n \"\"\"Toggle the visibility of connection-specific fields based on the selected mode.\"\"\"\n try:\n if field_name == \"tool\":\n try:\n if len(self.tools) == 0:\n try:\n self.tools, build_config[\"mcp_server\"][\"value\"] = await self.update_tool_list()\n build_config[\"tool\"][\"options\"] = [tool.name for tool in self.tools]\n build_config[\"tool\"][\"placeholder\"] = \"Select a tool\"\n except (TimeoutError, asyncio.TimeoutError) as e:\n msg = f\"Timeout updating tool list: {e!s}\"\n logger.exception(msg)\n if not build_config[\"tools_metadata\"][\"show\"]:\n build_config[\"tool\"][\"show\"] = True\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = \"\"\n build_config[\"tool\"][\"placeholder\"] = \"Timeout on MCP server\"\n else:\n build_config[\"tool\"][\"show\"] = False\n except ValueError:\n if not build_config[\"tools_metadata\"][\"show\"]:\n build_config[\"tool\"][\"show\"] = True\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = \"\"\n build_config[\"tool\"][\"placeholder\"] = \"Error on MCP Server\"\n else:\n build_config[\"tool\"][\"show\"] = False\n\n if field_value == \"\":\n return build_config\n tool_obj = None\n for tool in self.tools:\n if tool.name == field_value:\n tool_obj = tool\n break\n if tool_obj is None:\n msg = f\"Tool {field_value} not found in available tools: {self.tools}\"\n logger.warning(msg)\n return build_config\n await self._update_tool_config(build_config, field_value)\n except Exception as e:\n build_config[\"tool\"][\"options\"] = []\n msg = f\"Failed to update tools: {e!s}\"\n raise ValueError(msg) from e\n else:\n return build_config\n elif field_name == \"mcp_server\":\n if not field_value:\n build_config[\"tool\"][\"show\"] = False\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = \"\"\n build_config[\"tool\"][\"placeholder\"] = \"\"\n build_config[\"tool_placeholder\"][\"tool_mode\"] = False\n self.remove_non_default_keys(build_config)\n return build_config\n\n build_config[\"tool_placeholder\"][\"tool_mode\"] = True\n\n current_server_name = field_value.get(\"name\") if isinstance(field_value, dict) else field_value\n _last_selected_server = safe_cache_get(self._shared_component_cache, \"last_selected_server\", \"\")\n\n # To avoid unnecessary updates, only proceed if the server has actually changed\n if (_last_selected_server in (current_server_name, \"\")) and build_config[\"tool\"][\"show\"]:\n return build_config\n\n # Determine if \"Tool Mode\" is active by checking if the tool dropdown is hidden.\n is_in_tool_mode = build_config[\"tools_metadata\"][\"show\"]\n safe_cache_set(self._shared_component_cache, \"last_selected_server\", current_server_name)\n\n # Check if tools are already cached for this server before clearing\n cached_tools = None\n if current_server_name:\n servers_cache = safe_cache_get(self._shared_component_cache, \"servers\", {})\n if isinstance(servers_cache, dict):\n cached = servers_cache.get(current_server_name)\n if cached is not None:\n cached_tools = cached[\"tools\"]\n self.tools = cached_tools\n self.tool_names = cached[\"tool_names\"]\n self._tool_cache = cached[\"tool_cache\"]\n\n # Only clear tools if we don't have cached tools for the current server\n if not cached_tools:\n self.tools = [] # Clear previous tools only if no cache\n\n self.remove_non_default_keys(build_config) # Clear previous tool inputs\n\n # Only show the tool dropdown if not in tool_mode\n if not is_in_tool_mode:\n build_config[\"tool\"][\"show\"] = True\n if cached_tools:\n # Use cached tools to populate options immediately\n build_config[\"tool\"][\"options\"] = [tool.name for tool in cached_tools]\n build_config[\"tool\"][\"placeholder\"] = \"Select a tool\"\n else:\n # Show loading state only when we need to fetch tools\n build_config[\"tool\"][\"placeholder\"] = \"Loading tools...\"\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"value\"] = uuid.uuid4()\n else:\n # Keep the tool dropdown hidden if in tool_mode\n self._not_load_actions = True\n build_config[\"tool\"][\"show\"] = False\n\n elif field_name == \"tool_mode\":\n build_config[\"tool\"][\"placeholder\"] = \"\"\n build_config[\"tool\"][\"show\"] = not bool(field_value) and bool(build_config[\"mcp_server\"])\n self.remove_non_default_keys(build_config)\n self.tool = build_config[\"tool\"][\"value\"]\n if field_value:\n self._not_load_actions = True\n else:\n build_config[\"tool\"][\"value\"] = uuid.uuid4()\n build_config[\"tool\"][\"options\"] = []\n build_config[\"tool\"][\"show\"] = True\n build_config[\"tool\"][\"placeholder\"] = \"Loading tools...\"\n elif field_name == \"tools_metadata\":\n self._not_load_actions = False\n\n except Exception as e:\n msg = f\"Error in update_build_config: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n else:\n return build_config\n\n def get_inputs_for_all_tools(self, tools: list) -> dict:\n \"\"\"Get input schemas for all tools.\"\"\"\n inputs = {}\n for tool in tools:\n if not tool or not hasattr(tool, \"name\"):\n continue\n try:\n flat_schema = flatten_schema(tool.args_schema.schema())\n input_schema = create_input_schema_from_json_schema(flat_schema)\n langflow_inputs = schema_to_langflow_inputs(input_schema)\n inputs[tool.name] = langflow_inputs\n except (AttributeError, ValueError, TypeError, KeyError) as e:\n msg = f\"Error getting inputs for tool {getattr(tool, 'name', 'unknown')}: {e!s}\"\n logger.exception(msg)\n continue\n return inputs\n\n def remove_input_schema_from_build_config(\n self, build_config: dict, tool_name: str, input_schema: dict[list[InputTypes], Any]\n ):\n \"\"\"Remove the input schema for the tool from the build config.\"\"\"\n # Keep only schemas that don't belong to the current tool\n input_schema = {k: v for k, v in input_schema.items() if k != tool_name}\n # Remove all inputs from other tools\n for value in input_schema.values():\n for _input in value:\n if _input.name in build_config:\n build_config.pop(_input.name)\n\n def remove_non_default_keys(self, build_config: dict) -> None:\n \"\"\"Remove non-default keys from the build config.\"\"\"\n for key in list(build_config.keys()):\n if key not in self.default_keys:\n build_config.pop(key)\n\n async def _update_tool_config(self, build_config: dict, tool_name: str) -> None:\n \"\"\"Update tool configuration with proper error handling.\"\"\"\n if not self.tools:\n self.tools, build_config[\"mcp_server\"][\"value\"] = await self.update_tool_list()\n\n if not tool_name:\n return\n\n tool_obj = next((tool for tool in self.tools if tool.name == tool_name), None)\n if not tool_obj:\n msg = f\"Tool {tool_name} not found in available tools: {self.tools}\"\n self.remove_non_default_keys(build_config)\n build_config[\"tool\"][\"value\"] = \"\"\n logger.warning(msg)\n return\n\n try:\n # Store current values before removing inputs\n current_values = {}\n for key, value in build_config.items():\n if key not in self.default_keys and isinstance(value, dict) and \"value\" in value:\n current_values[key] = value[\"value\"]\n\n # Get all tool inputs and remove old ones\n input_schema_for_all_tools = self.get_inputs_for_all_tools(self.tools)\n self.remove_input_schema_from_build_config(build_config, tool_name, input_schema_for_all_tools)\n\n # Get and validate new inputs\n self.schema_inputs = await self._validate_schema_inputs(tool_obj)\n if not self.schema_inputs:\n msg = f\"No input parameters to configure for tool '{tool_name}'\"\n logger.info(msg)\n return\n\n # Add new inputs to build config\n for schema_input in self.schema_inputs:\n if not schema_input or not hasattr(schema_input, \"name\"):\n msg = \"Invalid schema input detected, skipping\"\n logger.warning(msg)\n continue\n\n try:\n name = schema_input.name\n input_dict = schema_input.to_dict()\n input_dict.setdefault(\"value\", None)\n input_dict.setdefault(\"required\", True)\n\n build_config[name] = input_dict\n\n # Preserve existing value if the parameter name exists in current_values\n if name in current_values:\n build_config[name][\"value\"] = current_values[name]\n\n except (AttributeError, KeyError, TypeError) as e:\n msg = f\"Error processing schema input {schema_input}: {e!s}\"\n logger.exception(msg)\n continue\n except ValueError as e:\n msg = f\"Schema validation error for tool {tool_name}: {e!s}\"\n logger.exception(msg)\n self.schema_inputs = []\n return\n except (AttributeError, KeyError, TypeError) as e:\n msg = f\"Error updating tool config: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n\n async def build_output(self) -> DataFrame:\n \"\"\"Build output with improved error handling and validation.\"\"\"\n try:\n self.tools, _ = await self.update_tool_list()\n if self.tool != \"\":\n # Set session context for persistent MCP sessions using Langflow session ID\n session_context = self._get_session_context()\n if session_context:\n self.stdio_client.set_session_context(session_context)\n self.sse_client.set_session_context(session_context)\n\n exec_tool = self._tool_cache[self.tool]\n tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool]\n kwargs = {}\n for arg in tool_args:\n value = getattr(self, arg.name, None)\n if value is not None:\n if isinstance(value, Message):\n kwargs[arg.name] = value.text\n else:\n kwargs[arg.name] = value\n\n unflattened_kwargs = maybe_unflatten_dict(kwargs)\n\n output = await exec_tool.coroutine(**unflattened_kwargs)\n\n tool_content = []\n for item in output.content:\n item_dict = item.model_dump()\n tool_content.append(item_dict)\n return DataFrame(data=tool_content)\n return DataFrame(data=[{\"error\": \"You must select a tool\"}])\n except Exception as e:\n msg = f\"Error in build_output: {e!s}\"\n logger.exception(msg)\n raise ValueError(msg) from e\n\n def _get_session_context(self) -> str | None:\n \"\"\"Get the Langflow session ID for MCP session caching.\"\"\"\n # Try to get session ID from the component's execution context\n if hasattr(self, \"graph\") and hasattr(self.graph, \"session_id\"):\n session_id = self.graph.session_id\n # Include server name to ensure different servers get different sessions\n server_name = \"\"\n mcp_server = getattr(self, \"mcp_server\", None)\n if isinstance(mcp_server, dict):\n server_name = mcp_server.get(\"name\", \"\")\n elif mcp_server:\n server_name = str(mcp_server)\n return f\"{session_id}_{server_name}\" if session_id else None\n return None\n\n async def _get_tools(self):\n \"\"\"Get cached tools or update if necessary.\"\"\"\n mcp_server = getattr(self, \"mcp_server\", None)\n if not self._not_load_actions:\n tools, _ = await self.update_tool_list(mcp_server)\n return tools\n return []\n" }, "mcp_server": { "_input_type": "McpInput", diff --git a/src/backend/tests/unit/components/data/test_kb_ingest.py b/src/backend/tests/unit/components/data/test_kb_ingest.py index 93be72a99..ad9c2afe3 100644 --- a/src/backend/tests/unit/components/data/test_kb_ingest.py +++ b/src/backend/tests/unit/components/data/test_kb_ingest.py @@ -1,8 +1,10 @@ import json -from unittest.mock import MagicMock, patch +import uuid +from unittest.mock import AsyncMock, MagicMock, patch import pandas as pd import pytest +from langflow.base.data.kb_utils import get_knowledge_bases from langflow.components.data.kb_ingest import KBIngestionComponent from langflow.schema.data import Data @@ -21,8 +23,43 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): with patch("langflow.components.data.kb_ingest.KNOWLEDGE_BASES_ROOT_PATH", tmp_path): yield + class MockUser: + def __init__(self, user_id): + self.id = user_id + self.username = "langflow" + @pytest.fixture - def default_kwargs(self, tmp_path): + def mock_user_data(self): + """Create mock user data that persists for the test function.""" + mock_uuid = uuid.uuid4() + mock_user = self.MockUser(mock_uuid) + return {"user_id": mock_uuid, "user": mock_user.username, "user_obj": mock_user} + + @pytest.fixture(autouse=True) + def setup_mocks(self, mock_user_data): + """Mock the component's user_id attribute and User object.""" + with ( + patch.object(KBIngestionComponent, "user_id", mock_user_data["user_id"]), + patch( + "langflow.components.data.kb_ingest.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user_data["user_obj"], + ), + patch( + "langflow.base.data.kb_utils.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user_data["user_obj"], + ), + ): + yield + + @pytest.fixture + def mock_user_id(self, mock_user_data): + """Get the mock user data.""" + return {"user_id": mock_user_data["user_id"], "user": mock_user_data["user"]} + + @pytest.fixture + def default_kwargs(self, tmp_path, mock_user_id): """Return default kwargs for component instantiation.""" # Create a sample DataFrame data_df = pd.DataFrame( @@ -38,8 +75,8 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): # Create knowledge base directory kb_name = "test_kb" - kb_path = tmp_path / kb_name - kb_path.mkdir(exist_ok=True) + kb_path = tmp_path / mock_user_id["user"] / kb_name + kb_path.mkdir(parents=True, exist_ok=True) # Create embedding metadata file metadata = { @@ -206,7 +243,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): assert "text" in metadata["summary"]["vectorized_columns"] assert "category" in metadata["summary"]["identifier_columns"] - def test_convert_df_to_data_objects(self, component_class, default_kwargs): + async def test_convert_df_to_data_objects(self, component_class, default_kwargs): """Test converting DataFrame to Data objects.""" component = component_class(**default_kwargs) data_df = default_kwargs["input_df"] @@ -218,7 +255,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): mock_chroma_instance.get.return_value = {"metadatas": []} mock_chroma.return_value = mock_chroma_instance - data_objects = component._convert_df_to_data_objects(data_df, config_list) + data_objects = await component._convert_df_to_data_objects(data_df, config_list) assert len(data_objects) == 2 assert all(isinstance(obj, Data) for obj in data_objects) @@ -230,7 +267,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): assert "category" in first_obj.data assert "_id" in first_obj.data - def test_convert_df_to_data_objects_no_duplicates(self, component_class, default_kwargs): + async def test_convert_df_to_data_objects_no_duplicates(self, component_class, default_kwargs): """Test converting DataFrame to Data objects with duplicate prevention.""" default_kwargs["allow_duplicates"] = False component = component_class(**default_kwargs) @@ -251,7 +288,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): mock_hash_obj.hexdigest.side_effect = [existing_hash, "different_hash"] mock_hash.return_value = mock_hash_obj - data_objects = component._convert_df_to_data_objects(data_df, config_list) + data_objects = await component._convert_df_to_data_objects(data_df, config_list) # Should only return one object (second row) since first is duplicate assert len(data_objects) == 1 @@ -274,7 +311,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): @patch("langflow.components.data.kb_ingest.json.loads") @patch("langflow.components.data.kb_ingest.decrypt_api_key") - def test_build_kb_info_success(self, mock_decrypt, mock_json_loads, component_class, default_kwargs): + async def test_build_kb_info_success(self, mock_decrypt, mock_json_loads, component_class, default_kwargs): """Test successful KB info building.""" component = component_class(**default_kwargs) @@ -287,7 +324,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): # Mock vector store creation with patch.object(component, "_create_vector_store"), patch.object(component, "_save_kb_files"): - result = component.build_kb_info() + result = await component.build_kb_info() assert isinstance(result, Data) assert "kb_id" in result.data @@ -295,32 +332,21 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): assert "rows" in result.data assert result.data["rows"] == 2 - def test_get_knowledge_bases(self, component_class, default_kwargs, tmp_path): + async def test_get_knowledge_bases(self, tmp_path, mock_user_id): """Test getting list of knowledge bases.""" - component = component_class(**default_kwargs) - # Create additional test directories - (tmp_path / "kb1").mkdir() - (tmp_path / "kb2").mkdir() - (tmp_path / ".hidden").mkdir() # Should be ignored + (tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True) + (tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True) + (tmp_path / mock_user_id["user"] / ".hidden").mkdir(parents=True, exist_ok=True) # Should be ignored - kb_list = component._get_knowledge_bases() + kb_list = await get_knowledge_bases(tmp_path, user_id=mock_user_id["user_id"]) assert "test_kb" in kb_list assert "kb1" in kb_list assert "kb2" in kb_list assert ".hidden" not in kb_list - @patch("langflow.components.data.kb_ingest.Path.exists") - def test_get_knowledge_bases_no_path(self, mock_exists, component_class, default_kwargs): - """Test getting knowledge bases when path doesn't exist.""" - component = component_class(**default_kwargs) - mock_exists.return_value = False - - kb_list = component._get_knowledge_bases() - assert kb_list == [] - - def test_update_build_config_new_kb(self, component_class, default_kwargs): + async def test_update_build_config_new_kb(self, component_class, default_kwargs): """Test updating build config for new knowledge base creation.""" component = component_class(**default_kwargs) @@ -329,26 +355,24 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): field_value = { "01_new_kb_name": "new_test_kb", "02_embedding_model": "sentence-transformers/all-MiniLM-L6-v2", - "03_api_key": None, + "03_api_key": "abc123", # Mock API key } # Mock embedding validation with ( patch.object(component, "_build_embeddings") as mock_build_emb, patch.object(component, "_save_embedding_metadata"), - patch.object(component, "_get_knowledge_bases") as mock_get_kbs, ): mock_embeddings = MagicMock() mock_embeddings.embed_query.return_value = [0.1, 0.2, 0.3] mock_build_emb.return_value = mock_embeddings - mock_get_kbs.return_value = ["new_test_kb"] - result = component.update_build_config(build_config, field_value, "knowledge_base") + result = await component.update_build_config(build_config, field_value, "knowledge_base") assert result["knowledge_base"]["value"] == "new_test_kb" assert "new_test_kb" in result["knowledge_base"]["options"] - def test_update_build_config_invalid_kb_name(self, component_class, default_kwargs): + async def test_update_build_config_invalid_kb_name(self, component_class, default_kwargs): """Test updating build config with invalid KB name.""" component = component_class(**default_kwargs) @@ -360,4 +384,4 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient): } with pytest.raises(ValueError, match="Invalid knowledge base name"): - component.update_build_config(build_config, field_value, "knowledge_base") + await component.update_build_config(build_config, field_value, "knowledge_base") diff --git a/src/backend/tests/unit/components/data/test_kb_retrieval.py b/src/backend/tests/unit/components/data/test_kb_retrieval.py index ee72c7840..d40d2070a 100644 --- a/src/backend/tests/unit/components/data/test_kb_retrieval.py +++ b/src/backend/tests/unit/components/data/test_kb_retrieval.py @@ -1,10 +1,13 @@ import contextlib import json +import uuid from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from langflow.base.data.kb_utils import get_knowledge_bases from langflow.components.data.kb_retrieval import KBRetrievalComponent +from pydantic import SecretStr from tests.base import ComponentTestBaseWithoutClient @@ -21,13 +24,48 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): with patch("langflow.components.data.kb_retrieval.KNOWLEDGE_BASES_ROOT_PATH", tmp_path): yield + class MockUser: + def __init__(self, user_id): + self.id = user_id + self.username = "langflow" + @pytest.fixture - def default_kwargs(self, tmp_path): + def mock_user_data(self): + """Create mock user data that persists for the test function.""" + mock_uuid = uuid.uuid4() + mock_user = self.MockUser(mock_uuid) + return {"user_id": mock_uuid, "user": mock_user.username, "user_obj": mock_user} + + @pytest.fixture(autouse=True) + def setup_mocks(self, mock_user_data): + """Mock the component's user_id attribute and User object.""" + with ( + patch.object(KBRetrievalComponent, "user_id", mock_user_data["user_id"]), + patch( + "langflow.components.data.kb_retrieval.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user_data["user_obj"], + ), + patch( + "langflow.base.data.kb_utils.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user_data["user_obj"], + ), + ): + yield + + @pytest.fixture + def mock_user_id(self, mock_user_data): + """Get the mock user data.""" + return {"user_id": mock_user_data["user_id"], "user": mock_user_data["user"]} + + @pytest.fixture + def default_kwargs(self, tmp_path, mock_user_id): """Return default kwargs for component instantiation.""" # Create knowledge base directory structure kb_name = "test_kb" - kb_path = tmp_path / kb_name - kb_path.mkdir(exist_ok=True) + kb_path = tmp_path / mock_user_id["user"] / kb_name + kb_path.mkdir(parents=True, exist_ok=True) # Create embedding metadata file metadata = { @@ -55,61 +93,50 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): # This is a new component, so it doesn't exist in older versions return [] - def test_get_knowledge_bases(self, component_class, default_kwargs, tmp_path): + async def test_get_knowledge_bases(self, tmp_path, mock_user_id): """Test getting list of knowledge bases.""" - component = component_class(**default_kwargs) - # Create additional test directories - (tmp_path / "kb1").mkdir() - (tmp_path / "kb2").mkdir() - (tmp_path / ".hidden").mkdir() # Should be ignored + (tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True) + (tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True) + (tmp_path / mock_user_id["user"] / ".hidden").mkdir(parents=True, exist_ok=True) # Should be ignored - kb_list = component._get_knowledge_bases() + kb_list = await get_knowledge_bases(tmp_path, user_id=mock_user_id["user_id"]) assert "test_kb" in kb_list assert "kb1" in kb_list assert "kb2" in kb_list assert ".hidden" not in kb_list - @patch("langflow.components.data.kb_retrieval.Path.exists") - def test_get_knowledge_bases_no_path(self, mock_exists, component_class, default_kwargs): - """Test getting knowledge bases when path doesn't exist.""" - component = component_class(**default_kwargs) - mock_exists.return_value = False - - kb_list = component._get_knowledge_bases() - assert kb_list == [] - - def test_update_build_config(self, component_class, default_kwargs, tmp_path): + async def test_update_build_config(self, component_class, default_kwargs, tmp_path, mock_user_id): """Test updating build configuration.""" component = component_class(**default_kwargs) # Create additional KB directories - (tmp_path / "kb1").mkdir() - (tmp_path / "kb2").mkdir() + (tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True) + (tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True) build_config = {"knowledge_base": {"value": "test_kb", "options": []}} - result = component.update_build_config(build_config, None, "knowledge_base") + result = await component.update_build_config(build_config, None, "knowledge_base") assert "test_kb" in result["knowledge_base"]["options"] assert "kb1" in result["knowledge_base"]["options"] assert "kb2" in result["knowledge_base"]["options"] - def test_update_build_config_invalid_kb(self, component_class, default_kwargs): + async def test_update_build_config_invalid_kb(self, component_class, default_kwargs): """Test updating build config when selected KB is not available.""" component = component_class(**default_kwargs) build_config = {"knowledge_base": {"value": "nonexistent_kb", "options": ["test_kb"]}} - result = component.update_build_config(build_config, None, "knowledge_base") + result = await component.update_build_config(build_config, None, "knowledge_base") assert result["knowledge_base"]["value"] is None - def test_get_kb_metadata_success(self, component_class, default_kwargs): + def test_get_kb_metadata_success(self, component_class, default_kwargs, mock_user_id): """Test successful metadata loading.""" component = component_class(**default_kwargs) - kb_path = Path(default_kwargs["kb_root_path"]) / default_kwargs["knowledge_base"] + kb_path = Path(default_kwargs["kb_root_path"]) / mock_user_id["user"] / default_kwargs["knowledge_base"] with patch("langflow.components.data.kb_retrieval.decrypt_api_key") as mock_decrypt: mock_decrypt.return_value = "decrypted_key" @@ -120,21 +147,21 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): assert metadata["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2" assert "chunk_size" in metadata - def test_get_kb_metadata_no_file(self, component_class, default_kwargs, tmp_path): + def test_get_kb_metadata_no_file(self, component_class, default_kwargs, tmp_path, mock_user_id): """Test metadata loading when file doesn't exist.""" component = component_class(**default_kwargs) - nonexistent_path = tmp_path / "nonexistent" - nonexistent_path.mkdir() + nonexistent_path = tmp_path / mock_user_id["user"] / "nonexistent" + nonexistent_path.mkdir(parents=True, exist_ok=True) metadata = component._get_kb_metadata(nonexistent_path) assert metadata == {} - def test_get_kb_metadata_json_error(self, component_class, default_kwargs, tmp_path): + def test_get_kb_metadata_json_error(self, component_class, default_kwargs, tmp_path, mock_user_id): """Test metadata loading with invalid JSON.""" component = component_class(**default_kwargs) - kb_path = tmp_path / "invalid_json_kb" - kb_path.mkdir() + kb_path = tmp_path / mock_user_id["user"] / "invalid_json_kb" + kb_path.mkdir(parents=True, exist_ok=True) # Create invalid JSON file (kb_path / "embedding_metadata.json").write_text("invalid json content") @@ -143,11 +170,11 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): assert metadata == {} - def test_get_kb_metadata_decrypt_error(self, component_class, default_kwargs, tmp_path): + def test_get_kb_metadata_decrypt_error(self, component_class, default_kwargs, tmp_path, mock_user_id): """Test metadata loading with decryption error.""" component = component_class(**default_kwargs) - kb_path = tmp_path / "decrypt_error_kb" - kb_path.mkdir() + kb_path = tmp_path / mock_user_id["user"] / "decrypt_error_kb" + kb_path.mkdir(parents=True, exist_ok=True) # Create metadata with encrypted key metadata = { @@ -274,10 +301,8 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): def test_build_embeddings_with_user_api_key(self, component_class, default_kwargs): """Test that user-provided API key overrides stored one.""" - # Create a mock secret input - - mock_secret = MagicMock() - mock_secret.get_secret_value.return_value = "user-provided-key" + # Use a real SecretStr object instead of a mock + mock_secret = SecretStr("user-provided-key") default_kwargs["api_key"] = mock_secret component = component_class(**default_kwargs) @@ -285,7 +310,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): metadata = { "embedding_provider": "OpenAI", "embedding_model": "text-embedding-ada-002", - "api_key": "stored-key", + "api_key": "stored-key", # This should be overridden by the user-provided key "chunk_size": 1000, } @@ -295,14 +320,17 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): component._build_embeddings(metadata) + # The user-provided key should override the stored key in metadata mock_openai.assert_called_once_with( - model="text-embedding-ada-002", api_key="user-provided-key", chunk_size=1000 + model="text-embedding-ada-002", + api_key="user-provided-key", # Should use the user-provided key, not "stored-key" + chunk_size=1000, ) - def test_get_chroma_kb_data_no_metadata(self, component_class, default_kwargs, tmp_path): + async def test_get_chroma_kb_data_no_metadata(self, component_class, default_kwargs, tmp_path, mock_user_id): """Test retrieving data when metadata is missing.""" # Remove metadata file - kb_path = tmp_path / default_kwargs["knowledge_base"] + kb_path = tmp_path / mock_user_id["user"] / default_kwargs["knowledge_base"] metadata_file = kb_path / "embedding_metadata.json" if metadata_file.exists(): metadata_file.unlink() @@ -310,7 +338,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): component = component_class(**default_kwargs) with pytest.raises(ValueError, match="Metadata not found for knowledge base"): - component.get_chroma_kb_data() + await component.get_chroma_kb_data() def test_get_chroma_kb_data_path_construction(self, component_class, default_kwargs): """Test that get_chroma_kb_data constructs the correct paths.""" @@ -331,7 +359,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): assert hasattr(component, "top_k") assert hasattr(component, "include_embeddings") - def test_get_chroma_kb_data_method_exists(self, component_class, default_kwargs): + async def test_get_chroma_kb_data_method_exists(self, component_class, default_kwargs): """Test that get_chroma_kb_data method exists and can be called.""" component = component_class(**default_kwargs) @@ -349,7 +377,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient): # This is a unit test focused on the component's internal logic with contextlib.suppress(Exception): - component.get_chroma_kb_data() + await component.get_chroma_kb_data() # Verify internal methods were called mock_get_metadata.assert_called_once()