feat: Make knowledge bases user-stored and support global vars (#9458)
* feat: Make knowledge bases user-stored * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Fix ruff error * [autofix.ci] apply automated fixes * Reuse code * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Don't show options by default * [autofix.ci] apply automated fixes * Pass in the Langflow API key if set * [autofix.ci] apply automated fixes * Update files.py * [autofix.ci] apply automated fixes * Properly handle secret retrieval * [autofix.ci] apply automated fixes * Update src/backend/base/langflow/base/data/kb_utils.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * Update src/backend/base/langflow/base/data/kb_utils.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * Update src/backend/base/langflow/components/data/kb_ingest.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Feedback from review * [autofix.ci] apply automated fixes * Fix other uses of incorrect user * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Feedback from review 2 * [autofix.ci] apply automated fixes * Update kb_ingest.py * [autofix.ci] apply automated fixes * Update tests * [autofix.ci] apply automated fixes * Update kb_ingest.py * [autofix.ci] apply automated fixes * Fix mypy issues * [autofix.ci] apply automated fixes * Update kb_utils.py * Update test_kb_ingest.py * Fix tests * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
2475e5a254
commit
59937ee9e7
13 changed files with 375 additions and 285 deletions
|
|
@ -9,6 +9,7 @@ from langchain_chroma import Chroma
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.api.utils import CurrentActiveUser
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
router = APIRouter(tags=["Knowledge Bases"], prefix="/knowledge_bases")
|
||||
|
|
@ -290,17 +291,19 @@ def get_kb_metadata(kb_path: Path) -> dict:
|
|||
|
||||
@router.get("", status_code=HTTPStatus.OK)
|
||||
@router.get("/", status_code=HTTPStatus.OK)
|
||||
async def list_knowledge_bases() -> list[KnowledgeBaseInfo]:
|
||||
async def list_knowledge_bases(current_user: CurrentActiveUser) -> list[KnowledgeBaseInfo]:
|
||||
"""List all available knowledge bases."""
|
||||
try:
|
||||
kb_root_path = get_kb_root_path()
|
||||
kb_user = current_user.username
|
||||
kb_path = kb_root_path / kb_user
|
||||
|
||||
if not kb_root_path.exists():
|
||||
if not kb_path.exists():
|
||||
return []
|
||||
|
||||
knowledge_bases = []
|
||||
|
||||
for kb_dir in kb_root_path.iterdir():
|
||||
for kb_dir in kb_path.iterdir():
|
||||
if not kb_dir.is_dir() or kb_dir.name.startswith("."):
|
||||
continue
|
||||
|
||||
|
|
@ -340,11 +343,12 @@ async def list_knowledge_bases() -> list[KnowledgeBaseInfo]:
|
|||
|
||||
|
||||
@router.get("/{kb_name}", status_code=HTTPStatus.OK)
|
||||
async def get_knowledge_base(kb_name: str) -> KnowledgeBaseInfo:
|
||||
async def get_knowledge_base(kb_name: str, current_user: CurrentActiveUser) -> KnowledgeBaseInfo:
|
||||
"""Get detailed information about a specific knowledge base."""
|
||||
try:
|
||||
kb_root_path = get_kb_root_path()
|
||||
kb_path = kb_root_path / kb_name
|
||||
kb_user = current_user.username
|
||||
kb_path = kb_root_path / kb_user / kb_name
|
||||
|
||||
if not kb_path.exists() or not kb_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
|
||||
|
|
@ -374,11 +378,12 @@ async def get_knowledge_base(kb_name: str) -> KnowledgeBaseInfo:
|
|||
|
||||
|
||||
@router.delete("/{kb_name}", status_code=HTTPStatus.OK)
|
||||
async def delete_knowledge_base(kb_name: str) -> dict[str, str]:
|
||||
async def delete_knowledge_base(kb_name: str, current_user: CurrentActiveUser) -> dict[str, str]:
|
||||
"""Delete a specific knowledge base."""
|
||||
try:
|
||||
kb_root_path = get_kb_root_path()
|
||||
kb_path = kb_root_path / kb_name
|
||||
kb_user = current_user.username
|
||||
kb_path = kb_root_path / kb_user / kb_name
|
||||
|
||||
if not kb_path.exists() or not kb_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
|
||||
|
|
@ -396,15 +401,17 @@ async def delete_knowledge_base(kb_name: str) -> dict[str, str]:
|
|||
|
||||
@router.delete("", status_code=HTTPStatus.OK)
|
||||
@router.delete("/", status_code=HTTPStatus.OK)
|
||||
async def delete_knowledge_bases_bulk(request: BulkDeleteRequest) -> dict[str, object]:
|
||||
async def delete_knowledge_bases_bulk(request: BulkDeleteRequest, current_user: CurrentActiveUser) -> dict[str, object]:
|
||||
"""Delete multiple knowledge bases."""
|
||||
try:
|
||||
kb_root_path = get_kb_root_path()
|
||||
kb_user = current_user.username
|
||||
kb_user_path = kb_root_path / kb_user
|
||||
deleted_count = 0
|
||||
not_found_kbs = []
|
||||
|
||||
for kb_name in request.kb_names:
|
||||
kb_path = kb_root_path / kb_name
|
||||
kb_path = kb_user_path / kb_name
|
||||
|
||||
if not kb_path.exists() or not kb_path.is_dir():
|
||||
not_found_kbs.append(kb_name)
|
||||
|
|
|
|||
|
|
@ -123,7 +123,9 @@ async def upload_user_file(
|
|||
unique_filename = new_filename
|
||||
else:
|
||||
# For normal files, ensure unique name by appending a count if necessary
|
||||
stmt = select(UserFile).where(col(UserFile.name).like(f"{root_filename}%"))
|
||||
stmt = select(UserFile).where(
|
||||
col(UserFile.name).like(f"{root_filename}%"), UserFile.user_id == current_user.id
|
||||
)
|
||||
existing_files = await session.exec(stmt)
|
||||
files = existing_files.all() # Fetch all matching records
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import math
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from langflow.services.database.models.user.crud import get_user_by_id
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
def compute_tfidf(documents: list[str], query_terms: list[str]) -> list[float]:
|
||||
|
|
@ -102,3 +107,31 @@ def compute_bm25(documents: list[str], query_terms: list[str], k1: float = 1.2,
|
|||
scores.append(doc_score)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
async def get_knowledge_bases(kb_root: Path, user_id: UUID | str) -> list[str]:
|
||||
"""Retrieve a list of available knowledge bases.
|
||||
|
||||
Returns:
|
||||
A list of knowledge base names.
|
||||
"""
|
||||
if not kb_root.exists():
|
||||
return []
|
||||
|
||||
# Get the current user
|
||||
async with session_scope() as db:
|
||||
if not user_id:
|
||||
msg = "User ID is required for fetching knowledge bases."
|
||||
raise ValueError(msg)
|
||||
user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
current_user = await get_user_by_id(db, user_id)
|
||||
if not current_user:
|
||||
msg = f"User with ID {user_id} not found."
|
||||
raise ValueError(msg)
|
||||
kb_user = current_user.username
|
||||
kb_path = kb_root / kb_user
|
||||
|
||||
if not kb_path.exists():
|
||||
return []
|
||||
|
||||
return [str(d.name) for d in kb_path.iterdir() if not d.name.startswith(".") and d.is_dir()]
|
||||
|
|
|
|||
|
|
@ -16,16 +16,15 @@ from langflow.base.mcp.util import (
|
|||
)
|
||||
from langflow.custom.custom_component.component_with_cache import ComponentWithCache
|
||||
from langflow.inputs.inputs import InputTypes # noqa: TC001
|
||||
from langflow.io import DropdownInput, McpInput, MessageTextInput, Output, SecretStrInput
|
||||
from langflow.io import DropdownInput, McpInput, MessageTextInput, Output
|
||||
from langflow.io.schema import flatten_schema, schema_to_langflow_inputs
|
||||
from langflow.logging import logger
|
||||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.schema.message import Message
|
||||
|
||||
# Import get_server from the backend API
|
||||
from langflow.services.auth.utils import create_user_longterm_token, get_current_user
|
||||
from langflow.services.database.models.user.crud import get_user_by_id
|
||||
from langflow.services.deps import get_session, get_settings_service, get_storage_service
|
||||
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
|
||||
|
||||
|
||||
class MCPToolsComponent(ComponentWithCache):
|
||||
|
|
@ -96,13 +95,6 @@ class MCPToolsComponent(ComponentWithCache):
|
|||
show=False,
|
||||
tool_mode=False,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
display_name="Langflow API Key",
|
||||
info="Langflow API key for authentication when fetching MCP servers and tools.",
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
|
|
@ -161,19 +153,11 @@ class MCPToolsComponent(ComponentWithCache):
|
|||
return self.tools, {"name": server_name, "config": server_config_from_value}
|
||||
|
||||
try:
|
||||
async for db in get_session():
|
||||
# TODO: In 1.6, this may need to be removed or adjusted
|
||||
# Try to get the super user token, if possible
|
||||
if self.api_key:
|
||||
current_user = await get_current_user(
|
||||
token=None,
|
||||
query_param=self.api_key,
|
||||
header_param=None,
|
||||
db=db,
|
||||
)
|
||||
else:
|
||||
user_id, _ = await create_user_longterm_token(db)
|
||||
current_user = await get_user_by_id(db, user_id)
|
||||
async with session_scope() as db:
|
||||
if not self.user_id:
|
||||
msg = "User ID is required for fetching MCP tools."
|
||||
raise ValueError(msg)
|
||||
current_user = await get_user_by_id(db, self.user_id)
|
||||
|
||||
# Try to get server config from DB/API
|
||||
server_config = await get_server(
|
||||
|
|
@ -184,39 +168,38 @@ class MCPToolsComponent(ComponentWithCache):
|
|||
settings_service=get_settings_service(),
|
||||
)
|
||||
|
||||
# If get_server returns empty but we have a config, use it
|
||||
if not server_config and server_config_from_value:
|
||||
server_config = server_config_from_value
|
||||
# If get_server returns empty but we have a config, use it
|
||||
if not server_config and server_config_from_value:
|
||||
server_config = server_config_from_value
|
||||
|
||||
if not server_config:
|
||||
self.tools = []
|
||||
return [], {"name": server_name, "config": server_config}
|
||||
if not server_config:
|
||||
self.tools = []
|
||||
return [], {"name": server_name, "config": server_config}
|
||||
|
||||
_, tool_list, tool_cache = await update_tools(
|
||||
server_name=server_name,
|
||||
server_config=server_config,
|
||||
mcp_stdio_client=self.stdio_client,
|
||||
mcp_sse_client=self.sse_client,
|
||||
)
|
||||
_, tool_list, tool_cache = await update_tools(
|
||||
server_name=server_name,
|
||||
server_config=server_config,
|
||||
mcp_stdio_client=self.stdio_client,
|
||||
mcp_sse_client=self.sse_client,
|
||||
)
|
||||
|
||||
self.tool_names = [tool.name for tool in tool_list if hasattr(tool, "name")]
|
||||
self._tool_cache = tool_cache
|
||||
self.tools = tool_list
|
||||
# Cache the result using shared cache
|
||||
cache_data = {
|
||||
"tools": tool_list,
|
||||
"tool_names": self.tool_names,
|
||||
"tool_cache": tool_cache,
|
||||
"config": server_config,
|
||||
}
|
||||
self.tool_names = [tool.name for tool in tool_list if hasattr(tool, "name")]
|
||||
self._tool_cache = tool_cache
|
||||
self.tools = tool_list
|
||||
# Cache the result using shared cache
|
||||
cache_data = {
|
||||
"tools": tool_list,
|
||||
"tool_names": self.tool_names,
|
||||
"tool_cache": tool_cache,
|
||||
"config": server_config,
|
||||
}
|
||||
|
||||
# Safely update the servers cache
|
||||
current_servers_cache = safe_cache_get(self._shared_component_cache, "servers", {})
|
||||
if isinstance(current_servers_cache, dict):
|
||||
current_servers_cache[server_name] = cache_data
|
||||
safe_cache_set(self._shared_component_cache, "servers", current_servers_cache)
|
||||
# Safely update the servers cache
|
||||
current_servers_cache = safe_cache_get(self._shared_component_cache, "servers", {})
|
||||
if isinstance(current_servers_cache, dict):
|
||||
current_servers_cache[server_name] = cache_data
|
||||
safe_cache_set(self._shared_component_cache, "servers", current_servers_cache)
|
||||
|
||||
return tool_list, {"name": server_name, "config": server_config}
|
||||
except (TimeoutError, asyncio.TimeoutError) as e:
|
||||
msg = f"Timeout updating tool list: {e!s}"
|
||||
logger.exception(msg)
|
||||
|
|
@ -225,6 +208,8 @@ class MCPToolsComponent(ComponentWithCache):
|
|||
msg = f"Error updating tool list: {e!s}"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
return tool_list, {"name": server_name, "config": server_config}
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
|
||||
"""Toggle the visibility of connection-specific fields based on the selected mode."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
|
|
@ -14,6 +16,7 @@ from cryptography.fernet import InvalidToken
|
|||
from langchain_chroma import Chroma
|
||||
from loguru import logger
|
||||
|
||||
from langflow.base.data.kb_utils import get_knowledge_bases
|
||||
from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES
|
||||
from langflow.custom import Component
|
||||
from langflow.io import BoolInput, DataFrameInput, DropdownInput, IntInput, Output, SecretStrInput, StrInput, TableInput
|
||||
|
|
@ -21,7 +24,8 @@ from langflow.schema.data import Data
|
|||
from langflow.schema.dotdict import dotdict # noqa: TC001
|
||||
from langflow.schema.table import EditMode
|
||||
from langflow.services.auth.utils import decrypt_api_key, encrypt_api_key
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.database.models.user.crud import get_user_by_id
|
||||
from langflow.services.deps import get_settings_service, get_variable_service, session_scope
|
||||
|
||||
HUGGINGFACE_MODEL_NAMES = ["sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"]
|
||||
COHERE_MODEL_NAMES = ["embed-english-v3.0", "embed-multilingual-v3.0"]
|
||||
|
|
@ -43,6 +47,10 @@ class KBIngestionComponent(Component):
|
|||
icon = "database"
|
||||
name = "KBIngestion"
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._cached_kb_path: Path | None = None
|
||||
|
||||
@dataclass
|
||||
class NewKnowledgeBaseInput:
|
||||
functionality: str = "create"
|
||||
|
|
@ -76,7 +84,7 @@ class KBIngestionComponent(Component):
|
|||
display_name="API Key",
|
||||
info="Provider API key for embedding model",
|
||||
required=True,
|
||||
load_from_db=True,
|
||||
load_from_db=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
|
|
@ -91,11 +99,7 @@ class KBIngestionComponent(Component):
|
|||
display_name="Knowledge",
|
||||
info="Select the knowledge to load data from.",
|
||||
required=True,
|
||||
options=[
|
||||
str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()
|
||||
]
|
||||
if KNOWLEDGE_BASES_ROOT_PATH.exists()
|
||||
else [],
|
||||
options=[],
|
||||
refresh_button=True,
|
||||
dialog_inputs=asdict(NewKnowledgeBaseInput()),
|
||||
),
|
||||
|
|
@ -329,22 +333,23 @@ class KBIngestionComponent(Component):
|
|||
|
||||
return metadata
|
||||
|
||||
def _create_vector_store(
|
||||
async def _create_vector_store(
|
||||
self, df_source: pd.DataFrame, config_list: list[dict[str, Any]], embedding_model: str, api_key: str
|
||||
) -> None:
|
||||
"""Create vector store following Local DB component pattern."""
|
||||
try:
|
||||
# Set up vector store directory
|
||||
base_dir = self._get_kb_root()
|
||||
|
||||
vector_store_dir = base_dir / self.knowledge_base
|
||||
vector_store_dir = await self._kb_path()
|
||||
if not vector_store_dir:
|
||||
msg = "Knowledge base path is not set. Please create a new knowledge base first."
|
||||
raise ValueError(msg)
|
||||
vector_store_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create embeddings model
|
||||
embedding_function = self._build_embeddings(embedding_model, api_key)
|
||||
|
||||
# Convert DataFrame to Data objects (following Local DB pattern)
|
||||
data_objects = self._convert_df_to_data_objects(df_source, config_list)
|
||||
data_objects = await self._convert_df_to_data_objects(df_source, config_list)
|
||||
|
||||
# Create vector store
|
||||
chroma = Chroma(
|
||||
|
|
@ -367,16 +372,18 @@ class KBIngestionComponent(Component):
|
|||
except (OSError, ValueError, RuntimeError) as e:
|
||||
self.log(f"Error creating vector store: {e}")
|
||||
|
||||
def _convert_df_to_data_objects(self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]) -> list[Data]:
|
||||
async def _convert_df_to_data_objects(
|
||||
self, df_source: pd.DataFrame, config_list: list[dict[str, Any]]
|
||||
) -> list[Data]:
|
||||
"""Convert DataFrame to Data objects for vector store."""
|
||||
data_objects: list[Data] = []
|
||||
|
||||
# Set up vector store directory
|
||||
base_dir = self._get_kb_root()
|
||||
kb_path = await self._kb_path()
|
||||
|
||||
# If we don't allow duplicates, we need to get the existing hashes
|
||||
chroma = Chroma(
|
||||
persist_directory=str(base_dir / self.knowledge_base),
|
||||
persist_directory=str(kb_path),
|
||||
collection_name=self.knowledge_base,
|
||||
)
|
||||
|
||||
|
|
@ -466,10 +473,34 @@ class KBIngestionComponent(Component):
|
|||
# Check allowed characters (condition 3)
|
||||
return re.match(r"^[a-zA-Z0-9_-]+$", name) is not None
|
||||
|
||||
async def _kb_path(self) -> Path | None:
|
||||
# Check if we already have the path cached
|
||||
cached_path = getattr(self, "_cached_kb_path", None)
|
||||
if cached_path is not None:
|
||||
return cached_path
|
||||
|
||||
# If not cached, compute it
|
||||
async with session_scope() as db:
|
||||
if not self.user_id:
|
||||
msg = "User ID is required for fetching knowledge base path."
|
||||
raise ValueError(msg)
|
||||
current_user = await get_user_by_id(db, self.user_id)
|
||||
if not current_user:
|
||||
msg = f"User with ID {self.user_id} not found."
|
||||
raise ValueError(msg)
|
||||
kb_user = current_user.username
|
||||
|
||||
kb_root = self._get_kb_root()
|
||||
|
||||
# Cache the result
|
||||
self._cached_kb_path = kb_root / kb_user / self.knowledge_base
|
||||
|
||||
return self._cached_kb_path
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# OUTPUT METHODS
|
||||
# ---------------------------------------------------------------------
|
||||
def build_kb_info(self) -> Data:
|
||||
async def build_kb_info(self) -> Data:
|
||||
"""Main ingestion routine → returns a dict with KB metadata."""
|
||||
try:
|
||||
# Get source DataFrame
|
||||
|
|
@ -479,11 +510,11 @@ class KBIngestionComponent(Component):
|
|||
config_list = self._validate_column_config(df_source)
|
||||
column_metadata = self._build_column_metadata(config_list, df_source)
|
||||
|
||||
# Prepare KB folder (using File Component patterns)
|
||||
kb_root = self._get_kb_root()
|
||||
kb_path = kb_root / self.knowledge_base
|
||||
|
||||
# Read the embedding info from the knowledge base folder
|
||||
kb_path = await self._kb_path()
|
||||
if not kb_path:
|
||||
msg = "Knowledge base path is not set. Please create a new knowledge base first."
|
||||
raise ValueError(msg)
|
||||
metadata_path = kb_path / "embedding_metadata.json"
|
||||
|
||||
# If the API key is not provided, try to read it from the metadata file
|
||||
|
|
@ -506,7 +537,7 @@ class KBIngestionComponent(Component):
|
|||
)
|
||||
|
||||
# Create vector store following Local DB component pattern
|
||||
self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key)
|
||||
await self._create_vector_store(df_source, config_list, embedding_model=embedding_model, api_key=api_key)
|
||||
|
||||
# Save KB files (using File Component storage patterns)
|
||||
self._save_kb_files(kb_path, config_list)
|
||||
|
|
@ -532,40 +563,77 @@ class KBIngestionComponent(Component):
|
|||
self.status = f"❌ KB ingestion failed: {e}"
|
||||
return Data(data={"error": str(e), "kb_name": self.knowledge_base})
|
||||
|
||||
def _get_knowledge_bases(self) -> list[str]:
|
||||
"""Retrieve a list of available knowledge bases.
|
||||
async def _get_api_key_variable(self, field_value: dict[str, Any]):
|
||||
async with session_scope() as db:
|
||||
if not self.user_id:
|
||||
msg = "User ID is required for fetching global variables."
|
||||
raise ValueError(msg)
|
||||
current_user = await get_user_by_id(db, self.user_id)
|
||||
if not current_user:
|
||||
msg = f"User with ID {self.user_id} not found."
|
||||
raise ValueError(msg)
|
||||
variable_service = get_variable_service()
|
||||
|
||||
Returns:
|
||||
A list of knowledge base names.
|
||||
"""
|
||||
# Return the list of directories in the knowledge base root path
|
||||
kb_root_path = self._get_kb_root()
|
||||
# Process the api_key field variable
|
||||
return await variable_service.get_variable(
|
||||
user_id=current_user.id,
|
||||
name=field_value["03_api_key"],
|
||||
field="",
|
||||
session=db,
|
||||
)
|
||||
|
||||
if not kb_root_path.exists():
|
||||
return []
|
||||
|
||||
return [str(d.name) for d in kb_root_path.iterdir() if not d.name.startswith(".") and d.is_dir()]
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
|
||||
async def update_build_config(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
) -> dotdict:
|
||||
"""Update build configuration based on provider selection."""
|
||||
# Create a new knowledge base
|
||||
if field_name == "knowledge_base":
|
||||
async with session_scope() as db:
|
||||
if not self.user_id:
|
||||
msg = "User ID is required for fetching knowledge base list."
|
||||
raise ValueError(msg)
|
||||
current_user = await get_user_by_id(db, self.user_id)
|
||||
if not current_user:
|
||||
msg = f"User with ID {self.user_id} not found."
|
||||
raise ValueError(msg)
|
||||
kb_user = current_user.username
|
||||
if isinstance(field_value, dict) and "01_new_kb_name" in field_value:
|
||||
# Validate the knowledge base name - Make sure it follows these rules:
|
||||
if not self.is_valid_collection_name(field_value["01_new_kb_name"]):
|
||||
msg = f"Invalid knowledge base name: {field_value['01_new_kb_name']}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# We need to test the API Key one time against the embedding model
|
||||
embed_model = self._build_embeddings(
|
||||
embedding_model=field_value["02_embedding_model"], api_key=field_value["03_api_key"]
|
||||
)
|
||||
api_key = field_value.get("03_api_key", None)
|
||||
with contextlib.suppress(Exception):
|
||||
# If the API key is a variable, resolve it
|
||||
api_key = await self._get_api_key_variable(field_value)
|
||||
|
||||
# Try to generate a dummy embedding to validate the API key
|
||||
embed_model.embed_query("test")
|
||||
# Make sure api_key is a string
|
||||
if not isinstance(api_key, str):
|
||||
msg = "API key must be a string."
|
||||
raise ValueError(msg)
|
||||
|
||||
# We need to test the API Key one time against the embedding model
|
||||
embed_model = self._build_embeddings(embedding_model=field_value["02_embedding_model"], api_key=api_key)
|
||||
|
||||
# Try to generate a dummy embedding to validate the API key without blocking the event loop
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(embed_model.embed_query, "test"),
|
||||
timeout=10,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
msg = "Embedding validation timed out. Please verify network connectivity and key."
|
||||
raise ValueError(msg) from e
|
||||
except Exception as e:
|
||||
msg = f"Embedding validation failed: {e!s}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# Create the new knowledge base directory
|
||||
kb_path = KNOWLEDGE_BASES_ROOT_PATH / field_value["01_new_kb_name"]
|
||||
kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / field_value["01_new_kb_name"]
|
||||
kb_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the embedding metadata
|
||||
|
|
@ -573,11 +641,16 @@ class KBIngestionComponent(Component):
|
|||
self._save_embedding_metadata(
|
||||
kb_path=kb_path,
|
||||
embedding_model=field_value["02_embedding_model"],
|
||||
api_key=field_value["03_api_key"],
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Update the knowledge base options dynamically
|
||||
build_config["knowledge_base"]["options"] = self._get_knowledge_bases()
|
||||
build_config["knowledge_base"]["options"] = await get_knowledge_bases(
|
||||
KNOWLEDGE_BASES_ROOT_PATH,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
|
||||
# If the selected knowledge base is not available, reset it
|
||||
if build_config["knowledge_base"]["value"] not in build_config["knowledge_base"]["options"]:
|
||||
build_config["knowledge_base"]["value"] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -5,13 +5,16 @@ from typing import Any
|
|||
from cryptography.fernet import InvalidToken
|
||||
from langchain_chroma import Chroma
|
||||
from loguru import logger
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langflow.base.data.kb_utils import get_knowledge_bases
|
||||
from langflow.custom import Component
|
||||
from langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, Output, SecretStrInput
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.services.auth.utils import decrypt_api_key
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.database.models.user.crud import get_user_by_id
|
||||
from langflow.services.deps import get_settings_service, session_scope
|
||||
|
||||
settings = get_settings_service().settings
|
||||
knowledge_directory = settings.knowledge_bases_dir
|
||||
|
|
@ -33,11 +36,7 @@ class KBRetrievalComponent(Component):
|
|||
display_name="Knowledge",
|
||||
info="Select the knowledge to load data from.",
|
||||
required=True,
|
||||
options=[
|
||||
str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()
|
||||
]
|
||||
if KNOWLEDGE_BASES_ROOT_PATH.exists()
|
||||
else [],
|
||||
options=[],
|
||||
refresh_button=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
|
|
@ -79,21 +78,13 @@ class KBRetrievalComponent(Component):
|
|||
),
|
||||
]
|
||||
|
||||
def _get_knowledge_bases(self) -> list[str]:
|
||||
"""Retrieve a list of available knowledge bases.
|
||||
|
||||
Returns:
|
||||
A list of knowledge base names.
|
||||
"""
|
||||
if not KNOWLEDGE_BASES_ROOT_PATH.exists():
|
||||
return []
|
||||
|
||||
return [str(d.name) for d in KNOWLEDGE_BASES_ROOT_PATH.iterdir() if not d.name.startswith(".") and d.is_dir()]
|
||||
|
||||
def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002
|
||||
async def update_build_config(self, build_config, field_value, field_name=None): # noqa: ARG002
|
||||
if field_name == "knowledge_base":
|
||||
# Update the knowledge base options dynamically
|
||||
build_config["knowledge_base"]["options"] = self._get_knowledge_bases()
|
||||
build_config["knowledge_base"]["options"] = await get_knowledge_bases(
|
||||
KNOWLEDGE_BASES_ROOT_PATH,
|
||||
user_id=self.user_id, # Use the user_id from the component context
|
||||
)
|
||||
|
||||
# If the selected knowledge base is not available, reset it
|
||||
if build_config["knowledge_base"]["value"] not in build_config["knowledge_base"]["options"]:
|
||||
|
|
@ -129,15 +120,12 @@ class KBRetrievalComponent(Component):
|
|||
|
||||
def _build_embeddings(self, metadata: dict):
|
||||
"""Build embedding model from metadata."""
|
||||
runtime_api_key = self.api_key.get_secret_value() if isinstance(self.api_key, SecretStr) else self.api_key
|
||||
provider = metadata.get("embedding_provider")
|
||||
model = metadata.get("embedding_model")
|
||||
api_key = metadata.get("api_key")
|
||||
api_key = runtime_api_key or metadata.get("api_key")
|
||||
chunk_size = metadata.get("chunk_size")
|
||||
|
||||
# If user provided a key in the input, it overrides the stored one.
|
||||
if self.api_key and self.api_key.get_secret_value():
|
||||
api_key = self.api_key.get_secret_value()
|
||||
|
||||
# Handle various providers
|
||||
if provider == "OpenAI":
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
|
@ -174,13 +162,23 @@ class KBRetrievalComponent(Component):
|
|||
msg = f"Embedding provider '{provider}' is not supported for retrieval."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def get_chroma_kb_data(self) -> DataFrame:
|
||||
async def get_chroma_kb_data(self) -> DataFrame:
|
||||
"""Retrieve data from the selected knowledge base by reading the Chroma collection.
|
||||
|
||||
Returns:
|
||||
A DataFrame containing the data rows from the knowledge base.
|
||||
"""
|
||||
kb_path = KNOWLEDGE_BASES_ROOT_PATH / self.knowledge_base
|
||||
# Get the current user
|
||||
async with session_scope() as db:
|
||||
if not self.user_id:
|
||||
msg = "User ID is required for fetching Knowledge Base data."
|
||||
raise ValueError(msg)
|
||||
current_user = await get_user_by_id(db, self.user_id)
|
||||
if not current_user:
|
||||
msg = f"User with ID {self.user_id} not found."
|
||||
raise ValueError(msg)
|
||||
kb_user = current_user.username
|
||||
kb_path = KNOWLEDGE_BASES_ROOT_PATH / kb_user / self.knowledge_base
|
||||
|
||||
metadata = self._get_kb_metadata(kb_path)
|
||||
if not metadata:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import orjson
|
||||
import pandas as pd
|
||||
|
|
@ -10,16 +9,12 @@ from fastapi.encoders import jsonable_encoder
|
|||
|
||||
from langflow.api.v2.files import upload_user_file
|
||||
from langflow.custom import Component
|
||||
from langflow.io import DropdownInput, HandleInput, SecretStrInput, StrInput
|
||||
from langflow.io import DropdownInput, HandleInput, StrInput
|
||||
from langflow.schema import Data, DataFrame, Message
|
||||
from langflow.services.auth.utils import create_user_longterm_token, get_current_user
|
||||
from langflow.services.database.models.user.crud import get_user_by_id
|
||||
from langflow.services.deps import get_session, get_settings_service, get_storage_service
|
||||
from langflow.services.deps import get_settings_service, get_storage_service, session_scope
|
||||
from langflow.template.field.base import Output
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
|
||||
class SaveToFileComponent(Component):
|
||||
display_name = "Save File"
|
||||
|
|
@ -55,13 +50,6 @@ class SaveToFileComponent(Component):
|
|||
value="",
|
||||
advanced=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
display_name="Langflow API Key",
|
||||
info="Langflow API key for authentication when saving the file.",
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [Output(display_name="File Path", name="message", method="save_to_file")]
|
||||
|
|
@ -148,25 +136,11 @@ class SaveToFileComponent(Component):
|
|||
raise FileNotFoundError(msg)
|
||||
|
||||
with file_path.open("rb") as f:
|
||||
async for db in get_session():
|
||||
# TODO: In 1.6, this may need to be removed or adjusted
|
||||
# Try to get the super user token, if possible
|
||||
current_user: User | None = None
|
||||
if self.api_key:
|
||||
current_user = await get_current_user(
|
||||
token="",
|
||||
query_param=self.api_key,
|
||||
header_param="",
|
||||
db=db,
|
||||
)
|
||||
else:
|
||||
user_id, _ = await create_user_longterm_token(db)
|
||||
current_user = await get_user_by_id(db, user_id)
|
||||
|
||||
# Fail if the user is not found
|
||||
if not current_user:
|
||||
msg = "User not found. Please provide a valid API key or ensure the user exists."
|
||||
async with session_scope() as db:
|
||||
if not self.user_id:
|
||||
msg = "User ID is required for file saving."
|
||||
raise ValueError(msg)
|
||||
current_user = await get_user_by_id(db, self.user_id)
|
||||
|
||||
await upload_user_file(
|
||||
file=UploadFile(filename=file_path.name, file=f, size=file_path.stat().st_size),
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from langflow.base.data.kb_utils import get_knowledge_bases
|
||||
from langflow.components.data.kb_ingest import KBIngestionComponent
|
||||
from langflow.schema.data import Data
|
||||
|
||||
|
|
@ -21,8 +23,43 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
with patch("langflow.components.data.kb_ingest.KNOWLEDGE_BASES_ROOT_PATH", tmp_path):
|
||||
yield
|
||||
|
||||
class MockUser:
|
||||
def __init__(self, user_id):
|
||||
self.id = user_id
|
||||
self.username = "langflow"
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self, tmp_path):
|
||||
def mock_user_data(self):
|
||||
"""Create mock user data that persists for the test function."""
|
||||
mock_uuid = uuid.uuid4()
|
||||
mock_user = self.MockUser(mock_uuid)
|
||||
return {"user_id": mock_uuid, "user": mock_user.username, "user_obj": mock_user}
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, mock_user_data):
|
||||
"""Mock the component's user_id attribute and User object."""
|
||||
with (
|
||||
patch.object(KBIngestionComponent, "user_id", mock_user_data["user_id"]),
|
||||
patch(
|
||||
"langflow.components.data.kb_ingest.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_data["user_obj"],
|
||||
),
|
||||
patch(
|
||||
"langflow.base.data.kb_utils.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_data["user_obj"],
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self, mock_user_data):
|
||||
"""Get the mock user data."""
|
||||
return {"user_id": mock_user_data["user_id"], "user": mock_user_data["user"]}
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self, tmp_path, mock_user_id):
|
||||
"""Return default kwargs for component instantiation."""
|
||||
# Create a sample DataFrame
|
||||
data_df = pd.DataFrame(
|
||||
|
|
@ -38,8 +75,8 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
# Create knowledge base directory
|
||||
kb_name = "test_kb"
|
||||
kb_path = tmp_path / kb_name
|
||||
kb_path.mkdir(exist_ok=True)
|
||||
kb_path = tmp_path / mock_user_id["user"] / kb_name
|
||||
kb_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create embedding metadata file
|
||||
metadata = {
|
||||
|
|
@ -206,7 +243,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
assert "text" in metadata["summary"]["vectorized_columns"]
|
||||
assert "category" in metadata["summary"]["identifier_columns"]
|
||||
|
||||
def test_convert_df_to_data_objects(self, component_class, default_kwargs):
|
||||
async def test_convert_df_to_data_objects(self, component_class, default_kwargs):
|
||||
"""Test converting DataFrame to Data objects."""
|
||||
component = component_class(**default_kwargs)
|
||||
data_df = default_kwargs["input_df"]
|
||||
|
|
@ -218,7 +255,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
mock_chroma_instance.get.return_value = {"metadatas": []}
|
||||
mock_chroma.return_value = mock_chroma_instance
|
||||
|
||||
data_objects = component._convert_df_to_data_objects(data_df, config_list)
|
||||
data_objects = await component._convert_df_to_data_objects(data_df, config_list)
|
||||
|
||||
assert len(data_objects) == 2
|
||||
assert all(isinstance(obj, Data) for obj in data_objects)
|
||||
|
|
@ -230,7 +267,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
assert "category" in first_obj.data
|
||||
assert "_id" in first_obj.data
|
||||
|
||||
def test_convert_df_to_data_objects_no_duplicates(self, component_class, default_kwargs):
|
||||
async def test_convert_df_to_data_objects_no_duplicates(self, component_class, default_kwargs):
|
||||
"""Test converting DataFrame to Data objects with duplicate prevention."""
|
||||
default_kwargs["allow_duplicates"] = False
|
||||
component = component_class(**default_kwargs)
|
||||
|
|
@ -251,7 +288,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
mock_hash_obj.hexdigest.side_effect = [existing_hash, "different_hash"]
|
||||
mock_hash.return_value = mock_hash_obj
|
||||
|
||||
data_objects = component._convert_df_to_data_objects(data_df, config_list)
|
||||
data_objects = await component._convert_df_to_data_objects(data_df, config_list)
|
||||
|
||||
# Should only return one object (second row) since first is duplicate
|
||||
assert len(data_objects) == 1
|
||||
|
|
@ -274,7 +311,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
@patch("langflow.components.data.kb_ingest.json.loads")
|
||||
@patch("langflow.components.data.kb_ingest.decrypt_api_key")
|
||||
def test_build_kb_info_success(self, mock_decrypt, mock_json_loads, component_class, default_kwargs):
|
||||
async def test_build_kb_info_success(self, mock_decrypt, mock_json_loads, component_class, default_kwargs):
|
||||
"""Test successful KB info building."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
|
|
@ -287,7 +324,7 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
# Mock vector store creation
|
||||
with patch.object(component, "_create_vector_store"), patch.object(component, "_save_kb_files"):
|
||||
result = component.build_kb_info()
|
||||
result = await component.build_kb_info()
|
||||
|
||||
assert isinstance(result, Data)
|
||||
assert "kb_id" in result.data
|
||||
|
|
@ -295,32 +332,21 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
assert "rows" in result.data
|
||||
assert result.data["rows"] == 2
|
||||
|
||||
def test_get_knowledge_bases(self, component_class, default_kwargs, tmp_path):
|
||||
async def test_get_knowledge_bases(self, tmp_path, mock_user_id):
|
||||
"""Test getting list of knowledge bases."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Create additional test directories
|
||||
(tmp_path / "kb1").mkdir()
|
||||
(tmp_path / "kb2").mkdir()
|
||||
(tmp_path / ".hidden").mkdir() # Should be ignored
|
||||
(tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / mock_user_id["user"] / ".hidden").mkdir(parents=True, exist_ok=True) # Should be ignored
|
||||
|
||||
kb_list = component._get_knowledge_bases()
|
||||
kb_list = await get_knowledge_bases(tmp_path, user_id=mock_user_id["user_id"])
|
||||
|
||||
assert "test_kb" in kb_list
|
||||
assert "kb1" in kb_list
|
||||
assert "kb2" in kb_list
|
||||
assert ".hidden" not in kb_list
|
||||
|
||||
@patch("langflow.components.data.kb_ingest.Path.exists")
|
||||
def test_get_knowledge_bases_no_path(self, mock_exists, component_class, default_kwargs):
|
||||
"""Test getting knowledge bases when path doesn't exist."""
|
||||
component = component_class(**default_kwargs)
|
||||
mock_exists.return_value = False
|
||||
|
||||
kb_list = component._get_knowledge_bases()
|
||||
assert kb_list == []
|
||||
|
||||
def test_update_build_config_new_kb(self, component_class, default_kwargs):
|
||||
async def test_update_build_config_new_kb(self, component_class, default_kwargs):
|
||||
"""Test updating build config for new knowledge base creation."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
|
|
@ -329,26 +355,24 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
field_value = {
|
||||
"01_new_kb_name": "new_test_kb",
|
||||
"02_embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"03_api_key": None,
|
||||
"03_api_key": "abc123", # Mock API key
|
||||
}
|
||||
|
||||
# Mock embedding validation
|
||||
with (
|
||||
patch.object(component, "_build_embeddings") as mock_build_emb,
|
||||
patch.object(component, "_save_embedding_metadata"),
|
||||
patch.object(component, "_get_knowledge_bases") as mock_get_kbs,
|
||||
):
|
||||
mock_embeddings = MagicMock()
|
||||
mock_embeddings.embed_query.return_value = [0.1, 0.2, 0.3]
|
||||
mock_build_emb.return_value = mock_embeddings
|
||||
mock_get_kbs.return_value = ["new_test_kb"]
|
||||
|
||||
result = component.update_build_config(build_config, field_value, "knowledge_base")
|
||||
result = await component.update_build_config(build_config, field_value, "knowledge_base")
|
||||
|
||||
assert result["knowledge_base"]["value"] == "new_test_kb"
|
||||
assert "new_test_kb" in result["knowledge_base"]["options"]
|
||||
|
||||
def test_update_build_config_invalid_kb_name(self, component_class, default_kwargs):
|
||||
async def test_update_build_config_invalid_kb_name(self, component_class, default_kwargs):
|
||||
"""Test updating build config with invalid KB name."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
|
|
@ -360,4 +384,4 @@ class TestKBIngestionComponent(ComponentTestBaseWithoutClient):
|
|||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid knowledge base name"):
|
||||
component.update_build_config(build_config, field_value, "knowledge_base")
|
||||
await component.update_build_config(build_config, field_value, "knowledge_base")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import contextlib
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.base.data.kb_utils import get_knowledge_bases
|
||||
from langflow.components.data.kb_retrieval import KBRetrievalComponent
|
||||
from pydantic import SecretStr
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
|
@ -21,13 +24,48 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
with patch("langflow.components.data.kb_retrieval.KNOWLEDGE_BASES_ROOT_PATH", tmp_path):
|
||||
yield
|
||||
|
||||
class MockUser:
|
||||
def __init__(self, user_id):
|
||||
self.id = user_id
|
||||
self.username = "langflow"
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self, tmp_path):
|
||||
def mock_user_data(self):
|
||||
"""Create mock user data that persists for the test function."""
|
||||
mock_uuid = uuid.uuid4()
|
||||
mock_user = self.MockUser(mock_uuid)
|
||||
return {"user_id": mock_uuid, "user": mock_user.username, "user_obj": mock_user}
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, mock_user_data):
|
||||
"""Mock the component's user_id attribute and User object."""
|
||||
with (
|
||||
patch.object(KBRetrievalComponent, "user_id", mock_user_data["user_id"]),
|
||||
patch(
|
||||
"langflow.components.data.kb_retrieval.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_data["user_obj"],
|
||||
),
|
||||
patch(
|
||||
"langflow.base.data.kb_utils.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_data["user_obj"],
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self, mock_user_data):
|
||||
"""Get the mock user data."""
|
||||
return {"user_id": mock_user_data["user_id"], "user": mock_user_data["user"]}
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self, tmp_path, mock_user_id):
|
||||
"""Return default kwargs for component instantiation."""
|
||||
# Create knowledge base directory structure
|
||||
kb_name = "test_kb"
|
||||
kb_path = tmp_path / kb_name
|
||||
kb_path.mkdir(exist_ok=True)
|
||||
kb_path = tmp_path / mock_user_id["user"] / kb_name
|
||||
kb_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create embedding metadata file
|
||||
metadata = {
|
||||
|
|
@ -55,61 +93,50 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
# This is a new component, so it doesn't exist in older versions
|
||||
return []
|
||||
|
||||
def test_get_knowledge_bases(self, component_class, default_kwargs, tmp_path):
|
||||
async def test_get_knowledge_bases(self, tmp_path, mock_user_id):
|
||||
"""Test getting list of knowledge bases."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Create additional test directories
|
||||
(tmp_path / "kb1").mkdir()
|
||||
(tmp_path / "kb2").mkdir()
|
||||
(tmp_path / ".hidden").mkdir() # Should be ignored
|
||||
(tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / mock_user_id["user"] / ".hidden").mkdir(parents=True, exist_ok=True) # Should be ignored
|
||||
|
||||
kb_list = component._get_knowledge_bases()
|
||||
kb_list = await get_knowledge_bases(tmp_path, user_id=mock_user_id["user_id"])
|
||||
|
||||
assert "test_kb" in kb_list
|
||||
assert "kb1" in kb_list
|
||||
assert "kb2" in kb_list
|
||||
assert ".hidden" not in kb_list
|
||||
|
||||
@patch("langflow.components.data.kb_retrieval.Path.exists")
|
||||
def test_get_knowledge_bases_no_path(self, mock_exists, component_class, default_kwargs):
|
||||
"""Test getting knowledge bases when path doesn't exist."""
|
||||
component = component_class(**default_kwargs)
|
||||
mock_exists.return_value = False
|
||||
|
||||
kb_list = component._get_knowledge_bases()
|
||||
assert kb_list == []
|
||||
|
||||
def test_update_build_config(self, component_class, default_kwargs, tmp_path):
|
||||
async def test_update_build_config(self, component_class, default_kwargs, tmp_path, mock_user_id):
|
||||
"""Test updating build configuration."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Create additional KB directories
|
||||
(tmp_path / "kb1").mkdir()
|
||||
(tmp_path / "kb2").mkdir()
|
||||
(tmp_path / mock_user_id["user"] / "kb1").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / mock_user_id["user"] / "kb2").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
build_config = {"knowledge_base": {"value": "test_kb", "options": []}}
|
||||
|
||||
result = component.update_build_config(build_config, None, "knowledge_base")
|
||||
result = await component.update_build_config(build_config, None, "knowledge_base")
|
||||
|
||||
assert "test_kb" in result["knowledge_base"]["options"]
|
||||
assert "kb1" in result["knowledge_base"]["options"]
|
||||
assert "kb2" in result["knowledge_base"]["options"]
|
||||
|
||||
def test_update_build_config_invalid_kb(self, component_class, default_kwargs):
|
||||
async def test_update_build_config_invalid_kb(self, component_class, default_kwargs):
|
||||
"""Test updating build config when selected KB is not available."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
build_config = {"knowledge_base": {"value": "nonexistent_kb", "options": ["test_kb"]}}
|
||||
|
||||
result = component.update_build_config(build_config, None, "knowledge_base")
|
||||
result = await component.update_build_config(build_config, None, "knowledge_base")
|
||||
|
||||
assert result["knowledge_base"]["value"] is None
|
||||
|
||||
def test_get_kb_metadata_success(self, component_class, default_kwargs):
|
||||
def test_get_kb_metadata_success(self, component_class, default_kwargs, mock_user_id):
|
||||
"""Test successful metadata loading."""
|
||||
component = component_class(**default_kwargs)
|
||||
kb_path = Path(default_kwargs["kb_root_path"]) / default_kwargs["knowledge_base"]
|
||||
kb_path = Path(default_kwargs["kb_root_path"]) / mock_user_id["user"] / default_kwargs["knowledge_base"]
|
||||
|
||||
with patch("langflow.components.data.kb_retrieval.decrypt_api_key") as mock_decrypt:
|
||||
mock_decrypt.return_value = "decrypted_key"
|
||||
|
|
@ -120,21 +147,21 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
assert metadata["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert "chunk_size" in metadata
|
||||
|
||||
def test_get_kb_metadata_no_file(self, component_class, default_kwargs, tmp_path):
|
||||
def test_get_kb_metadata_no_file(self, component_class, default_kwargs, tmp_path, mock_user_id):
|
||||
"""Test metadata loading when file doesn't exist."""
|
||||
component = component_class(**default_kwargs)
|
||||
nonexistent_path = tmp_path / "nonexistent"
|
||||
nonexistent_path.mkdir()
|
||||
nonexistent_path = tmp_path / mock_user_id["user"] / "nonexistent"
|
||||
nonexistent_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata = component._get_kb_metadata(nonexistent_path)
|
||||
|
||||
assert metadata == {}
|
||||
|
||||
def test_get_kb_metadata_json_error(self, component_class, default_kwargs, tmp_path):
|
||||
def test_get_kb_metadata_json_error(self, component_class, default_kwargs, tmp_path, mock_user_id):
|
||||
"""Test metadata loading with invalid JSON."""
|
||||
component = component_class(**default_kwargs)
|
||||
kb_path = tmp_path / "invalid_json_kb"
|
||||
kb_path.mkdir()
|
||||
kb_path = tmp_path / mock_user_id["user"] / "invalid_json_kb"
|
||||
kb_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create invalid JSON file
|
||||
(kb_path / "embedding_metadata.json").write_text("invalid json content")
|
||||
|
|
@ -143,11 +170,11 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
assert metadata == {}
|
||||
|
||||
def test_get_kb_metadata_decrypt_error(self, component_class, default_kwargs, tmp_path):
|
||||
def test_get_kb_metadata_decrypt_error(self, component_class, default_kwargs, tmp_path, mock_user_id):
|
||||
"""Test metadata loading with decryption error."""
|
||||
component = component_class(**default_kwargs)
|
||||
kb_path = tmp_path / "decrypt_error_kb"
|
||||
kb_path.mkdir()
|
||||
kb_path = tmp_path / mock_user_id["user"] / "decrypt_error_kb"
|
||||
kb_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create metadata with encrypted key
|
||||
metadata = {
|
||||
|
|
@ -274,10 +301,8 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
def test_build_embeddings_with_user_api_key(self, component_class, default_kwargs):
|
||||
"""Test that user-provided API key overrides stored one."""
|
||||
# Create a mock secret input
|
||||
|
||||
mock_secret = MagicMock()
|
||||
mock_secret.get_secret_value.return_value = "user-provided-key"
|
||||
# Use a real SecretStr object instead of a mock
|
||||
mock_secret = SecretStr("user-provided-key")
|
||||
|
||||
default_kwargs["api_key"] = mock_secret
|
||||
component = component_class(**default_kwargs)
|
||||
|
|
@ -285,7 +310,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
metadata = {
|
||||
"embedding_provider": "OpenAI",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"api_key": "stored-key",
|
||||
"api_key": "stored-key", # This should be overridden by the user-provided key
|
||||
"chunk_size": 1000,
|
||||
}
|
||||
|
||||
|
|
@ -295,14 +320,17 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
component._build_embeddings(metadata)
|
||||
|
||||
# The user-provided key should override the stored key in metadata
|
||||
mock_openai.assert_called_once_with(
|
||||
model="text-embedding-ada-002", api_key="user-provided-key", chunk_size=1000
|
||||
model="text-embedding-ada-002",
|
||||
api_key="user-provided-key", # Should use the user-provided key, not "stored-key"
|
||||
chunk_size=1000,
|
||||
)
|
||||
|
||||
def test_get_chroma_kb_data_no_metadata(self, component_class, default_kwargs, tmp_path):
|
||||
async def test_get_chroma_kb_data_no_metadata(self, component_class, default_kwargs, tmp_path, mock_user_id):
|
||||
"""Test retrieving data when metadata is missing."""
|
||||
# Remove metadata file
|
||||
kb_path = tmp_path / default_kwargs["knowledge_base"]
|
||||
kb_path = tmp_path / mock_user_id["user"] / default_kwargs["knowledge_base"]
|
||||
metadata_file = kb_path / "embedding_metadata.json"
|
||||
if metadata_file.exists():
|
||||
metadata_file.unlink()
|
||||
|
|
@ -310,7 +338,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
component = component_class(**default_kwargs)
|
||||
|
||||
with pytest.raises(ValueError, match="Metadata not found for knowledge base"):
|
||||
component.get_chroma_kb_data()
|
||||
await component.get_chroma_kb_data()
|
||||
|
||||
def test_get_chroma_kb_data_path_construction(self, component_class, default_kwargs):
|
||||
"""Test that get_chroma_kb_data constructs the correct paths."""
|
||||
|
|
@ -331,7 +359,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
assert hasattr(component, "top_k")
|
||||
assert hasattr(component, "include_embeddings")
|
||||
|
||||
def test_get_chroma_kb_data_method_exists(self, component_class, default_kwargs):
|
||||
async def test_get_chroma_kb_data_method_exists(self, component_class, default_kwargs):
|
||||
"""Test that get_chroma_kb_data method exists and can be called."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
|
|
@ -349,7 +377,7 @@ class TestKBRetrievalComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
# This is a unit test focused on the component's internal logic
|
||||
with contextlib.suppress(Exception):
|
||||
component.get_chroma_kb_data()
|
||||
await component.get_chroma_kb_data()
|
||||
|
||||
# Verify internal methods were called
|
||||
mock_get_metadata.assert_called_once()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue