diff --git a/.github/workflows/py_autofix.yml b/.github/workflows/py_autofix.yml index 25d7e4dab..453b78119 100644 --- a/.github/workflows/py_autofix.yml +++ b/.github/workflows/py_autofix.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@v4 - name: "Setup Environment" uses: ./.github/actions/setup-uv - - run: uv run ruff check --fix-only . --ignore A005 + - run: uv run ruff check --fix-only . - run: uv run ruff format . --config pyproject.toml - uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef - name: Minimize uv cache diff --git a/.github/workflows/style-check-py.yml b/.github/workflows/style-check-py.yml index 0157df43f..e997e07a7 100644 --- a/.github/workflows/style-check-py.yml +++ b/.github/workflows/style-check-py.yml @@ -24,6 +24,6 @@ jobs: - name: Register problem matcher run: echo "::add-matcher::.github/workflows/matchers/ruff.json" - name: Run Ruff Check - run: uv run --only-dev ruff check --output-format=github . --ignore A005 + run: uv run --only-dev ruff check --output-format=github . - name: Minimize uv cache run: uv cache prune --ci diff --git a/Makefile b/Makefile index 2e76dfa6c..657258ab8 100644 --- a/Makefile +++ b/Makefile @@ -198,7 +198,7 @@ fix_codespell: ## run codespell to fix spelling errors poetry run codespell --toml pyproject.toml --write format_backend: ## backend code formatters - @uv run ruff check . --fix --ignore EXE002 --ignore A005 + @uv run ruff check . --fix @uv run ruff format . --config pyproject.toml format_frontend: ## frontend code formatters diff --git a/pyproject.toml b/pyproject.toml index 7cdc279a8..d9e21462c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ dev = [ "types-redis>=4.6.0.5", "ipykernel>=6.29.0", "mypy>=1.11.0", - "ruff>=0.9.1,<0.10", + "ruff>=0.9.7,<0.10", "httpx>=0.27.0", "pytest>=8.2.0", "types-requests>=2.32.0", diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 3a6ff403a..7edab2735 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -580,14 +580,14 @@ async def experimental_run_flow( @router.post( - "/predict/{flow_id}", + "/predict/{_flow_id}", dependencies=[Depends(api_key_security)], ) @router.post( - "/process/{flow_id}", + "/process/{_flow_id}", dependencies=[Depends(api_key_security)], ) -async def process() -> None: +async def process(_flow_id) -> None: """Endpoint to process an input with a given flow_id.""" # Raise a depreciation warning logger.warning( diff --git a/src/backend/base/langflow/base/io/__init__.py b/src/backend/base/langflow/base/io/__init__.py index e69de29bb..dc9fd4c06 100644 --- a/src/backend/base/langflow/base/io/__init__.py +++ b/src/backend/base/langflow/base/io/__init__.py @@ -0,0 +1 @@ +# noqa: A005 diff --git a/src/backend/base/langflow/components/apify/apify_actor.py b/src/backend/base/langflow/components/apify/apify_actor.py index 7d6c6414a..1eb17051d 100644 --- a/src/backend/base/langflow/components/apify/apify_actor.py +++ b/src/backend/base/langflow/components/apify/apify_actor.py @@ -89,9 +89,9 @@ class ApifyActorsComponent(Component): def run_model(self) -> list[Data]: """Run the Actor and return node output.""" - _input = json.loads(self.run_input) + input_ = json.loads(self.run_input) fields = ApifyActorsComponent.parse_dataset_fields(self.dataset_fields) if self.dataset_fields else None - res = self._run_actor(self.actor_id, _input, fields=fields) + res = self._run_actor(self.actor_id, input_, fields=fields) if self.flatten_dataset: res = [ApifyActorsComponent.flatten(item) for item in res] data = [Data(data=item) for item in res] @@ -113,16 +113,16 @@ class ApifyActorsComponent(Component): properties = {"run_input": properties} # works from input schema - _info = [ + info_ = [ ( "JSON encoded as a string with input schema (STRICTLY FOLLOW JSON FORMAT AND SCHEMA):\n\n" f"{json.dumps(properties, separators=(',', ':'))}" ) ] if required: - _info.append("\n\nRequired fields:\n" + "\n".join(required)) + info_.append("\n\nRequired fields:\n" + "\n".join(required)) - info = "".join(_info) + info = "".join(info_) input_model_cls = ApifyActorsComponent.create_input_model_class(info) tool_cls = ApifyActorsComponent.create_tool_class(self, readme, input_model_cls, actor_id) diff --git a/src/backend/base/langflow/components/composio/composio_api.py b/src/backend/base/langflow/components/composio/composio_api.py index 8dfd7c249..fcd69e065 100644 --- a/src/backend/base/langflow/components/composio/composio_api.py +++ b/src/backend/base/langflow/components/composio/composio_api.py @@ -158,7 +158,7 @@ class ComposioAPIComponent(LCToolComponent): for item in data.get("items", []): for auth_scheme in item.get("auth_schemes", []): - if auth_scheme.get("mode") in ["OAUTH1", "OAUTH2"]: + if auth_scheme.get("mode") in {"OAUTH1", "OAUTH2"}: oauth_apps.append(item["key"].upper()) break except requests.RequestException as e: diff --git a/src/backend/base/langflow/components/data/api_request.py b/src/backend/base/langflow/components/data/api_request.py index 27d691725..56dcd1351 100644 --- a/src/backend/base/langflow/components/data/api_request.py +++ b/src/backend/base/langflow/components/data/api_request.py @@ -321,7 +321,7 @@ class APIRequestComponent(Component): elif field_name == "curl": field_config["advanced"] = not use_curl field_config["real_time_refresh"] = use_curl - elif field_name in ["body", "headers"]: + elif field_name in {"body", "headers"}: field_config["advanced"] = True # Always keep body and headers in advanced when use_curl is False else: field_config["advanced"] = use_curl @@ -359,7 +359,7 @@ class APIRequestComponent(Component): if field_name in common_fields: field_config["advanced"] = False elif field_name in body_fields: - field_config["advanced"] = method not in ["POST", "PUT", "PATCH"] + field_config["advanced"] = method not in {"POST", "PUT", "PATCH"} elif field_name in always_advanced_fields: field_config["advanced"] = True else: diff --git a/src/backend/base/langflow/components/helpers/batch_run.py b/src/backend/base/langflow/components/helpers/batch_run.py index 5aeb509a1..7e6468faa 100644 --- a/src/backend/base/langflow/components/helpers/batch_run.py +++ b/src/backend/base/langflow/components/helpers/batch_run.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from typing import TYPE_CHECKING, Any from loguru import logger @@ -165,7 +166,7 @@ class BatchRunComponent(Component): ] # Sort by index to maintain order - responses_with_idx.sort(key=lambda x: x[0]) + responses_with_idx.sort(key=operator.itemgetter(0)) # Build the final data with enhanced metadata rows: list[dict[str, Any]] = [] diff --git a/src/backend/base/langflow/components/models/anthropic.py b/src/backend/base/langflow/components/models/anthropic.py index 3a963ccc8..feb1514af 100644 --- a/src/backend/base/langflow/components/models/anthropic.py +++ b/src/backend/base/langflow/components/models/anthropic.py @@ -138,7 +138,7 @@ class AnthropicModelComponent(LCModelComponent): return None def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): - if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value: + if field_name in {"base_url", "model_name", "tool_model_enabled", "api_key"} and field_value: try: if len(self.api_key) == 0: ids = ANTHROPIC_MODELS diff --git a/src/backend/base/langflow/components/models/google_generative_ai.py b/src/backend/base/langflow/components/models/google_generative_ai.py index 51a1f6762..cd8fa208b 100644 --- a/src/backend/base/langflow/components/models/google_generative_ai.py +++ b/src/backend/base/langflow/components/models/google_generative_ai.py @@ -129,7 +129,7 @@ class GoogleGenerativeAIComponent(LCModelComponent): return model_ids def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): - if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value: + if field_name in {"base_url", "model_name", "tool_model_enabled", "api_key"} and field_value: try: if len(self.api_key) == 0: ids = GOOGLE_GENERATIVE_AI_MODELS diff --git a/src/backend/base/langflow/components/models/groq.py b/src/backend/base/langflow/components/models/groq.py index ebaa27228..5d94df090 100644 --- a/src/backend/base/langflow/components/models/groq.py +++ b/src/backend/base/langflow/components/models/groq.py @@ -105,7 +105,7 @@ class GroqModel(LCModelComponent): return model_ids def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): - if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value: + if field_name in {"base_url", "model_name", "tool_model_enabled", "api_key"} and field_value: try: if len(self.api_key) != 0: try: diff --git a/src/backend/base/langflow/components/models/nvidia.py b/src/backend/base/langflow/components/models/nvidia.py index 3b3063559..132a4df0c 100644 --- a/src/backend/base/langflow/components/models/nvidia.py +++ b/src/backend/base/langflow/components/models/nvidia.py @@ -78,7 +78,7 @@ class NVIDIAModelComponent(LCModelComponent): return [model.id for model in build_model.available_models] def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): - if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value: + if field_name in {"base_url", "model_name", "tool_model_enabled", "api_key"} and field_value: try: ids = self.get_models(self.tool_model_enabled) build_config["model_name"]["options"] = ids diff --git a/src/backend/base/langflow/components/nvidia/nvidia_ingest.py b/src/backend/base/langflow/components/nvidia/nvidia_ingest.py index 158339186..6f0c802a7 100644 --- a/src/backend/base/langflow/components/nvidia/nvidia_ingest.py +++ b/src/backend/base/langflow/components/nvidia/nvidia_ingest.py @@ -231,5 +231,5 @@ class NvidiaIngestComponent(Component): # image is not yet supported; skip if encountered self.log(f"Unsupported document type: {document_type}", name="NVIDIAIngestComponent") - self.status = data if data else "No data" + self.status = data or "No data" return data diff --git a/src/backend/base/langflow/components/processing/save_to_file.py b/src/backend/base/langflow/components/processing/save_to_file.py index 9494c185d..d595d74bf 100644 --- a/src/backend/base/langflow/components/processing/save_to_file.py +++ b/src/backend/base/langflow/components/processing/save_to_file.py @@ -87,7 +87,7 @@ class SaveToFileComponent(Component): build_config["data"]["show"] = field_value == "Data" build_config["message"]["show"] = field_value == "Message" - if field_value in ["DataFrame", "Data"]: + if field_value in {"DataFrame", "Data"}: build_config["file_format"]["options"] = self.DATA_FORMAT_CHOICES elif field_value == "Message": build_config["file_format"]["options"] = self.MESSAGE_FORMAT_CHOICES diff --git a/src/backend/base/langflow/components/tools/arxiv.py b/src/backend/base/langflow/components/tools/arxiv.py index 8f5cdf63d..dde27e4f3 100644 --- a/src/backend/base/langflow/components/tools/arxiv.py +++ b/src/backend/base/langflow/components/tools/arxiv.py @@ -112,7 +112,7 @@ class ArXivComponent(Component): # Validate URL scheme and host parsed_url = urlparse(url) - if parsed_url.scheme not in ("http", "https"): + if parsed_url.scheme not in {"http", "https"}: error_msg = f"Invalid URL scheme: {parsed_url.scheme}" raise ValueError(error_msg) if parsed_url.hostname != "export.arxiv.org": diff --git a/src/backend/base/langflow/components/vectorstores/astradb.py b/src/backend/base/langflow/components/vectorstores/astradb.py index b254c233d..0f7b16491 100644 --- a/src/backend/base/langflow/components/vectorstores/astradb.py +++ b/src/backend/base/langflow/components/vectorstores/astradb.py @@ -571,7 +571,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): # Go over each possible provider and add metadata to configure in Astra DB Portal for provider in provider_options: # Skip Bring your own and Nvidia, automatically configured - if provider in ["Bring your own", "Nvidia"]: + if provider in {"Bring your own", "Nvidia"}: build_config["collection_name"]["dialog_inputs"]["fields"]["data"]["node"]["template"][ "embedding_generation_provider" ]["options_metadata"].append({"icon": self.get_provider_icon(provider_name=provider.lower())}) @@ -601,7 +601,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): # If we retrieved options based on the token, show the dropdown build_config["collection_name"]["options"] = [col["name"] for col in collection_options] build_config["collection_name"]["options_metadata"] = [ - {k: v for k, v in col.items() if k not in ["name"]} for col in collection_options + {k: v for k, v in col.items() if k != "name"} for col in collection_options ] # Reset the selected collection @@ -620,7 +620,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): # If we retrieved options based on the token, show the dropdown build_config["database_name"]["options"] = [db["name"] for db in database_options] build_config["database_name"]["options_metadata"] = [ - {k: v for k, v in db.items() if k not in ["name"]} for db in database_options + {k: v for k, v in db.items() if k != "name"} for db in database_options ] # Reset the selected database @@ -667,12 +667,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): raise ValueError(msg) from e # Add the new database to the list of options - build_config["database_name"]["options"] = build_config["database_name"]["options"] + [ - field_value["new_database_name"] - ] - build_config["database_name"]["options_metadata"] = build_config["database_name"]["options_metadata"] + [ - {"status": "PENDING"} - ] + build_config["database_name"]["options"] += [field_value["new_database_name"]] + build_config["database_name"]["options_metadata"] += [{"status": "PENDING"}] return self.reset_collection_list(build_config) @@ -726,9 +722,9 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): # Add the new collection to the list of options icon = "NVIDIA" if provider == "Nvidia" else "vectorstores" - build_config["collection_name"]["options_metadata"] = build_config["collection_name"][ - "options_metadata" - ] + [{"records": 0, "provider": provider, "icon": icon, "model": model}] + build_config["collection_name"]["options_metadata"] += [ + {"records": 0, "provider": provider, "icon": icon, "model": model} + ] return build_config @@ -748,7 +744,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): return self.reset_build_config(build_config) # If this is the first execution of the component, reset and build database list - if first_run or field_name in ["token", "environment"]: + if first_run or field_name in {"token", "environment"}: return self.reset_database_list(build_config) # Refresh the collection name options @@ -790,7 +786,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent): # Add the new collection to the list of options build_config["collection_name"]["options"].append(field_value) build_config["collection_name"]["options_metadata"].append( - {"records": 0, "provider": None, "icon": "", "model": None} + { + "records": 0, + "provider": None, + "icon": "", + "model": None, + } ) # Ensure that autodetect collection is set to False, since its a new collection diff --git a/src/backend/base/langflow/components/youtube/video_details.py b/src/backend/base/langflow/components/youtube/video_details.py index 013d1d465..0e8332e92 100644 --- a/src/backend/base/langflow/components/youtube/video_details.py +++ b/src/backend/base/langflow/components/youtube/video_details.py @@ -230,7 +230,7 @@ class YouTubeVideoDetailsComponent(Component): thumb_cols = [col for col in video_df.columns if col.startswith("thumbnail_")] # Reorder columns based on what's included - ordered_cols = basic_cols[:] + ordered_cols = basic_cols.copy() if self.include_statistics: ordered_cols.extend([col for col in stat_cols if col in video_df.columns]) diff --git a/src/backend/base/langflow/graph/graph/utils.py b/src/backend/base/langflow/graph/graph/utils.py index 26650f343..d29dbc804 100644 --- a/src/backend/base/langflow/graph/graph/utils.py +++ b/src/backend/base/langflow/graph/graph/utils.py @@ -518,7 +518,7 @@ def layered_topological_sort( layers: list[list[str]] = [] visited = set() - cycle_counts = {vertex: 0 for vertex in vertices_ids} + cycle_counts = dict.fromkeys(vertices_ids, 0) current_layer = 0 # Process the first layer separately to avoid duplicates diff --git a/src/backend/base/langflow/helpers/flow.py b/src/backend/base/langflow/helpers/flow.py index c2b543769..6846c3428 100644 --- a/src/backend/base/langflow/helpers/flow.py +++ b/src/backend/base/langflow/helpers/flow.py @@ -322,7 +322,7 @@ def json_schema_from_flow(flow: Flow) -> dict: from langflow.graph.graph.base import Graph # Get the flow's data which contains the nodes and their configurations - flow_data = flow.data if flow.data else {} + flow_data = flow.data or {} graph = Graph.from_payload(flow_data) input_nodes = [vertex for vertex in graph.vertices if vertex.is_input] diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Custom Component Maker.json b/src/backend/base/langflow/initial_setup/starter_projects/Custom Component Maker.json index e3d3e1396..1dd44db62 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Custom Component Maker.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Custom Component Maker.json @@ -1972,7 +1972,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in (\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\") and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" + "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in {\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\"} and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" }, "input_value": { "_input_type": "MessageInput", diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Gmail Agent.json b/src/backend/base/langflow/initial_setup/starter_projects/Gmail Agent.json index dfecfc54b..d65ca00e1 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Gmail Agent.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Gmail Agent.json @@ -7,7 +7,7 @@ "data": { "sourceHandle": { "dataType": "ChatInput", - "id": "ChatInput-3YnjI", + "id": "ChatInput-fifot", "name": "message", "output_types": [ "Message" @@ -15,19 +15,19 @@ }, "targetHandle": { "fieldName": "input_value", - "id": "Agent-PH3eS", + "id": "Agent-5rqMu", "inputTypes": [ "Message" ], "type": "str" } }, - "id": "reactflow__edge-ChatInput-3YnjI{œdataTypeœ:œChatInputœ,œidœ:œChatInput-3YnjIœ,œnameœ:œmessageœ,œoutput_typesœ:[œMessageœ]}-Agent-PH3eS{œfieldNameœ:œinput_valueœ,œidœ:œAgent-PH3eSœ,œinputTypesœ:[œMessageœ],œtypeœ:œstrœ}", + "id": "reactflow__edge-ChatInput-fifot{œdataTypeœ:œChatInputœ,œidœ:œChatInput-fifotœ,œnameœ:œmessageœ,œoutput_typesœ:[œMessageœ]}-Agent-5rqMu{œfieldNameœ:œinput_valueœ,œidœ:œAgent-5rqMuœ,œinputTypesœ:[œMessageœ],œtypeœ:œstrœ}", "selected": false, - "source": "ChatInput-3YnjI", - "sourceHandle": "{œdataTypeœ: œChatInputœ, œidœ: œChatInput-3YnjIœ, œnameœ: œmessageœ, œoutput_typesœ: [œMessageœ]}", - "target": "Agent-PH3eS", - "targetHandle": "{œfieldNameœ: œinput_valueœ, œidœ: œAgent-PH3eSœ, œinputTypesœ: [œMessageœ], œtypeœ: œstrœ}" + "source": "ChatInput-fifot", + "sourceHandle": "{œdataTypeœ: œChatInputœ, œidœ: œChatInput-fifotœ, œnameœ: œmessageœ, œoutput_typesœ: [œMessageœ]}", + "target": "Agent-5rqMu", + "targetHandle": "{œfieldNameœ: œinput_valueœ, œidœ: œAgent-5rqMuœ, œinputTypesœ: [œMessageœ], œtypeœ: œstrœ}" }, { "animated": false, @@ -35,7 +35,7 @@ "data": { "sourceHandle": { "dataType": "Agent", - "id": "Agent-PH3eS", + "id": "Agent-5rqMu", "name": "response", "output_types": [ "Message" @@ -43,7 +43,7 @@ }, "targetHandle": { "fieldName": "input_value", - "id": "ChatOutput-iLQcv", + "id": "ChatOutput-mXpv2", "inputTypes": [ "Data", "DataFrame", @@ -52,19 +52,20 @@ "type": "str" } }, - "id": "reactflow__edge-Agent-PH3eS{œdataTypeœ:œAgentœ,œidœ:œAgent-PH3eSœ,œnameœ:œresponseœ,œoutput_typesœ:[œMessageœ]}-ChatOutput-iLQcv{œfieldNameœ:œinput_valueœ,œidœ:œChatOutput-iLQcvœ,œinputTypesœ:[œDataœ,œDataFrameœ,œMessageœ],œtypeœ:œstrœ}", + "id": "reactflow__edge-Agent-5rqMu{œdataTypeœ:œAgentœ,œidœ:œAgent-5rqMuœ,œnameœ:œresponseœ,œoutput_typesœ:[œMessageœ]}-ChatOutput-mXpv2{œfieldNameœ:œinput_valueœ,œidœ:œChatOutput-mXpv2œ,œinputTypesœ:[œDataœ,œDataFrameœ,œMessageœ],œtypeœ:œstrœ}", "selected": false, - "source": "Agent-PH3eS", - "sourceHandle": "{œdataTypeœ: œAgentœ, œidœ: œAgent-PH3eSœ, œnameœ: œresponseœ, œoutput_typesœ: [œMessageœ]}", - "target": "ChatOutput-iLQcv", - "targetHandle": "{œfieldNameœ: œinput_valueœ, œidœ: œChatOutput-iLQcvœ, œinputTypesœ: [œDataœ, œDataFrameœ, œMessageœ], œtypeœ: œstrœ}" + "source": "Agent-5rqMu", + "sourceHandle": "{œdataTypeœ: œAgentœ, œidœ: œAgent-5rqMuœ, œnameœ: œresponseœ, œoutput_typesœ: [œMessageœ]}", + "target": "ChatOutput-mXpv2", + "targetHandle": "{œfieldNameœ: œinput_valueœ, œidœ: œChatOutput-mXpv2œ, œinputTypesœ: [œDataœ, œDataFrameœ, œMessageœ], œtypeœ: œstrœ}" }, { + "animated": false, "className": "", "data": { "sourceHandle": { "dataType": "ComposioAPI", - "id": "ComposioAPI-adjCJ", + "id": "ComposioAPI-Z0Iiy", "name": "tools", "output_types": [ "Tool" @@ -72,24 +73,24 @@ }, "targetHandle": { "fieldName": "tools", - "id": "Agent-PH3eS", + "id": "Agent-5rqMu", "inputTypes": [ "Tool" ], "type": "other" } }, - "id": "reactflow__edge-ComposioAPI-adjCJ{œdataTypeœ:œComposioAPIœ,œidœ:œComposioAPI-adjCJœ,œnameœ:œtoolsœ,œoutput_typesœ:[œToolœ]}-Agent-PH3eS{œfieldNameœ:œtoolsœ,œidœ:œAgent-PH3eSœ,œinputTypesœ:[œToolœ],œtypeœ:œotherœ}", - "source": "ComposioAPI-adjCJ", - "sourceHandle": "{œdataTypeœ: œComposioAPIœ, œidœ: œComposioAPI-adjCJœ, œnameœ: œtoolsœ, œoutput_typesœ: [œToolœ]}", - "target": "Agent-PH3eS", - "targetHandle": "{œfieldNameœ: œtoolsœ, œidœ: œAgent-PH3eSœ, œinputTypesœ: [œToolœ], œtypeœ: œotherœ}" + "id": "reactflow__edge-ComposioAPI-Z0Iiy{œdataTypeœ:œComposioAPIœ,œidœ:œComposioAPI-Z0Iiyœ,œnameœ:œtoolsœ,œoutput_typesœ:[œToolœ]}-Agent-5rqMu{œfieldNameœ:œtoolsœ,œidœ:œAgent-5rqMuœ,œinputTypesœ:[œToolœ],œtypeœ:œotherœ}", + "source": "ComposioAPI-Z0Iiy", + "sourceHandle": "{œdataTypeœ: œComposioAPIœ, œidœ: œComposioAPI-Z0Iiyœ, œnameœ: œtoolsœ, œoutput_typesœ: [œToolœ]}", + "target": "Agent-5rqMu", + "targetHandle": "{œfieldNameœ: œtoolsœ, œidœ: œAgent-5rqMuœ, œinputTypesœ: [œToolœ], œtypeœ: œotherœ}" } ], "nodes": [ { "data": { - "id": "Agent-PH3eS", + "id": "Agent-5rqMu", "node": { "base_classes": [ "Message" @@ -134,7 +135,7 @@ "icon": "bot", "key": "Agent", "legacy": false, - "lf_version": "1.1.5", + "lf_version": "1.2.0", "metadata": {}, "minimized": false, "output_types": [], @@ -741,7 +742,7 @@ "type": "Agent" }, "dragging": false, - "id": "Agent-PH3eS", + "id": "Agent-5rqMu", "measured": { "height": 624, "width": 320 @@ -755,7 +756,7 @@ }, { "data": { - "id": "ChatInput-3YnjI", + "id": "ChatInput-fifot", "node": { "base_classes": [ "Message" @@ -783,7 +784,7 @@ "icon": "MessagesSquare", "key": "ChatInput", "legacy": false, - "lf_version": "1.1.5", + "lf_version": "1.2.0", "metadata": {}, "minimized": true, "output_types": [], @@ -1052,7 +1053,7 @@ "type": "ChatInput" }, "dragging": false, - "id": "ChatInput-3YnjI", + "id": "ChatInput-fifot", "measured": { "height": 66, "width": 192 @@ -1066,7 +1067,7 @@ }, { "data": { - "id": "ChatOutput-iLQcv", + "id": "ChatOutput-mXpv2", "node": { "base_classes": [ "Message" @@ -1364,7 +1365,7 @@ "type": "ChatOutput" }, "dragging": false, - "id": "ChatOutput-iLQcv", + "id": "ChatOutput-mXpv2", "measured": { "height": 66, "width": 192 @@ -1378,7 +1379,7 @@ }, { "data": { - "id": "note-5GU6o", + "id": "note-Oh8JB", "node": { "description": "# Gmail Agent\nUsing this flow you can send emails, create drafts, fetch emails and more\n\n## Instructions\n\n1. Get Composio API Key\n - Visit https://app.composio.dev\n - Enter the key in the \"Composio API Key\" field\n\n2. Authenticate Gmail Account\n - Select Gmail App from the dropdown menu in the App Names field\n - Click the refresh button next to the App Name\n - Follow the Gmail authentication link\n - After authenticating, click refresh again\n - Verify that authentication status shows as successful\n\n3. Select Actions\n - Default actions (pre-selected):\n - GMAIL_SEND_EMAIL: Send emails directly\n - GMAIL_CREATE_EMAIL_DRAFT: Create draft emails\n - Select additional actions based on your needs\n\n4. Configure OpenAI\n - Enter your OpenAI API key in the Agent OpenAI API key field\n\n5. Run Agent\n Example prompts:\n - \"Send an email to johndoe@gmail.com wishing them Happy birthday!\"\n - \"Create a draft email about project updates\"", "display_name": "", @@ -1389,7 +1390,7 @@ }, "dragging": false, "height": 842, - "id": "note-5GU6o", + "id": "note-Oh8JB", "measured": { "height": 842, "width": 395 @@ -1407,7 +1408,7 @@ "data": { "description": "Use Composio toolset to run actions with your agent", "display_name": "Composio Tools", - "id": "ComposioAPI-adjCJ", + "id": "ComposioAPI-Z0Iiy", "node": { "base_classes": [ "Tool" @@ -1432,6 +1433,7 @@ "frozen": false, "icon": "Composio", "legacy": false, + "lf_version": "1.2.0", "metadata": {}, "minimized": false, "output_types": [], @@ -1737,7 +1739,7 @@ "show": true, "title_case": false, "type": "code", - "value": "# Standard library imports\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport requests\n\n# Third-party imports\nfrom composio.client.collections import AppAuthScheme\nfrom composio.client.exceptions import NoItemsFound\nfrom composio_langchain import Action, ComposioToolSet\nfrom langchain_core.tools import Tool\nfrom loguru import logger\n\n# Local imports\nfrom langflow.base.langchain_utilities.model import LCToolComponent\nfrom langflow.inputs import DropdownInput, LinkInput, MessageTextInput, MultiselectInput, SecretStrInput, StrInput\nfrom langflow.io import Output\n\n\nclass ComposioAPIComponent(LCToolComponent):\n display_name: str = \"Composio Tools\"\n description: str = \"Use Composio toolset to run actions with your agent\"\n name = \"ComposioAPI\"\n icon = \"Composio\"\n documentation: str = \"https://docs.composio.dev\"\n\n inputs = [\n # Basic configuration inputs\n MessageTextInput(name=\"entity_id\", display_name=\"Entity ID\", value=\"default\", advanced=True),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Composio API Key\",\n required=True,\n info=\"Refer to https://docs.composio.dev/faq/api_key/api_key\",\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"app_names\",\n display_name=\"App Name\",\n options=[],\n value=\"\",\n info=\"The app name to use. Please refresh after selecting app name\",\n refresh_button=True,\n required=True,\n ),\n # Authentication-related inputs (initially hidden)\n SecretStrInput(\n name=\"app_credentials\",\n display_name=\"App Credentials\",\n required=False,\n dynamic=True,\n show=False,\n info=\"Credentials for app authentication (API Key, Password, etc)\",\n load_from_db=False,\n ),\n MessageTextInput(\n name=\"username\",\n display_name=\"Username\",\n required=False,\n dynamic=True,\n show=False,\n info=\"Username for Basic authentication\",\n ),\n LinkInput(\n name=\"auth_link\",\n display_name=\"Authentication Link\",\n value=\"\",\n info=\"Click to authenticate with OAuth2\",\n dynamic=True,\n show=False,\n placeholder=\"Click to authenticate\",\n ),\n StrInput(\n name=\"auth_status\",\n display_name=\"Auth Status\",\n value=\"Not Connected\",\n info=\"Current authentication status\",\n dynamic=True,\n show=False,\n ),\n MultiselectInput(\n name=\"action_names\",\n display_name=\"Actions to use\",\n required=True,\n options=[],\n value=[],\n info=\"The actions to pass to agent to execute\",\n dynamic=True,\n show=False,\n ),\n ]\n\n outputs = [\n Output(name=\"tools\", display_name=\"Tools\", method=\"build_tool\"),\n ]\n\n def _check_for_authorization(self, app: str) -> str:\n \"\"\"Checks if the app is authorized.\n\n Args:\n app (str): The app name to check authorization for.\n\n Returns:\n str: The authorization status or URL.\n \"\"\"\n toolset = self._build_wrapper()\n entity = toolset.client.get_entity(id=self.entity_id)\n try:\n # Check if user is already connected\n entity.get_connection(app=app)\n except NoItemsFound:\n # Get auth scheme for the app\n auth_scheme = self._get_auth_scheme(app)\n return self._handle_auth_by_scheme(entity, app, auth_scheme)\n except Exception: # noqa: BLE001\n logger.exception(\"Authorization error\")\n return \"Error checking authorization\"\n else:\n return f\"{app} CONNECTED\"\n\n def _get_auth_scheme(self, app_name: str) -> AppAuthScheme:\n \"\"\"Get the primary auth scheme for an app.\n\n Args:\n app_name (str): The name of the app to get auth scheme for.\n\n Returns:\n AppAuthScheme: The auth scheme details.\n \"\"\"\n toolset = self._build_wrapper()\n try:\n return toolset.get_auth_scheme_for_app(app=app_name.lower())\n except Exception: # noqa: BLE001\n logger.exception(f\"Error getting auth scheme for {app_name}\")\n return None\n\n def _get_oauth_apps(self, api_key: str) -> list[str]:\n \"\"\"Fetch OAuth-enabled apps from Composio API.\n\n Args:\n api_key (str): The Composio API key.\n\n Returns:\n list[str]: A list containing OAuth-enabled app names.\n \"\"\"\n oauth_apps = []\n try:\n url = \"https://backend.composio.dev/api/v1/apps\"\n headers = {\"x-api-key\": api_key}\n params = {\n \"includeLocal\": \"true\",\n \"additionalFields\": \"auth_schemes\",\n \"sortBy\": \"alphabet\",\n }\n\n response = requests.get(url, headers=headers, params=params, timeout=20)\n data = response.json()\n\n for item in data.get(\"items\", []):\n for auth_scheme in item.get(\"auth_schemes\", []):\n if auth_scheme.get(\"mode\") in [\"OAUTH1\", \"OAUTH2\"]:\n oauth_apps.append(item[\"key\"].upper())\n break\n except requests.RequestException as e:\n logger.error(f\"Error fetching OAuth apps: {e}\")\n return []\n else:\n return oauth_apps\n\n def _handle_auth_by_scheme(self, entity: Any, app: str, auth_scheme: AppAuthScheme) -> str:\n \"\"\"Handle authentication based on the auth scheme.\n\n Args:\n entity (Any): The entity instance.\n app (str): The app name.\n auth_scheme (AppAuthScheme): The auth scheme details.\n\n Returns:\n str: The authentication status or URL.\n \"\"\"\n auth_mode = auth_scheme.auth_mode\n\n try:\n # First check if already connected\n entity.get_connection(app=app)\n except NoItemsFound:\n # If not connected, handle new connection based on auth mode\n if auth_mode == \"API_KEY\":\n if hasattr(self, \"app_credentials\") and self.app_credentials:\n try:\n entity.initiate_connection(\n app_name=app,\n auth_mode=\"API_KEY\",\n auth_config={\"api_key\": self.app_credentials},\n use_composio_auth=False,\n force_new_integration=True,\n )\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error connecting with API Key: {e}\")\n return \"Invalid API Key\"\n else:\n return f\"{app} CONNECTED\"\n return \"Enter API Key\"\n\n if (\n auth_mode == \"BASIC\"\n and hasattr(self, \"username\")\n and hasattr(self, \"app_credentials\")\n and self.username\n and self.app_credentials\n ):\n try:\n entity.initiate_connection(\n app_name=app,\n auth_mode=\"BASIC\",\n auth_config={\"username\": self.username, \"password\": self.app_credentials},\n use_composio_auth=False,\n force_new_integration=True,\n )\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error connecting with Basic Auth: {e}\")\n return \"Invalid credentials\"\n else:\n return f\"{app} CONNECTED\"\n elif auth_mode == \"BASIC\":\n return \"Enter Username and Password\"\n\n if auth_mode == \"OAUTH2\":\n try:\n return self._initiate_default_connection(entity, app)\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error initiating OAuth2: {e}\")\n return \"OAuth2 initialization failed\"\n\n return \"Unsupported auth mode\"\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error checking connection status: {e}\")\n return f\"Error: {e!s}\"\n else:\n return f\"{app} CONNECTED\"\n\n def _initiate_default_connection(self, entity: Any, app: str) -> str:\n connection = entity.initiate_connection(app_name=app, use_composio_auth=True, force_new_integration=True)\n return connection.redirectUrl\n\n def _get_connected_app_names_for_entity(self) -> list[str]:\n toolset = self._build_wrapper()\n connections = toolset.client.get_entity(id=self.entity_id).get_connections()\n return list({connection.appUniqueId for connection in connections})\n\n def _get_normalized_app_name(self) -> str:\n \"\"\"Get app name without connection status suffix.\n\n Returns:\n str: Normalized app name.\n \"\"\"\n return self.app_names.replace(\" ✅\", \"\").replace(\"_connected\", \"\")\n\n def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None) -> dict: # noqa: ARG002\n # Update the available apps options from the API\n if hasattr(self, \"api_key\") and self.api_key != \"\":\n toolset = self._build_wrapper()\n build_config[\"app_names\"][\"options\"] = self._get_oauth_apps(api_key=self.api_key)\n\n # First, ensure all dynamic fields are hidden by default\n dynamic_fields = [\"app_credentials\", \"username\", \"auth_link\", \"auth_status\", \"action_names\"]\n for field in dynamic_fields:\n if field in build_config:\n if build_config[field][\"value\"] is None or build_config[field][\"value\"] == \"\":\n build_config[field][\"show\"] = False\n build_config[field][\"advanced\"] = True\n build_config[field][\"load_from_db\"] = False\n else:\n build_config[field][\"show\"] = True\n build_config[field][\"advanced\"] = False\n\n if field_name == \"app_names\" and (not hasattr(self, \"app_names\") or not self.app_names):\n build_config[\"auth_status\"][\"show\"] = True\n build_config[\"auth_status\"][\"value\"] = \"Please select an app first\"\n return build_config\n\n if field_name == \"app_names\" and hasattr(self, \"api_key\") and self.api_key != \"\":\n # app_name = self._get_normalized_app_name()\n app_name = self.app_names\n try:\n toolset = self._build_wrapper()\n entity = toolset.client.get_entity(id=self.entity_id)\n\n # Always show auth_status when app is selected\n build_config[\"auth_status\"][\"show\"] = True\n build_config[\"auth_status\"][\"advanced\"] = False\n\n try:\n # Check if already connected\n entity.get_connection(app=app_name)\n build_config[\"auth_status\"][\"value\"] = \"✅\"\n build_config[\"auth_link\"][\"show\"] = False\n # Show action selection for connected apps\n build_config[\"action_names\"][\"show\"] = True\n build_config[\"action_names\"][\"advanced\"] = False\n\n except NoItemsFound:\n # Get auth scheme and show relevant fields\n auth_scheme = self._get_auth_scheme(app_name)\n auth_mode = auth_scheme.auth_mode\n logger.info(f\"Auth mode for {app_name}: {auth_mode}\")\n\n if auth_mode == \"API_KEY\":\n build_config[\"app_credentials\"][\"show\"] = True\n build_config[\"app_credentials\"][\"advanced\"] = False\n build_config[\"app_credentials\"][\"display_name\"] = \"API Key\"\n build_config[\"auth_status\"][\"value\"] = \"Enter API Key\"\n\n elif auth_mode == \"BASIC\":\n build_config[\"username\"][\"show\"] = True\n build_config[\"username\"][\"advanced\"] = False\n build_config[\"app_credentials\"][\"show\"] = True\n build_config[\"app_credentials\"][\"advanced\"] = False\n build_config[\"app_credentials\"][\"display_name\"] = \"Password\"\n build_config[\"auth_status\"][\"value\"] = \"Enter Username and Password\"\n\n elif auth_mode == \"OAUTH2\":\n build_config[\"auth_link\"][\"show\"] = True\n build_config[\"auth_link\"][\"advanced\"] = False\n auth_url = self._initiate_default_connection(entity, app_name)\n build_config[\"auth_link\"][\"value\"] = auth_url\n build_config[\"auth_status\"][\"value\"] = \"Click link to authenticate\"\n\n else:\n build_config[\"auth_status\"][\"value\"] = \"Unsupported auth mode\"\n\n # Update action names if connected\n if build_config[\"auth_status\"][\"value\"] == \"✅\":\n all_action_names = [str(action).replace(\"Action.\", \"\") for action in Action.all()]\n app_action_names = [\n action_name\n for action_name in all_action_names\n if action_name.lower().startswith(app_name.lower() + \"_\")\n ]\n if build_config[\"action_names\"][\"options\"] != app_action_names:\n build_config[\"action_names\"][\"options\"] = app_action_names\n build_config[\"action_names\"][\"value\"] = [app_action_names[0]] if app_action_names else [\"\"]\n\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error checking auth status: {e}, app: {app_name}\")\n build_config[\"auth_status\"][\"value\"] = f\"Error: {e!s}\"\n\n return build_config\n\n def build_tool(self) -> Sequence[Tool]:\n \"\"\"Build Composio tools based on selected actions.\n\n Returns:\n Sequence[Tool]: List of configured Composio tools.\n \"\"\"\n composio_toolset = self._build_wrapper()\n return composio_toolset.get_tools(actions=self.action_names)\n\n def _build_wrapper(self) -> ComposioToolSet:\n \"\"\"Build the Composio toolset wrapper.\n\n Returns:\n ComposioToolSet: The initialized toolset.\n\n Raises:\n ValueError: If the API key is not found or invalid.\n \"\"\"\n try:\n if not self.api_key:\n msg = \"Composio API Key is required\"\n raise ValueError(msg)\n return ComposioToolSet(api_key=self.api_key, entity_id=self.entity_id)\n except ValueError as e:\n logger.error(f\"Error building Composio wrapper: {e}\")\n msg = \"Please provide a valid Composio API Key in the component settings\"\n raise ValueError(msg) from e\n" + "value": "# Standard library imports\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport requests\n\n# Third-party imports\nfrom composio.client.collections import AppAuthScheme\nfrom composio.client.exceptions import NoItemsFound\nfrom composio_langchain import Action, ComposioToolSet\nfrom langchain_core.tools import Tool\nfrom loguru import logger\n\n# Local imports\nfrom langflow.base.langchain_utilities.model import LCToolComponent\nfrom langflow.inputs import DropdownInput, LinkInput, MessageTextInput, MultiselectInput, SecretStrInput, StrInput\nfrom langflow.io import Output\n\n\nclass ComposioAPIComponent(LCToolComponent):\n display_name: str = \"Composio Tools\"\n description: str = \"Use Composio toolset to run actions with your agent\"\n name = \"ComposioAPI\"\n icon = \"Composio\"\n documentation: str = \"https://docs.composio.dev\"\n\n inputs = [\n # Basic configuration inputs\n MessageTextInput(name=\"entity_id\", display_name=\"Entity ID\", value=\"default\", advanced=True),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Composio API Key\",\n required=True,\n info=\"Refer to https://docs.composio.dev/faq/api_key/api_key\",\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"app_names\",\n display_name=\"App Name\",\n options=[],\n value=\"\",\n info=\"The app name to use. Please refresh after selecting app name\",\n refresh_button=True,\n required=True,\n ),\n # Authentication-related inputs (initially hidden)\n SecretStrInput(\n name=\"app_credentials\",\n display_name=\"App Credentials\",\n required=False,\n dynamic=True,\n show=False,\n info=\"Credentials for app authentication (API Key, Password, etc)\",\n load_from_db=False,\n ),\n MessageTextInput(\n name=\"username\",\n display_name=\"Username\",\n required=False,\n dynamic=True,\n show=False,\n info=\"Username for Basic authentication\",\n ),\n LinkInput(\n name=\"auth_link\",\n display_name=\"Authentication Link\",\n value=\"\",\n info=\"Click to authenticate with OAuth2\",\n dynamic=True,\n show=False,\n placeholder=\"Click to authenticate\",\n ),\n StrInput(\n name=\"auth_status\",\n display_name=\"Auth Status\",\n value=\"Not Connected\",\n info=\"Current authentication status\",\n dynamic=True,\n show=False,\n ),\n MultiselectInput(\n name=\"action_names\",\n display_name=\"Actions to use\",\n required=True,\n options=[],\n value=[],\n info=\"The actions to pass to agent to execute\",\n dynamic=True,\n show=False,\n ),\n ]\n\n outputs = [\n Output(name=\"tools\", display_name=\"Tools\", method=\"build_tool\"),\n ]\n\n def _check_for_authorization(self, app: str) -> str:\n \"\"\"Checks if the app is authorized.\n\n Args:\n app (str): The app name to check authorization for.\n\n Returns:\n str: The authorization status or URL.\n \"\"\"\n toolset = self._build_wrapper()\n entity = toolset.client.get_entity(id=self.entity_id)\n try:\n # Check if user is already connected\n entity.get_connection(app=app)\n except NoItemsFound:\n # Get auth scheme for the app\n auth_scheme = self._get_auth_scheme(app)\n return self._handle_auth_by_scheme(entity, app, auth_scheme)\n except Exception: # noqa: BLE001\n logger.exception(\"Authorization error\")\n return \"Error checking authorization\"\n else:\n return f\"{app} CONNECTED\"\n\n def _get_auth_scheme(self, app_name: str) -> AppAuthScheme:\n \"\"\"Get the primary auth scheme for an app.\n\n Args:\n app_name (str): The name of the app to get auth scheme for.\n\n Returns:\n AppAuthScheme: The auth scheme details.\n \"\"\"\n toolset = self._build_wrapper()\n try:\n return toolset.get_auth_scheme_for_app(app=app_name.lower())\n except Exception: # noqa: BLE001\n logger.exception(f\"Error getting auth scheme for {app_name}\")\n return None\n\n def _get_oauth_apps(self, api_key: str) -> list[str]:\n \"\"\"Fetch OAuth-enabled apps from Composio API.\n\n Args:\n api_key (str): The Composio API key.\n\n Returns:\n list[str]: A list containing OAuth-enabled app names.\n \"\"\"\n oauth_apps = []\n try:\n url = \"https://backend.composio.dev/api/v1/apps\"\n headers = {\"x-api-key\": api_key}\n params = {\n \"includeLocal\": \"true\",\n \"additionalFields\": \"auth_schemes\",\n \"sortBy\": \"alphabet\",\n }\n\n response = requests.get(url, headers=headers, params=params, timeout=20)\n data = response.json()\n\n for item in data.get(\"items\", []):\n for auth_scheme in item.get(\"auth_schemes\", []):\n if auth_scheme.get(\"mode\") in {\"OAUTH1\", \"OAUTH2\"}:\n oauth_apps.append(item[\"key\"].upper())\n break\n except requests.RequestException as e:\n logger.error(f\"Error fetching OAuth apps: {e}\")\n return []\n else:\n return oauth_apps\n\n def _handle_auth_by_scheme(self, entity: Any, app: str, auth_scheme: AppAuthScheme) -> str:\n \"\"\"Handle authentication based on the auth scheme.\n\n Args:\n entity (Any): The entity instance.\n app (str): The app name.\n auth_scheme (AppAuthScheme): The auth scheme details.\n\n Returns:\n str: The authentication status or URL.\n \"\"\"\n auth_mode = auth_scheme.auth_mode\n\n try:\n # First check if already connected\n entity.get_connection(app=app)\n except NoItemsFound:\n # If not connected, handle new connection based on auth mode\n if auth_mode == \"API_KEY\":\n if hasattr(self, \"app_credentials\") and self.app_credentials:\n try:\n entity.initiate_connection(\n app_name=app,\n auth_mode=\"API_KEY\",\n auth_config={\"api_key\": self.app_credentials},\n use_composio_auth=False,\n force_new_integration=True,\n )\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error connecting with API Key: {e}\")\n return \"Invalid API Key\"\n else:\n return f\"{app} CONNECTED\"\n return \"Enter API Key\"\n\n if (\n auth_mode == \"BASIC\"\n and hasattr(self, \"username\")\n and hasattr(self, \"app_credentials\")\n and self.username\n and self.app_credentials\n ):\n try:\n entity.initiate_connection(\n app_name=app,\n auth_mode=\"BASIC\",\n auth_config={\"username\": self.username, \"password\": self.app_credentials},\n use_composio_auth=False,\n force_new_integration=True,\n )\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error connecting with Basic Auth: {e}\")\n return \"Invalid credentials\"\n else:\n return f\"{app} CONNECTED\"\n elif auth_mode == \"BASIC\":\n return \"Enter Username and Password\"\n\n if auth_mode == \"OAUTH2\":\n try:\n return self._initiate_default_connection(entity, app)\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error initiating OAuth2: {e}\")\n return \"OAuth2 initialization failed\"\n\n return \"Unsupported auth mode\"\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error checking connection status: {e}\")\n return f\"Error: {e!s}\"\n else:\n return f\"{app} CONNECTED\"\n\n def _initiate_default_connection(self, entity: Any, app: str) -> str:\n connection = entity.initiate_connection(app_name=app, use_composio_auth=True, force_new_integration=True)\n return connection.redirectUrl\n\n def _get_connected_app_names_for_entity(self) -> list[str]:\n toolset = self._build_wrapper()\n connections = toolset.client.get_entity(id=self.entity_id).get_connections()\n return list({connection.appUniqueId for connection in connections})\n\n def _get_normalized_app_name(self) -> str:\n \"\"\"Get app name without connection status suffix.\n\n Returns:\n str: Normalized app name.\n \"\"\"\n return self.app_names.replace(\" ✅\", \"\").replace(\"_connected\", \"\")\n\n def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None) -> dict: # noqa: ARG002\n # Update the available apps options from the API\n if hasattr(self, \"api_key\") and self.api_key != \"\":\n toolset = self._build_wrapper()\n build_config[\"app_names\"][\"options\"] = self._get_oauth_apps(api_key=self.api_key)\n\n # First, ensure all dynamic fields are hidden by default\n dynamic_fields = [\"app_credentials\", \"username\", \"auth_link\", \"auth_status\", \"action_names\"]\n for field in dynamic_fields:\n if field in build_config:\n if build_config[field][\"value\"] is None or build_config[field][\"value\"] == \"\":\n build_config[field][\"show\"] = False\n build_config[field][\"advanced\"] = True\n build_config[field][\"load_from_db\"] = False\n else:\n build_config[field][\"show\"] = True\n build_config[field][\"advanced\"] = False\n\n if field_name == \"app_names\" and (not hasattr(self, \"app_names\") or not self.app_names):\n build_config[\"auth_status\"][\"show\"] = True\n build_config[\"auth_status\"][\"value\"] = \"Please select an app first\"\n return build_config\n\n if field_name == \"app_names\" and hasattr(self, \"api_key\") and self.api_key != \"\":\n # app_name = self._get_normalized_app_name()\n app_name = self.app_names\n try:\n toolset = self._build_wrapper()\n entity = toolset.client.get_entity(id=self.entity_id)\n\n # Always show auth_status when app is selected\n build_config[\"auth_status\"][\"show\"] = True\n build_config[\"auth_status\"][\"advanced\"] = False\n\n try:\n # Check if already connected\n entity.get_connection(app=app_name)\n build_config[\"auth_status\"][\"value\"] = \"✅\"\n build_config[\"auth_link\"][\"show\"] = False\n # Show action selection for connected apps\n build_config[\"action_names\"][\"show\"] = True\n build_config[\"action_names\"][\"advanced\"] = False\n\n except NoItemsFound:\n # Get auth scheme and show relevant fields\n auth_scheme = self._get_auth_scheme(app_name)\n auth_mode = auth_scheme.auth_mode\n logger.info(f\"Auth mode for {app_name}: {auth_mode}\")\n\n if auth_mode == \"API_KEY\":\n build_config[\"app_credentials\"][\"show\"] = True\n build_config[\"app_credentials\"][\"advanced\"] = False\n build_config[\"app_credentials\"][\"display_name\"] = \"API Key\"\n build_config[\"auth_status\"][\"value\"] = \"Enter API Key\"\n\n elif auth_mode == \"BASIC\":\n build_config[\"username\"][\"show\"] = True\n build_config[\"username\"][\"advanced\"] = False\n build_config[\"app_credentials\"][\"show\"] = True\n build_config[\"app_credentials\"][\"advanced\"] = False\n build_config[\"app_credentials\"][\"display_name\"] = \"Password\"\n build_config[\"auth_status\"][\"value\"] = \"Enter Username and Password\"\n\n elif auth_mode == \"OAUTH2\":\n build_config[\"auth_link\"][\"show\"] = True\n build_config[\"auth_link\"][\"advanced\"] = False\n auth_url = self._initiate_default_connection(entity, app_name)\n build_config[\"auth_link\"][\"value\"] = auth_url\n build_config[\"auth_status\"][\"value\"] = \"Click link to authenticate\"\n\n else:\n build_config[\"auth_status\"][\"value\"] = \"Unsupported auth mode\"\n\n # Update action names if connected\n if build_config[\"auth_status\"][\"value\"] == \"✅\":\n all_action_names = [str(action).replace(\"Action.\", \"\") for action in Action.all()]\n app_action_names = [\n action_name\n for action_name in all_action_names\n if action_name.lower().startswith(app_name.lower() + \"_\")\n ]\n if build_config[\"action_names\"][\"options\"] != app_action_names:\n build_config[\"action_names\"][\"options\"] = app_action_names\n build_config[\"action_names\"][\"value\"] = [app_action_names[0]] if app_action_names else [\"\"]\n\n except Exception as e: # noqa: BLE001\n logger.error(f\"Error checking auth status: {e}, app: {app_name}\")\n build_config[\"auth_status\"][\"value\"] = f\"Error: {e!s}\"\n\n return build_config\n\n def build_tool(self) -> Sequence[Tool]:\n \"\"\"Build Composio tools based on selected actions.\n\n Returns:\n Sequence[Tool]: List of configured Composio tools.\n \"\"\"\n composio_toolset = self._build_wrapper()\n return composio_toolset.get_tools(actions=self.action_names)\n\n def _build_wrapper(self) -> ComposioToolSet:\n \"\"\"Build the Composio toolset wrapper.\n\n Returns:\n ComposioToolSet: The initialized toolset.\n\n Raises:\n ValueError: If the API key is not found or invalid.\n \"\"\"\n try:\n if not self.api_key:\n msg = \"Composio API Key is required\"\n raise ValueError(msg)\n return ComposioToolSet(api_key=self.api_key, entity_id=self.entity_id)\n except ValueError as e:\n logger.error(f\"Error building Composio wrapper: {e}\")\n msg = \"Please provide a valid Composio API Key in the component settings\"\n raise ValueError(msg) from e\n" }, "entity_id": { "_input_type": "MessageTextInput", @@ -1792,7 +1794,7 @@ "type": "ComposioAPI" }, "dragging": false, - "id": "ComposioAPI-adjCJ", + "id": "ComposioAPI-Z0Iiy", "measured": { "height": 497, "width": 320 @@ -1806,17 +1808,16 @@ } ], "viewport": { - "x": 494.6705020244866, - "y": 423.6508642555026, - "zoom": 0.7202622571895975 + "x": 568.8302643946312, + "y": 91.93195183355544, + "zoom": 0.7104297128050097 } }, "description": "Interact with Gmail to send emails, create drafts, and fetch messages", - "endpoint_name": null, - "id": "6e5d7690-35da-4163-8c2f-9693ebb59f5c", - "is_component": false, - "last_tested_version": "1.2.0", "name": "Gmail Agent", + "endpoint_name": null, + "id": "0473161e-ca7e-413c-9113-e98a142313ed", + "is_component": false, "tags": [ "agents" ] diff --git a/src/backend/base/langflow/initial_setup/starter_projects/LoopTemplate.json b/src/backend/base/langflow/initial_setup/starter_projects/LoopTemplate.json index 8a2563619..9656f303f 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/LoopTemplate.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/LoopTemplate.json @@ -281,7 +281,7 @@ "show": true, "title_case": false, "type": "code", - "value": "import urllib.request\nfrom urllib.parse import urlparse\nfrom xml.etree.ElementTree import Element\n\nfrom defusedxml.ElementTree import fromstring\n\nfrom langflow.custom import Component\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, Output\nfrom langflow.schema import Data\n\n\nclass ArXivComponent(Component):\n display_name = \"arXiv\"\n description = \"Search and retrieve papers from arXiv.org\"\n icon = \"arXiv\"\n\n inputs = [\n MessageTextInput(\n name=\"search_query\",\n display_name=\"Search Query\",\n info=\"The search query for arXiv papers (e.g., 'quantum computing')\",\n tool_mode=True,\n ),\n DropdownInput(\n name=\"search_type\",\n display_name=\"Search Field\",\n info=\"The field to search in\",\n options=[\"all\", \"title\", \"abstract\", \"author\", \"cat\"], # cat is for category\n value=\"all\",\n ),\n IntInput(\n name=\"max_results\",\n display_name=\"Max Results\",\n info=\"Maximum number of results to return\",\n value=10,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Papers\", name=\"papers\", method=\"search_papers\"),\n ]\n\n def build_query_url(self) -> str:\n \"\"\"Build the arXiv API query URL.\"\"\"\n base_url = \"http://export.arxiv.org/api/query?\"\n\n # Build the search query\n search_query = f\"{self.search_type}:{self.search_query}\"\n\n # URL parameters\n params = {\n \"search_query\": search_query,\n \"max_results\": str(self.max_results),\n }\n\n # Convert params to URL query string\n query_string = \"&\".join([f\"{k}={urllib.parse.quote(str(v))}\" for k, v in params.items()])\n\n return base_url + query_string\n\n def parse_atom_response(self, response_text: str) -> list[dict]:\n \"\"\"Parse the Atom XML response from arXiv.\"\"\"\n # Parse XML safely using defusedxml\n root = fromstring(response_text)\n\n # Define namespace dictionary for XML parsing\n ns = {\"atom\": \"http://www.w3.org/2005/Atom\", \"arxiv\": \"http://arxiv.org/schemas/atom\"}\n\n papers = []\n # Process each entry (paper)\n for entry in root.findall(\"atom:entry\", ns):\n paper = {\n \"id\": self._get_text(entry, \"atom:id\", ns),\n \"title\": self._get_text(entry, \"atom:title\", ns),\n \"summary\": self._get_text(entry, \"atom:summary\", ns),\n \"published\": self._get_text(entry, \"atom:published\", ns),\n \"updated\": self._get_text(entry, \"atom:updated\", ns),\n \"authors\": [author.find(\"atom:name\", ns).text for author in entry.findall(\"atom:author\", ns)],\n \"arxiv_url\": self._get_link(entry, \"alternate\", ns),\n \"pdf_url\": self._get_link(entry, \"related\", ns),\n \"comment\": self._get_text(entry, \"arxiv:comment\", ns),\n \"journal_ref\": self._get_text(entry, \"arxiv:journal_ref\", ns),\n \"primary_category\": self._get_category(entry, ns),\n \"categories\": [cat.get(\"term\") for cat in entry.findall(\"atom:category\", ns)],\n }\n papers.append(paper)\n\n return papers\n\n def _get_text(self, element: Element, path: str, ns: dict) -> str | None:\n \"\"\"Safely extract text from an XML element.\"\"\"\n el = element.find(path, ns)\n return el.text.strip() if el is not None and el.text else None\n\n def _get_link(self, element: Element, rel: str, ns: dict) -> str | None:\n \"\"\"Get link URL based on relation type.\"\"\"\n for link in element.findall(\"atom:link\", ns):\n if link.get(\"rel\") == rel:\n return link.get(\"href\")\n return None\n\n def _get_category(self, element: Element, ns: dict) -> str | None:\n \"\"\"Get primary category.\"\"\"\n cat = element.find(\"arxiv:primary_category\", ns)\n return cat.get(\"term\") if cat is not None else None\n\n def search_papers(self) -> list[Data]:\n \"\"\"Search arXiv and return results.\"\"\"\n try:\n # Build the query URL\n url = self.build_query_url()\n\n # Validate URL scheme and host\n parsed_url = urlparse(url)\n if parsed_url.scheme not in (\"http\", \"https\"):\n error_msg = f\"Invalid URL scheme: {parsed_url.scheme}\"\n raise ValueError(error_msg)\n if parsed_url.hostname != \"export.arxiv.org\":\n error_msg = f\"Invalid host: {parsed_url.hostname}\"\n raise ValueError(error_msg)\n\n # Create a custom opener that only allows http/https schemes\n class RestrictedHTTPHandler(urllib.request.HTTPHandler):\n def http_open(self, req):\n return super().http_open(req)\n\n class RestrictedHTTPSHandler(urllib.request.HTTPSHandler):\n def https_open(self, req):\n return super().https_open(req)\n\n # Build opener with restricted handlers\n opener = urllib.request.build_opener(RestrictedHTTPHandler, RestrictedHTTPSHandler)\n urllib.request.install_opener(opener)\n\n # Make the request with validated URL using restricted opener\n response = opener.open(url)\n response_text = response.read().decode(\"utf-8\")\n\n # Parse the response\n papers = self.parse_atom_response(response_text)\n\n # Convert to Data objects\n results = [Data(data=paper) for paper in papers]\n self.status = results\n except (urllib.error.URLError, ValueError) as e:\n error_data = Data(data={\"error\": f\"Request error: {e!s}\"})\n self.status = error_data\n return [error_data]\n else:\n return results\n" + "value": "import urllib.request\nfrom urllib.parse import urlparse\nfrom xml.etree.ElementTree import Element\n\nfrom defusedxml.ElementTree import fromstring\n\nfrom langflow.custom import Component\nfrom langflow.io import DropdownInput, IntInput, MessageTextInput, Output\nfrom langflow.schema import Data\n\n\nclass ArXivComponent(Component):\n display_name = \"arXiv\"\n description = \"Search and retrieve papers from arXiv.org\"\n icon = \"arXiv\"\n\n inputs = [\n MessageTextInput(\n name=\"search_query\",\n display_name=\"Search Query\",\n info=\"The search query for arXiv papers (e.g., 'quantum computing')\",\n tool_mode=True,\n ),\n DropdownInput(\n name=\"search_type\",\n display_name=\"Search Field\",\n info=\"The field to search in\",\n options=[\"all\", \"title\", \"abstract\", \"author\", \"cat\"], # cat is for category\n value=\"all\",\n ),\n IntInput(\n name=\"max_results\",\n display_name=\"Max Results\",\n info=\"Maximum number of results to return\",\n value=10,\n ),\n ]\n\n outputs = [\n Output(display_name=\"Papers\", name=\"papers\", method=\"search_papers\"),\n ]\n\n def build_query_url(self) -> str:\n \"\"\"Build the arXiv API query URL.\"\"\"\n base_url = \"http://export.arxiv.org/api/query?\"\n\n # Build the search query\n search_query = f\"{self.search_type}:{self.search_query}\"\n\n # URL parameters\n params = {\n \"search_query\": search_query,\n \"max_results\": str(self.max_results),\n }\n\n # Convert params to URL query string\n query_string = \"&\".join([f\"{k}={urllib.parse.quote(str(v))}\" for k, v in params.items()])\n\n return base_url + query_string\n\n def parse_atom_response(self, response_text: str) -> list[dict]:\n \"\"\"Parse the Atom XML response from arXiv.\"\"\"\n # Parse XML safely using defusedxml\n root = fromstring(response_text)\n\n # Define namespace dictionary for XML parsing\n ns = {\"atom\": \"http://www.w3.org/2005/Atom\", \"arxiv\": \"http://arxiv.org/schemas/atom\"}\n\n papers = []\n # Process each entry (paper)\n for entry in root.findall(\"atom:entry\", ns):\n paper = {\n \"id\": self._get_text(entry, \"atom:id\", ns),\n \"title\": self._get_text(entry, \"atom:title\", ns),\n \"summary\": self._get_text(entry, \"atom:summary\", ns),\n \"published\": self._get_text(entry, \"atom:published\", ns),\n \"updated\": self._get_text(entry, \"atom:updated\", ns),\n \"authors\": [author.find(\"atom:name\", ns).text for author in entry.findall(\"atom:author\", ns)],\n \"arxiv_url\": self._get_link(entry, \"alternate\", ns),\n \"pdf_url\": self._get_link(entry, \"related\", ns),\n \"comment\": self._get_text(entry, \"arxiv:comment\", ns),\n \"journal_ref\": self._get_text(entry, \"arxiv:journal_ref\", ns),\n \"primary_category\": self._get_category(entry, ns),\n \"categories\": [cat.get(\"term\") for cat in entry.findall(\"atom:category\", ns)],\n }\n papers.append(paper)\n\n return papers\n\n def _get_text(self, element: Element, path: str, ns: dict) -> str | None:\n \"\"\"Safely extract text from an XML element.\"\"\"\n el = element.find(path, ns)\n return el.text.strip() if el is not None and el.text else None\n\n def _get_link(self, element: Element, rel: str, ns: dict) -> str | None:\n \"\"\"Get link URL based on relation type.\"\"\"\n for link in element.findall(\"atom:link\", ns):\n if link.get(\"rel\") == rel:\n return link.get(\"href\")\n return None\n\n def _get_category(self, element: Element, ns: dict) -> str | None:\n \"\"\"Get primary category.\"\"\"\n cat = element.find(\"arxiv:primary_category\", ns)\n return cat.get(\"term\") if cat is not None else None\n\n def search_papers(self) -> list[Data]:\n \"\"\"Search arXiv and return results.\"\"\"\n try:\n # Build the query URL\n url = self.build_query_url()\n\n # Validate URL scheme and host\n parsed_url = urlparse(url)\n if parsed_url.scheme not in {\"http\", \"https\"}:\n error_msg = f\"Invalid URL scheme: {parsed_url.scheme}\"\n raise ValueError(error_msg)\n if parsed_url.hostname != \"export.arxiv.org\":\n error_msg = f\"Invalid host: {parsed_url.hostname}\"\n raise ValueError(error_msg)\n\n # Create a custom opener that only allows http/https schemes\n class RestrictedHTTPHandler(urllib.request.HTTPHandler):\n def http_open(self, req):\n return super().http_open(req)\n\n class RestrictedHTTPSHandler(urllib.request.HTTPSHandler):\n def https_open(self, req):\n return super().https_open(req)\n\n # Build opener with restricted handlers\n opener = urllib.request.build_opener(RestrictedHTTPHandler, RestrictedHTTPSHandler)\n urllib.request.install_opener(opener)\n\n # Make the request with validated URL using restricted opener\n response = opener.open(url)\n response_text = response.read().decode(\"utf-8\")\n\n # Parse the response\n papers = self.parse_atom_response(response_text)\n\n # Convert to Data objects\n results = [Data(data=paper) for paper in papers]\n self.status = results\n except (urllib.error.URLError, ValueError) as e:\n error_data = Data(data={\"error\": f\"Request error: {e!s}\"})\n self.status = error_data\n return [error_data]\n else:\n return results\n" }, "max_results": { "_input_type": "IntInput", @@ -783,7 +783,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in (\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\") and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" + "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in {\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\"} and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" }, "input_value": { "_input_type": "MessageInput", diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Portfolio Website Code Generator.json b/src/backend/base/langflow/initial_setup/starter_projects/Portfolio Website Code Generator.json index c1e0a944f..670a5e4ee 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Portfolio Website Code Generator.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Portfolio Website Code Generator.json @@ -753,7 +753,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in (\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\") and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" + "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in {\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\"} and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" }, "input_value": { "_input_type": "MessageInput", @@ -1089,7 +1089,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in (\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\") and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" + "value": "from typing import Any\n\nimport requests\nfrom loguru import logger\n\nfrom langflow.base.models.anthropic_constants import ANTHROPIC_MODELS\nfrom langflow.base.models.model import LCModelComponent\nfrom langflow.field_typing import LanguageModel\nfrom langflow.field_typing.range_spec import RangeSpec\nfrom langflow.io import BoolInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput\nfrom langflow.schema.dotdict import dotdict\n\n\nclass AnthropicModelComponent(LCModelComponent):\n display_name = \"Anthropic\"\n description = \"Generate text using Anthropic Chat&Completion LLMs with prefill support.\"\n icon = \"Anthropic\"\n name = \"AnthropicModel\"\n\n inputs = [\n *LCModelComponent._base_inputs,\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n value=4096,\n info=\"The maximum number of tokens to generate. Set to 0 for unlimited tokens.\",\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=ANTHROPIC_MODELS,\n refresh_button=True,\n value=ANTHROPIC_MODELS[0],\n combobox=True,\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Anthropic API Key\",\n info=\"Your Anthropic API key.\",\n value=None,\n required=True,\n real_time_refresh=True,\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n value=0.1,\n info=\"Run inference with this temperature. Must by in the closed interval [0.0, 1.0].\",\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n ),\n MessageTextInput(\n name=\"base_url\",\n display_name=\"Anthropic API URL\",\n info=\"Endpoint of the Anthropic API. Defaults to 'https://api.anthropic.com' if not specified.\",\n value=\"https://api.anthropic.com\",\n real_time_refresh=True,\n ),\n BoolInput(\n name=\"tool_model_enabled\",\n display_name=\"Enable Tool Models\",\n info=(\n \"Select if you want to use models that can work with tools. If yes, only those models will be shown.\"\n ),\n advanced=False,\n value=False,\n real_time_refresh=True,\n ),\n MessageTextInput(\n name=\"prefill\", display_name=\"Prefill\", info=\"Prefill text to guide the model's response.\", advanced=True\n ),\n ]\n\n def build_model(self) -> LanguageModel: # type: ignore[type-var]\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n try:\n output = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n max_tokens_to_sample=self.max_tokens,\n temperature=self.temperature,\n anthropic_api_url=self.base_url,\n streaming=self.stream,\n )\n except Exception as e:\n msg = \"Could not connect to Anthropic API.\"\n raise ValueError(msg) from e\n\n return output\n\n def get_models(self, tool_model_enabled: bool | None = None) -> list[str]:\n try:\n import anthropic\n\n client = anthropic.Anthropic(api_key=self.api_key)\n models = client.models.list(limit=20).data\n model_ids = [model.id for model in models]\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n model_ids = ANTHROPIC_MODELS\n if tool_model_enabled:\n try:\n from langchain_anthropic.chat_models import ChatAnthropic\n except ImportError as e:\n msg = \"langchain_anthropic is not installed. Please install it with `pip install langchain_anthropic`.\"\n raise ImportError(msg) from e\n for model in model_ids:\n model_with_tool = ChatAnthropic(\n model=self.model_name,\n anthropic_api_key=self.api_key,\n anthropic_api_url=self.base_url,\n )\n if not self.supports_tool_calling(model_with_tool):\n model_ids.remove(model)\n return model_ids\n\n def _get_exception_message(self, exception: Exception) -> str | None:\n \"\"\"Get a message from an Anthropic exception.\n\n Args:\n exception (Exception): The exception to get the message from.\n\n Returns:\n str: The message from the exception.\n \"\"\"\n try:\n from anthropic import BadRequestError\n except ImportError:\n return None\n if isinstance(exception, BadRequestError):\n message = exception.body.get(\"error\", {}).get(\"message\")\n if message:\n return message\n return None\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n if field_name in {\"base_url\", \"model_name\", \"tool_model_enabled\", \"api_key\"} and field_value:\n try:\n if len(self.api_key) == 0:\n ids = ANTHROPIC_MODELS\n else:\n try:\n ids = self.get_models(tool_model_enabled=self.tool_model_enabled)\n except (ImportError, ValueError, requests.exceptions.RequestException) as e:\n logger.exception(f\"Error getting model names: {e}\")\n ids = ANTHROPIC_MODELS\n build_config[\"model_name\"][\"options\"] = ids\n build_config[\"model_name\"][\"value\"] = ids[0]\n except Exception as e:\n msg = f\"Error getting model names: {e}\"\n raise ValueError(msg) from e\n return build_config\n" }, "input_value": { "_input_type": "MessageInput", diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json b/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json index bb745e2e7..b71b500d0 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Vector Store RAG.json @@ -3385,7 +3385,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from collections import defaultdict\nfrom dataclasses import asdict, dataclass, field\n\nfrom astrapy import AstraDBAdmin, DataAPIClient, Database\nfrom astrapy.info import CollectionDescriptor\nfrom langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions\n\nfrom langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store\nfrom langflow.helpers import docs_to_data\nfrom langflow.inputs import FloatInput, NestedDictInput\nfrom langflow.io import (\n BoolInput,\n DropdownInput,\n HandleInput,\n IntInput,\n SecretStrInput,\n StrInput,\n)\nfrom langflow.schema import Data\nfrom langflow.utils.version import get_version_info\n\n\nclass AstraDBVectorStoreComponent(LCVectorStoreComponent):\n display_name: str = \"Astra DB\"\n description: str = \"Ingest and search documents in Astra DB\"\n documentation: str = \"https://docs.datastax.com/en/langflow/astra-components.html\"\n name = \"AstraDB\"\n icon: str = \"AstraDB\"\n\n _cached_vector_store: AstraDBVectorStore | None = None\n\n @dataclass\n class NewDatabaseInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_database\",\n \"description\": \"\",\n \"display_name\": \"Create new database\",\n \"field_order\": [\"new_database_name\", \"cloud_provider\", \"region\"],\n \"template\": {\n \"new_database_name\": StrInput(\n name=\"new_database_name\",\n display_name=\"Name\",\n info=\"Name of the new database to create in Astra DB.\",\n required=True,\n ),\n \"cloud_provider\": DropdownInput(\n name=\"cloud_provider\",\n display_name=\"Cloud provider\",\n info=\"Cloud provider for the new database.\",\n options=[\"Amazon Web Services\", \"Google Cloud Platform\", \"Microsoft Azure\"],\n required=True,\n real_time_refresh=True,\n ),\n \"region\": DropdownInput(\n name=\"region\",\n display_name=\"Region\",\n info=\"Region for the new database.\",\n options=[],\n required=True,\n ),\n },\n },\n }\n }\n )\n\n @dataclass\n class NewCollectionInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_collection\",\n \"description\": \"\",\n \"display_name\": \"Create new collection\",\n \"field_order\": [\n \"new_collection_name\",\n \"embedding_generation_provider\",\n \"embedding_generation_model\",\n \"dimension\",\n ],\n \"template\": {\n \"new_collection_name\": StrInput(\n name=\"new_collection_name\",\n display_name=\"Name\",\n info=\"Name of the new collection to create in Astra DB.\",\n required=True,\n ),\n \"embedding_generation_provider\": DropdownInput(\n name=\"embedding_generation_provider\",\n display_name=\"Embedding generation method\",\n info=\"Provider to use for generating embeddings.\",\n real_time_refresh=True,\n required=True,\n options=[\"Bring your own\", \"Nvidia\"],\n ),\n \"embedding_generation_model\": DropdownInput(\n name=\"embedding_generation_model\",\n display_name=\"Embedding model\",\n info=\"Model to use for generating embeddings.\",\n required=True,\n options=[],\n ),\n \"dimension\": IntInput(\n name=\"dimension\",\n display_name=\"Dimensions (Required only for `Bring your own`)\",\n info=\"Dimensions of the embeddings to generate.\",\n required=False,\n value=1024,\n ),\n },\n },\n }\n }\n )\n\n inputs = [\n SecretStrInput(\n name=\"token\",\n display_name=\"Astra DB Application Token\",\n info=\"Authentication token for accessing Astra DB.\",\n value=\"ASTRA_DB_APPLICATION_TOKEN\",\n required=True,\n real_time_refresh=True,\n input_types=[],\n ),\n StrInput(\n name=\"environment\",\n display_name=\"Environment\",\n info=\"The environment for the Astra DB API Endpoint.\",\n advanced=True,\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"database_name\",\n display_name=\"Database\",\n info=\"The Database name for the Astra DB instance.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewDatabaseInput()),\n combobox=True,\n ),\n StrInput(\n name=\"api_endpoint\",\n display_name=\"Astra DB API Endpoint\",\n info=\"The API Endpoint for the Astra DB instance. Supercedes database selection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"collection_name\",\n display_name=\"Collection\",\n info=\"The name of the collection within Astra DB where the vectors will be stored.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewCollectionInput()),\n combobox=True,\n advanced=True,\n ),\n StrInput(\n name=\"keyspace\",\n display_name=\"Keyspace\",\n info=\"Optional keyspace within Astra DB to use for the collection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"embedding_choice\",\n display_name=\"Embedding Model or Astra Vectorize\",\n info=\"Choose an embedding model or use Astra Vectorize.\",\n options=[\"Embedding Model\", \"Astra Vectorize\"],\n value=\"Embedding Model\",\n advanced=True,\n real_time_refresh=True,\n ),\n HandleInput(\n name=\"embedding_model\",\n display_name=\"Embedding Model\",\n input_types=[\"Embeddings\"],\n info=\"Specify the Embedding Model. Not required for Astra Vectorize collections.\",\n required=False,\n ),\n *LCVectorStoreComponent.inputs,\n IntInput(\n name=\"number_of_results\",\n display_name=\"Number of Search Results\",\n info=\"Number of search results to return.\",\n advanced=True,\n value=4,\n ),\n DropdownInput(\n name=\"search_type\",\n display_name=\"Search Type\",\n info=\"Search type to use\",\n options=[\"Similarity\", \"Similarity with score threshold\", \"MMR (Max Marginal Relevance)\"],\n value=\"Similarity\",\n advanced=True,\n ),\n FloatInput(\n name=\"search_score_threshold\",\n display_name=\"Search Score Threshold\",\n info=\"Minimum similarity score threshold for search results. \"\n \"(when using 'Similarity with score threshold')\",\n value=0,\n advanced=True,\n ),\n NestedDictInput(\n name=\"advanced_search_filter\",\n display_name=\"Search Metadata Filter\",\n info=\"Optional dictionary of filters to apply to the search query.\",\n advanced=True,\n ),\n BoolInput(\n name=\"autodetect_collection\",\n display_name=\"Autodetect Collection\",\n info=\"Boolean flag to determine whether to autodetect the collection.\",\n advanced=True,\n value=True,\n ),\n StrInput(\n name=\"content_field\",\n display_name=\"Content Field\",\n info=\"Field to use as the text content field for the vector store.\",\n advanced=True,\n ),\n StrInput(\n name=\"deletion_field\",\n display_name=\"Deletion Based On Field\",\n info=\"When this parameter is provided, documents in the target collection with \"\n \"metadata field values matching the input metadata field value will be deleted \"\n \"before new data is loaded.\",\n advanced=True,\n ),\n BoolInput(\n name=\"ignore_invalid_documents\",\n display_name=\"Ignore Invalid Documents\",\n info=\"Boolean flag to determine whether to ignore invalid documents at runtime.\",\n advanced=True,\n ),\n NestedDictInput(\n name=\"astradb_vectorstore_kwargs\",\n display_name=\"AstraDBVectorStore Parameters\",\n info=\"Optional dictionary of additional parameters for the AstraDBVectorStore.\",\n advanced=True,\n ),\n ]\n\n @classmethod\n def map_cloud_providers(cls):\n # TODO: Programmatically fetch the regions for each cloud provider\n return {\n \"Amazon Web Services\": {\n \"id\": \"aws\",\n \"regions\": [\"us-east-2\", \"ap-south-1\", \"eu-west-1\"],\n },\n \"Google Cloud Platform\": {\n \"id\": \"gcp\",\n \"regions\": [\"us-east1\"],\n },\n \"Microsoft Azure\": {\n \"id\": \"azure\",\n \"regions\": [\"westus3\"],\n },\n }\n\n @classmethod\n def get_vectorize_providers(cls, token: str, environment: str | None = None, api_endpoint: str | None = None):\n try:\n # Get the admin object\n admin = AstraDBAdmin(token=token, environment=environment)\n db_admin = admin.get_database_admin(api_endpoint=api_endpoint)\n\n # Get the list of embedding providers\n embedding_providers = db_admin.find_embedding_providers().as_dict()\n\n vectorize_providers_mapping = {}\n # Map the provider display name to the provider key and models\n for provider_key, provider_data in embedding_providers[\"embeddingProviders\"].items():\n # Get the provider display name and models\n display_name = provider_data[\"displayName\"]\n models = [model[\"name\"] for model in provider_data[\"models\"]]\n\n # Build our mapping\n vectorize_providers_mapping[display_name] = [provider_key, models]\n\n # Sort the resulting dictionary\n return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))\n except Exception as e:\n msg = f\"Error fetching vectorize providers: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n async def create_database_api(\n cls,\n new_database_name: str,\n cloud_provider: str,\n region: str,\n token: str,\n environment: str | None = None,\n keyspace: str | None = None,\n ):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Call the create database function\n return await admin_client.async_create_database(\n name=new_database_name,\n cloud_provider=cls.map_cloud_providers()[cloud_provider][\"id\"],\n region=region,\n keyspace=keyspace,\n wait_until_active=False,\n )\n\n @classmethod\n async def create_collection_api(\n cls,\n new_collection_name: str,\n token: str,\n api_endpoint: str,\n environment: str | None = None,\n keyspace: str | None = None,\n dimension: int | None = None,\n embedding_generation_provider: str | None = None,\n embedding_generation_model: str | None = None,\n ):\n # Create the data API client\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the database object\n database = client.get_async_database(api_endpoint=api_endpoint, token=token)\n\n # Build vectorize options, if needed\n vectorize_options = None\n if not dimension:\n vectorize_options = CollectionVectorServiceOptions(\n provider=cls.get_vectorize_providers(\n token=token, environment=environment, api_endpoint=api_endpoint\n ).get(embedding_generation_provider, [None, []])[0],\n model_name=embedding_generation_model,\n )\n\n # Create the collection\n return await database.create_collection(\n name=new_collection_name,\n keyspace=keyspace,\n dimension=dimension,\n service=vectorize_options,\n )\n\n @classmethod\n def get_database_list_static(cls, token: str, environment: str | None = None):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Get the list of databases\n db_list = list(admin_client.list_databases())\n\n # Set the environment properly\n env_string = \"\"\n if environment and environment != \"prod\":\n env_string = f\"-{environment}\"\n\n # Generate the api endpoint for each database\n db_info_dict = {}\n for db in db_list:\n try:\n # Get the API endpoint for the database\n api_endpoint = f\"https://{db.info.id}-{db.info.region}.apps.astra{env_string}.datastax.com\"\n\n # Get the number of collections\n try:\n num_collections = len(\n list(\n client.get_database(\n api_endpoint=api_endpoint, token=token, keyspace=db.info.keyspace\n ).list_collection_names(keyspace=db.info.keyspace)\n )\n )\n except Exception: # noqa: BLE001\n num_collections = 0\n if db.status != \"PENDING\":\n continue\n\n # Add the database to the dictionary\n db_info_dict[db.info.name] = {\n \"api_endpoint\": api_endpoint,\n \"collections\": num_collections,\n \"status\": db.status if db.status != \"ACTIVE\" else None,\n }\n except Exception: # noqa: BLE001, S110\n pass\n\n return db_info_dict\n\n def get_database_list(self):\n return self.get_database_list_static(token=self.token, environment=self.environment)\n\n @classmethod\n def get_api_endpoint_static(\n cls,\n token: str,\n environment: str | None = None,\n api_endpoint: str | None = None,\n database_name: str | None = None,\n ):\n # If the api_endpoint is set, return it\n if api_endpoint:\n return api_endpoint\n\n # Check if the database_name is like a url\n if database_name and database_name.startswith(\"https://\"):\n return database_name\n\n # If the database is not set, nothing we can do.\n if not database_name:\n return None\n\n # Grab the database object\n db = cls.get_database_list_static(token=token, environment=environment).get(database_name)\n if not db:\n return None\n\n # Otherwise, get the URL from the database list\n return db.get(\"api_endpoint\")\n\n def get_api_endpoint(self):\n return self.get_api_endpoint_static(\n token=self.token,\n environment=self.environment,\n api_endpoint=self.api_endpoint,\n database_name=self.database_name,\n )\n\n def get_keyspace(self):\n keyspace = self.keyspace\n\n if keyspace:\n return keyspace.strip()\n\n return None\n\n def get_database_object(self, api_endpoint: str | None = None):\n try:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n return client.get_database(\n api_endpoint=api_endpoint or self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n except Exception as e:\n msg = f\"Error fetching database object: {e}\"\n raise ValueError(msg) from e\n\n def collection_data(self, collection_name: str, database: Database | None = None):\n try:\n if not database:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n database = client.get_database(\n api_endpoint=self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n\n collection = database.get_collection(collection_name, keyspace=self.get_keyspace())\n\n return collection.estimated_document_count()\n except Exception as e: # noqa: BLE001\n self.log(f\"Error checking collection data: {e}\")\n\n return None\n\n def _initialize_database_options(self):\n try:\n return [\n {\n \"name\": name,\n \"status\": info[\"status\"],\n \"collections\": info[\"collections\"],\n \"api_endpoint\": info[\"api_endpoint\"],\n \"icon\": \"data\",\n }\n for name, info in self.get_database_list().items()\n ]\n except Exception as e:\n msg = f\"Error fetching database options: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n def get_provider_icon(cls, collection: CollectionDescriptor | None = None, provider_name: str | None = None) -> str:\n # Get the provider name from the collection\n provider_name = provider_name or (\n collection.options.vector.service.provider\n if collection and collection.options and collection.options.vector and collection.options.vector.service\n else None\n )\n\n # If there is no provider, use the vector store icon\n if not provider_name or provider_name == \"bring your own\":\n return \"vectorstores\"\n\n # Special case for certain models\n # TODO: Add more icons\n if provider_name == \"nvidia\":\n return \"NVIDIA\"\n if provider_name == \"openai\":\n return \"OpenAI\"\n\n # Title case on the provider for the icon if no special case\n return provider_name.title()\n\n def _initialize_collection_options(self, api_endpoint: str | None = None):\n # Nothing to generate if we don't have an API endpoint yet\n api_endpoint = api_endpoint or self.get_api_endpoint()\n if not api_endpoint:\n return []\n\n # Retrieve the database object\n database = self.get_database_object(api_endpoint=api_endpoint)\n\n # Get the list of collections\n collection_list = list(database.list_collections(keyspace=self.get_keyspace()))\n\n # Return the list of collections and metadata associated\n return [\n {\n \"name\": col.name,\n \"records\": self.collection_data(collection_name=col.name, database=database),\n \"provider\": (\n col.options.vector.service.provider if col.options.vector and col.options.vector.service else None\n ),\n \"icon\": self.get_provider_icon(collection=col),\n \"model\": (\n col.options.vector.service.model_name if col.options.vector and col.options.vector.service else None\n ),\n }\n for col in collection_list\n ]\n\n def reset_provider_options(self, build_config: dict):\n # Get the list of vectorize providers\n vectorize_providers = self.get_vectorize_providers(\n token=self.token,\n environment=self.environment,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n )\n\n # Append a special case for Bring your own\n vectorize_providers[\"Bring your own\"] = [None, [\"Bring your own\"]]\n\n # If the collection is set, allow user to see embedding options\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"] = [\"Bring your own\", \"Nvidia\", *[key for key in vectorize_providers if key != \"Nvidia\"]]\n\n # For all not Bring your own or Nvidia providers, add metadata saying configure in Astra DB Portal\n provider_options = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"]\n\n # Go over each possible provider and add metadata to configure in Astra DB Portal\n for provider in provider_options:\n # Skip Bring your own and Nvidia, automatically configured\n if provider in [\"Bring your own\", \"Nvidia\"]:\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\"icon\": self.get_provider_icon(provider_name=provider.lower())})\n continue\n\n # Add metadata to configure in Astra DB Portal\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\" \": \"Configure in Astra DB Portal\"})\n\n # And allow the user to see the models based on a selected provider\n embedding_provider = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"value\"]\n\n # Set the options for the embedding model based on the provider\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_model\"\n ][\"options\"] = vectorize_providers.get(embedding_provider, [[], []])[1]\n\n return build_config\n\n def reset_collection_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n collection_options = self._initialize_collection_options(api_endpoint=build_config[\"api_endpoint\"][\"value\"])\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"collection_name\"][\"options\"] = [col[\"name\"] for col in collection_options]\n build_config[\"collection_name\"][\"options_metadata\"] = [\n {k: v for k, v in col.items() if k not in [\"name\"]} for col in collection_options\n ]\n\n # Reset the selected collection\n if build_config[\"collection_name\"][\"value\"] not in build_config[\"collection_name\"][\"options\"]:\n build_config[\"collection_name\"][\"value\"] = \"\"\n\n # If we have a database, collection name should not be advanced\n build_config[\"collection_name\"][\"advanced\"] = not build_config[\"database_name\"][\"value\"]\n\n return build_config\n\n def reset_database_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n database_options = self._initialize_database_options()\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"database_name\"][\"options\"] = [db[\"name\"] for db in database_options]\n build_config[\"database_name\"][\"options_metadata\"] = [\n {k: v for k, v in db.items() if k not in [\"name\"]} for db in database_options\n ]\n\n # Reset the selected database\n if build_config[\"database_name\"][\"value\"] not in build_config[\"database_name\"][\"options\"]:\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n # If we have a token, database name should not be advanced\n build_config[\"database_name\"][\"advanced\"] = not build_config[\"token\"][\"value\"]\n\n return build_config\n\n def reset_build_config(self, build_config: dict):\n # Reset the list of databases we have based on the token provided\n build_config[\"database_name\"][\"options\"] = []\n build_config[\"database_name\"][\"options_metadata\"] = []\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"database_name\"][\"advanced\"] = True\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n\n # Reset the list of collections and metadata associated\n build_config[\"collection_name\"][\"options\"] = []\n build_config[\"collection_name\"][\"options_metadata\"] = []\n build_config[\"collection_name\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n return build_config\n\n async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):\n # Callback for database creation\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" in field_value:\n try:\n await self.create_database_api(\n new_database_name=field_value[\"new_database_name\"],\n token=self.token,\n keyspace=self.get_keyspace(),\n environment=self.environment,\n cloud_provider=field_value[\"cloud_provider\"],\n region=field_value[\"region\"],\n )\n except Exception as e:\n msg = f\"Error creating database: {e}\"\n raise ValueError(msg) from e\n\n # Add the new database to the list of options\n build_config[\"database_name\"][\"options\"] = build_config[\"database_name\"][\"options\"] + [\n field_value[\"new_database_name\"]\n ]\n build_config[\"database_name\"][\"options_metadata\"] = build_config[\"database_name\"][\"options_metadata\"] + [\n {\"status\": \"PENDING\"}\n ]\n\n return self.reset_collection_list(build_config)\n\n # This is the callback required to update the list of regions for a cloud provider\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" not in field_value:\n cloud_provider = field_value[\"cloud_provider\"]\n build_config[\"database_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\"region\"][\n \"options\"\n ] = self.map_cloud_providers()[cloud_provider][\"regions\"]\n\n return build_config\n\n # Callback for the creation of collections\n if field_name == \"collection_name\" and isinstance(field_value, dict) and \"new_collection_name\" in field_value:\n try:\n # Get the dimension if its a BYO provider\n dimension = (\n field_value[\"dimension\"]\n if field_value[\"embedding_generation_provider\"] == \"Bring your own\"\n else None\n )\n\n # Create the collection\n await self.create_collection_api(\n new_collection_name=field_value[\"new_collection_name\"],\n token=self.token,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n environment=self.environment,\n keyspace=self.get_keyspace(),\n dimension=dimension,\n embedding_generation_provider=field_value[\"embedding_generation_provider\"],\n embedding_generation_model=field_value[\"embedding_generation_model\"],\n )\n except Exception as e:\n msg = f\"Error creating collection: {e}\"\n raise ValueError(msg) from e\n\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"value\"] = field_value[\"new_collection_name\"]\n build_config[\"collection_name\"][\"options\"].append(field_value[\"new_collection_name\"])\n\n # Get the provider and model for the new collection\n generation_provider = field_value[\"embedding_generation_provider\"]\n provider = generation_provider if generation_provider != \"Bring your own\" else None\n generation_model = field_value[\"embedding_generation_model\"]\n model = generation_model if generation_model and generation_model != \"Bring your own\" else None\n\n # Set the embedding choice\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\" if provider else \"Embedding Model\"\n build_config[\"embedding_model\"][\"advanced\"] = bool(provider)\n\n # Add the new collection to the list of options\n icon = \"NVIDIA\" if provider == \"Nvidia\" else \"vectorstores\"\n build_config[\"collection_name\"][\"options_metadata\"] = build_config[\"collection_name\"][\n \"options_metadata\"\n ] + [{\"records\": 0, \"provider\": provider, \"icon\": icon, \"model\": model}]\n\n return build_config\n\n # Callback to update the model list based on the embedding provider\n if (\n field_name == \"collection_name\"\n and isinstance(field_value, dict)\n and \"new_collection_name\" not in field_value\n ):\n return self.reset_provider_options(build_config)\n\n # When the component first executes, this is the update refresh call\n first_run = field_name == \"collection_name\" and not field_value and not build_config[\"database_name\"][\"options\"]\n\n # If the token has not been provided, simply return the empty build config\n if not self.token:\n return self.reset_build_config(build_config)\n\n # If this is the first execution of the component, reset and build database list\n if first_run or field_name in [\"token\", \"environment\"]:\n return self.reset_database_list(build_config)\n\n # Refresh the collection name options\n if field_name == \"database_name\" and not isinstance(field_value, dict):\n # If missing, refresh the database options\n if field_value not in build_config[\"database_name\"][\"options\"]:\n build_config = await self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n build_config[\"database_name\"][\"value\"] = \"\"\n else:\n # Find the position of the selected database to align with metadata\n index_of_name = build_config[\"database_name\"][\"options\"].index(field_value)\n\n # Initializing database condition\n pending = build_config[\"database_name\"][\"options_metadata\"][index_of_name][\"status\"] == \"PENDING\"\n if pending:\n return self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n\n # Set the API endpoint based on the selected database\n build_config[\"api_endpoint\"][\"value\"] = build_config[\"database_name\"][\"options_metadata\"][\n index_of_name\n ][\"api_endpoint\"]\n\n # Reset the provider options\n build_config = self.reset_provider_options(build_config)\n\n # Reset the list of collections we have based on the token provided\n return self.reset_collection_list(build_config)\n\n # Hide embedding model option if opriona_metadata provider is not null\n if field_name == \"collection_name\" and not isinstance(field_value, dict):\n # Assume we will be autodetecting the collection:\n build_config[\"autodetect_collection\"][\"value\"] = True\n\n # Reload the collection list\n build_config = self.reset_collection_list(build_config)\n\n # Set the options for collection name to be the field value if its a new collection\n if field_value and field_value not in build_config[\"collection_name\"][\"options\"]:\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"options\"].append(field_value)\n build_config[\"collection_name\"][\"options_metadata\"].append(\n {\"records\": 0, \"provider\": None, \"icon\": \"\", \"model\": None}\n )\n\n # Ensure that autodetect collection is set to False, since its a new collection\n build_config[\"autodetect_collection\"][\"value\"] = False\n\n # If nothing is selected, can't detect provider - return\n if not field_value:\n return build_config\n\n # Find the position of the selected collection to align with metadata\n index_of_name = build_config[\"collection_name\"][\"options\"].index(field_value)\n value_of_provider = build_config[\"collection_name\"][\"options_metadata\"][index_of_name][\"provider\"]\n\n # If we were able to determine the Vectorize provider, set it accordingly\n if value_of_provider:\n build_config[\"embedding_model\"][\"advanced\"] = True\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\"\n else:\n build_config[\"embedding_model\"][\"advanced\"] = False\n build_config[\"embedding_choice\"][\"value\"] = \"Embedding Model\"\n\n return build_config\n\n return build_config\n\n @check_cached_vector_store\n def build_vector_store(self):\n try:\n from langchain_astradb import AstraDBVectorStore\n except ImportError as e:\n msg = (\n \"Could not import langchain Astra DB integration package. \"\n \"Please install it with `pip install langchain-astradb`.\"\n )\n raise ImportError(msg) from e\n\n # Get the embedding model and additional params\n embedding_params = (\n {\"embedding\": self.embedding_model}\n if self.embedding_model and self.embedding_choice == \"Embedding Model\"\n else {}\n )\n\n # Get the additional parameters\n additional_params = self.astradb_vectorstore_kwargs or {}\n\n # Get Langflow version and platform information\n __version__ = get_version_info()[\"version\"]\n langflow_prefix = \"\"\n # if os.getenv(\"AWS_EXECUTION_ENV\") == \"AWS_ECS_FARGATE\": # TODO: More precise way of detecting\n # langflow_prefix = \"ds-\"\n\n # Get the database object\n database = self.get_database_object()\n autodetect = self.collection_name in database.list_collection_names() and self.autodetect_collection\n\n # Bundle up the auto-detect parameters\n autodetect_params = {\n \"autodetect_collection\": autodetect,\n \"content_field\": (\n self.content_field\n if self.content_field and embedding_params\n else (\n \"page_content\"\n if embedding_params\n and self.collection_data(collection_name=self.collection_name, database=database) == 0\n else None\n )\n ),\n \"ignore_invalid_documents\": self.ignore_invalid_documents,\n }\n\n # Attempt to build the Vector Store object\n try:\n vector_store = AstraDBVectorStore(\n # Astra DB Authentication Parameters\n token=self.token,\n api_endpoint=database.api_endpoint,\n namespace=database.keyspace,\n collection_name=self.collection_name,\n environment=self.environment,\n # Astra DB Usage Tracking Parameters\n ext_callers=[(f\"{langflow_prefix}langflow\", __version__)],\n # Astra DB Vector Store Parameters\n **autodetect_params,\n **embedding_params,\n **additional_params,\n )\n except Exception as e:\n msg = f\"Error initializing AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n # Add documents to the vector store\n self._add_documents_to_vector_store(vector_store)\n\n return vector_store\n\n def _add_documents_to_vector_store(self, vector_store) -> None:\n documents = []\n for _input in self.ingest_data or []:\n if isinstance(_input, Data):\n documents.append(_input.to_lc_document())\n else:\n msg = \"Vector Store Inputs must be Data objects.\"\n raise TypeError(msg)\n\n if documents and self.deletion_field:\n self.log(f\"Deleting documents where {self.deletion_field}\")\n try:\n database = self.get_database_object()\n collection = database.get_collection(self.collection_name, keyspace=database.keyspace)\n delete_values = list({doc.metadata[self.deletion_field] for doc in documents})\n self.log(f\"Deleting documents where {self.deletion_field} matches {delete_values}.\")\n collection.delete_many({f\"metadata.{self.deletion_field}\": {\"$in\": delete_values}})\n except Exception as e:\n msg = f\"Error deleting documents from AstraDBVectorStore based on '{self.deletion_field}': {e}\"\n raise ValueError(msg) from e\n\n if documents:\n self.log(f\"Adding {len(documents)} documents to the Vector Store.\")\n try:\n vector_store.add_documents(documents)\n except Exception as e:\n msg = f\"Error adding documents to AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n else:\n self.log(\"No documents to add to the Vector Store.\")\n\n def _map_search_type(self) -> str:\n search_type_mapping = {\n \"Similarity with score threshold\": \"similarity_score_threshold\",\n \"MMR (Max Marginal Relevance)\": \"mmr\",\n }\n\n return search_type_mapping.get(self.search_type, \"similarity\")\n\n def _build_search_args(self):\n query = self.search_query if isinstance(self.search_query, str) and self.search_query.strip() else None\n\n if query:\n args = {\n \"query\": query,\n \"search_type\": self._map_search_type(),\n \"k\": self.number_of_results,\n \"score_threshold\": self.search_score_threshold,\n }\n elif self.advanced_search_filter:\n args = {\n \"n\": self.number_of_results,\n }\n else:\n return {}\n\n filter_arg = self.advanced_search_filter or {}\n if filter_arg:\n args[\"filter\"] = filter_arg\n\n return args\n\n def search_documents(self, vector_store=None) -> list[Data]:\n vector_store = vector_store or self.build_vector_store()\n\n self.log(f\"Search input: {self.search_query}\")\n self.log(f\"Search type: {self.search_type}\")\n self.log(f\"Number of results: {self.number_of_results}\")\n\n try:\n search_args = self._build_search_args()\n except Exception as e:\n msg = f\"Error in AstraDBVectorStore._build_search_args: {e}\"\n raise ValueError(msg) from e\n\n if not search_args:\n self.log(\"No search input or filters provided. Skipping search.\")\n return []\n\n docs = []\n search_method = \"search\" if \"query\" in search_args else \"metadata_search\"\n\n try:\n self.log(f\"Calling vector_store.{search_method} with args: {search_args}\")\n docs = getattr(vector_store, search_method)(**search_args)\n except Exception as e:\n msg = f\"Error performing {search_method} in AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n self.log(f\"Retrieved documents: {len(docs)}\")\n\n data = docs_to_data(docs)\n self.log(f\"Converted documents to data: {len(data)}\")\n self.status = data\n\n return data\n\n def get_retriever_kwargs(self):\n search_args = self._build_search_args()\n\n return {\n \"search_type\": self._map_search_type(),\n \"search_kwargs\": search_args,\n }\n" + "value": "from collections import defaultdict\nfrom dataclasses import asdict, dataclass, field\n\nfrom astrapy import AstraDBAdmin, DataAPIClient, Database\nfrom astrapy.info import CollectionDescriptor\nfrom langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions\n\nfrom langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store\nfrom langflow.helpers import docs_to_data\nfrom langflow.inputs import FloatInput, NestedDictInput\nfrom langflow.io import (\n BoolInput,\n DropdownInput,\n HandleInput,\n IntInput,\n SecretStrInput,\n StrInput,\n)\nfrom langflow.schema import Data\nfrom langflow.utils.version import get_version_info\n\n\nclass AstraDBVectorStoreComponent(LCVectorStoreComponent):\n display_name: str = \"Astra DB\"\n description: str = \"Ingest and search documents in Astra DB\"\n documentation: str = \"https://docs.datastax.com/en/langflow/astra-components.html\"\n name = \"AstraDB\"\n icon: str = \"AstraDB\"\n\n _cached_vector_store: AstraDBVectorStore | None = None\n\n @dataclass\n class NewDatabaseInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_database\",\n \"description\": \"\",\n \"display_name\": \"Create new database\",\n \"field_order\": [\"new_database_name\", \"cloud_provider\", \"region\"],\n \"template\": {\n \"new_database_name\": StrInput(\n name=\"new_database_name\",\n display_name=\"Name\",\n info=\"Name of the new database to create in Astra DB.\",\n required=True,\n ),\n \"cloud_provider\": DropdownInput(\n name=\"cloud_provider\",\n display_name=\"Cloud provider\",\n info=\"Cloud provider for the new database.\",\n options=[\"Amazon Web Services\", \"Google Cloud Platform\", \"Microsoft Azure\"],\n required=True,\n real_time_refresh=True,\n ),\n \"region\": DropdownInput(\n name=\"region\",\n display_name=\"Region\",\n info=\"Region for the new database.\",\n options=[],\n required=True,\n ),\n },\n },\n }\n }\n )\n\n @dataclass\n class NewCollectionInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_collection\",\n \"description\": \"\",\n \"display_name\": \"Create new collection\",\n \"field_order\": [\n \"new_collection_name\",\n \"embedding_generation_provider\",\n \"embedding_generation_model\",\n \"dimension\",\n ],\n \"template\": {\n \"new_collection_name\": StrInput(\n name=\"new_collection_name\",\n display_name=\"Name\",\n info=\"Name of the new collection to create in Astra DB.\",\n required=True,\n ),\n \"embedding_generation_provider\": DropdownInput(\n name=\"embedding_generation_provider\",\n display_name=\"Embedding generation method\",\n info=\"Provider to use for generating embeddings.\",\n real_time_refresh=True,\n required=True,\n options=[\"Bring your own\", \"Nvidia\"],\n ),\n \"embedding_generation_model\": DropdownInput(\n name=\"embedding_generation_model\",\n display_name=\"Embedding model\",\n info=\"Model to use for generating embeddings.\",\n required=True,\n options=[],\n ),\n \"dimension\": IntInput(\n name=\"dimension\",\n display_name=\"Dimensions (Required only for `Bring your own`)\",\n info=\"Dimensions of the embeddings to generate.\",\n required=False,\n value=1024,\n ),\n },\n },\n }\n }\n )\n\n inputs = [\n SecretStrInput(\n name=\"token\",\n display_name=\"Astra DB Application Token\",\n info=\"Authentication token for accessing Astra DB.\",\n value=\"ASTRA_DB_APPLICATION_TOKEN\",\n required=True,\n real_time_refresh=True,\n input_types=[],\n ),\n StrInput(\n name=\"environment\",\n display_name=\"Environment\",\n info=\"The environment for the Astra DB API Endpoint.\",\n advanced=True,\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"database_name\",\n display_name=\"Database\",\n info=\"The Database name for the Astra DB instance.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewDatabaseInput()),\n combobox=True,\n ),\n StrInput(\n name=\"api_endpoint\",\n display_name=\"Astra DB API Endpoint\",\n info=\"The API Endpoint for the Astra DB instance. Supercedes database selection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"collection_name\",\n display_name=\"Collection\",\n info=\"The name of the collection within Astra DB where the vectors will be stored.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewCollectionInput()),\n combobox=True,\n advanced=True,\n ),\n StrInput(\n name=\"keyspace\",\n display_name=\"Keyspace\",\n info=\"Optional keyspace within Astra DB to use for the collection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"embedding_choice\",\n display_name=\"Embedding Model or Astra Vectorize\",\n info=\"Choose an embedding model or use Astra Vectorize.\",\n options=[\"Embedding Model\", \"Astra Vectorize\"],\n value=\"Embedding Model\",\n advanced=True,\n real_time_refresh=True,\n ),\n HandleInput(\n name=\"embedding_model\",\n display_name=\"Embedding Model\",\n input_types=[\"Embeddings\"],\n info=\"Specify the Embedding Model. Not required for Astra Vectorize collections.\",\n required=False,\n ),\n *LCVectorStoreComponent.inputs,\n IntInput(\n name=\"number_of_results\",\n display_name=\"Number of Search Results\",\n info=\"Number of search results to return.\",\n advanced=True,\n value=4,\n ),\n DropdownInput(\n name=\"search_type\",\n display_name=\"Search Type\",\n info=\"Search type to use\",\n options=[\"Similarity\", \"Similarity with score threshold\", \"MMR (Max Marginal Relevance)\"],\n value=\"Similarity\",\n advanced=True,\n ),\n FloatInput(\n name=\"search_score_threshold\",\n display_name=\"Search Score Threshold\",\n info=\"Minimum similarity score threshold for search results. \"\n \"(when using 'Similarity with score threshold')\",\n value=0,\n advanced=True,\n ),\n NestedDictInput(\n name=\"advanced_search_filter\",\n display_name=\"Search Metadata Filter\",\n info=\"Optional dictionary of filters to apply to the search query.\",\n advanced=True,\n ),\n BoolInput(\n name=\"autodetect_collection\",\n display_name=\"Autodetect Collection\",\n info=\"Boolean flag to determine whether to autodetect the collection.\",\n advanced=True,\n value=True,\n ),\n StrInput(\n name=\"content_field\",\n display_name=\"Content Field\",\n info=\"Field to use as the text content field for the vector store.\",\n advanced=True,\n ),\n StrInput(\n name=\"deletion_field\",\n display_name=\"Deletion Based On Field\",\n info=\"When this parameter is provided, documents in the target collection with \"\n \"metadata field values matching the input metadata field value will be deleted \"\n \"before new data is loaded.\",\n advanced=True,\n ),\n BoolInput(\n name=\"ignore_invalid_documents\",\n display_name=\"Ignore Invalid Documents\",\n info=\"Boolean flag to determine whether to ignore invalid documents at runtime.\",\n advanced=True,\n ),\n NestedDictInput(\n name=\"astradb_vectorstore_kwargs\",\n display_name=\"AstraDBVectorStore Parameters\",\n info=\"Optional dictionary of additional parameters for the AstraDBVectorStore.\",\n advanced=True,\n ),\n ]\n\n @classmethod\n def map_cloud_providers(cls):\n # TODO: Programmatically fetch the regions for each cloud provider\n return {\n \"Amazon Web Services\": {\n \"id\": \"aws\",\n \"regions\": [\"us-east-2\", \"ap-south-1\", \"eu-west-1\"],\n },\n \"Google Cloud Platform\": {\n \"id\": \"gcp\",\n \"regions\": [\"us-east1\"],\n },\n \"Microsoft Azure\": {\n \"id\": \"azure\",\n \"regions\": [\"westus3\"],\n },\n }\n\n @classmethod\n def get_vectorize_providers(cls, token: str, environment: str | None = None, api_endpoint: str | None = None):\n try:\n # Get the admin object\n admin = AstraDBAdmin(token=token, environment=environment)\n db_admin = admin.get_database_admin(api_endpoint=api_endpoint)\n\n # Get the list of embedding providers\n embedding_providers = db_admin.find_embedding_providers().as_dict()\n\n vectorize_providers_mapping = {}\n # Map the provider display name to the provider key and models\n for provider_key, provider_data in embedding_providers[\"embeddingProviders\"].items():\n # Get the provider display name and models\n display_name = provider_data[\"displayName\"]\n models = [model[\"name\"] for model in provider_data[\"models\"]]\n\n # Build our mapping\n vectorize_providers_mapping[display_name] = [provider_key, models]\n\n # Sort the resulting dictionary\n return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))\n except Exception as e:\n msg = f\"Error fetching vectorize providers: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n async def create_database_api(\n cls,\n new_database_name: str,\n cloud_provider: str,\n region: str,\n token: str,\n environment: str | None = None,\n keyspace: str | None = None,\n ):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Call the create database function\n return await admin_client.async_create_database(\n name=new_database_name,\n cloud_provider=cls.map_cloud_providers()[cloud_provider][\"id\"],\n region=region,\n keyspace=keyspace,\n wait_until_active=False,\n )\n\n @classmethod\n async def create_collection_api(\n cls,\n new_collection_name: str,\n token: str,\n api_endpoint: str,\n environment: str | None = None,\n keyspace: str | None = None,\n dimension: int | None = None,\n embedding_generation_provider: str | None = None,\n embedding_generation_model: str | None = None,\n ):\n # Create the data API client\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the database object\n database = client.get_async_database(api_endpoint=api_endpoint, token=token)\n\n # Build vectorize options, if needed\n vectorize_options = None\n if not dimension:\n vectorize_options = CollectionVectorServiceOptions(\n provider=cls.get_vectorize_providers(\n token=token, environment=environment, api_endpoint=api_endpoint\n ).get(embedding_generation_provider, [None, []])[0],\n model_name=embedding_generation_model,\n )\n\n # Create the collection\n return await database.create_collection(\n name=new_collection_name,\n keyspace=keyspace,\n dimension=dimension,\n service=vectorize_options,\n )\n\n @classmethod\n def get_database_list_static(cls, token: str, environment: str | None = None):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Get the list of databases\n db_list = list(admin_client.list_databases())\n\n # Set the environment properly\n env_string = \"\"\n if environment and environment != \"prod\":\n env_string = f\"-{environment}\"\n\n # Generate the api endpoint for each database\n db_info_dict = {}\n for db in db_list:\n try:\n # Get the API endpoint for the database\n api_endpoint = f\"https://{db.info.id}-{db.info.region}.apps.astra{env_string}.datastax.com\"\n\n # Get the number of collections\n try:\n num_collections = len(\n list(\n client.get_database(\n api_endpoint=api_endpoint, token=token, keyspace=db.info.keyspace\n ).list_collection_names(keyspace=db.info.keyspace)\n )\n )\n except Exception: # noqa: BLE001\n num_collections = 0\n if db.status != \"PENDING\":\n continue\n\n # Add the database to the dictionary\n db_info_dict[db.info.name] = {\n \"api_endpoint\": api_endpoint,\n \"collections\": num_collections,\n \"status\": db.status if db.status != \"ACTIVE\" else None,\n }\n except Exception: # noqa: BLE001, S110\n pass\n\n return db_info_dict\n\n def get_database_list(self):\n return self.get_database_list_static(token=self.token, environment=self.environment)\n\n @classmethod\n def get_api_endpoint_static(\n cls,\n token: str,\n environment: str | None = None,\n api_endpoint: str | None = None,\n database_name: str | None = None,\n ):\n # If the api_endpoint is set, return it\n if api_endpoint:\n return api_endpoint\n\n # Check if the database_name is like a url\n if database_name and database_name.startswith(\"https://\"):\n return database_name\n\n # If the database is not set, nothing we can do.\n if not database_name:\n return None\n\n # Grab the database object\n db = cls.get_database_list_static(token=token, environment=environment).get(database_name)\n if not db:\n return None\n\n # Otherwise, get the URL from the database list\n return db.get(\"api_endpoint\")\n\n def get_api_endpoint(self):\n return self.get_api_endpoint_static(\n token=self.token,\n environment=self.environment,\n api_endpoint=self.api_endpoint,\n database_name=self.database_name,\n )\n\n def get_keyspace(self):\n keyspace = self.keyspace\n\n if keyspace:\n return keyspace.strip()\n\n return None\n\n def get_database_object(self, api_endpoint: str | None = None):\n try:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n return client.get_database(\n api_endpoint=api_endpoint or self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n except Exception as e:\n msg = f\"Error fetching database object: {e}\"\n raise ValueError(msg) from e\n\n def collection_data(self, collection_name: str, database: Database | None = None):\n try:\n if not database:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n database = client.get_database(\n api_endpoint=self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n\n collection = database.get_collection(collection_name, keyspace=self.get_keyspace())\n\n return collection.estimated_document_count()\n except Exception as e: # noqa: BLE001\n self.log(f\"Error checking collection data: {e}\")\n\n return None\n\n def _initialize_database_options(self):\n try:\n return [\n {\n \"name\": name,\n \"status\": info[\"status\"],\n \"collections\": info[\"collections\"],\n \"api_endpoint\": info[\"api_endpoint\"],\n \"icon\": \"data\",\n }\n for name, info in self.get_database_list().items()\n ]\n except Exception as e:\n msg = f\"Error fetching database options: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n def get_provider_icon(cls, collection: CollectionDescriptor | None = None, provider_name: str | None = None) -> str:\n # Get the provider name from the collection\n provider_name = provider_name or (\n collection.options.vector.service.provider\n if collection and collection.options and collection.options.vector and collection.options.vector.service\n else None\n )\n\n # If there is no provider, use the vector store icon\n if not provider_name or provider_name == \"bring your own\":\n return \"vectorstores\"\n\n # Special case for certain models\n # TODO: Add more icons\n if provider_name == \"nvidia\":\n return \"NVIDIA\"\n if provider_name == \"openai\":\n return \"OpenAI\"\n\n # Title case on the provider for the icon if no special case\n return provider_name.title()\n\n def _initialize_collection_options(self, api_endpoint: str | None = None):\n # Nothing to generate if we don't have an API endpoint yet\n api_endpoint = api_endpoint or self.get_api_endpoint()\n if not api_endpoint:\n return []\n\n # Retrieve the database object\n database = self.get_database_object(api_endpoint=api_endpoint)\n\n # Get the list of collections\n collection_list = list(database.list_collections(keyspace=self.get_keyspace()))\n\n # Return the list of collections and metadata associated\n return [\n {\n \"name\": col.name,\n \"records\": self.collection_data(collection_name=col.name, database=database),\n \"provider\": (\n col.options.vector.service.provider if col.options.vector and col.options.vector.service else None\n ),\n \"icon\": self.get_provider_icon(collection=col),\n \"model\": (\n col.options.vector.service.model_name if col.options.vector and col.options.vector.service else None\n ),\n }\n for col in collection_list\n ]\n\n def reset_provider_options(self, build_config: dict):\n # Get the list of vectorize providers\n vectorize_providers = self.get_vectorize_providers(\n token=self.token,\n environment=self.environment,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n )\n\n # Append a special case for Bring your own\n vectorize_providers[\"Bring your own\"] = [None, [\"Bring your own\"]]\n\n # If the collection is set, allow user to see embedding options\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"] = [\"Bring your own\", \"Nvidia\", *[key for key in vectorize_providers if key != \"Nvidia\"]]\n\n # For all not Bring your own or Nvidia providers, add metadata saying configure in Astra DB Portal\n provider_options = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"]\n\n # Go over each possible provider and add metadata to configure in Astra DB Portal\n for provider in provider_options:\n # Skip Bring your own and Nvidia, automatically configured\n if provider in {\"Bring your own\", \"Nvidia\"}:\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\"icon\": self.get_provider_icon(provider_name=provider.lower())})\n continue\n\n # Add metadata to configure in Astra DB Portal\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\" \": \"Configure in Astra DB Portal\"})\n\n # And allow the user to see the models based on a selected provider\n embedding_provider = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"value\"]\n\n # Set the options for the embedding model based on the provider\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_model\"\n ][\"options\"] = vectorize_providers.get(embedding_provider, [[], []])[1]\n\n return build_config\n\n def reset_collection_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n collection_options = self._initialize_collection_options(api_endpoint=build_config[\"api_endpoint\"][\"value\"])\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"collection_name\"][\"options\"] = [col[\"name\"] for col in collection_options]\n build_config[\"collection_name\"][\"options_metadata\"] = [\n {k: v for k, v in col.items() if k != \"name\"} for col in collection_options\n ]\n\n # Reset the selected collection\n if build_config[\"collection_name\"][\"value\"] not in build_config[\"collection_name\"][\"options\"]:\n build_config[\"collection_name\"][\"value\"] = \"\"\n\n # If we have a database, collection name should not be advanced\n build_config[\"collection_name\"][\"advanced\"] = not build_config[\"database_name\"][\"value\"]\n\n return build_config\n\n def reset_database_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n database_options = self._initialize_database_options()\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"database_name\"][\"options\"] = [db[\"name\"] for db in database_options]\n build_config[\"database_name\"][\"options_metadata\"] = [\n {k: v for k, v in db.items() if k != \"name\"} for db in database_options\n ]\n\n # Reset the selected database\n if build_config[\"database_name\"][\"value\"] not in build_config[\"database_name\"][\"options\"]:\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n # If we have a token, database name should not be advanced\n build_config[\"database_name\"][\"advanced\"] = not build_config[\"token\"][\"value\"]\n\n return build_config\n\n def reset_build_config(self, build_config: dict):\n # Reset the list of databases we have based on the token provided\n build_config[\"database_name\"][\"options\"] = []\n build_config[\"database_name\"][\"options_metadata\"] = []\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"database_name\"][\"advanced\"] = True\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n\n # Reset the list of collections and metadata associated\n build_config[\"collection_name\"][\"options\"] = []\n build_config[\"collection_name\"][\"options_metadata\"] = []\n build_config[\"collection_name\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n return build_config\n\n async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):\n # Callback for database creation\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" in field_value:\n try:\n await self.create_database_api(\n new_database_name=field_value[\"new_database_name\"],\n token=self.token,\n keyspace=self.get_keyspace(),\n environment=self.environment,\n cloud_provider=field_value[\"cloud_provider\"],\n region=field_value[\"region\"],\n )\n except Exception as e:\n msg = f\"Error creating database: {e}\"\n raise ValueError(msg) from e\n\n # Add the new database to the list of options\n build_config[\"database_name\"][\"options\"] += [field_value[\"new_database_name\"]]\n build_config[\"database_name\"][\"options_metadata\"] += [{\"status\": \"PENDING\"}]\n\n return self.reset_collection_list(build_config)\n\n # This is the callback required to update the list of regions for a cloud provider\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" not in field_value:\n cloud_provider = field_value[\"cloud_provider\"]\n build_config[\"database_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\"region\"][\n \"options\"\n ] = self.map_cloud_providers()[cloud_provider][\"regions\"]\n\n return build_config\n\n # Callback for the creation of collections\n if field_name == \"collection_name\" and isinstance(field_value, dict) and \"new_collection_name\" in field_value:\n try:\n # Get the dimension if its a BYO provider\n dimension = (\n field_value[\"dimension\"]\n if field_value[\"embedding_generation_provider\"] == \"Bring your own\"\n else None\n )\n\n # Create the collection\n await self.create_collection_api(\n new_collection_name=field_value[\"new_collection_name\"],\n token=self.token,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n environment=self.environment,\n keyspace=self.get_keyspace(),\n dimension=dimension,\n embedding_generation_provider=field_value[\"embedding_generation_provider\"],\n embedding_generation_model=field_value[\"embedding_generation_model\"],\n )\n except Exception as e:\n msg = f\"Error creating collection: {e}\"\n raise ValueError(msg) from e\n\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"value\"] = field_value[\"new_collection_name\"]\n build_config[\"collection_name\"][\"options\"].append(field_value[\"new_collection_name\"])\n\n # Get the provider and model for the new collection\n generation_provider = field_value[\"embedding_generation_provider\"]\n provider = generation_provider if generation_provider != \"Bring your own\" else None\n generation_model = field_value[\"embedding_generation_model\"]\n model = generation_model if generation_model and generation_model != \"Bring your own\" else None\n\n # Set the embedding choice\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\" if provider else \"Embedding Model\"\n build_config[\"embedding_model\"][\"advanced\"] = bool(provider)\n\n # Add the new collection to the list of options\n icon = \"NVIDIA\" if provider == \"Nvidia\" else \"vectorstores\"\n build_config[\"collection_name\"][\"options_metadata\"] += [\n {\"records\": 0, \"provider\": provider, \"icon\": icon, \"model\": model}\n ]\n\n return build_config\n\n # Callback to update the model list based on the embedding provider\n if (\n field_name == \"collection_name\"\n and isinstance(field_value, dict)\n and \"new_collection_name\" not in field_value\n ):\n return self.reset_provider_options(build_config)\n\n # When the component first executes, this is the update refresh call\n first_run = field_name == \"collection_name\" and not field_value and not build_config[\"database_name\"][\"options\"]\n\n # If the token has not been provided, simply return the empty build config\n if not self.token:\n return self.reset_build_config(build_config)\n\n # If this is the first execution of the component, reset and build database list\n if first_run or field_name in {\"token\", \"environment\"}:\n return self.reset_database_list(build_config)\n\n # Refresh the collection name options\n if field_name == \"database_name\" and not isinstance(field_value, dict):\n # If missing, refresh the database options\n if field_value not in build_config[\"database_name\"][\"options\"]:\n build_config = await self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n build_config[\"database_name\"][\"value\"] = \"\"\n else:\n # Find the position of the selected database to align with metadata\n index_of_name = build_config[\"database_name\"][\"options\"].index(field_value)\n\n # Initializing database condition\n pending = build_config[\"database_name\"][\"options_metadata\"][index_of_name][\"status\"] == \"PENDING\"\n if pending:\n return self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n\n # Set the API endpoint based on the selected database\n build_config[\"api_endpoint\"][\"value\"] = build_config[\"database_name\"][\"options_metadata\"][\n index_of_name\n ][\"api_endpoint\"]\n\n # Reset the provider options\n build_config = self.reset_provider_options(build_config)\n\n # Reset the list of collections we have based on the token provided\n return self.reset_collection_list(build_config)\n\n # Hide embedding model option if opriona_metadata provider is not null\n if field_name == \"collection_name\" and not isinstance(field_value, dict):\n # Assume we will be autodetecting the collection:\n build_config[\"autodetect_collection\"][\"value\"] = True\n\n # Reload the collection list\n build_config = self.reset_collection_list(build_config)\n\n # Set the options for collection name to be the field value if its a new collection\n if field_value and field_value not in build_config[\"collection_name\"][\"options\"]:\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"options\"].append(field_value)\n build_config[\"collection_name\"][\"options_metadata\"].append(\n {\n \"records\": 0,\n \"provider\": None,\n \"icon\": \"\",\n \"model\": None,\n }\n )\n\n # Ensure that autodetect collection is set to False, since its a new collection\n build_config[\"autodetect_collection\"][\"value\"] = False\n\n # If nothing is selected, can't detect provider - return\n if not field_value:\n return build_config\n\n # Find the position of the selected collection to align with metadata\n index_of_name = build_config[\"collection_name\"][\"options\"].index(field_value)\n value_of_provider = build_config[\"collection_name\"][\"options_metadata\"][index_of_name][\"provider\"]\n\n # If we were able to determine the Vectorize provider, set it accordingly\n if value_of_provider:\n build_config[\"embedding_model\"][\"advanced\"] = True\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\"\n else:\n build_config[\"embedding_model\"][\"advanced\"] = False\n build_config[\"embedding_choice\"][\"value\"] = \"Embedding Model\"\n\n return build_config\n\n return build_config\n\n @check_cached_vector_store\n def build_vector_store(self):\n try:\n from langchain_astradb import AstraDBVectorStore\n except ImportError as e:\n msg = (\n \"Could not import langchain Astra DB integration package. \"\n \"Please install it with `pip install langchain-astradb`.\"\n )\n raise ImportError(msg) from e\n\n # Get the embedding model and additional params\n embedding_params = (\n {\"embedding\": self.embedding_model}\n if self.embedding_model and self.embedding_choice == \"Embedding Model\"\n else {}\n )\n\n # Get the additional parameters\n additional_params = self.astradb_vectorstore_kwargs or {}\n\n # Get Langflow version and platform information\n __version__ = get_version_info()[\"version\"]\n langflow_prefix = \"\"\n # if os.getenv(\"AWS_EXECUTION_ENV\") == \"AWS_ECS_FARGATE\": # TODO: More precise way of detecting\n # langflow_prefix = \"ds-\"\n\n # Get the database object\n database = self.get_database_object()\n autodetect = self.collection_name in database.list_collection_names() and self.autodetect_collection\n\n # Bundle up the auto-detect parameters\n autodetect_params = {\n \"autodetect_collection\": autodetect,\n \"content_field\": (\n self.content_field\n if self.content_field and embedding_params\n else (\n \"page_content\"\n if embedding_params\n and self.collection_data(collection_name=self.collection_name, database=database) == 0\n else None\n )\n ),\n \"ignore_invalid_documents\": self.ignore_invalid_documents,\n }\n\n # Attempt to build the Vector Store object\n try:\n vector_store = AstraDBVectorStore(\n # Astra DB Authentication Parameters\n token=self.token,\n api_endpoint=database.api_endpoint,\n namespace=database.keyspace,\n collection_name=self.collection_name,\n environment=self.environment,\n # Astra DB Usage Tracking Parameters\n ext_callers=[(f\"{langflow_prefix}langflow\", __version__)],\n # Astra DB Vector Store Parameters\n **autodetect_params,\n **embedding_params,\n **additional_params,\n )\n except Exception as e:\n msg = f\"Error initializing AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n # Add documents to the vector store\n self._add_documents_to_vector_store(vector_store)\n\n return vector_store\n\n def _add_documents_to_vector_store(self, vector_store) -> None:\n documents = []\n for _input in self.ingest_data or []:\n if isinstance(_input, Data):\n documents.append(_input.to_lc_document())\n else:\n msg = \"Vector Store Inputs must be Data objects.\"\n raise TypeError(msg)\n\n if documents and self.deletion_field:\n self.log(f\"Deleting documents where {self.deletion_field}\")\n try:\n database = self.get_database_object()\n collection = database.get_collection(self.collection_name, keyspace=database.keyspace)\n delete_values = list({doc.metadata[self.deletion_field] for doc in documents})\n self.log(f\"Deleting documents where {self.deletion_field} matches {delete_values}.\")\n collection.delete_many({f\"metadata.{self.deletion_field}\": {\"$in\": delete_values}})\n except Exception as e:\n msg = f\"Error deleting documents from AstraDBVectorStore based on '{self.deletion_field}': {e}\"\n raise ValueError(msg) from e\n\n if documents:\n self.log(f\"Adding {len(documents)} documents to the Vector Store.\")\n try:\n vector_store.add_documents(documents)\n except Exception as e:\n msg = f\"Error adding documents to AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n else:\n self.log(\"No documents to add to the Vector Store.\")\n\n def _map_search_type(self) -> str:\n search_type_mapping = {\n \"Similarity with score threshold\": \"similarity_score_threshold\",\n \"MMR (Max Marginal Relevance)\": \"mmr\",\n }\n\n return search_type_mapping.get(self.search_type, \"similarity\")\n\n def _build_search_args(self):\n query = self.search_query if isinstance(self.search_query, str) and self.search_query.strip() else None\n\n if query:\n args = {\n \"query\": query,\n \"search_type\": self._map_search_type(),\n \"k\": self.number_of_results,\n \"score_threshold\": self.search_score_threshold,\n }\n elif self.advanced_search_filter:\n args = {\n \"n\": self.number_of_results,\n }\n else:\n return {}\n\n filter_arg = self.advanced_search_filter or {}\n if filter_arg:\n args[\"filter\"] = filter_arg\n\n return args\n\n def search_documents(self, vector_store=None) -> list[Data]:\n vector_store = vector_store or self.build_vector_store()\n\n self.log(f\"Search input: {self.search_query}\")\n self.log(f\"Search type: {self.search_type}\")\n self.log(f\"Number of results: {self.number_of_results}\")\n\n try:\n search_args = self._build_search_args()\n except Exception as e:\n msg = f\"Error in AstraDBVectorStore._build_search_args: {e}\"\n raise ValueError(msg) from e\n\n if not search_args:\n self.log(\"No search input or filters provided. Skipping search.\")\n return []\n\n docs = []\n search_method = \"search\" if \"query\" in search_args else \"metadata_search\"\n\n try:\n self.log(f\"Calling vector_store.{search_method} with args: {search_args}\")\n docs = getattr(vector_store, search_method)(**search_args)\n except Exception as e:\n msg = f\"Error performing {search_method} in AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n self.log(f\"Retrieved documents: {len(docs)}\")\n\n data = docs_to_data(docs)\n self.log(f\"Converted documents to data: {len(data)}\")\n self.status = data\n\n return data\n\n def get_retriever_kwargs(self):\n search_args = self._build_search_args()\n\n return {\n \"search_type\": self._map_search_type(),\n \"search_kwargs\": search_args,\n }\n" }, "collection_name": { "_input_type": "DropdownInput", @@ -3405,16 +3405,35 @@ ], "name": "create_collection", "template": { - "new_collection_name": { - "_input_type": "StrInput", + "dimension": { + "_input_type": "IntInput", "advanced": false, - "display_name": "Name", + "display_name": "Dimensions (Required only for `Bring your own`)", "dynamic": false, - "info": "Name of the new collection to create in Astra DB.", + "info": "Dimensions of the embeddings to generate.", "list": false, "list_add_label": "Add More", - "load_from_db": false, - "name": "new_collection_name", + "name": "dimension", + "placeholder": "", + "required": false, + "show": true, + "title_case": false, + "tool_mode": false, + "trace_as_metadata": true, + "type": "int", + "value": "" + }, + "embedding_generation_model": { + "_input_type": "DropdownInput", + "advanced": false, + "combobox": false, + "dialog_inputs": {}, + "display_name": "Embedding model", + "dynamic": false, + "info": "Model to use for generating embeddings.", + "name": "embedding_generation_model", + "options": [], + "options_metadata": [], "placeholder": "", "required": true, "show": true, @@ -3448,17 +3467,16 @@ "type": "str", "value": "" }, - "embedding_generation_model": { - "_input_type": "DropdownInput", + "new_collection_name": { + "_input_type": "StrInput", "advanced": false, - "combobox": false, - "dialog_inputs": {}, - "display_name": "Embedding model", + "display_name": "Name", "dynamic": false, - "info": "Model to use for generating embeddings.", - "name": "embedding_generation_model", - "options": [], - "options_metadata": [], + "info": "Name of the new collection to create in Astra DB.", + "list": false, + "list_add_label": "Add More", + "load_from_db": false, + "name": "new_collection_name", "placeholder": "", "required": true, "show": true, @@ -3467,24 +3485,6 @@ "trace_as_metadata": true, "type": "str", "value": "" - }, - "dimension": { - "_input_type": "IntInput", - "advanced": false, - "display_name": "Dimensions (Required only for `Bring your own`)", - "dynamic": false, - "info": "Dimensions of the embeddings to generate.", - "list": false, - "list_add_label": "Add More", - "name": "dimension", - "placeholder": "", - "required": false, - "show": true, - "title_case": false, - "tool_mode": false, - "trace_as_metadata": true, - "type": "int", - "value": "" } } } @@ -3545,25 +3545,6 @@ ], "name": "create_database", "template": { - "new_database_name": { - "_input_type": "StrInput", - "advanced": false, - "display_name": "Name", - "dynamic": false, - "info": "Name of the new database to create in Astra DB.", - "list": false, - "list_add_label": "Add More", - "load_from_db": false, - "name": "new_database_name", - "placeholder": "", - "required": true, - "show": true, - "title_case": false, - "tool_mode": false, - "trace_as_metadata": true, - "type": "str", - "value": "" - }, "cloud_provider": { "_input_type": "DropdownInput", "advanced": false, @@ -3589,6 +3570,25 @@ "type": "str", "value": "" }, + "new_database_name": { + "_input_type": "StrInput", + "advanced": false, + "display_name": "Name", + "dynamic": false, + "info": "Name of the new database to create in Astra DB.", + "list": false, + "list_add_label": "Add More", + "load_from_db": false, + "name": "new_database_name", + "placeholder": "", + "required": true, + "show": true, + "title_case": false, + "tool_mode": false, + "trace_as_metadata": true, + "type": "str", + "value": "" + }, "region": { "_input_type": "DropdownInput", "advanced": false, @@ -4083,7 +4083,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from collections import defaultdict\nfrom dataclasses import asdict, dataclass, field\n\nfrom astrapy import AstraDBAdmin, DataAPIClient, Database\nfrom astrapy.info import CollectionDescriptor\nfrom langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions\n\nfrom langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store\nfrom langflow.helpers import docs_to_data\nfrom langflow.inputs import FloatInput, NestedDictInput\nfrom langflow.io import (\n BoolInput,\n DropdownInput,\n HandleInput,\n IntInput,\n SecretStrInput,\n StrInput,\n)\nfrom langflow.schema import Data\nfrom langflow.utils.version import get_version_info\n\n\nclass AstraDBVectorStoreComponent(LCVectorStoreComponent):\n display_name: str = \"Astra DB\"\n description: str = \"Ingest and search documents in Astra DB\"\n documentation: str = \"https://docs.datastax.com/en/langflow/astra-components.html\"\n name = \"AstraDB\"\n icon: str = \"AstraDB\"\n\n _cached_vector_store: AstraDBVectorStore | None = None\n\n @dataclass\n class NewDatabaseInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_database\",\n \"description\": \"\",\n \"display_name\": \"Create new database\",\n \"field_order\": [\"new_database_name\", \"cloud_provider\", \"region\"],\n \"template\": {\n \"new_database_name\": StrInput(\n name=\"new_database_name\",\n display_name=\"Name\",\n info=\"Name of the new database to create in Astra DB.\",\n required=True,\n ),\n \"cloud_provider\": DropdownInput(\n name=\"cloud_provider\",\n display_name=\"Cloud provider\",\n info=\"Cloud provider for the new database.\",\n options=[\"Amazon Web Services\", \"Google Cloud Platform\", \"Microsoft Azure\"],\n required=True,\n real_time_refresh=True,\n ),\n \"region\": DropdownInput(\n name=\"region\",\n display_name=\"Region\",\n info=\"Region for the new database.\",\n options=[],\n required=True,\n ),\n },\n },\n }\n }\n )\n\n @dataclass\n class NewCollectionInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_collection\",\n \"description\": \"\",\n \"display_name\": \"Create new collection\",\n \"field_order\": [\n \"new_collection_name\",\n \"embedding_generation_provider\",\n \"embedding_generation_model\",\n \"dimension\",\n ],\n \"template\": {\n \"new_collection_name\": StrInput(\n name=\"new_collection_name\",\n display_name=\"Name\",\n info=\"Name of the new collection to create in Astra DB.\",\n required=True,\n ),\n \"embedding_generation_provider\": DropdownInput(\n name=\"embedding_generation_provider\",\n display_name=\"Embedding generation method\",\n info=\"Provider to use for generating embeddings.\",\n real_time_refresh=True,\n required=True,\n options=[\"Bring your own\", \"Nvidia\"],\n ),\n \"embedding_generation_model\": DropdownInput(\n name=\"embedding_generation_model\",\n display_name=\"Embedding model\",\n info=\"Model to use for generating embeddings.\",\n required=True,\n options=[],\n ),\n \"dimension\": IntInput(\n name=\"dimension\",\n display_name=\"Dimensions (Required only for `Bring your own`)\",\n info=\"Dimensions of the embeddings to generate.\",\n required=False,\n value=1024,\n ),\n },\n },\n }\n }\n )\n\n inputs = [\n SecretStrInput(\n name=\"token\",\n display_name=\"Astra DB Application Token\",\n info=\"Authentication token for accessing Astra DB.\",\n value=\"ASTRA_DB_APPLICATION_TOKEN\",\n required=True,\n real_time_refresh=True,\n input_types=[],\n ),\n StrInput(\n name=\"environment\",\n display_name=\"Environment\",\n info=\"The environment for the Astra DB API Endpoint.\",\n advanced=True,\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"database_name\",\n display_name=\"Database\",\n info=\"The Database name for the Astra DB instance.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewDatabaseInput()),\n combobox=True,\n ),\n StrInput(\n name=\"api_endpoint\",\n display_name=\"Astra DB API Endpoint\",\n info=\"The API Endpoint for the Astra DB instance. Supercedes database selection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"collection_name\",\n display_name=\"Collection\",\n info=\"The name of the collection within Astra DB where the vectors will be stored.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewCollectionInput()),\n combobox=True,\n advanced=True,\n ),\n StrInput(\n name=\"keyspace\",\n display_name=\"Keyspace\",\n info=\"Optional keyspace within Astra DB to use for the collection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"embedding_choice\",\n display_name=\"Embedding Model or Astra Vectorize\",\n info=\"Choose an embedding model or use Astra Vectorize.\",\n options=[\"Embedding Model\", \"Astra Vectorize\"],\n value=\"Embedding Model\",\n advanced=True,\n real_time_refresh=True,\n ),\n HandleInput(\n name=\"embedding_model\",\n display_name=\"Embedding Model\",\n input_types=[\"Embeddings\"],\n info=\"Specify the Embedding Model. Not required for Astra Vectorize collections.\",\n required=False,\n ),\n *LCVectorStoreComponent.inputs,\n IntInput(\n name=\"number_of_results\",\n display_name=\"Number of Search Results\",\n info=\"Number of search results to return.\",\n advanced=True,\n value=4,\n ),\n DropdownInput(\n name=\"search_type\",\n display_name=\"Search Type\",\n info=\"Search type to use\",\n options=[\"Similarity\", \"Similarity with score threshold\", \"MMR (Max Marginal Relevance)\"],\n value=\"Similarity\",\n advanced=True,\n ),\n FloatInput(\n name=\"search_score_threshold\",\n display_name=\"Search Score Threshold\",\n info=\"Minimum similarity score threshold for search results. \"\n \"(when using 'Similarity with score threshold')\",\n value=0,\n advanced=True,\n ),\n NestedDictInput(\n name=\"advanced_search_filter\",\n display_name=\"Search Metadata Filter\",\n info=\"Optional dictionary of filters to apply to the search query.\",\n advanced=True,\n ),\n BoolInput(\n name=\"autodetect_collection\",\n display_name=\"Autodetect Collection\",\n info=\"Boolean flag to determine whether to autodetect the collection.\",\n advanced=True,\n value=True,\n ),\n StrInput(\n name=\"content_field\",\n display_name=\"Content Field\",\n info=\"Field to use as the text content field for the vector store.\",\n advanced=True,\n ),\n StrInput(\n name=\"deletion_field\",\n display_name=\"Deletion Based On Field\",\n info=\"When this parameter is provided, documents in the target collection with \"\n \"metadata field values matching the input metadata field value will be deleted \"\n \"before new data is loaded.\",\n advanced=True,\n ),\n BoolInput(\n name=\"ignore_invalid_documents\",\n display_name=\"Ignore Invalid Documents\",\n info=\"Boolean flag to determine whether to ignore invalid documents at runtime.\",\n advanced=True,\n ),\n NestedDictInput(\n name=\"astradb_vectorstore_kwargs\",\n display_name=\"AstraDBVectorStore Parameters\",\n info=\"Optional dictionary of additional parameters for the AstraDBVectorStore.\",\n advanced=True,\n ),\n ]\n\n @classmethod\n def map_cloud_providers(cls):\n # TODO: Programmatically fetch the regions for each cloud provider\n return {\n \"Amazon Web Services\": {\n \"id\": \"aws\",\n \"regions\": [\"us-east-2\", \"ap-south-1\", \"eu-west-1\"],\n },\n \"Google Cloud Platform\": {\n \"id\": \"gcp\",\n \"regions\": [\"us-east1\"],\n },\n \"Microsoft Azure\": {\n \"id\": \"azure\",\n \"regions\": [\"westus3\"],\n },\n }\n\n @classmethod\n def get_vectorize_providers(cls, token: str, environment: str | None = None, api_endpoint: str | None = None):\n try:\n # Get the admin object\n admin = AstraDBAdmin(token=token, environment=environment)\n db_admin = admin.get_database_admin(api_endpoint=api_endpoint)\n\n # Get the list of embedding providers\n embedding_providers = db_admin.find_embedding_providers().as_dict()\n\n vectorize_providers_mapping = {}\n # Map the provider display name to the provider key and models\n for provider_key, provider_data in embedding_providers[\"embeddingProviders\"].items():\n # Get the provider display name and models\n display_name = provider_data[\"displayName\"]\n models = [model[\"name\"] for model in provider_data[\"models\"]]\n\n # Build our mapping\n vectorize_providers_mapping[display_name] = [provider_key, models]\n\n # Sort the resulting dictionary\n return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))\n except Exception as e:\n msg = f\"Error fetching vectorize providers: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n async def create_database_api(\n cls,\n new_database_name: str,\n cloud_provider: str,\n region: str,\n token: str,\n environment: str | None = None,\n keyspace: str | None = None,\n ):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Call the create database function\n return await admin_client.async_create_database(\n name=new_database_name,\n cloud_provider=cls.map_cloud_providers()[cloud_provider][\"id\"],\n region=region,\n keyspace=keyspace,\n wait_until_active=False,\n )\n\n @classmethod\n async def create_collection_api(\n cls,\n new_collection_name: str,\n token: str,\n api_endpoint: str,\n environment: str | None = None,\n keyspace: str | None = None,\n dimension: int | None = None,\n embedding_generation_provider: str | None = None,\n embedding_generation_model: str | None = None,\n ):\n # Create the data API client\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the database object\n database = client.get_async_database(api_endpoint=api_endpoint, token=token)\n\n # Build vectorize options, if needed\n vectorize_options = None\n if not dimension:\n vectorize_options = CollectionVectorServiceOptions(\n provider=cls.get_vectorize_providers(\n token=token, environment=environment, api_endpoint=api_endpoint\n ).get(embedding_generation_provider, [None, []])[0],\n model_name=embedding_generation_model,\n )\n\n # Create the collection\n return await database.create_collection(\n name=new_collection_name,\n keyspace=keyspace,\n dimension=dimension,\n service=vectorize_options,\n )\n\n @classmethod\n def get_database_list_static(cls, token: str, environment: str | None = None):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Get the list of databases\n db_list = list(admin_client.list_databases())\n\n # Set the environment properly\n env_string = \"\"\n if environment and environment != \"prod\":\n env_string = f\"-{environment}\"\n\n # Generate the api endpoint for each database\n db_info_dict = {}\n for db in db_list:\n try:\n # Get the API endpoint for the database\n api_endpoint = f\"https://{db.info.id}-{db.info.region}.apps.astra{env_string}.datastax.com\"\n\n # Get the number of collections\n try:\n num_collections = len(\n list(\n client.get_database(\n api_endpoint=api_endpoint, token=token, keyspace=db.info.keyspace\n ).list_collection_names(keyspace=db.info.keyspace)\n )\n )\n except Exception: # noqa: BLE001\n num_collections = 0\n if db.status != \"PENDING\":\n continue\n\n # Add the database to the dictionary\n db_info_dict[db.info.name] = {\n \"api_endpoint\": api_endpoint,\n \"collections\": num_collections,\n \"status\": db.status if db.status != \"ACTIVE\" else None,\n }\n except Exception: # noqa: BLE001, S110\n pass\n\n return db_info_dict\n\n def get_database_list(self):\n return self.get_database_list_static(token=self.token, environment=self.environment)\n\n @classmethod\n def get_api_endpoint_static(\n cls,\n token: str,\n environment: str | None = None,\n api_endpoint: str | None = None,\n database_name: str | None = None,\n ):\n # If the api_endpoint is set, return it\n if api_endpoint:\n return api_endpoint\n\n # Check if the database_name is like a url\n if database_name and database_name.startswith(\"https://\"):\n return database_name\n\n # If the database is not set, nothing we can do.\n if not database_name:\n return None\n\n # Grab the database object\n db = cls.get_database_list_static(token=token, environment=environment).get(database_name)\n if not db:\n return None\n\n # Otherwise, get the URL from the database list\n return db.get(\"api_endpoint\")\n\n def get_api_endpoint(self):\n return self.get_api_endpoint_static(\n token=self.token,\n environment=self.environment,\n api_endpoint=self.api_endpoint,\n database_name=self.database_name,\n )\n\n def get_keyspace(self):\n keyspace = self.keyspace\n\n if keyspace:\n return keyspace.strip()\n\n return None\n\n def get_database_object(self, api_endpoint: str | None = None):\n try:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n return client.get_database(\n api_endpoint=api_endpoint or self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n except Exception as e:\n msg = f\"Error fetching database object: {e}\"\n raise ValueError(msg) from e\n\n def collection_data(self, collection_name: str, database: Database | None = None):\n try:\n if not database:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n database = client.get_database(\n api_endpoint=self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n\n collection = database.get_collection(collection_name, keyspace=self.get_keyspace())\n\n return collection.estimated_document_count()\n except Exception as e: # noqa: BLE001\n self.log(f\"Error checking collection data: {e}\")\n\n return None\n\n def _initialize_database_options(self):\n try:\n return [\n {\n \"name\": name,\n \"status\": info[\"status\"],\n \"collections\": info[\"collections\"],\n \"api_endpoint\": info[\"api_endpoint\"],\n \"icon\": \"data\",\n }\n for name, info in self.get_database_list().items()\n ]\n except Exception as e:\n msg = f\"Error fetching database options: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n def get_provider_icon(cls, collection: CollectionDescriptor | None = None, provider_name: str | None = None) -> str:\n # Get the provider name from the collection\n provider_name = provider_name or (\n collection.options.vector.service.provider\n if collection and collection.options and collection.options.vector and collection.options.vector.service\n else None\n )\n\n # If there is no provider, use the vector store icon\n if not provider_name or provider_name == \"bring your own\":\n return \"vectorstores\"\n\n # Special case for certain models\n # TODO: Add more icons\n if provider_name == \"nvidia\":\n return \"NVIDIA\"\n if provider_name == \"openai\":\n return \"OpenAI\"\n\n # Title case on the provider for the icon if no special case\n return provider_name.title()\n\n def _initialize_collection_options(self, api_endpoint: str | None = None):\n # Nothing to generate if we don't have an API endpoint yet\n api_endpoint = api_endpoint or self.get_api_endpoint()\n if not api_endpoint:\n return []\n\n # Retrieve the database object\n database = self.get_database_object(api_endpoint=api_endpoint)\n\n # Get the list of collections\n collection_list = list(database.list_collections(keyspace=self.get_keyspace()))\n\n # Return the list of collections and metadata associated\n return [\n {\n \"name\": col.name,\n \"records\": self.collection_data(collection_name=col.name, database=database),\n \"provider\": (\n col.options.vector.service.provider if col.options.vector and col.options.vector.service else None\n ),\n \"icon\": self.get_provider_icon(collection=col),\n \"model\": (\n col.options.vector.service.model_name if col.options.vector and col.options.vector.service else None\n ),\n }\n for col in collection_list\n ]\n\n def reset_provider_options(self, build_config: dict):\n # Get the list of vectorize providers\n vectorize_providers = self.get_vectorize_providers(\n token=self.token,\n environment=self.environment,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n )\n\n # Append a special case for Bring your own\n vectorize_providers[\"Bring your own\"] = [None, [\"Bring your own\"]]\n\n # If the collection is set, allow user to see embedding options\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"] = [\"Bring your own\", \"Nvidia\", *[key for key in vectorize_providers if key != \"Nvidia\"]]\n\n # For all not Bring your own or Nvidia providers, add metadata saying configure in Astra DB Portal\n provider_options = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"]\n\n # Go over each possible provider and add metadata to configure in Astra DB Portal\n for provider in provider_options:\n # Skip Bring your own and Nvidia, automatically configured\n if provider in [\"Bring your own\", \"Nvidia\"]:\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\"icon\": self.get_provider_icon(provider_name=provider.lower())})\n continue\n\n # Add metadata to configure in Astra DB Portal\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\" \": \"Configure in Astra DB Portal\"})\n\n # And allow the user to see the models based on a selected provider\n embedding_provider = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"value\"]\n\n # Set the options for the embedding model based on the provider\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_model\"\n ][\"options\"] = vectorize_providers.get(embedding_provider, [[], []])[1]\n\n return build_config\n\n def reset_collection_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n collection_options = self._initialize_collection_options(api_endpoint=build_config[\"api_endpoint\"][\"value\"])\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"collection_name\"][\"options\"] = [col[\"name\"] for col in collection_options]\n build_config[\"collection_name\"][\"options_metadata\"] = [\n {k: v for k, v in col.items() if k not in [\"name\"]} for col in collection_options\n ]\n\n # Reset the selected collection\n if build_config[\"collection_name\"][\"value\"] not in build_config[\"collection_name\"][\"options\"]:\n build_config[\"collection_name\"][\"value\"] = \"\"\n\n # If we have a database, collection name should not be advanced\n build_config[\"collection_name\"][\"advanced\"] = not build_config[\"database_name\"][\"value\"]\n\n return build_config\n\n def reset_database_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n database_options = self._initialize_database_options()\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"database_name\"][\"options\"] = [db[\"name\"] for db in database_options]\n build_config[\"database_name\"][\"options_metadata\"] = [\n {k: v for k, v in db.items() if k not in [\"name\"]} for db in database_options\n ]\n\n # Reset the selected database\n if build_config[\"database_name\"][\"value\"] not in build_config[\"database_name\"][\"options\"]:\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n # If we have a token, database name should not be advanced\n build_config[\"database_name\"][\"advanced\"] = not build_config[\"token\"][\"value\"]\n\n return build_config\n\n def reset_build_config(self, build_config: dict):\n # Reset the list of databases we have based on the token provided\n build_config[\"database_name\"][\"options\"] = []\n build_config[\"database_name\"][\"options_metadata\"] = []\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"database_name\"][\"advanced\"] = True\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n\n # Reset the list of collections and metadata associated\n build_config[\"collection_name\"][\"options\"] = []\n build_config[\"collection_name\"][\"options_metadata\"] = []\n build_config[\"collection_name\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n return build_config\n\n async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):\n # Callback for database creation\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" in field_value:\n try:\n await self.create_database_api(\n new_database_name=field_value[\"new_database_name\"],\n token=self.token,\n keyspace=self.get_keyspace(),\n environment=self.environment,\n cloud_provider=field_value[\"cloud_provider\"],\n region=field_value[\"region\"],\n )\n except Exception as e:\n msg = f\"Error creating database: {e}\"\n raise ValueError(msg) from e\n\n # Add the new database to the list of options\n build_config[\"database_name\"][\"options\"] = build_config[\"database_name\"][\"options\"] + [\n field_value[\"new_database_name\"]\n ]\n build_config[\"database_name\"][\"options_metadata\"] = build_config[\"database_name\"][\"options_metadata\"] + [\n {\"status\": \"PENDING\"}\n ]\n\n return self.reset_collection_list(build_config)\n\n # This is the callback required to update the list of regions for a cloud provider\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" not in field_value:\n cloud_provider = field_value[\"cloud_provider\"]\n build_config[\"database_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\"region\"][\n \"options\"\n ] = self.map_cloud_providers()[cloud_provider][\"regions\"]\n\n return build_config\n\n # Callback for the creation of collections\n if field_name == \"collection_name\" and isinstance(field_value, dict) and \"new_collection_name\" in field_value:\n try:\n # Get the dimension if its a BYO provider\n dimension = (\n field_value[\"dimension\"]\n if field_value[\"embedding_generation_provider\"] == \"Bring your own\"\n else None\n )\n\n # Create the collection\n await self.create_collection_api(\n new_collection_name=field_value[\"new_collection_name\"],\n token=self.token,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n environment=self.environment,\n keyspace=self.get_keyspace(),\n dimension=dimension,\n embedding_generation_provider=field_value[\"embedding_generation_provider\"],\n embedding_generation_model=field_value[\"embedding_generation_model\"],\n )\n except Exception as e:\n msg = f\"Error creating collection: {e}\"\n raise ValueError(msg) from e\n\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"value\"] = field_value[\"new_collection_name\"]\n build_config[\"collection_name\"][\"options\"].append(field_value[\"new_collection_name\"])\n\n # Get the provider and model for the new collection\n generation_provider = field_value[\"embedding_generation_provider\"]\n provider = generation_provider if generation_provider != \"Bring your own\" else None\n generation_model = field_value[\"embedding_generation_model\"]\n model = generation_model if generation_model and generation_model != \"Bring your own\" else None\n\n # Set the embedding choice\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\" if provider else \"Embedding Model\"\n build_config[\"embedding_model\"][\"advanced\"] = bool(provider)\n\n # Add the new collection to the list of options\n icon = \"NVIDIA\" if provider == \"Nvidia\" else \"vectorstores\"\n build_config[\"collection_name\"][\"options_metadata\"] = build_config[\"collection_name\"][\n \"options_metadata\"\n ] + [{\"records\": 0, \"provider\": provider, \"icon\": icon, \"model\": model}]\n\n return build_config\n\n # Callback to update the model list based on the embedding provider\n if (\n field_name == \"collection_name\"\n and isinstance(field_value, dict)\n and \"new_collection_name\" not in field_value\n ):\n return self.reset_provider_options(build_config)\n\n # When the component first executes, this is the update refresh call\n first_run = field_name == \"collection_name\" and not field_value and not build_config[\"database_name\"][\"options\"]\n\n # If the token has not been provided, simply return the empty build config\n if not self.token:\n return self.reset_build_config(build_config)\n\n # If this is the first execution of the component, reset and build database list\n if first_run or field_name in [\"token\", \"environment\"]:\n return self.reset_database_list(build_config)\n\n # Refresh the collection name options\n if field_name == \"database_name\" and not isinstance(field_value, dict):\n # If missing, refresh the database options\n if field_value not in build_config[\"database_name\"][\"options\"]:\n build_config = await self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n build_config[\"database_name\"][\"value\"] = \"\"\n else:\n # Find the position of the selected database to align with metadata\n index_of_name = build_config[\"database_name\"][\"options\"].index(field_value)\n\n # Initializing database condition\n pending = build_config[\"database_name\"][\"options_metadata\"][index_of_name][\"status\"] == \"PENDING\"\n if pending:\n return self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n\n # Set the API endpoint based on the selected database\n build_config[\"api_endpoint\"][\"value\"] = build_config[\"database_name\"][\"options_metadata\"][\n index_of_name\n ][\"api_endpoint\"]\n\n # Reset the provider options\n build_config = self.reset_provider_options(build_config)\n\n # Reset the list of collections we have based on the token provided\n return self.reset_collection_list(build_config)\n\n # Hide embedding model option if opriona_metadata provider is not null\n if field_name == \"collection_name\" and not isinstance(field_value, dict):\n # Assume we will be autodetecting the collection:\n build_config[\"autodetect_collection\"][\"value\"] = True\n\n # Reload the collection list\n build_config = self.reset_collection_list(build_config)\n\n # Set the options for collection name to be the field value if its a new collection\n if field_value and field_value not in build_config[\"collection_name\"][\"options\"]:\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"options\"].append(field_value)\n build_config[\"collection_name\"][\"options_metadata\"].append(\n {\"records\": 0, \"provider\": None, \"icon\": \"\", \"model\": None}\n )\n\n # Ensure that autodetect collection is set to False, since its a new collection\n build_config[\"autodetect_collection\"][\"value\"] = False\n\n # If nothing is selected, can't detect provider - return\n if not field_value:\n return build_config\n\n # Find the position of the selected collection to align with metadata\n index_of_name = build_config[\"collection_name\"][\"options\"].index(field_value)\n value_of_provider = build_config[\"collection_name\"][\"options_metadata\"][index_of_name][\"provider\"]\n\n # If we were able to determine the Vectorize provider, set it accordingly\n if value_of_provider:\n build_config[\"embedding_model\"][\"advanced\"] = True\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\"\n else:\n build_config[\"embedding_model\"][\"advanced\"] = False\n build_config[\"embedding_choice\"][\"value\"] = \"Embedding Model\"\n\n return build_config\n\n return build_config\n\n @check_cached_vector_store\n def build_vector_store(self):\n try:\n from langchain_astradb import AstraDBVectorStore\n except ImportError as e:\n msg = (\n \"Could not import langchain Astra DB integration package. \"\n \"Please install it with `pip install langchain-astradb`.\"\n )\n raise ImportError(msg) from e\n\n # Get the embedding model and additional params\n embedding_params = (\n {\"embedding\": self.embedding_model}\n if self.embedding_model and self.embedding_choice == \"Embedding Model\"\n else {}\n )\n\n # Get the additional parameters\n additional_params = self.astradb_vectorstore_kwargs or {}\n\n # Get Langflow version and platform information\n __version__ = get_version_info()[\"version\"]\n langflow_prefix = \"\"\n # if os.getenv(\"AWS_EXECUTION_ENV\") == \"AWS_ECS_FARGATE\": # TODO: More precise way of detecting\n # langflow_prefix = \"ds-\"\n\n # Get the database object\n database = self.get_database_object()\n autodetect = self.collection_name in database.list_collection_names() and self.autodetect_collection\n\n # Bundle up the auto-detect parameters\n autodetect_params = {\n \"autodetect_collection\": autodetect,\n \"content_field\": (\n self.content_field\n if self.content_field and embedding_params\n else (\n \"page_content\"\n if embedding_params\n and self.collection_data(collection_name=self.collection_name, database=database) == 0\n else None\n )\n ),\n \"ignore_invalid_documents\": self.ignore_invalid_documents,\n }\n\n # Attempt to build the Vector Store object\n try:\n vector_store = AstraDBVectorStore(\n # Astra DB Authentication Parameters\n token=self.token,\n api_endpoint=database.api_endpoint,\n namespace=database.keyspace,\n collection_name=self.collection_name,\n environment=self.environment,\n # Astra DB Usage Tracking Parameters\n ext_callers=[(f\"{langflow_prefix}langflow\", __version__)],\n # Astra DB Vector Store Parameters\n **autodetect_params,\n **embedding_params,\n **additional_params,\n )\n except Exception as e:\n msg = f\"Error initializing AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n # Add documents to the vector store\n self._add_documents_to_vector_store(vector_store)\n\n return vector_store\n\n def _add_documents_to_vector_store(self, vector_store) -> None:\n documents = []\n for _input in self.ingest_data or []:\n if isinstance(_input, Data):\n documents.append(_input.to_lc_document())\n else:\n msg = \"Vector Store Inputs must be Data objects.\"\n raise TypeError(msg)\n\n if documents and self.deletion_field:\n self.log(f\"Deleting documents where {self.deletion_field}\")\n try:\n database = self.get_database_object()\n collection = database.get_collection(self.collection_name, keyspace=database.keyspace)\n delete_values = list({doc.metadata[self.deletion_field] for doc in documents})\n self.log(f\"Deleting documents where {self.deletion_field} matches {delete_values}.\")\n collection.delete_many({f\"metadata.{self.deletion_field}\": {\"$in\": delete_values}})\n except Exception as e:\n msg = f\"Error deleting documents from AstraDBVectorStore based on '{self.deletion_field}': {e}\"\n raise ValueError(msg) from e\n\n if documents:\n self.log(f\"Adding {len(documents)} documents to the Vector Store.\")\n try:\n vector_store.add_documents(documents)\n except Exception as e:\n msg = f\"Error adding documents to AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n else:\n self.log(\"No documents to add to the Vector Store.\")\n\n def _map_search_type(self) -> str:\n search_type_mapping = {\n \"Similarity with score threshold\": \"similarity_score_threshold\",\n \"MMR (Max Marginal Relevance)\": \"mmr\",\n }\n\n return search_type_mapping.get(self.search_type, \"similarity\")\n\n def _build_search_args(self):\n query = self.search_query if isinstance(self.search_query, str) and self.search_query.strip() else None\n\n if query:\n args = {\n \"query\": query,\n \"search_type\": self._map_search_type(),\n \"k\": self.number_of_results,\n \"score_threshold\": self.search_score_threshold,\n }\n elif self.advanced_search_filter:\n args = {\n \"n\": self.number_of_results,\n }\n else:\n return {}\n\n filter_arg = self.advanced_search_filter or {}\n if filter_arg:\n args[\"filter\"] = filter_arg\n\n return args\n\n def search_documents(self, vector_store=None) -> list[Data]:\n vector_store = vector_store or self.build_vector_store()\n\n self.log(f\"Search input: {self.search_query}\")\n self.log(f\"Search type: {self.search_type}\")\n self.log(f\"Number of results: {self.number_of_results}\")\n\n try:\n search_args = self._build_search_args()\n except Exception as e:\n msg = f\"Error in AstraDBVectorStore._build_search_args: {e}\"\n raise ValueError(msg) from e\n\n if not search_args:\n self.log(\"No search input or filters provided. Skipping search.\")\n return []\n\n docs = []\n search_method = \"search\" if \"query\" in search_args else \"metadata_search\"\n\n try:\n self.log(f\"Calling vector_store.{search_method} with args: {search_args}\")\n docs = getattr(vector_store, search_method)(**search_args)\n except Exception as e:\n msg = f\"Error performing {search_method} in AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n self.log(f\"Retrieved documents: {len(docs)}\")\n\n data = docs_to_data(docs)\n self.log(f\"Converted documents to data: {len(data)}\")\n self.status = data\n\n return data\n\n def get_retriever_kwargs(self):\n search_args = self._build_search_args()\n\n return {\n \"search_type\": self._map_search_type(),\n \"search_kwargs\": search_args,\n }\n" + "value": "from collections import defaultdict\nfrom dataclasses import asdict, dataclass, field\n\nfrom astrapy import AstraDBAdmin, DataAPIClient, Database\nfrom astrapy.info import CollectionDescriptor\nfrom langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions\n\nfrom langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store\nfrom langflow.helpers import docs_to_data\nfrom langflow.inputs import FloatInput, NestedDictInput\nfrom langflow.io import (\n BoolInput,\n DropdownInput,\n HandleInput,\n IntInput,\n SecretStrInput,\n StrInput,\n)\nfrom langflow.schema import Data\nfrom langflow.utils.version import get_version_info\n\n\nclass AstraDBVectorStoreComponent(LCVectorStoreComponent):\n display_name: str = \"Astra DB\"\n description: str = \"Ingest and search documents in Astra DB\"\n documentation: str = \"https://docs.datastax.com/en/langflow/astra-components.html\"\n name = \"AstraDB\"\n icon: str = \"AstraDB\"\n\n _cached_vector_store: AstraDBVectorStore | None = None\n\n @dataclass\n class NewDatabaseInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_database\",\n \"description\": \"\",\n \"display_name\": \"Create new database\",\n \"field_order\": [\"new_database_name\", \"cloud_provider\", \"region\"],\n \"template\": {\n \"new_database_name\": StrInput(\n name=\"new_database_name\",\n display_name=\"Name\",\n info=\"Name of the new database to create in Astra DB.\",\n required=True,\n ),\n \"cloud_provider\": DropdownInput(\n name=\"cloud_provider\",\n display_name=\"Cloud provider\",\n info=\"Cloud provider for the new database.\",\n options=[\"Amazon Web Services\", \"Google Cloud Platform\", \"Microsoft Azure\"],\n required=True,\n real_time_refresh=True,\n ),\n \"region\": DropdownInput(\n name=\"region\",\n display_name=\"Region\",\n info=\"Region for the new database.\",\n options=[],\n required=True,\n ),\n },\n },\n }\n }\n )\n\n @dataclass\n class NewCollectionInput:\n functionality: str = \"create\"\n fields: dict[str, dict] = field(\n default_factory=lambda: {\n \"data\": {\n \"node\": {\n \"name\": \"create_collection\",\n \"description\": \"\",\n \"display_name\": \"Create new collection\",\n \"field_order\": [\n \"new_collection_name\",\n \"embedding_generation_provider\",\n \"embedding_generation_model\",\n \"dimension\",\n ],\n \"template\": {\n \"new_collection_name\": StrInput(\n name=\"new_collection_name\",\n display_name=\"Name\",\n info=\"Name of the new collection to create in Astra DB.\",\n required=True,\n ),\n \"embedding_generation_provider\": DropdownInput(\n name=\"embedding_generation_provider\",\n display_name=\"Embedding generation method\",\n info=\"Provider to use for generating embeddings.\",\n real_time_refresh=True,\n required=True,\n options=[\"Bring your own\", \"Nvidia\"],\n ),\n \"embedding_generation_model\": DropdownInput(\n name=\"embedding_generation_model\",\n display_name=\"Embedding model\",\n info=\"Model to use for generating embeddings.\",\n required=True,\n options=[],\n ),\n \"dimension\": IntInput(\n name=\"dimension\",\n display_name=\"Dimensions (Required only for `Bring your own`)\",\n info=\"Dimensions of the embeddings to generate.\",\n required=False,\n value=1024,\n ),\n },\n },\n }\n }\n )\n\n inputs = [\n SecretStrInput(\n name=\"token\",\n display_name=\"Astra DB Application Token\",\n info=\"Authentication token for accessing Astra DB.\",\n value=\"ASTRA_DB_APPLICATION_TOKEN\",\n required=True,\n real_time_refresh=True,\n input_types=[],\n ),\n StrInput(\n name=\"environment\",\n display_name=\"Environment\",\n info=\"The environment for the Astra DB API Endpoint.\",\n advanced=True,\n real_time_refresh=True,\n ),\n DropdownInput(\n name=\"database_name\",\n display_name=\"Database\",\n info=\"The Database name for the Astra DB instance.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewDatabaseInput()),\n combobox=True,\n ),\n StrInput(\n name=\"api_endpoint\",\n display_name=\"Astra DB API Endpoint\",\n info=\"The API Endpoint for the Astra DB instance. Supercedes database selection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"collection_name\",\n display_name=\"Collection\",\n info=\"The name of the collection within Astra DB where the vectors will be stored.\",\n required=True,\n refresh_button=True,\n real_time_refresh=True,\n dialog_inputs=asdict(NewCollectionInput()),\n combobox=True,\n advanced=True,\n ),\n StrInput(\n name=\"keyspace\",\n display_name=\"Keyspace\",\n info=\"Optional keyspace within Astra DB to use for the collection.\",\n advanced=True,\n ),\n DropdownInput(\n name=\"embedding_choice\",\n display_name=\"Embedding Model or Astra Vectorize\",\n info=\"Choose an embedding model or use Astra Vectorize.\",\n options=[\"Embedding Model\", \"Astra Vectorize\"],\n value=\"Embedding Model\",\n advanced=True,\n real_time_refresh=True,\n ),\n HandleInput(\n name=\"embedding_model\",\n display_name=\"Embedding Model\",\n input_types=[\"Embeddings\"],\n info=\"Specify the Embedding Model. Not required for Astra Vectorize collections.\",\n required=False,\n ),\n *LCVectorStoreComponent.inputs,\n IntInput(\n name=\"number_of_results\",\n display_name=\"Number of Search Results\",\n info=\"Number of search results to return.\",\n advanced=True,\n value=4,\n ),\n DropdownInput(\n name=\"search_type\",\n display_name=\"Search Type\",\n info=\"Search type to use\",\n options=[\"Similarity\", \"Similarity with score threshold\", \"MMR (Max Marginal Relevance)\"],\n value=\"Similarity\",\n advanced=True,\n ),\n FloatInput(\n name=\"search_score_threshold\",\n display_name=\"Search Score Threshold\",\n info=\"Minimum similarity score threshold for search results. \"\n \"(when using 'Similarity with score threshold')\",\n value=0,\n advanced=True,\n ),\n NestedDictInput(\n name=\"advanced_search_filter\",\n display_name=\"Search Metadata Filter\",\n info=\"Optional dictionary of filters to apply to the search query.\",\n advanced=True,\n ),\n BoolInput(\n name=\"autodetect_collection\",\n display_name=\"Autodetect Collection\",\n info=\"Boolean flag to determine whether to autodetect the collection.\",\n advanced=True,\n value=True,\n ),\n StrInput(\n name=\"content_field\",\n display_name=\"Content Field\",\n info=\"Field to use as the text content field for the vector store.\",\n advanced=True,\n ),\n StrInput(\n name=\"deletion_field\",\n display_name=\"Deletion Based On Field\",\n info=\"When this parameter is provided, documents in the target collection with \"\n \"metadata field values matching the input metadata field value will be deleted \"\n \"before new data is loaded.\",\n advanced=True,\n ),\n BoolInput(\n name=\"ignore_invalid_documents\",\n display_name=\"Ignore Invalid Documents\",\n info=\"Boolean flag to determine whether to ignore invalid documents at runtime.\",\n advanced=True,\n ),\n NestedDictInput(\n name=\"astradb_vectorstore_kwargs\",\n display_name=\"AstraDBVectorStore Parameters\",\n info=\"Optional dictionary of additional parameters for the AstraDBVectorStore.\",\n advanced=True,\n ),\n ]\n\n @classmethod\n def map_cloud_providers(cls):\n # TODO: Programmatically fetch the regions for each cloud provider\n return {\n \"Amazon Web Services\": {\n \"id\": \"aws\",\n \"regions\": [\"us-east-2\", \"ap-south-1\", \"eu-west-1\"],\n },\n \"Google Cloud Platform\": {\n \"id\": \"gcp\",\n \"regions\": [\"us-east1\"],\n },\n \"Microsoft Azure\": {\n \"id\": \"azure\",\n \"regions\": [\"westus3\"],\n },\n }\n\n @classmethod\n def get_vectorize_providers(cls, token: str, environment: str | None = None, api_endpoint: str | None = None):\n try:\n # Get the admin object\n admin = AstraDBAdmin(token=token, environment=environment)\n db_admin = admin.get_database_admin(api_endpoint=api_endpoint)\n\n # Get the list of embedding providers\n embedding_providers = db_admin.find_embedding_providers().as_dict()\n\n vectorize_providers_mapping = {}\n # Map the provider display name to the provider key and models\n for provider_key, provider_data in embedding_providers[\"embeddingProviders\"].items():\n # Get the provider display name and models\n display_name = provider_data[\"displayName\"]\n models = [model[\"name\"] for model in provider_data[\"models\"]]\n\n # Build our mapping\n vectorize_providers_mapping[display_name] = [provider_key, models]\n\n # Sort the resulting dictionary\n return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))\n except Exception as e:\n msg = f\"Error fetching vectorize providers: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n async def create_database_api(\n cls,\n new_database_name: str,\n cloud_provider: str,\n region: str,\n token: str,\n environment: str | None = None,\n keyspace: str | None = None,\n ):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Call the create database function\n return await admin_client.async_create_database(\n name=new_database_name,\n cloud_provider=cls.map_cloud_providers()[cloud_provider][\"id\"],\n region=region,\n keyspace=keyspace,\n wait_until_active=False,\n )\n\n @classmethod\n async def create_collection_api(\n cls,\n new_collection_name: str,\n token: str,\n api_endpoint: str,\n environment: str | None = None,\n keyspace: str | None = None,\n dimension: int | None = None,\n embedding_generation_provider: str | None = None,\n embedding_generation_model: str | None = None,\n ):\n # Create the data API client\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the database object\n database = client.get_async_database(api_endpoint=api_endpoint, token=token)\n\n # Build vectorize options, if needed\n vectorize_options = None\n if not dimension:\n vectorize_options = CollectionVectorServiceOptions(\n provider=cls.get_vectorize_providers(\n token=token, environment=environment, api_endpoint=api_endpoint\n ).get(embedding_generation_provider, [None, []])[0],\n model_name=embedding_generation_model,\n )\n\n # Create the collection\n return await database.create_collection(\n name=new_collection_name,\n keyspace=keyspace,\n dimension=dimension,\n service=vectorize_options,\n )\n\n @classmethod\n def get_database_list_static(cls, token: str, environment: str | None = None):\n client = DataAPIClient(token=token, environment=environment)\n\n # Get the admin object\n admin_client = client.get_admin(token=token)\n\n # Get the list of databases\n db_list = list(admin_client.list_databases())\n\n # Set the environment properly\n env_string = \"\"\n if environment and environment != \"prod\":\n env_string = f\"-{environment}\"\n\n # Generate the api endpoint for each database\n db_info_dict = {}\n for db in db_list:\n try:\n # Get the API endpoint for the database\n api_endpoint = f\"https://{db.info.id}-{db.info.region}.apps.astra{env_string}.datastax.com\"\n\n # Get the number of collections\n try:\n num_collections = len(\n list(\n client.get_database(\n api_endpoint=api_endpoint, token=token, keyspace=db.info.keyspace\n ).list_collection_names(keyspace=db.info.keyspace)\n )\n )\n except Exception: # noqa: BLE001\n num_collections = 0\n if db.status != \"PENDING\":\n continue\n\n # Add the database to the dictionary\n db_info_dict[db.info.name] = {\n \"api_endpoint\": api_endpoint,\n \"collections\": num_collections,\n \"status\": db.status if db.status != \"ACTIVE\" else None,\n }\n except Exception: # noqa: BLE001, S110\n pass\n\n return db_info_dict\n\n def get_database_list(self):\n return self.get_database_list_static(token=self.token, environment=self.environment)\n\n @classmethod\n def get_api_endpoint_static(\n cls,\n token: str,\n environment: str | None = None,\n api_endpoint: str | None = None,\n database_name: str | None = None,\n ):\n # If the api_endpoint is set, return it\n if api_endpoint:\n return api_endpoint\n\n # Check if the database_name is like a url\n if database_name and database_name.startswith(\"https://\"):\n return database_name\n\n # If the database is not set, nothing we can do.\n if not database_name:\n return None\n\n # Grab the database object\n db = cls.get_database_list_static(token=token, environment=environment).get(database_name)\n if not db:\n return None\n\n # Otherwise, get the URL from the database list\n return db.get(\"api_endpoint\")\n\n def get_api_endpoint(self):\n return self.get_api_endpoint_static(\n token=self.token,\n environment=self.environment,\n api_endpoint=self.api_endpoint,\n database_name=self.database_name,\n )\n\n def get_keyspace(self):\n keyspace = self.keyspace\n\n if keyspace:\n return keyspace.strip()\n\n return None\n\n def get_database_object(self, api_endpoint: str | None = None):\n try:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n return client.get_database(\n api_endpoint=api_endpoint or self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n except Exception as e:\n msg = f\"Error fetching database object: {e}\"\n raise ValueError(msg) from e\n\n def collection_data(self, collection_name: str, database: Database | None = None):\n try:\n if not database:\n client = DataAPIClient(token=self.token, environment=self.environment)\n\n database = client.get_database(\n api_endpoint=self.get_api_endpoint(),\n token=self.token,\n keyspace=self.get_keyspace(),\n )\n\n collection = database.get_collection(collection_name, keyspace=self.get_keyspace())\n\n return collection.estimated_document_count()\n except Exception as e: # noqa: BLE001\n self.log(f\"Error checking collection data: {e}\")\n\n return None\n\n def _initialize_database_options(self):\n try:\n return [\n {\n \"name\": name,\n \"status\": info[\"status\"],\n \"collections\": info[\"collections\"],\n \"api_endpoint\": info[\"api_endpoint\"],\n \"icon\": \"data\",\n }\n for name, info in self.get_database_list().items()\n ]\n except Exception as e:\n msg = f\"Error fetching database options: {e}\"\n raise ValueError(msg) from e\n\n @classmethod\n def get_provider_icon(cls, collection: CollectionDescriptor | None = None, provider_name: str | None = None) -> str:\n # Get the provider name from the collection\n provider_name = provider_name or (\n collection.options.vector.service.provider\n if collection and collection.options and collection.options.vector and collection.options.vector.service\n else None\n )\n\n # If there is no provider, use the vector store icon\n if not provider_name or provider_name == \"bring your own\":\n return \"vectorstores\"\n\n # Special case for certain models\n # TODO: Add more icons\n if provider_name == \"nvidia\":\n return \"NVIDIA\"\n if provider_name == \"openai\":\n return \"OpenAI\"\n\n # Title case on the provider for the icon if no special case\n return provider_name.title()\n\n def _initialize_collection_options(self, api_endpoint: str | None = None):\n # Nothing to generate if we don't have an API endpoint yet\n api_endpoint = api_endpoint or self.get_api_endpoint()\n if not api_endpoint:\n return []\n\n # Retrieve the database object\n database = self.get_database_object(api_endpoint=api_endpoint)\n\n # Get the list of collections\n collection_list = list(database.list_collections(keyspace=self.get_keyspace()))\n\n # Return the list of collections and metadata associated\n return [\n {\n \"name\": col.name,\n \"records\": self.collection_data(collection_name=col.name, database=database),\n \"provider\": (\n col.options.vector.service.provider if col.options.vector and col.options.vector.service else None\n ),\n \"icon\": self.get_provider_icon(collection=col),\n \"model\": (\n col.options.vector.service.model_name if col.options.vector and col.options.vector.service else None\n ),\n }\n for col in collection_list\n ]\n\n def reset_provider_options(self, build_config: dict):\n # Get the list of vectorize providers\n vectorize_providers = self.get_vectorize_providers(\n token=self.token,\n environment=self.environment,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n )\n\n # Append a special case for Bring your own\n vectorize_providers[\"Bring your own\"] = [None, [\"Bring your own\"]]\n\n # If the collection is set, allow user to see embedding options\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"] = [\"Bring your own\", \"Nvidia\", *[key for key in vectorize_providers if key != \"Nvidia\"]]\n\n # For all not Bring your own or Nvidia providers, add metadata saying configure in Astra DB Portal\n provider_options = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options\"]\n\n # Go over each possible provider and add metadata to configure in Astra DB Portal\n for provider in provider_options:\n # Skip Bring your own and Nvidia, automatically configured\n if provider in {\"Bring your own\", \"Nvidia\"}:\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\"icon\": self.get_provider_icon(provider_name=provider.lower())})\n continue\n\n # Add metadata to configure in Astra DB Portal\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"options_metadata\"].append({\" \": \"Configure in Astra DB Portal\"})\n\n # And allow the user to see the models based on a selected provider\n embedding_provider = build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_provider\"\n ][\"value\"]\n\n # Set the options for the embedding model based on the provider\n build_config[\"collection_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\n \"embedding_generation_model\"\n ][\"options\"] = vectorize_providers.get(embedding_provider, [[], []])[1]\n\n return build_config\n\n def reset_collection_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n collection_options = self._initialize_collection_options(api_endpoint=build_config[\"api_endpoint\"][\"value\"])\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"collection_name\"][\"options\"] = [col[\"name\"] for col in collection_options]\n build_config[\"collection_name\"][\"options_metadata\"] = [\n {k: v for k, v in col.items() if k != \"name\"} for col in collection_options\n ]\n\n # Reset the selected collection\n if build_config[\"collection_name\"][\"value\"] not in build_config[\"collection_name\"][\"options\"]:\n build_config[\"collection_name\"][\"value\"] = \"\"\n\n # If we have a database, collection name should not be advanced\n build_config[\"collection_name\"][\"advanced\"] = not build_config[\"database_name\"][\"value\"]\n\n return build_config\n\n def reset_database_list(self, build_config: dict):\n # Get the list of options we have based on the token provided\n database_options = self._initialize_database_options()\n\n # If we retrieved options based on the token, show the dropdown\n build_config[\"database_name\"][\"options\"] = [db[\"name\"] for db in database_options]\n build_config[\"database_name\"][\"options_metadata\"] = [\n {k: v for k, v in db.items() if k != \"name\"} for db in database_options\n ]\n\n # Reset the selected database\n if build_config[\"database_name\"][\"value\"] not in build_config[\"database_name\"][\"options\"]:\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n # If we have a token, database name should not be advanced\n build_config[\"database_name\"][\"advanced\"] = not build_config[\"token\"][\"value\"]\n\n return build_config\n\n def reset_build_config(self, build_config: dict):\n # Reset the list of databases we have based on the token provided\n build_config[\"database_name\"][\"options\"] = []\n build_config[\"database_name\"][\"options_metadata\"] = []\n build_config[\"database_name\"][\"value\"] = \"\"\n build_config[\"database_name\"][\"advanced\"] = True\n build_config[\"api_endpoint\"][\"value\"] = \"\"\n\n # Reset the list of collections and metadata associated\n build_config[\"collection_name\"][\"options\"] = []\n build_config[\"collection_name\"][\"options_metadata\"] = []\n build_config[\"collection_name\"][\"value\"] = \"\"\n build_config[\"collection_name\"][\"advanced\"] = True\n\n return build_config\n\n async def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):\n # Callback for database creation\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" in field_value:\n try:\n await self.create_database_api(\n new_database_name=field_value[\"new_database_name\"],\n token=self.token,\n keyspace=self.get_keyspace(),\n environment=self.environment,\n cloud_provider=field_value[\"cloud_provider\"],\n region=field_value[\"region\"],\n )\n except Exception as e:\n msg = f\"Error creating database: {e}\"\n raise ValueError(msg) from e\n\n # Add the new database to the list of options\n build_config[\"database_name\"][\"options\"] += [field_value[\"new_database_name\"]]\n build_config[\"database_name\"][\"options_metadata\"] += [{\"status\": \"PENDING\"}]\n\n return self.reset_collection_list(build_config)\n\n # This is the callback required to update the list of regions for a cloud provider\n if field_name == \"database_name\" and isinstance(field_value, dict) and \"new_database_name\" not in field_value:\n cloud_provider = field_value[\"cloud_provider\"]\n build_config[\"database_name\"][\"dialog_inputs\"][\"fields\"][\"data\"][\"node\"][\"template\"][\"region\"][\n \"options\"\n ] = self.map_cloud_providers()[cloud_provider][\"regions\"]\n\n return build_config\n\n # Callback for the creation of collections\n if field_name == \"collection_name\" and isinstance(field_value, dict) and \"new_collection_name\" in field_value:\n try:\n # Get the dimension if its a BYO provider\n dimension = (\n field_value[\"dimension\"]\n if field_value[\"embedding_generation_provider\"] == \"Bring your own\"\n else None\n )\n\n # Create the collection\n await self.create_collection_api(\n new_collection_name=field_value[\"new_collection_name\"],\n token=self.token,\n api_endpoint=build_config[\"api_endpoint\"][\"value\"],\n environment=self.environment,\n keyspace=self.get_keyspace(),\n dimension=dimension,\n embedding_generation_provider=field_value[\"embedding_generation_provider\"],\n embedding_generation_model=field_value[\"embedding_generation_model\"],\n )\n except Exception as e:\n msg = f\"Error creating collection: {e}\"\n raise ValueError(msg) from e\n\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"value\"] = field_value[\"new_collection_name\"]\n build_config[\"collection_name\"][\"options\"].append(field_value[\"new_collection_name\"])\n\n # Get the provider and model for the new collection\n generation_provider = field_value[\"embedding_generation_provider\"]\n provider = generation_provider if generation_provider != \"Bring your own\" else None\n generation_model = field_value[\"embedding_generation_model\"]\n model = generation_model if generation_model and generation_model != \"Bring your own\" else None\n\n # Set the embedding choice\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\" if provider else \"Embedding Model\"\n build_config[\"embedding_model\"][\"advanced\"] = bool(provider)\n\n # Add the new collection to the list of options\n icon = \"NVIDIA\" if provider == \"Nvidia\" else \"vectorstores\"\n build_config[\"collection_name\"][\"options_metadata\"] += [\n {\"records\": 0, \"provider\": provider, \"icon\": icon, \"model\": model}\n ]\n\n return build_config\n\n # Callback to update the model list based on the embedding provider\n if (\n field_name == \"collection_name\"\n and isinstance(field_value, dict)\n and \"new_collection_name\" not in field_value\n ):\n return self.reset_provider_options(build_config)\n\n # When the component first executes, this is the update refresh call\n first_run = field_name == \"collection_name\" and not field_value and not build_config[\"database_name\"][\"options\"]\n\n # If the token has not been provided, simply return the empty build config\n if not self.token:\n return self.reset_build_config(build_config)\n\n # If this is the first execution of the component, reset and build database list\n if first_run or field_name in {\"token\", \"environment\"}:\n return self.reset_database_list(build_config)\n\n # Refresh the collection name options\n if field_name == \"database_name\" and not isinstance(field_value, dict):\n # If missing, refresh the database options\n if field_value not in build_config[\"database_name\"][\"options\"]:\n build_config = await self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n build_config[\"database_name\"][\"value\"] = \"\"\n else:\n # Find the position of the selected database to align with metadata\n index_of_name = build_config[\"database_name\"][\"options\"].index(field_value)\n\n # Initializing database condition\n pending = build_config[\"database_name\"][\"options_metadata\"][index_of_name][\"status\"] == \"PENDING\"\n if pending:\n return self.update_build_config(build_config, field_value=self.token, field_name=\"token\")\n\n # Set the API endpoint based on the selected database\n build_config[\"api_endpoint\"][\"value\"] = build_config[\"database_name\"][\"options_metadata\"][\n index_of_name\n ][\"api_endpoint\"]\n\n # Reset the provider options\n build_config = self.reset_provider_options(build_config)\n\n # Reset the list of collections we have based on the token provided\n return self.reset_collection_list(build_config)\n\n # Hide embedding model option if opriona_metadata provider is not null\n if field_name == \"collection_name\" and not isinstance(field_value, dict):\n # Assume we will be autodetecting the collection:\n build_config[\"autodetect_collection\"][\"value\"] = True\n\n # Reload the collection list\n build_config = self.reset_collection_list(build_config)\n\n # Set the options for collection name to be the field value if its a new collection\n if field_value and field_value not in build_config[\"collection_name\"][\"options\"]:\n # Add the new collection to the list of options\n build_config[\"collection_name\"][\"options\"].append(field_value)\n build_config[\"collection_name\"][\"options_metadata\"].append(\n {\n \"records\": 0,\n \"provider\": None,\n \"icon\": \"\",\n \"model\": None,\n }\n )\n\n # Ensure that autodetect collection is set to False, since its a new collection\n build_config[\"autodetect_collection\"][\"value\"] = False\n\n # If nothing is selected, can't detect provider - return\n if not field_value:\n return build_config\n\n # Find the position of the selected collection to align with metadata\n index_of_name = build_config[\"collection_name\"][\"options\"].index(field_value)\n value_of_provider = build_config[\"collection_name\"][\"options_metadata\"][index_of_name][\"provider\"]\n\n # If we were able to determine the Vectorize provider, set it accordingly\n if value_of_provider:\n build_config[\"embedding_model\"][\"advanced\"] = True\n build_config[\"embedding_choice\"][\"value\"] = \"Astra Vectorize\"\n else:\n build_config[\"embedding_model\"][\"advanced\"] = False\n build_config[\"embedding_choice\"][\"value\"] = \"Embedding Model\"\n\n return build_config\n\n return build_config\n\n @check_cached_vector_store\n def build_vector_store(self):\n try:\n from langchain_astradb import AstraDBVectorStore\n except ImportError as e:\n msg = (\n \"Could not import langchain Astra DB integration package. \"\n \"Please install it with `pip install langchain-astradb`.\"\n )\n raise ImportError(msg) from e\n\n # Get the embedding model and additional params\n embedding_params = (\n {\"embedding\": self.embedding_model}\n if self.embedding_model and self.embedding_choice == \"Embedding Model\"\n else {}\n )\n\n # Get the additional parameters\n additional_params = self.astradb_vectorstore_kwargs or {}\n\n # Get Langflow version and platform information\n __version__ = get_version_info()[\"version\"]\n langflow_prefix = \"\"\n # if os.getenv(\"AWS_EXECUTION_ENV\") == \"AWS_ECS_FARGATE\": # TODO: More precise way of detecting\n # langflow_prefix = \"ds-\"\n\n # Get the database object\n database = self.get_database_object()\n autodetect = self.collection_name in database.list_collection_names() and self.autodetect_collection\n\n # Bundle up the auto-detect parameters\n autodetect_params = {\n \"autodetect_collection\": autodetect,\n \"content_field\": (\n self.content_field\n if self.content_field and embedding_params\n else (\n \"page_content\"\n if embedding_params\n and self.collection_data(collection_name=self.collection_name, database=database) == 0\n else None\n )\n ),\n \"ignore_invalid_documents\": self.ignore_invalid_documents,\n }\n\n # Attempt to build the Vector Store object\n try:\n vector_store = AstraDBVectorStore(\n # Astra DB Authentication Parameters\n token=self.token,\n api_endpoint=database.api_endpoint,\n namespace=database.keyspace,\n collection_name=self.collection_name,\n environment=self.environment,\n # Astra DB Usage Tracking Parameters\n ext_callers=[(f\"{langflow_prefix}langflow\", __version__)],\n # Astra DB Vector Store Parameters\n **autodetect_params,\n **embedding_params,\n **additional_params,\n )\n except Exception as e:\n msg = f\"Error initializing AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n # Add documents to the vector store\n self._add_documents_to_vector_store(vector_store)\n\n return vector_store\n\n def _add_documents_to_vector_store(self, vector_store) -> None:\n documents = []\n for _input in self.ingest_data or []:\n if isinstance(_input, Data):\n documents.append(_input.to_lc_document())\n else:\n msg = \"Vector Store Inputs must be Data objects.\"\n raise TypeError(msg)\n\n if documents and self.deletion_field:\n self.log(f\"Deleting documents where {self.deletion_field}\")\n try:\n database = self.get_database_object()\n collection = database.get_collection(self.collection_name, keyspace=database.keyspace)\n delete_values = list({doc.metadata[self.deletion_field] for doc in documents})\n self.log(f\"Deleting documents where {self.deletion_field} matches {delete_values}.\")\n collection.delete_many({f\"metadata.{self.deletion_field}\": {\"$in\": delete_values}})\n except Exception as e:\n msg = f\"Error deleting documents from AstraDBVectorStore based on '{self.deletion_field}': {e}\"\n raise ValueError(msg) from e\n\n if documents:\n self.log(f\"Adding {len(documents)} documents to the Vector Store.\")\n try:\n vector_store.add_documents(documents)\n except Exception as e:\n msg = f\"Error adding documents to AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n else:\n self.log(\"No documents to add to the Vector Store.\")\n\n def _map_search_type(self) -> str:\n search_type_mapping = {\n \"Similarity with score threshold\": \"similarity_score_threshold\",\n \"MMR (Max Marginal Relevance)\": \"mmr\",\n }\n\n return search_type_mapping.get(self.search_type, \"similarity\")\n\n def _build_search_args(self):\n query = self.search_query if isinstance(self.search_query, str) and self.search_query.strip() else None\n\n if query:\n args = {\n \"query\": query,\n \"search_type\": self._map_search_type(),\n \"k\": self.number_of_results,\n \"score_threshold\": self.search_score_threshold,\n }\n elif self.advanced_search_filter:\n args = {\n \"n\": self.number_of_results,\n }\n else:\n return {}\n\n filter_arg = self.advanced_search_filter or {}\n if filter_arg:\n args[\"filter\"] = filter_arg\n\n return args\n\n def search_documents(self, vector_store=None) -> list[Data]:\n vector_store = vector_store or self.build_vector_store()\n\n self.log(f\"Search input: {self.search_query}\")\n self.log(f\"Search type: {self.search_type}\")\n self.log(f\"Number of results: {self.number_of_results}\")\n\n try:\n search_args = self._build_search_args()\n except Exception as e:\n msg = f\"Error in AstraDBVectorStore._build_search_args: {e}\"\n raise ValueError(msg) from e\n\n if not search_args:\n self.log(\"No search input or filters provided. Skipping search.\")\n return []\n\n docs = []\n search_method = \"search\" if \"query\" in search_args else \"metadata_search\"\n\n try:\n self.log(f\"Calling vector_store.{search_method} with args: {search_args}\")\n docs = getattr(vector_store, search_method)(**search_args)\n except Exception as e:\n msg = f\"Error performing {search_method} in AstraDBVectorStore: {e}\"\n raise ValueError(msg) from e\n\n self.log(f\"Retrieved documents: {len(docs)}\")\n\n data = docs_to_data(docs)\n self.log(f\"Converted documents to data: {len(data)}\")\n self.status = data\n\n return data\n\n def get_retriever_kwargs(self):\n search_args = self._build_search_args()\n\n return {\n \"search_type\": self._map_search_type(),\n \"search_kwargs\": search_args,\n }\n" }, "collection_name": { "_input_type": "DropdownInput", @@ -4103,16 +4103,35 @@ ], "name": "create_collection", "template": { - "new_collection_name": { - "_input_type": "StrInput", + "dimension": { + "_input_type": "IntInput", "advanced": false, - "display_name": "Name", + "display_name": "Dimensions (Required only for `Bring your own`)", "dynamic": false, - "info": "Name of the new collection to create in Astra DB.", + "info": "Dimensions of the embeddings to generate.", "list": false, "list_add_label": "Add More", - "load_from_db": false, - "name": "new_collection_name", + "name": "dimension", + "placeholder": "", + "required": false, + "show": true, + "title_case": false, + "tool_mode": false, + "trace_as_metadata": true, + "type": "int", + "value": "" + }, + "embedding_generation_model": { + "_input_type": "DropdownInput", + "advanced": false, + "combobox": false, + "dialog_inputs": {}, + "display_name": "Embedding model", + "dynamic": false, + "info": "Model to use for generating embeddings.", + "name": "embedding_generation_model", + "options": [], + "options_metadata": [], "placeholder": "", "required": true, "show": true, @@ -4146,17 +4165,16 @@ "type": "str", "value": "" }, - "embedding_generation_model": { - "_input_type": "DropdownInput", + "new_collection_name": { + "_input_type": "StrInput", "advanced": false, - "combobox": false, - "dialog_inputs": {}, - "display_name": "Embedding model", + "display_name": "Name", "dynamic": false, - "info": "Model to use for generating embeddings.", - "name": "embedding_generation_model", - "options": [], - "options_metadata": [], + "info": "Name of the new collection to create in Astra DB.", + "list": false, + "list_add_label": "Add More", + "load_from_db": false, + "name": "new_collection_name", "placeholder": "", "required": true, "show": true, @@ -4165,24 +4183,6 @@ "trace_as_metadata": true, "type": "str", "value": "" - }, - "dimension": { - "_input_type": "IntInput", - "advanced": false, - "display_name": "Dimensions (Required only for `Bring your own`)", - "dynamic": false, - "info": "Dimensions of the embeddings to generate.", - "list": false, - "list_add_label": "Add More", - "name": "dimension", - "placeholder": "", - "required": false, - "show": true, - "title_case": false, - "tool_mode": false, - "trace_as_metadata": true, - "type": "int", - "value": "" } } } @@ -4243,25 +4243,6 @@ ], "name": "create_database", "template": { - "new_database_name": { - "_input_type": "StrInput", - "advanced": false, - "display_name": "Name", - "dynamic": false, - "info": "Name of the new database to create in Astra DB.", - "list": false, - "list_add_label": "Add More", - "load_from_db": false, - "name": "new_database_name", - "placeholder": "", - "required": true, - "show": true, - "title_case": false, - "tool_mode": false, - "trace_as_metadata": true, - "type": "str", - "value": "" - }, "cloud_provider": { "_input_type": "DropdownInput", "advanced": false, @@ -4287,6 +4268,25 @@ "type": "str", "value": "" }, + "new_database_name": { + "_input_type": "StrInput", + "advanced": false, + "display_name": "Name", + "dynamic": false, + "info": "Name of the new database to create in Astra DB.", + "list": false, + "list_add_label": "Add More", + "load_from_db": false, + "name": "new_database_name", + "placeholder": "", + "required": true, + "show": true, + "title_case": false, + "tool_mode": false, + "trace_as_metadata": true, + "type": "str", + "value": "" + }, "region": { "_input_type": "DropdownInput", "advanced": false, diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Youtube Analysis.json b/src/backend/base/langflow/initial_setup/starter_projects/Youtube Analysis.json index a3107898e..f0c9d0a46 100644 --- a/src/backend/base/langflow/initial_setup/starter_projects/Youtube Analysis.json +++ b/src/backend/base/langflow/initial_setup/starter_projects/Youtube Analysis.json @@ -377,7 +377,7 @@ "show": true, "title_case": false, "type": "code", - "value": "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any\n\nfrom loguru import logger\n\nfrom langflow.custom import Component\nfrom langflow.io import (\n BoolInput,\n DataFrameInput,\n HandleInput,\n MessageTextInput,\n MultilineInput,\n Output,\n)\nfrom langflow.schema import DataFrame\n\nif TYPE_CHECKING:\n from langchain_core.runnables import Runnable\n\n\nclass BatchRunComponent(Component):\n display_name = \"Batch Run\"\n description = (\n \"Runs a language model over each row of a DataFrame's text column and returns a new \"\n \"DataFrame with three columns: '**text_input**' (the original text), \"\n \"'**model_response**' (the model's response),and '**batch_index**' (the processing order).\"\n )\n icon = \"List\"\n beta = True\n\n inputs = [\n HandleInput(\n name=\"model\",\n display_name=\"Language Model\",\n info=\"Connect the 'Language Model' output from your LLM component here.\",\n input_types=[\"LanguageModel\"],\n required=True,\n ),\n MultilineInput(\n name=\"system_message\",\n display_name=\"System Message\",\n info=\"Multi-line system instruction for all rows in the DataFrame.\",\n required=False,\n ),\n DataFrameInput(\n name=\"df\",\n display_name=\"DataFrame\",\n info=\"The DataFrame whose column (specified by 'column_name') we'll treat as text messages.\",\n required=True,\n ),\n MessageTextInput(\n name=\"column_name\",\n display_name=\"Column Name\",\n info=\"The name of the DataFrame column to treat as text messages. Default='text'.\",\n value=\"text\",\n required=True,\n advanced=True,\n ),\n BoolInput(\n name=\"enable_metadata\",\n display_name=\"Enable Metadata\",\n info=\"If True, add metadata to the output DataFrame.\",\n value=True,\n required=False,\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(\n display_name=\"Batch Results\",\n name=\"batch_results\",\n method=\"run_batch\",\n info=\"A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.\",\n ),\n ]\n\n def _create_base_row(self, text_input: str = \"\", model_response: str = \"\", batch_index: int = -1) -> dict[str, Any]:\n \"\"\"Create a base row with optional metadata.\"\"\"\n return {\n \"text_input\": text_input,\n \"model_response\": model_response,\n \"batch_index\": batch_index,\n }\n\n def _add_metadata(\n self, row: dict[str, Any], *, success: bool = True, system_msg: str = \"\", error: str | None = None\n ) -> None:\n \"\"\"Add metadata to a row if enabled.\"\"\"\n if not self.enable_metadata:\n return\n\n if success:\n row[\"metadata\"] = {\n \"has_system_message\": bool(system_msg),\n \"input_length\": len(row[\"text_input\"]),\n \"response_length\": len(row[\"model_response\"]),\n \"processing_status\": \"success\",\n }\n else:\n row[\"metadata\"] = {\n \"error\": error,\n \"processing_status\": \"failed\",\n }\n\n async def run_batch(self) -> DataFrame:\n \"\"\"Process each row in df[column_name] with the language model asynchronously.\n\n Returns:\n DataFrame: A new DataFrame containing:\n - text_input: The original input text\n - model_response: The model's response\n - batch_index: The processing order\n - metadata: Additional processing information\n\n Raises:\n ValueError: If the specified column is not found in the DataFrame\n TypeError: If the model is not compatible or input types are wrong\n \"\"\"\n model: Runnable = self.model\n system_msg = self.system_message or \"\"\n df: DataFrame = self.df\n col_name = self.column_name or \"text\"\n\n # Validate inputs first\n if not isinstance(df, DataFrame):\n msg = f\"Expected DataFrame input, got {type(df)}\"\n raise TypeError(msg)\n\n if col_name not in df.columns:\n msg = f\"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}\"\n raise ValueError(msg)\n\n try:\n # Convert the specified column to a list of strings\n user_texts = df[col_name].astype(str).tolist()\n total_rows = len(user_texts)\n\n logger.info(f\"Processing {total_rows} rows with batch run\")\n\n # Prepare the batch of conversations\n conversations = [\n [{\"role\": \"system\", \"content\": system_msg}, {\"role\": \"user\", \"content\": text}]\n if system_msg\n else [{\"role\": \"user\", \"content\": text}]\n for text in user_texts\n ]\n\n # Configure the model with project info and callbacks\n model = model.with_config(\n {\n \"run_name\": self.display_name,\n \"project_name\": self.get_project_name(),\n \"callbacks\": self.get_langchain_callbacks(),\n }\n )\n\n # Process batches and track progress\n responses_with_idx = [\n (idx, response)\n for idx, response in zip(\n range(len(conversations)), await model.abatch(list(conversations)), strict=True\n )\n ]\n\n # Sort by index to maintain order\n responses_with_idx.sort(key=lambda x: x[0])\n\n # Build the final data with enhanced metadata\n rows: list[dict[str, Any]] = []\n for idx, response in responses_with_idx:\n resp_text = response.content if hasattr(response, \"content\") else str(response)\n row = self._create_base_row(\n text_input=user_texts[idx],\n model_response=resp_text,\n batch_index=idx,\n )\n self._add_metadata(row, success=True, system_msg=system_msg)\n rows.append(row)\n\n # Log progress\n if (idx + 1) % max(1, total_rows // 10) == 0:\n logger.info(f\"Processed {idx + 1}/{total_rows} rows\")\n\n logger.info(\"Batch processing completed successfully\")\n return DataFrame(rows)\n\n except (KeyError, AttributeError) as e:\n # Handle data structure and attribute access errors\n logger.error(f\"Data processing error: {e!s}\")\n error_row = self._create_base_row()\n self._add_metadata(error_row, success=False, error=str(e))\n return DataFrame([error_row])\n" + "value": "from __future__ import annotations\n\nimport operator\nfrom typing import TYPE_CHECKING, Any\n\nfrom loguru import logger\n\nfrom langflow.custom import Component\nfrom langflow.io import (\n BoolInput,\n DataFrameInput,\n HandleInput,\n MessageTextInput,\n MultilineInput,\n Output,\n)\nfrom langflow.schema import DataFrame\n\nif TYPE_CHECKING:\n from langchain_core.runnables import Runnable\n\n\nclass BatchRunComponent(Component):\n display_name = \"Batch Run\"\n description = (\n \"Runs a language model over each row of a DataFrame's text column and returns a new \"\n \"DataFrame with three columns: '**text_input**' (the original text), \"\n \"'**model_response**' (the model's response),and '**batch_index**' (the processing order).\"\n )\n icon = \"List\"\n beta = True\n\n inputs = [\n HandleInput(\n name=\"model\",\n display_name=\"Language Model\",\n info=\"Connect the 'Language Model' output from your LLM component here.\",\n input_types=[\"LanguageModel\"],\n required=True,\n ),\n MultilineInput(\n name=\"system_message\",\n display_name=\"System Message\",\n info=\"Multi-line system instruction for all rows in the DataFrame.\",\n required=False,\n ),\n DataFrameInput(\n name=\"df\",\n display_name=\"DataFrame\",\n info=\"The DataFrame whose column (specified by 'column_name') we'll treat as text messages.\",\n required=True,\n ),\n MessageTextInput(\n name=\"column_name\",\n display_name=\"Column Name\",\n info=\"The name of the DataFrame column to treat as text messages. Default='text'.\",\n value=\"text\",\n required=True,\n advanced=True,\n ),\n BoolInput(\n name=\"enable_metadata\",\n display_name=\"Enable Metadata\",\n info=\"If True, add metadata to the output DataFrame.\",\n value=True,\n required=False,\n advanced=True,\n ),\n ]\n\n outputs = [\n Output(\n display_name=\"Batch Results\",\n name=\"batch_results\",\n method=\"run_batch\",\n info=\"A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.\",\n ),\n ]\n\n def _create_base_row(self, text_input: str = \"\", model_response: str = \"\", batch_index: int = -1) -> dict[str, Any]:\n \"\"\"Create a base row with optional metadata.\"\"\"\n return {\n \"text_input\": text_input,\n \"model_response\": model_response,\n \"batch_index\": batch_index,\n }\n\n def _add_metadata(\n self, row: dict[str, Any], *, success: bool = True, system_msg: str = \"\", error: str | None = None\n ) -> None:\n \"\"\"Add metadata to a row if enabled.\"\"\"\n if not self.enable_metadata:\n return\n\n if success:\n row[\"metadata\"] = {\n \"has_system_message\": bool(system_msg),\n \"input_length\": len(row[\"text_input\"]),\n \"response_length\": len(row[\"model_response\"]),\n \"processing_status\": \"success\",\n }\n else:\n row[\"metadata\"] = {\n \"error\": error,\n \"processing_status\": \"failed\",\n }\n\n async def run_batch(self) -> DataFrame:\n \"\"\"Process each row in df[column_name] with the language model asynchronously.\n\n Returns:\n DataFrame: A new DataFrame containing:\n - text_input: The original input text\n - model_response: The model's response\n - batch_index: The processing order\n - metadata: Additional processing information\n\n Raises:\n ValueError: If the specified column is not found in the DataFrame\n TypeError: If the model is not compatible or input types are wrong\n \"\"\"\n model: Runnable = self.model\n system_msg = self.system_message or \"\"\n df: DataFrame = self.df\n col_name = self.column_name or \"text\"\n\n # Validate inputs first\n if not isinstance(df, DataFrame):\n msg = f\"Expected DataFrame input, got {type(df)}\"\n raise TypeError(msg)\n\n if col_name not in df.columns:\n msg = f\"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}\"\n raise ValueError(msg)\n\n try:\n # Convert the specified column to a list of strings\n user_texts = df[col_name].astype(str).tolist()\n total_rows = len(user_texts)\n\n logger.info(f\"Processing {total_rows} rows with batch run\")\n\n # Prepare the batch of conversations\n conversations = [\n [{\"role\": \"system\", \"content\": system_msg}, {\"role\": \"user\", \"content\": text}]\n if system_msg\n else [{\"role\": \"user\", \"content\": text}]\n for text in user_texts\n ]\n\n # Configure the model with project info and callbacks\n model = model.with_config(\n {\n \"run_name\": self.display_name,\n \"project_name\": self.get_project_name(),\n \"callbacks\": self.get_langchain_callbacks(),\n }\n )\n\n # Process batches and track progress\n responses_with_idx = [\n (idx, response)\n for idx, response in zip(\n range(len(conversations)), await model.abatch(list(conversations)), strict=True\n )\n ]\n\n # Sort by index to maintain order\n responses_with_idx.sort(key=operator.itemgetter(0))\n\n # Build the final data with enhanced metadata\n rows: list[dict[str, Any]] = []\n for idx, response in responses_with_idx:\n resp_text = response.content if hasattr(response, \"content\") else str(response)\n row = self._create_base_row(\n text_input=user_texts[idx],\n model_response=resp_text,\n batch_index=idx,\n )\n self._add_metadata(row, success=True, system_msg=system_msg)\n rows.append(row)\n\n # Log progress\n if (idx + 1) % max(1, total_rows // 10) == 0:\n logger.info(f\"Processed {idx + 1}/{total_rows} rows\")\n\n logger.info(\"Batch processing completed successfully\")\n return DataFrame(rows)\n\n except (KeyError, AttributeError) as e:\n # Handle data structure and attribute access errors\n logger.error(f\"Data processing error: {e!s}\")\n error_row = self._create_base_row()\n self._add_metadata(error_row, success=False, error=str(e))\n return DataFrame([error_row])\n" }, "column_name": { "_input_type": "StrInput", diff --git a/src/backend/base/langflow/io/__init__.py b/src/backend/base/langflow/io/__init__.py index e7946e3b4..05fbd061e 100644 --- a/src/backend/base/langflow/io/__init__.py +++ b/src/backend/base/langflow/io/__init__.py @@ -1,3 +1,4 @@ +# noqa: A005 from langflow.inputs import ( BoolInput, CodeInput, diff --git a/src/backend/base/langflow/logging/__init__.py b/src/backend/base/langflow/logging/__init__.py index cc1b11cb1..8f8ed22c9 100644 --- a/src/backend/base/langflow/logging/__init__.py +++ b/src/backend/base/langflow/logging/__init__.py @@ -1,3 +1,4 @@ +# noqa: A005 from .logger import configure, logger from .setup import disable_logging, enable_logging diff --git a/src/backend/base/langflow/serialization/serialization.py b/src/backend/base/langflow/serialization/serialization.py index 3441d80ba..c5f22de80 100644 --- a/src/backend/base/langflow/serialization/serialization.py +++ b/src/backend/base/langflow/serialization/serialization.py @@ -145,7 +145,7 @@ def _serialize_numpy_type(obj: Any, max_length: int | None, max_items: int | Non if np.issubdtype(obj.dtype, np.bool_): return bool(obj) if np.issubdtype(obj.dtype, np.complexfloating): - return complex(cast(complex, obj)) + return complex(cast("complex", obj)) if np.issubdtype(obj.dtype, np.str_): return _serialize_str(str(obj), max_length, max_items) if np.issubdtype(obj.dtype, np.bytes_) and hasattr(obj, "tobytes"): @@ -209,7 +209,7 @@ def _serialize_dispatcher(obj: Any, max_length: int | None, max_items: int | Non if np.issubdtype(obj.dtype, np.bool_): return bool(obj) if np.issubdtype(obj.dtype, np.complexfloating): - return complex(cast(complex, obj)) + return complex(cast("complex", obj)) if np.issubdtype(obj.dtype, np.str_): return str(obj) if np.issubdtype(obj.dtype, np.bytes_) and hasattr(obj, "tobytes"): diff --git a/src/backend/base/langflow/services/job_queue/service.py b/src/backend/base/langflow/services/job_queue/service.py index e042a9892..3d2a437a3 100644 --- a/src/backend/base/langflow/services/job_queue/service.py +++ b/src/backend/base/langflow/services/job_queue/service.py @@ -203,7 +203,7 @@ class JobQueueService(Service): return logger.info(f"Commencing cleanup for job_id {job_id}") - main_queue, event_manager, task = self._queues[job_id] + main_queue, _event_manager, task = self._queues[job_id] # Cancel the associated task if it is still running. if task and not task.done(): diff --git a/src/backend/base/langflow/services/socket/__init__.py b/src/backend/base/langflow/services/socket/__init__.py index e69de29bb..dc9fd4c06 100644 --- a/src/backend/base/langflow/services/socket/__init__.py +++ b/src/backend/base/langflow/services/socket/__init__.py @@ -0,0 +1 @@ +# noqa: A005 diff --git a/src/backend/base/langflow/services/telemetry/service.py b/src/backend/base/langflow/services/telemetry/service.py index fbf874bc9..15de1600e 100644 --- a/src/backend/base/langflow/services/telemetry/service.py +++ b/src/backend/base/langflow/services/telemetry/service.py @@ -93,7 +93,7 @@ class TelemetryService(Service): def _get_langflow_desktop(self) -> bool: # Coerce to bool, could be 1, 0, True, False, "1", "0", "True", "False" - return str(os.getenv("LANGFLOW_DESKTOP", "False")).lower() in ("1", "true") + return str(os.getenv("LANGFLOW_DESKTOP", "False")).lower() in {"1", "true"} async def log_package_version(self) -> None: python_version = ".".join(platform.python_version().split(".")[:2]) diff --git a/src/backend/tests/unit/api/v1/test_endpoints.py b/src/backend/tests/unit/api/v1/test_endpoints.py index 282026e7e..fed5bb35b 100644 --- a/src/backend/tests/unit/api/v1/test_endpoints.py +++ b/src/backend/tests/unit/api/v1/test_endpoints.py @@ -57,7 +57,7 @@ async def test_update_component_outputs(client: AsyncClient, logged_in_headers: async def test_update_component_model_name_options(client: AsyncClient, logged_in_headers: dict): """Test that model_name options are updated when selecting a provider.""" component = AgentComponent() - component_node, cc_instance = build_custom_component_template( + component_node, _cc_instance = build_custom_component_template( component, ) diff --git a/src/backend/tests/unit/components/vectorstores/test_chroma_vector_store_component.py b/src/backend/tests/unit/components/vectorstores/test_chroma_vector_store_component.py index 5775233d8..061c417a1 100644 --- a/src/backend/tests/unit/components/vectorstores/test_chroma_vector_store_component.py +++ b/src/backend/tests/unit/components/vectorstores/test_chroma_vector_store_component.py @@ -271,9 +271,9 @@ class TestChromaVectorStoreComponent(ComponentTestBaseWithoutClient): assert isinstance(data_obj, Data) assert "id" in data_obj.data assert "text" in data_obj.data - assert data_obj.data["text"] in ["Document 1", "Document 2"] + assert data_obj.data["text"] in {"Document 1", "Document 2"} assert "metadata_field" in data_obj.data - assert data_obj.data["metadata_field"] in ["value1", "value2"] + assert data_obj.data["metadata_field"] in {"value1", "value2"} def test_chroma_collection_to_data_without_metadata( self, component_class: type[ChromaVectorStoreComponent], default_kwargs: dict[str, Any] @@ -300,7 +300,7 @@ class TestChromaVectorStoreComponent(ComponentTestBaseWithoutClient): assert isinstance(data_obj, Data) assert "id" in data_obj.data assert "text" in data_obj.data - assert data_obj.data["text"] in ["Simple document 1", "Simple document 2"] + assert data_obj.data["text"] in {"Simple document 1", "Simple document 2"} def test_chroma_collection_to_data_empty_collection( self, component_class: type[ChromaVectorStoreComponent], default_kwargs: dict[str, Any] diff --git a/src/backend/tests/unit/graph/graph/test_utils.py b/src/backend/tests/unit/graph/graph/test_utils.py index fc695ea45..99d0c7004 100644 --- a/src/backend/tests/unit/graph/graph/test_utils.py +++ b/src/backend/tests/unit/graph/graph/test_utils.py @@ -869,7 +869,7 @@ def test_get_sorted_vertices_with_unconnected_graph(): predecessor_map = {vertex: data["predecessors"] for vertex, data in graph_dict.items()} def is_input_vertex(vertex_id: str) -> bool: - return vertex_id in ["A"] + return vertex_id == "A" def get_vertex_predecessors(vertex_id: str) -> list[str]: return predecessor_map[vertex_id] diff --git a/src/backend/tests/unit/serialization/test_serialization.py b/src/backend/tests/unit/serialization/test_serialization.py index e6cf46d78..e8110906c 100644 --- a/src/backend/tests/unit/serialization/test_serialization.py +++ b/src/backend/tests/unit/serialization/test_serialization.py @@ -263,13 +263,13 @@ class TestSerializationHypothesis: assert isinstance(serialize(np.uint64(42)), int) # Test floats - assert serialize(np.float64(3.14)) == 3.14 - assert isinstance(serialize(np.float64(3.14)), float) + assert serialize(np.float64(math.pi)) == math.pi + assert isinstance(serialize(np.float64(math.pi)), float) # Test float32 (need to account for precision differences) - float32_val = serialize(np.float32(3.14)) + float32_val = serialize(np.float32(math.pi)) assert isinstance(float32_val, float) - assert abs(float32_val - 3.14) < 1e-6 # Check if close enough + assert abs(float32_val - math.pi) < 1e-6 # Check if close enough # Test bool assert serialize(np.bool_(True)) is True # noqa: FBT003 diff --git a/uv.lock b/uv.lock index 25d56c54e..a6e51eb22 100644 --- a/uv.lock +++ b/uv.lock @@ -4544,7 +4544,7 @@ dev = [ { name = "pytest-xdist", specifier = ">=3.6.0" }, { name = "requests", specifier = ">=2.32.0" }, { name = "respx", specifier = ">=0.21.1" }, - { name = "ruff", specifier = ">=0.9.1,<0.10" }, + { name = "ruff", specifier = ">=0.9.7,<0.10" }, { name = "types-aiofiles", specifier = ">=24.1.0.20240626" }, { name = "types-google-cloud-ndb", specifier = ">=2.2.0.0" }, { name = "types-markdown", specifier = ">=3.7.0.20240822" },