feat: Make knowledge bases user-stored and support global vars (#9458)

* feat: Make knowledge bases user-stored

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Fix ruff error

* [autofix.ci] apply automated fixes

* Reuse code

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Don't show options by default

* [autofix.ci] apply automated fixes

* Pass in the Langflow API key if set

* [autofix.ci] apply automated fixes

* Update files.py

* [autofix.ci] apply automated fixes

* Properly handle secret retrieval

* [autofix.ci] apply automated fixes

* Update src/backend/base/langflow/base/data/kb_utils.py

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* Update src/backend/base/langflow/base/data/kb_utils.py

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* Update src/backend/base/langflow/components/data/kb_ingest.py

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Feedback from review

* [autofix.ci] apply automated fixes

* Fix other uses of incorrect user

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Feedback from review 2

* [autofix.ci] apply automated fixes

* Update kb_ingest.py

* [autofix.ci] apply automated fixes

* Update tests

* [autofix.ci] apply automated fixes

* Update kb_ingest.py

* [autofix.ci] apply automated fixes

* Fix mypy issues

* [autofix.ci] apply automated fixes

* Update kb_utils.py

* Update test_kb_ingest.py

* Fix tests

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Eric Hare 2025-08-21 16:30:54 -07:00 committed by GitHub
commit 59937ee9e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 375 additions and 285 deletions

View file

@ -9,6 +9,7 @@ from langchain_chroma import Chroma
from loguru import logger
from pydantic import BaseModel
from langflow.api.utils import CurrentActiveUser
from langflow.services.deps import get_settings_service
router = APIRouter(tags=["Knowledge Bases"], prefix="/knowledge_bases")
@ -290,17 +291,19 @@ def get_kb_metadata(kb_path: Path) -> dict:
@router.get("", status_code=HTTPStatus.OK)
@router.get("/", status_code=HTTPStatus.OK)
async def list_knowledge_bases() -> list[KnowledgeBaseInfo]:
async def list_knowledge_bases(current_user: CurrentActiveUser) -> list[KnowledgeBaseInfo]:
"""List all available knowledge bases."""
try:
kb_root_path = get_kb_root_path()
kb_user = current_user.username
kb_path = kb_root_path / kb_user
if not kb_root_path.exists():
if not kb_path.exists():
return []
knowledge_bases = []
for kb_dir in kb_root_path.iterdir():
for kb_dir in kb_path.iterdir():
if not kb_dir.is_dir() or kb_dir.name.startswith("."):
continue
@ -340,11 +343,12 @@ async def list_knowledge_bases() -> list[KnowledgeBaseInfo]:
@router.get("/{kb_name}", status_code=HTTPStatus.OK)
async def get_knowledge_base(kb_name: str) -> KnowledgeBaseInfo:
async def get_knowledge_base(kb_name: str, current_user: CurrentActiveUser) -> KnowledgeBaseInfo:
"""Get detailed information about a specific knowledge base."""
try:
kb_root_path = get_kb_root_path()
kb_path = kb_root_path / kb_name
kb_user = current_user.username
kb_path = kb_root_path / kb_user / kb_name
if not kb_path.exists() or not kb_path.is_dir():
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
@ -374,11 +378,12 @@ async def get_knowledge_base(kb_name: str) -> KnowledgeBaseInfo:
@router.delete("/{kb_name}", status_code=HTTPStatus.OK)
async def delete_knowledge_base(kb_name: str) -> dict[str, str]:
async def delete_knowledge_base(kb_name: str, current_user: CurrentActiveUser) -> dict[str, str]:
"""Delete a specific knowledge base."""
try:
kb_root_path = get_kb_root_path()
kb_path = kb_root_path / kb_name
kb_user = current_user.username
kb_path = kb_root_path / kb_user / kb_name
if not kb_path.exists() or not kb_path.is_dir():
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
@ -396,15 +401,17 @@ async def delete_knowledge_base(kb_name: str) -> dict[str, str]:
@router.delete("", status_code=HTTPStatus.OK)
@router.delete("/", status_code=HTTPStatus.OK)
async def delete_knowledge_bases_bulk(request: BulkDeleteRequest) -> dict[str, object]:
async def delete_knowledge_bases_bulk(request: BulkDeleteRequest, current_user: CurrentActiveUser) -> dict[str, object]:
"""Delete multiple knowledge bases."""
try:
kb_root_path = get_kb_root_path()
kb_user = current_user.username
kb_user_path = kb_root_path / kb_user
deleted_count = 0
not_found_kbs = []
for kb_name in request.kb_names:
kb_path = kb_root_path / kb_name
kb_path = kb_user_path / kb_name
if not kb_path.exists() or not kb_path.is_dir():
not_found_kbs.append(kb_name)

View file

@ -123,7 +123,9 @@ async def upload_user_file(
unique_filename = new_filename
else:
# For normal files, ensure unique name by appending a count if necessary
stmt = select(UserFile).where(col(UserFile.name).like(f"{root_filename}%"))
stmt = select(UserFile).where(
col(UserFile.name).like(f"{root_filename}%"), UserFile.user_id == current_user.id
)
existing_files = await session.exec(stmt)
files = existing_files.all() # Fetch all matching records

View file

@ -1,5 +1,10 @@
import math
from collections import Counter
from pathlib import Path
from uuid import UUID
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import session_scope
def compute_tfidf(documents: list[str], query_terms: list[str]) -> list[float]:
@ -102,3 +107,31 @@ def compute_bm25(documents: list[str], query_terms: list[str], k1: float = 1.2,
scores.append(doc_score)
return scores
async def get_knowledge_bases(kb_root: Path, user_id: UUID | str) -> list[str]:
"""Retrieve a list of available knowledge bases.
Returns:
A list of knowledge base names.
"""
if not kb_root.exists():
return []
# Get the current user
async with session_scope() as db:
if not user_id:
msg = "User ID is required for fetching knowledge bases."
raise ValueError(msg)
user_id = UUID(user_id) if isinstance(user_id, str) else user_id
current_user = await get_user_by_id(db, user_id)
if not current_user:
msg = f"User with ID {user_id} not found."
raise ValueError(msg)
kb_user = current_user.username
kb_path = kb_root / kb_user
if not kb_path.exists():
return []
return [str(d.name) for d in kb_path.iterdir() if not d.name.startswith(".") and d.is_dir()]

View file

@ -16,16 +16,15 @@ from langflow.base.mcp.util import (
)
from langflow.custom.custom_component.component_with_cache import ComponentWithCache
from langflow.inputs.inputs import InputTypes # noqa: TC001
from langflow.io import DropdownInput, McpInput, MessageTextInput, Output, SecretStrInput
from langflow.io import DropdownInput, McpInput, MessageTextInput, Output
from langflow.io.schema import flatten_schema, schema_to_langflow_inputs
from langflow.logging import logger
from langflow.schema.dataframe import DataFrame
from langflow.schema.message import Message
# Import get_server from the backend API
from langflow.services.auth.utils import create_user_longterm_token, get_current_user
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import get_session, get_settings_service, get_storage_service
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
class MCPToolsComponent(ComponentWithCache):
@ -96,13 +95,6 @@ class MCPToolsComponent(ComponentWithCache):
show=False,
tool_mode=False,
),
SecretStrInput(
name="api_key",
display_name="Langflow API Key",
info="Langflow API key for authentication when fetching MCP servers and tools.",
required=False,
advanced=True,
),
]
outputs = [
@ -161,19 +153,11 @@ class MCPToolsComponent(ComponentWithCache):
return self.tools, {"name": server_name, "config": server_config_from_value}
try:
async for db in get_session():
# TODO: In 1.6, this may need to be removed or adjusted
# Try to get the super user token, if possible
if self.api_key:
current_user = await get_current_user(
token=None,
query_param=self.api_key,
header_param=None,
db=db,
)
else:
user_id, _ = await create_user_longterm_token(db)
current_user = await get_user_by_id(db, user_id)
async with session_scope() as db:
if not self.user_id:
msg = "User ID is required for fetching MCP tools."
raise ValueError(msg)
current_user = await get_user_by_id(db, self.user_id)
# Try to get server config from DB/API
server_config = await get_server(
@ -184,39 +168,38 @@ class MCPToolsComponent(ComponentWithCache):
settings_service=get_settings_service(),
)
# If get_server returns empty but we have a config, use it
if not server_config and server_config_from_value:
server_config = server_config_from_value
# If get_server returns empty but we have a config, use it
if not server_config and server_config_from_value:
server_config = server_config_from_value
if not server_config:
self.tools = []
return [], {"name": server_name, "config": server_config}
if not server_config:
self.tools = []
return [], {"name": server_name, "config": server_config}
_, tool_list, tool_cache = await update_tools(
server_name=server_name,
server_config=server_config,
mcp_stdio_client=self.stdio_client,
mcp_sse_client=self.sse_client,
)
_, tool_list, tool_cache = await update_tools(
server_name=server_name,
server_config=server_config,
mcp_stdio_client=self.stdio_client,
mcp_sse_client=self.sse_client,
)
self.tool_names = [tool.name for tool in tool_list if hasattr(tool, "name")]
self._tool_cache = tool_cache
self.tools = tool_list
# Cache the result using shared cache
cache_data = {
"tools": tool_list,
"tool_names": self.tool_names,
"tool_cache": tool_cache,
"config": server_config,
}
self.tool_names = [tool.name for tool in tool_list if hasattr(tool, "name")]
self._tool_cache = tool_cache
self.tools = tool_list
# Cache the result using shared cache
cache_data = {
"tools": tool_list,
"tool_names": self.tool_names,
"tool_cache": tool_cache,
"config": server_config,
}
# Safely update the servers cache
current_servers_cache = safe_cache_get(self._shared_component_cache, "servers", {})
if isinstance(current_servers_cache, dict):
current_servers_cache[server_name] = cache_data
safe_cache_set(self._shared_component_cache, "servers", current_servers_cache)
# Safely update the servers cache
current_servers_cache = safe_cache_get(self._shared_component_cache, "servers", {})
if isinstance(current_servers_cache, dict):
current_servers_cache[server_name] = cache_data
safe_cache_set(self._shared_component_cache, "servers", current_servers_cache)
return tool_list, {"name": server_name, "config": server_config}
except (TimeoutError, asyncio.TimeoutError) as e:
msg = f"Timeout updating tool list: {e!s}"
logger.exception(msg)
@ -225,6 +208,8 @@ class MCPToolsComponent(ComponentWithCache):
msg = f"Error updating tool list: {e!s}"
logger.exception(msg)
raise ValueError(msg) from e
else:
return tool_list, {"name": server_name, "config": server_config}
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
"""Toggle the visibility of connection-specific fields based on the selected mode."""

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import asyncio
import contextlib
import hashlib
import json
import re
@ -14,6 +16,7 @@ from cryptography.fernet import InvalidToken
from langchain_chroma import Chroma
from loguru import logger
from langflow.base.data.kb_utils import get_knowledge_bases
from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES
from langflow.custom import Component
from langflow.io import BoolInput, DataFrameInput, DropdownInput, IntInput, Output, SecretStrInput, StrInput, TableInput
@ -21,7 +24,8 @@ from langflow.schema.data import Data
from langflow.schema.dotdict import dotdict # noqa: TC001
from langflow.schema.table import EditMode
from langflow.services.auth.utils import decrypt_api_key, encrypt_api_key
from langflow.services.deps import get_settings_service
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import get_settings_service, get_variable_service, session_scope
HUGGINGFACE_MODEL_NAMES = ["sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"]
COHERE_MODEL_NAMES = ["embed-english-v3.0", "embed-multilingual-v3.0"]
@ -43,6 +47,10 @@ class KBIngestionComponent(Component):
icon = "database"
name = "KBIngestion"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._cached_kb_path: Path | None = None
@dataclass
class NewKnowledgeBaseInput:
functionality: str = "create"
@ -76,7 +84,7 @@ class KBIngestionComponent(Component):
display_name="API Key",
info="Provider API key for embedding model",
required=True,
load_from_db=True,
load_from_db=False,
),
},
},
@ -91,11 +99,7 @@ class KBIngestionComponent(Component):
display_name="Knowledge",
info="Select the knowledge to load data from.",
required=True,
options=[
str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()
]
if KNOWLEDGE_BASES_ROOT_PATH.exists()
else [],
options=[],
refresh_button=True,
dialog_inputs=asdict(NewKnowledgeBaseInput()),
),
@ -329,22 +333,23 @@ class KBIngestionComponent(Component):
return metadata
def _create_vector_store(
async def _create_vector_store(
self, df_source: pd.DataFrame, config_list: list[dict[str, Any]], embedding_model: str, api_key: str
) -> None:
"""Create vector store following Local DB component pattern."""
try:
# Set up vector store directory
base_dir = self._get_kb_root()
vector_store_dir = base_dir / self.knowledge_base
vector_store_dir = await self._kb_path()
if not vector_store_dir:
msg = "Knowledge base path is not set. Please create a new knowledge base first."
raise ValueError(msg)
vector_store_dir.mkdir(parents=True, exist_ok=True)
# Create embeddings model
embedding_function = self._build_embeddings(embedding_model, api_key)
# Convert DataFrame to Data objects (following Local DB pattern)
data_objects = self._convert_df_to_data_objects(df_source, config_list)
data_objects = await self._convert_df_to_data_objects(df_source, config_list)
# Create vector store
chroma = Chroma(
@ -367,16 +372,18 @@ class KBIngestionComponent(Component):
except (OSError, ValueError, RuntimeError) as e:
self.log(f"Error creating vector store: {e}")
def _convert_df_to_data_objects(self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]) -> list[Data]:
async def _convert_df_to_data_objects(
self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]
) -> list[Data]:
"""Convert DataFrame to Data objects for vector store."""
data_objects: list[Data] = []
# Set up vector store directory
base_dir = self._get_kb_root()
kb_path = await self._kb_path()
# If we don't allow duplicates, we need to get the existing hashes
chroma = Chroma(
persist_directory=str(base_dir / self.knowledge_base),
persist_directory=str(kb_path),
collection_name=self.knowledge_base,
)
@ -466,10 +473,34 @@ class KBIngestionComponent(Component):
# Check allowed characters (condition 3)
return re.match(r"^[a-zA-Z0-9_-]+$", name) is not None
async def _kb_path(self) -> Path | None:
# Check if we already have the path cached
cached_path = getattr(self, "_cached_kb_path", None)
if cached_path is not None:
return cached_path
# If not cached, compute it
async with session_scope() as db:
if not self.user_id:
msg = "User ID is required for fetching knowledge base path."
raise ValueError(msg)
current_user = await get_user_by_id(db, self.user_id)
if not current_user:
msg = f"User with ID {self.user_id} not found."
raise ValueError(msg)
kb_user = current_user.username
kb_root = self._get_kb_root()
# Cache the result
self._cached_kb_path = kb_root / kb_user / self.knowledge_base
return self._cached_kb_path
# ---------------------------------------------------------------------
# OUTPUT METHODS
# ---------------------------------------------------------------------
def build_kb_info(self) -> Data:
async def build_kb_info(self) -> Data:
"""Main ingestion routine → returns a dict with KB metadata."""
try:
# Get source DataFrame
@ -479,11 +510,11 @@ class KBIngestionComponent(Component):
config_list = self._validate_column_config(df_source)
column_metadata = self._build_column_metadata(config_list, df_source)
# Prepare KB folder (using File Component patterns)
kb_root = self._get_kb_root()
kb_path = kb_root / self.knowledge_base
# Read the embedding info from the knowledge base folder
kb_path = await self._kb_path()
if not kb_path:
msg = "Knowledge base path is not set. Please create a new knowledge base first."
raise ValueError(msg)
metadata_path = kb_path / "embedding_metadata.json"
# If the API key is not provided, try to read it from the metadata file
@ -506,7 +537,7 @@ class KBIngestionComponent(Component):
)
# Create vector store following Local DB component pattern
self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key)
await self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key)
# Save KB files (using File Component storage patterns)
self._save_kb_files(kb_path, config_list)
@ -532,40 +563,77 @@ class KBIngestionComponent(Component):
self.status = f"❌ KB ingestion failed: {e}"
return Data(data={"error": str(e), "kb_name": self.knowledge_base})
def _get_knowledge_bases(self) -> list[str]:
"""Retrieve a list of available knowledge bases.
async def _get_api_key_variable(self, field_value: dict[str, Any]):
async with session_scope() as db:
if not self.user_id:
msg = "User ID is required for fetching global variables."
raise ValueError(msg)
current_user = await get_user_by_id(db, self.user_id)
if not current_user:
msg = f"User with ID {self.user_id} not found."
raise ValueError(msg)
variable_service = get_variable_service()
Returns:
A list of knowledge base names.
"""
# Return the list of directories in the knowledge base root path
kb_root_path = self._get_kb_root()
# Process the api_key field variable
return await variable_service.get_variable(
user_id=current_user.id,
name=field_value["03_api_key"],
field="",
session=db,
)
if not kb_root_path.exists():
return []
return [str(d.name) for d in kb_root_path.iterdir() if not d.name.startswith(".") and d.is_dir()]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
async def update_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
) -> dotdict:
"""Update build configuration based on provider selection."""
# Create a new knowledge base
if field_name == "knowledge_base":
async with session_scope() as db:
if not self.user_id:
msg = "User ID is required for fetching knowledge base list."
raise ValueError(msg)
current_user = await get_user_by_id(db, self.user_id)
if not current_user:
msg = f"User with ID {self.user_id} not found."
raise ValueError(msg)
kb_user = current_user.username
if isinstance(field_value, dict) and "01_new_kb_name" in field_value:
# Validate the knowledge base name - Make sure it follows these rules:
if not self.is_valid_collection_name(field_value["01_new_kb_name"]):
msg = f"Invalid knowledge base name: {field_value['01_new_kb_name']}"
raise ValueError(msg)
# We need to test the API Key one time against the embedding model
embed_model = self._build_embeddings(
embedding_model=field_value["02_embedding_model"], api_key=field_value["03_api_key"]
)
api_key = field_value.get("03_api_key", None)
with contextlib.suppress(Exception):
# If the API key is a variable, resolve it
api_key = await self._get_api_key_variable(field_value)
# Try to generate a dummy embedding to validate the API key
embed_model.embed_query("test")
# Make sure api_key is a string
if not isinstance(api_key, str):
msg = "API key must be a string."
raise ValueError(msg)
# We need to test the API Key one time against the embedding model
embed_model = self._build_embeddings(embedding_model=field_value["02_embedding_model"], api_key=api_key)
# Try to generate a dummy embedding to validate the API key without blocking the event loop
try:
await asyncio.wait_for(
asyncio.to_thread(embed_model.embed_query, "test"),
timeout=10,
)
except TimeoutError as e:
msg = "Embedding validation timed out. Please verify network connectivity and key."
raise ValueError(msg) from e
except Exception as e:
msg = f"Embedding validation failed: {e!s}"
raise ValueError(msg) from e
# Create the new knowledge base directory
kb_path = KNOWLEDGE_BASES_ROOT_PATH / field_value["01_new_kb_name"]
kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / field_value["01_new_kb_name"]
kb_path.mkdir(parents=True, exist_ok=True)
# Save the embedding metadata
@ -573,11 +641,16 @@ class KBIngestionComponent(Component):
self._save_embedding_metadata(
kb_path=kb_path,
embedding_model=field_value["02_embedding_model"],
api_key=field_value["03_api_key"],
api_key=api_key,
)
# Update the knowledge base options dynamically
build_config["knowledge_base"]["options"] = self._get_knowledge_bases()
build_config["knowledge_base"]["options"] = await get_knowledge_bases(
KNOWLEDGE_BASES_ROOT_PATH,
user_id=self.user_id,
)
# If the selected knowledge base is not available, reset it
if build_config["knowledge_base"]["value"] not in build_config["knowledge_base"]["options"]:
build_config["knowledge_base"]["value"] = None

View file

@ -5,13 +5,16 @@ from typing import Any
from cryptography.fernet import InvalidToken
from langchain_chroma import Chroma
from loguru import logger
from pydantic import SecretStr
from langflow.base.data.kb_utils import get_knowledge_bases
from langflow.custom import Component
from langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, Output, SecretStrInput
from langflow.schema.data import Data
from langflow.schema.dataframe import DataFrame
from langflow.services.auth.utils import decrypt_api_key
from langflow.services.deps import get_settings_service
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import get_settings_service, session_scope
settings = get_settings_service().settings
knowledge_directory = settings.knowledge_bases_dir
@ -33,11 +36,7 @@ class KBRetrievalComponent(Component):
display_name="Knowledge",
info="Select the knowledge to load data from.",
required=True,
options=[
str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()
]
if KNOWLEDGE_BASES_ROOT_PATH.exists()
else [],
options=[],
refresh_button=True,
real_time_refresh=True,
),
@ -79,21 +78,13 @@ class KBRetrievalComponent(Component):
),
]
def _get_knowledge_bases(self) -> list[str]:
"""Retrieve a list of available knowledge bases.
Returns:
A list of knowledge base names.
"""
if not KNOWLEDGE_BASES_ROOT_PATH.exists():
return []
return [str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()]
def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002
async def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002
if field_name == "knowledge_base":
# Update the knowledge base options dynamically
build_config["knowledge_base"]["options"] = self._get_knowledge_bases()
build_config["knowledge_base"]["options"] = await get_knowledge_bases(
KNOWLEDGE_BASES_ROOT_PATH,
user_id=self.user_id, # Use the user_id from the component context
)
# If the selected knowledge base is not available, reset it
if build_config["knowledge_base"]["value"] not in build_config["knowledge_base"]["options"]:
@ -129,15 +120,12 @@ class KBRetrievalComponent(Component):
def _build_embeddings(self, metadata: dict):
"""Build embedding model from metadata."""
runtime_api_key = self.api_key.get_secret_value() if isinstance(self.api_key, SecretStr) else self.api_key
provider = metadata.get("embedding_provider")
model = metadata.get("embedding_model")
api_key = metadata.get("api_key")
api_key = runtime_api_key or metadata.get("api_key")
chunk_size = metadata.get("chunk_size")
# If user provided a key in the input, it overrides the stored one.
if self.api_key and self.api_key.get_secret_value():
api_key = self.api_key.get_secret_value()
# Handle various providers
if provider == "OpenAI":
from langchain_openai import OpenAIEmbeddings
@ -174,13 +162,23 @@ class KBRetrievalComponent(Component):
msg = f"Embedding provider '{provider}' is not supported for retrieval."
raise NotImplementedError(msg)
def get_chroma_kb_data(self) -> DataFrame:
async def get_chroma_kb_data(self) -> DataFrame:
"""Retrieve data from the selected knowledge base by reading the Chroma collection.
Returns:
A DataFrame containing the data rows from the knowledge base.
"""
kb_path = KNOWLEDGE_BASES_ROOT_PATH / self.knowledge_base
# Get the current user
async with session_scope() as db:
if not self.user_id:
msg = "User ID is required for fetching Knowledge Base data."
raise ValueError(msg)
current_user = await get_user_by_id(db, self.user_id)
if not current_user:
msg = f"User with ID {self.user_id} not found."
raise ValueError(msg)
kb_user = current_user.username
kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / self.knowledge_base
metadata = self._get_kb_metadata(kb_path)
if not metadata:

View file

@ -1,7 +1,6 @@
import json
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from typing import TYPE_CHECKING
import orjson
import pandas as pd
@ -10,16 +9,12 @@ from fastapi.encoders import jsonable_encoder
from langflow.api.v2.files import upload_user_file
from langflow.custom import Component
from langflow.io import DropdownInput, HandleInput, SecretStrInput, StrInput
from langflow.io import DropdownInput, HandleInput, StrInput
from langflow.schema import Data, DataFrame, Message
from langflow.services.auth.utils import create_user_longterm_token, get_current_user
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import get_session, get_settings_service, get_storage_service
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
from langflow.template.field.base import Output
if TYPE_CHECKING:
from langflow.services.database.models.user.model import User
class SaveToFileComponent(Component):
display_name = "Save File"
@ -55,13 +50,6 @@ class SaveToFileComponent(Component):
value="",
advanced=True,
),
SecretStrInput(
name="api_key",
display_name="Langflow API Key",
info="Langflow API key for authentication when saving the file.",
required=False,
advanced=True,
),
]
outputs = [Output(display_name="File Path", name="message", method="save_to_file")]
@ -148,25 +136,11 @@ class SaveToFileComponent(Component):
raise FileNotFoundError(msg)
with file_path.open("rb") as f:
async for db in get_session():
# TODO: In 1.6, this may need to be removed or adjusted
# Try to get the super user token, if possible
current_user: User | None = None
if self.api_key:
current_user = await get_current_user(
token="",
query_param=self.api_key,
header_param="",
db=db,
)
else:
user_id, _ = await create_user_longterm_token(db)
current_user = await get_user_by_id(db, user_id)
# Fail if the user is not found
if not current_user:
msg = "User not found. Please provide a valid API key or ensure the user exists."
async with session_scope() as db:
if not self.user_id:
msg = "User ID is required for file saving."
raise ValueError(msg)
current_user = await get_user_by_id(db, self.user_id)
await upload_user_file(
file=UploadFile(filename=file_path.name, file=f, size=file_path.stat().st_size),

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,8 +1,10 @@
import json
from unittest.mock import MagicMock, patch
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pandas as pd
import pytest
from langflow.base.data.kb_utils import get_knowledge_bases
from langflow.components.data.kb_ingest import KBIngestionComponent
from langflow.schema.data import Data
@ -21,8 +23,43 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
with patch("langflow.components.data.kb_ingest.KNOWLEDGE_BASES_ROOT_PATH", tmp_path):
yield
class MockUser:
def __init__(self, user_id):
self.id = user_id
self.username = "langflow"
@pytest.fixture
def default_kwargs(self, tmp_path):
def mock_user_data(self):
"""Create mock user data that persists for the test function."""
mock_uuid = uuid.uuid4()
mock_user = self.MockUser(mock_uuid)
return {"user_id": mock_uuid, "user": mock_user.username, "user_obj": mock_user}
@pytest.fixture(autouse=True)
def setup_mocks(self, mock_user_data):
"""Mock the component's user_id attribute and User object."""
with (
patch.object(KBIngestionComponent, "user_id", mock_user_data["user_id"]),
patch(
"langflow.components.data.kb_ingest.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user_data["user_obj"],
),
patch(
"langflow.base.data.kb_utils.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user_data["user_obj"],
),
):
yield
@pytest.fixture
def mock_user_id(self, mock_user_data):
"""Get the mock user data."""
return {"user_id": mock_user_data["user_id"], "user": mock_user_data["user"]}
@pytest.fixture
def default_kwargs(self, tmp_path, mock_user_id):
"""Return default kwargs for component instantiation."""
# Create a sample DataFrame
data_df = pd.DataFrame(
@ -38,8 +75,8 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
# Create knowledge base directory
kb_name = "test_kb"
kb_path = tmp_path / kb_name
kb_path.mkdir(exist_ok=True)
kb_path = tmp_path / mock_user_id["user"] / kb_name
kb_path.mkdir(parents=True, exist_ok=True)
# Create embedding metadata file
metadata = {
@ -206,7 +243,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
assert "text" in metadata["summary"]["vectorized_columns"]
assert "category" in metadata["summary"]["identifier_columns"]
def test_convert_df_to_data_objects(self, component_class, default_kwargs):
async def test_convert_df_to_data_objects(self, component_class, default_kwargs):
"""Test converting DataFrame to Data objects."""
component = component_class(**default_kwargs)
data_df = default_kwargs["input_df"]
@ -218,7 +255,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
mock_chroma_instance.get.return_value = {"metadatas": []}
mock_chroma.return_value = mock_chroma_instance
data_objects = component._convert_df_to_data_objects(data_df, config_list)
data_objects = await component._convert_df_to_data_objects(data_df, config_list)
assert len(data_objects) == 2
assert all(isinstance(obj, Data) for obj in data_objects)
@ -230,7 +267,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
assert "category" in first_obj.data
assert "_id" in first_obj.data
def test_convert_df_to_data_objects_no_duplicates(self, component_class, default_kwargs):
async def test_convert_df_to_data_objects_no_duplicates(self, component_class, default_kwargs):
"""Test converting DataFrame to Data objects with duplicate prevention."""
default_kwargs["allow_duplicates"] = False
component = component_class(**default_kwargs)
@ -251,7 +288,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
mock_hash_obj.hexdigest.side_effect = [existing_hash, "different_hash"]
mock_hash.return_value = mock_hash_obj
data_objects = component._convert_df_to_data_objects(data_df, config_list)
data_objects = await component._convert_df_to_data_objects(data_df, config_list)
# Should only return one object (second row) since first is duplicate
assert len(data_objects) == 1
@ -274,7 +311,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
@patch("langflow.components.data.kb_ingest.json.loads")
@patch("langflow.components.data.kb_ingest.decrypt_api_key")
def test_build_kb_info_success(self, mock_decrypt, mock_json_loads, component_class, default_kwargs):
async def test_build_kb_info_success(self, mock_decrypt, mock_json_loads, component_class, default_kwargs):
"""Test successful KB info building."""
component = component_class(**default_kwargs)
@ -287,7 +324,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
# Mock vector store creation
with patch.object(component, "_create_vector_store"), patch.object(component, "_save_kb_files"):
result = component.build_kb_info()
result = await component.build_kb_info()
assert isinstance(result, Data)
assert "kb_id" in result.data
@ -295,32 +332,21 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
assert "rows" in result.data
assert result.data["rows"] == 2
def test_get_knowledge_bases(self, component_class, default_kwargs, tmp_path):
async def test_get_knowledge_bases(self, tmp_path, mock_user_id):
"""Test getting list of knowledge bases."""
component = component_class(**default_kwargs)
# Create additional test directories
(tmp_path / "kb1").mkdir()
(tmp_path / "kb2").mkdir()
(tmp_path / ".hidden").mkdir() # Should be ignored
(tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True)
(tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True)
(tmp_path / mock_user_id["user"] / ".hidden").mkdir(parents=True, exist_ok=True) # Should be ignored
kb_list = component._get_knowledge_bases()
kb_list = await get_knowledge_bases(tmp_path, user_id=mock_user_id["user_id"])
assert "test_kb" in kb_list
assert "kb1" in kb_list
assert "kb2" in kb_list
assert ".hidden" not in kb_list
@patch("langflow.components.data.kb_ingest.Path.exists")
def test_get_knowledge_bases_no_path(self, mock_exists, component_class, default_kwargs):
"""Test getting knowledge bases when path doesn't exist."""
component = component_class(**default_kwargs)
mock_exists.return_value = False
kb_list = component._get_knowledge_bases()
assert kb_list == []
def test_update_build_config_new_kb(self, component_class, default_kwargs):
async def test_update_build_config_new_kb(self, component_class, default_kwargs):
"""Test updating build config for new knowledge base creation."""
component = component_class(**default_kwargs)
@ -329,26 +355,24 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
field_value = {
"01_new_kb_name": "new_test_kb",
"02_embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"03_api_key": None,
"03_api_key": "abc123", # Mock API key
}
# Mock embedding validation
with (
patch.object(component, "_build_embeddings") as mock_build_emb,
patch.object(component, "_save_embedding_metadata"),
patch.object(component, "_get_knowledge_bases") as mock_get_kbs,
):
mock_embeddings = MagicMock()
mock_embeddings.embed_query.return_value = [0.1, 0.2, 0.3]
mock_build_emb.return_value = mock_embeddings
mock_get_kbs.return_value = ["new_test_kb"]
result = component.update_build_config(build_config, field_value, "knowledge_base")
result = await component.update_build_config(build_config, field_value, "knowledge_base")
assert result["knowledge_base"]["value"] == "new_test_kb"
assert "new_test_kb" in result["knowledge_base"]["options"]
def test_update_build_config_invalid_kb_name(self, component_class, default_kwargs):
async def test_update_build_config_invalid_kb_name(self, component_class, default_kwargs):
"""Test updating build config with invalid KB name."""
component = component_class(**default_kwargs)
@ -360,4 +384,4 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
}
with pytest.raises(ValueError, match="Invalid knowledge base name"):
component.update_build_config(build_config, field_value, "knowledge_base")
await component.update_build_config(build_config, field_value, "knowledge_base")

View file

@ -1,10 +1,13 @@
import contextlib
import json
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langflow.base.data.kb_utils import get_knowledge_bases
from langflow.components.data.kb_retrieval import KBRetrievalComponent
from pydantic import SecretStr
from tests.base import ComponentTestBaseWithoutClient
@ -21,13 +24,48 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
with patch("langflow.components.data.kb_retrieval.KNOWLEDGE_BASES_ROOT_PATH", tmp_path):
yield
class MockUser:
def __init__(self, user_id):
self.id = user_id
self.username = "langflow"
@pytest.fixture
def default_kwargs(self, tmp_path):
def mock_user_data(self):
"""Create mock user data that persists for the test function."""
mock_uuid = uuid.uuid4()
mock_user = self.MockUser(mock_uuid)
return {"user_id": mock_uuid, "user": mock_user.username, "user_obj": mock_user}
@pytest.fixture(autouse=True)
def setup_mocks(self, mock_user_data):
"""Mock the component's user_id attribute and User object."""
with (
patch.object(KBRetrievalComponent, "user_id", mock_user_data["user_id"]),
patch(
"langflow.components.data.kb_retrieval.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user_data["user_obj"],
),
patch(
"langflow.base.data.kb_utils.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user_data["user_obj"],
),
):
yield
@pytest.fixture
def mock_user_id(self, mock_user_data):
"""Get the mock user data."""
return {"user_id": mock_user_data["user_id"], "user": mock_user_data["user"]}
@pytest.fixture
def default_kwargs(self, tmp_path, mock_user_id):
"""Return default kwargs for component instantiation."""
# Create knowledge base directory structure
kb_name = "test_kb"
kb_path = tmp_path / kb_name
kb_path.mkdir(exist_ok=True)
kb_path = tmp_path / mock_user_id["user"] / kb_name
kb_path.mkdir(parents=True, exist_ok=True)
# Create embedding metadata file
metadata = {
@ -55,61 +93,50 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
# This is a new component, so it doesn't exist in older versions
return []
def test_get_knowledge_bases(self, component_class, default_kwargs, tmp_path):
async def test_get_knowledge_bases(self, tmp_path, mock_user_id):
"""Test getting list of knowledge bases."""
component = component_class(**default_kwargs)
# Create additional test directories
(tmp_path / "kb1").mkdir()
(tmp_path / "kb2").mkdir()
(tmp_path / ".hidden").mkdir() # Should be ignored
(tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True)
(tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True)
(tmp_path / mock_user_id["user"] / ".hidden").mkdir(parents=True, exist_ok=True) # Should be ignored
kb_list = component._get_knowledge_bases()
kb_list = await get_knowledge_bases(tmp_path, user_id=mock_user_id["user_id"])
assert "test_kb" in kb_list
assert "kb1" in kb_list
assert "kb2" in kb_list
assert ".hidden" not in kb_list
@patch("langflow.components.data.kb_retrieval.Path.exists")
def test_get_knowledge_bases_no_path(self, mock_exists, component_class, default_kwargs):
"""Test getting knowledge bases when path doesn't exist."""
component = component_class(**default_kwargs)
mock_exists.return_value = False
kb_list = component._get_knowledge_bases()
assert kb_list == []
def test_update_build_config(self, component_class, default_kwargs, tmp_path):
async def test_update_build_config(self, component_class, default_kwargs, tmp_path, mock_user_id):
"""Test updating build configuration."""
component = component_class(**default_kwargs)
# Create additional KB directories
(tmp_path / "kb1").mkdir()
(tmp_path / "kb2").mkdir()
(tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True)
(tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True)
build_config = {"knowledge_base": {"value": "test_kb", "options": []}}
result = component.update_build_config(build_config, None, "knowledge_base")
result = await component.update_build_config(build_config, None, "knowledge_base")
assert "test_kb" in result["knowledge_base"]["options"]
assert "kb1" in result["knowledge_base"]["options"]
assert "kb2" in result["knowledge_base"]["options"]
def test_update_build_config_invalid_kb(self, component_class, default_kwargs):
async def test_update_build_config_invalid_kb(self, component_class, default_kwargs):
"""Test updating build config when selected KB is not available."""
component = component_class(**default_kwargs)
build_config = {"knowledge_base": {"value": "nonexistent_kb", "options": ["test_kb"]}}
result = component.update_build_config(build_config, None, "knowledge_base")
result = await component.update_build_config(build_config, None, "knowledge_base")
assert result["knowledge_base"]["value"] is None
def test_get_kb_metadata_success(self, component_class, default_kwargs):
def test_get_kb_metadata_success(self, component_class, default_kwargs, mock_user_id):
"""Test successful metadata loading."""
component = component_class(**default_kwargs)
kb_path = Path(default_kwargs["kb_root_path"]) / default_kwargs["knowledge_base"]
kb_path = Path(default_kwargs["kb_root_path"]) / mock_user_id["user"] / default_kwargs["knowledge_base"]
with patch("langflow.components.data.kb_retrieval.decrypt_api_key") as mock_decrypt:
mock_decrypt.return_value = "decrypted_key"
@ -120,21 +147,21 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
assert metadata["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert "chunk_size" in metadata
def test_get_kb_metadata_no_file(self, component_class, default_kwargs, tmp_path):
def test_get_kb_metadata_no_file(self, component_class, default_kwargs, tmp_path, mock_user_id):
"""Test metadata loading when file doesn't exist."""
component = component_class(**default_kwargs)
nonexistent_path = tmp_path / "nonexistent"
nonexistent_path.mkdir()
nonexistent_path = tmp_path / mock_user_id["user"] / "nonexistent"
nonexistent_path.mkdir(parents=True, exist_ok=True)
metadata = component._get_kb_metadata(nonexistent_path)
assert metadata == {}
def test_get_kb_metadata_json_error(self, component_class, default_kwargs, tmp_path):
def test_get_kb_metadata_json_error(self, component_class, default_kwargs, tmp_path, mock_user_id):
"""Test metadata loading with invalid JSON."""
component = component_class(**default_kwargs)
kb_path = tmp_path / "invalid_json_kb"
kb_path.mkdir()
kb_path = tmp_path / mock_user_id["user"] / "invalid_json_kb"
kb_path.mkdir(parents=True, exist_ok=True)
# Create invalid JSON file
(kb_path / "embedding_metadata.json").write_text("invalid json content")
@ -143,11 +170,11 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
assert metadata == {}
def test_get_kb_metadata_decrypt_error(self, component_class, default_kwargs, tmp_path):
def test_get_kb_metadata_decrypt_error(self, component_class, default_kwargs, tmp_path, mock_user_id):
"""Test metadata loading with decryption error."""
component = component_class(**default_kwargs)
kb_path = tmp_path / "decrypt_error_kb"
kb_path.mkdir()
kb_path = tmp_path / mock_user_id["user"] / "decrypt_error_kb"
kb_path.mkdir(parents=True, exist_ok=True)
# Create metadata with encrypted key
metadata = {
@ -274,10 +301,8 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
def test_build_embeddings_with_user_api_key(self, component_class, default_kwargs):
"""Test that user-provided API key overrides stored one."""
# Create a mock secret input
mock_secret = MagicMock()
mock_secret.get_secret_value.return_value = "user-provided-key"
# Use a real SecretStr object instead of a mock
mock_secret = SecretStr("user-provided-key")
default_kwargs["api_key"] = mock_secret
component = component_class(**default_kwargs)
@ -285,7 +310,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
metadata = {
"embedding_provider": "OpenAI",
"embedding_model": "text-embedding-ada-002",
"api_key": "stored-key",
"api_key": "stored-key", # This should be overridden by the user-provided key
"chunk_size": 1000,
}
@ -295,14 +320,17 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
component._build_embeddings(metadata)
# The user-provided key should override the stored key in metadata
mock_openai.assert_called_once_with(
model="text-embedding-ada-002", api_key="user-provided-key", chunk_size=1000
model="text-embedding-ada-002",
api_key="user-provided-key", # Should use the user-provided key, not "stored-key"
chunk_size=1000,
)
def test_get_chroma_kb_data_no_metadata(self, component_class, default_kwargs, tmp_path):
async def test_get_chroma_kb_data_no_metadata(self, component_class, default_kwargs, tmp_path, mock_user_id):
"""Test retrieving data when metadata is missing."""
# Remove metadata file
kb_path = tmp_path / default_kwargs["knowledge_base"]
kb_path = tmp_path / mock_user_id["user"] / default_kwargs["knowledge_base"]
metadata_file = kb_path / "embedding_metadata.json"
if metadata_file.exists():
metadata_file.unlink()
@ -310,7 +338,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
component = component_class(**default_kwargs)
with pytest.raises(ValueError, match="Metadata not found for knowledge base"):
component.get_chroma_kb_data()
await component.get_chroma_kb_data()
def test_get_chroma_kb_data_path_construction(self, component_class, default_kwargs):
"""Test that get_chroma_kb_data constructs the correct paths."""
@ -331,7 +359,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
assert hasattr(component, "top_k")
assert hasattr(component, "include_embeddings")
def test_get_chroma_kb_data_method_exists(self, component_class, default_kwargs):
async def test_get_chroma_kb_data_method_exists(self, component_class, default_kwargs):
"""Test that get_chroma_kb_data method exists and can be called."""
component = component_class(**default_kwargs)
@ -349,7 +377,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
# This is a unit test focused on the component's internal logic
with contextlib.suppress(Exception):
component.get_chroma_kb_data()
await component.get_chroma_kb_data()
# Verify internal methods were called
mock_get_metadata.assert_called_once()