feat: Enhance AstraDB tool component with vector search and metadata filter (#6887)

* feat: Enhance AstraDB tool component with advanced configuration and semantic search

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Format and Lint

* Format and Lint

* refactor: Improve AstraDB tool component with code cleanup and documentation

* [autofix.ci] apply automated fixes

* Lint & Format

* [autofix.ci] apply automated fixes

* Add search_query description input

* Format backend

* [autofix.ci] apply automated fixes

* Error message on Astra DB CQL Tool

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Format backend

* Enhance AstraDB CQL Tool Component with new tools_params input and update filtering logic. Deprecate partition and clustering keys inputs. Introduce attribute_name for improved field mapping.

* Add 'is_date' parameter to AstraDBToolComponent for date filtering and update filter logic to handle date values.

* Revert "Format backend"

This reverts commit 0f12efbd817d82087bc9b48af809e0384b1eb160.

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* format backend

* [autofix.ci] apply automated fixes

* Implement timestamp parsing in AstraDB components and update filtering logic to utilize the new method. Rename 'is_date' to 'is_timestamp' for clarity in parameter definitions.

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
Samuel Matioli 2025-03-21 15:30:11 -03:00 committed by GitHub
commit 4527c473be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 443 additions and 49 deletions

View file

@ -1,4 +1,5 @@
import os
from datetime import datetime, timezone
from typing import Any
from astrapy import Collection, DataAPIClient, Database
@ -6,13 +7,15 @@ from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.tools import StructuredTool, Tool
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput
from langflow.io import BoolInput, DictInput, HandleInput, IntInput, SecretStrInput, StrInput, TableInput
from langflow.logging import logger
from langflow.schema import Data
from langflow.schema.table import EditMode
class AstraDBToolComponent(LCToolComponent):
display_name: str = "Astra DB Tool"
description: str = "Create a tool to get transactional data from DataStax Astra DB Collection"
description: str = "Tool to run hybrid vector and metadata search on DataStax Astra DB Collection"
documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-tool"
icon: str = "AstraDB"
@ -20,19 +23,19 @@ class AstraDBToolComponent(LCToolComponent):
StrInput(
name="tool_name",
display_name="Tool Name",
info="The name of the tool.",
info="The name of the tool to be passed to the LLM.",
required=True,
),
StrInput(
name="tool_description",
display_name="Tool Description",
info="The description of the tool.",
info="Describe the tool to LLM. Add any information that can help the LLM to use the tool.",
required=True,
),
StrInput(
name="namespace",
display_name="Namespace Name",
info="The name of the namespace within Astra where the collection is be stored.",
name="keyspace",
display_name="Keyspace Name",
info="The name of the keyspace within Astra where the collection is stored.",
value="default_keyspace",
advanced=True,
),
@ -59,16 +62,90 @@ class AstraDBToolComponent(LCToolComponent):
StrInput(
name="projection_attributes",
display_name="Projection Attributes",
info="Attributes to return separated by comma.",
info="Attributes to be returned by the tool separated by comma.",
required=True,
value="*",
advanced=True,
),
TableInput(
name="tools_params_v2",
display_name="Tools Parameters",
info="Define the structure for the tool parameters. Describe the parameters "
"in a way the LLM can understand how to use them.",
required=False,
table_schema=[
{
"name": "name",
"display_name": "Name",
"type": "str",
"description": "Specify the name of the output field/parameter for the model.",
"default": "field",
"edit_mode": EditMode.INLINE,
},
{
"name": "attribute_name",
"display_name": "Attribute Name",
"type": "str",
"description": "Specify the attribute name to be filtered on the collection. "
"Leave empty if the attribute name is the same as the name of the field.",
"default": "",
"edit_mode": EditMode.INLINE,
},
{
"name": "description",
"display_name": "Description",
"type": "str",
"description": "Describe the purpose of the output field.",
"default": "description of field",
"edit_mode": EditMode.POPOVER,
},
{
"name": "metadata",
"display_name": "Is Metadata",
"type": "boolean",
"edit_mode": EditMode.INLINE,
"description": ("Indicate if the field is included in the metadata field."),
"options": ["True", "False"],
"default": "False",
},
{
"name": "mandatory",
"display_name": "Is Mandatory",
"type": "boolean",
"edit_mode": EditMode.INLINE,
"description": ("Indicate if the field is mandatory."),
"options": ["True", "False"],
"default": "False",
},
{
"name": "is_timestamp",
"display_name": "Is Timestamp",
"type": "boolean",
"edit_mode": EditMode.INLINE,
"description": ("Indicate if the field is a timestamp."),
"options": ["True", "False"],
"default": "False",
},
{
"name": "operator",
"display_name": "Operator",
"type": "str",
"description": "Set the operator for the field. "
"https://docs.datastax.com/en/astra-db-serverless/api-reference/documents.html#operators",
"default": "$eq",
"options": ["$gt", "$gte", "$lt", "$lte", "$eq", "$ne", "$in", "$nin", "$exists", "$all", "$size"],
"edit_mode": EditMode.INLINE,
},
],
value=[],
),
DictInput(
name="tool_params",
info="Attributes to filter and description to the model. Add ! for mandatory (e.g: !customerId)",
info="DEPRECATED: Attributes to filter and description to the model. "
"Add ! for mandatory (e.g: !customerId)",
display_name="Tool params",
is_list=True,
advanced=True,
),
DictInput(
name="static_filters",
@ -84,6 +161,29 @@ class AstraDBToolComponent(LCToolComponent):
advanced=True,
value=5,
),
BoolInput(
name="use_search_query",
display_name="Semantic Search",
info="When this parameter is activated, the search query parameter will be used to search the collection.",
advanced=False,
value=False,
),
BoolInput(
name="use_vectorize",
display_name="Use Astra DB Vectorize",
info="When this parameter is activated, Astra DB Vectorize method will be used to generate the embeddings.",
advanced=False,
value=False,
),
HandleInput(name="embedding", display_name="Embedding Model", input_types=["Embeddings"]),
StrInput(
name="semantic_search_instruction",
display_name="Semantic Search Instruction",
info="The instruction to use for the semantic search.",
required=True,
value="Search query to find relevant documents.",
advanced=True,
),
]
_cached_client: DataAPIClient | None = None
@ -94,12 +194,22 @@ class AstraDBToolComponent(LCToolComponent):
if self._cached_collection:
return self._cached_collection
cached_client = DataAPIClient(self.token)
cached_db = cached_client.get_database(self.api_endpoint, namespace=self.namespace)
self._cached_collection = cached_db.get_collection(self.collection_name)
return self._cached_collection
try:
cached_client = DataAPIClient(self.token)
cached_db = cached_client.get_database(self.api_endpoint, keyspace=self.keyspace)
self._cached_collection = cached_db.get_collection(self.collection_name)
except Exception as e:
msg = f"Error building collection: {e}"
raise ValueError(msg) from e
else:
return self._cached_collection
def create_args_schema(self) -> dict[str, BaseModel]:
"""DEPRECATED: This method is deprecated. Please use create_args_schema_v2 instead.
It is keep only for backward compatibility.
"""
logger.warning("This is the old way to define the tool parameters. Please use the new way.")
args: dict[str, tuple[Any, Field] | list[str]] = {}
for key in self.tool_params:
@ -108,6 +218,31 @@ class AstraDBToolComponent(LCToolComponent):
else: # Optional
args[key] = (str | None, Field(description=self.tool_params[key], default=None))
if self.use_search_query:
args["search_query"] = (
str | None,
Field(description="Search query to find relevant documents.", default=None),
)
model = create_model("ToolInput", **args, __base__=BaseModel)
return {"ToolInput": model}
def create_args_schema_v2(self) -> dict[str, BaseModel]:
"""Create the tool input schema using the new tool parameters configuration."""
args: dict[str, tuple[Any, Field] | list[str]] = {}
for tool_param in self.tools_params_v2:
if tool_param["mandatory"]:
args[tool_param["name"]] = (str, Field(description=tool_param["description"]))
else:
args[tool_param["name"]] = (str | None, Field(description=tool_param["description"], default=None))
if self.use_search_query:
args["search_query"] = (
str,
Field(description=self.semantic_search_instruction),
)
model = create_model("ToolInput", **args, __base__=BaseModel)
return {"ToolInput": model}
@ -117,7 +252,7 @@ class AstraDBToolComponent(LCToolComponent):
Returns:
Tool: The built Astra DB tool.
"""
schema_dict = self.create_args_schema()
schema_dict = self.create_args_schema() if len(self.tool_params.keys()) > 0 else self.create_args_schema_v2()
tool = StructuredTool.from_function(
name=self.tool_name,
@ -130,10 +265,18 @@ class AstraDBToolComponent(LCToolComponent):
return tool
def projection_args(self, input_str: str) -> dict:
def projection_args(self, input_str: str) -> dict | None:
"""Build the projection arguments for the AstraDB query."""
elements = input_str.split(",")
result = {}
if elements == ["*"]:
return None
# Force the projection to exclude the $vector field as it is not required by the tool
result["$vector"] = False
# Fields with ! as prefix should be removed from the projection
for element in elements:
if element.startswith("!"):
result[element[1:]] = False
@ -142,13 +285,127 @@ class AstraDBToolComponent(LCToolComponent):
return result
def parse_timestamp(self, timestamp_str: str) -> datetime:
"""Parse a timestamp string into Astra DB REST API format.
Args:
timestamp_str (str): Input timestamp string
Returns:
datetime: Datetime object
Raises:
ValueError: If the timestamp cannot be parsed
"""
# Common datetime formats to try
formats = [
"%Y-%m-%d", # 2024-03-21
"%Y-%m-%dT%H:%M:%S", # 2024-03-21T15:30:00
"%Y-%m-%dT%H:%M:%S%z", # 2024-03-21T15:30:00+0000
"%Y-%m-%d %H:%M:%S", # 2024-03-21 15:30:00
"%d/%m/%Y", # 21/03/2024
"%Y/%m/%d", # 2024/03/21
]
for fmt in formats:
try:
# Parse the date string
date_obj = datetime.strptime(timestamp_str, fmt).astimezone()
# If the parsed date has no timezone info, assume UTC
if date_obj.tzinfo is None:
date_obj = date_obj.replace(tzinfo=timezone.utc)
# Convert to UTC and format
return date_obj.astimezone(timezone.utc)
except ValueError:
continue
msg = f"Could not parse date: {timestamp_str}"
logger.error(msg)
raise ValueError(msg)
def build_filter(self, args: dict, filter_settings: list) -> dict:
"""Build filter dictionary for AstraDB query.
Args:
args: Dictionary of arguments from the tool
filter_settings: List of filter settings from tools_params_v2
Returns:
Dictionary containing the filter conditions
"""
filters = {**self.static_filters}
for key, value in args.items():
# Skip search_query as it's handled separately
if key == "search_query":
continue
filter_setting = next((x for x in filter_settings if x["name"] == key), None)
if filter_setting and value is not None:
field_name = filter_setting["attribute_name"] if filter_setting["attribute_name"] else key
filter_key = field_name if not filter_setting["metadata"] else f"metadata.{field_name}"
if filter_setting["operator"] == "$exists":
filters[filter_key] = {**filters.get(filter_key, {}), filter_setting["operator"]: True}
elif filter_setting["operator"] in ["$in", "$nin", "$all"]:
filters[filter_key] = {
**filters.get(filter_key, {}),
filter_setting["operator"]: value.split(",") if isinstance(value, str) else value,
}
elif filter_setting["is_timestamp"] == True: # noqa: E712
try:
filters[filter_key] = {
**filters.get(filter_key, {}),
filter_setting["operator"]: self.parse_timestamp(value),
}
except ValueError as e:
msg = f"Error parsing timestamp: {e} - Use the prompt to specify the date in the correct format"
logger.error(msg)
raise ValueError(msg) from e
else:
filters[filter_key] = {**filters.get(filter_key, {}), filter_setting["operator"]: value}
return filters
def run_model(self, **args) -> Data | list[Data]:
"""Run the query to get the data from the AstraDB collection."""
collection = self._build_collection()
results = collection.find(
({**args, **self.static_filters}),
projection=self.projection_args(self.projection_attributes),
limit=self.number_of_results,
)
sort = {}
# Build filters using the new method
filters = self.build_filter(args, self.tools_params_v2)
# Build the vector search on
if self.use_search_query and args["search_query"] is not None and args["search_query"] != "":
if self.use_vectorize:
sort["$vectorize"] = args["search_query"]
else:
if self.embedding is None:
msg = "Embedding model is not set. Please set the embedding model or use Astra DB Vectorize."
logger.error(msg)
raise ValueError(msg)
embedding_query = self.embedding.embed_query(args["search_query"])
sort["$vector"] = embedding_query
del args["search_query"]
find_options = {
"filter": filters,
"limit": self.number_of_results,
"sort": sort,
}
projection = self.projection_args(self.projection_attributes)
if projection and len(projection) > 0:
find_options["projection"] = projection
try:
results = collection.find(**find_options)
except Exception as e:
msg = f"Error on Astra DB Tool {self.tool_name} request: {e}"
logger.error(msg)
raise ValueError(msg) from e
logger.info(f"Tool {self.tool_name} executed`")
data: list[Data] = [Data(data=doc) for doc in results]
self.status = data

View file

@ -1,4 +1,6 @@
import json
import urllib
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Any
@ -7,8 +9,10 @@ from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.tools import StructuredTool, Tool
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput, TableInput
from langflow.logging import logger
from langflow.schema import Data
from langflow.schema.table import EditMode
class AstraDBCQLToolComponent(LCToolComponent):
@ -61,18 +65,85 @@ class AstraDBCQLToolComponent(LCToolComponent):
value="*",
advanced=True,
),
TableInput(
name="tools_params",
display_name="Tools Parameters",
info="Define the structure for the tool parameters. Describe the parameters "
"in a way the LLM can understand how to use them. Add the parameters "
"respecting the table schema (Partition Keys, Clustering Keys and Indexed Fields).",
required=False,
table_schema=[
{
"name": "name",
"display_name": "Name",
"type": "str",
"description": "Name of the field/parameter to be used by the model.",
"default": "field",
"edit_mode": EditMode.INLINE,
},
{
"name": "field_name",
"display_name": "Field Name",
"type": "str",
"description": "Specify the column name to be filtered on the table. "
"Leave empty if the attribute name is the same as the name of the field.",
"default": "",
"edit_mode": EditMode.INLINE,
},
{
"name": "description",
"display_name": "Description",
"type": "str",
"description": "Describe the purpose of the parameter.",
"default": "description of tool parameter",
"edit_mode": EditMode.POPOVER,
},
{
"name": "mandatory",
"display_name": "Is Mandatory",
"type": "boolean",
"edit_mode": EditMode.INLINE,
"description": ("Indicate if the field is mandatory."),
"options": ["True", "False"],
"default": "False",
},
{
"name": "is_timestamp",
"display_name": "Is Timestamp",
"type": "boolean",
"edit_mode": EditMode.INLINE,
"description": ("Indicate if the field is a timestamp."),
"options": ["True", "False"],
"default": "False",
},
{
"name": "operator",
"display_name": "Operator",
"type": "str",
"description": "Set the operator for the field. "
"https://docs.datastax.com/en/astra-db-serverless/api-reference/documents.html#operators",
"default": "$eq",
"options": ["$gt", "$gte", "$lt", "$lte", "$eq", "$ne", "$in", "$nin", "$exists", "$all", "$size"],
"edit_mode": EditMode.INLINE,
},
],
value=[],
),
DictInput(
name="partition_keys",
display_name="Partition Keys",
display_name="DEPRECATED: Partition Keys",
is_list=True,
info="Field name and description to the model",
required=True,
required=False,
advanced=True,
),
DictInput(
name="clustering_keys",
display_name="Clustering Keys",
display_name="DEPRECATED: Clustering Keys",
is_list=True,
info="Field name and description to the model",
required=False,
advanced=True,
),
DictInput(
name="static_filters",
@ -90,22 +161,84 @@ class AstraDBCQLToolComponent(LCToolComponent):
),
]
def parse_timestamp(self, timestamp_str: str) -> str:
"""Parse a timestamp string into Astra DB REST API format.
Args:
timestamp_str (str): Input timestamp string
Returns:
str: Formatted timestamp string in YYYY-MM-DDTHH:MI:SS.000Z format
Raises:
ValueError: If the timestamp cannot be parsed
"""
# Common datetime formats to try
formats = [
"%Y-%m-%d", # 2024-03-21
"%Y-%m-%dT%H:%M:%S", # 2024-03-21T15:30:00
"%Y-%m-%dT%H:%M:%S%z", # 2024-03-21T15:30:00+0000
"%Y-%m-%d %H:%M:%S", # 2024-03-21 15:30:00
"%d/%m/%Y", # 21/03/2024
"%Y/%m/%d", # 2024/03/21
]
for fmt in formats:
try:
# Parse the date string
date_obj = datetime.strptime(timestamp_str, fmt).astimezone()
# If the parsed date has no timezone info, assume UTC
if date_obj.tzinfo is None:
date_obj = date_obj.replace(tzinfo=timezone.utc)
# Convert to UTC and format
utc_date = date_obj.astimezone(timezone.utc)
return utc_date.strftime("%Y-%m-%dT%H:%M:%S.000Z")
except ValueError:
continue
msg = f"Could not parse date: {timestamp_str}"
logger.error(msg)
raise ValueError(msg)
def astra_rest(self, args):
headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}"}
astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/"
key = []
where = {}
# Partition keys are mandatory
key = [self.partition_keys[k] for k in self.partition_keys]
for param in self.tools_params:
field_name = param["field_name"] if param["field_name"] else param["name"]
field_value = None
# Clustering keys are optional
for k in self.clustering_keys:
if k in args:
key.append(args[k])
elif self.static_filters[k] is not None:
key.append(self.static_filters[k])
if field_name in self.static_filters:
field_value = self.static_filters[field_name]
elif param["name"] in args:
field_value = args[param["name"]]
url = f"{astra_url}{'/'.join(key)}?page-size={self.number_of_results}"
if field_value is None:
continue
if param["is_timestamp"] == True: # noqa: E712
try:
field_value = self.parse_timestamp(field_value)
except ValueError as e:
msg = f"Error parsing timestamp: {e} - Use the prompt to specify the date in the correct format"
logger.error(msg)
raise ValueError(msg) from e
if param["operator"] == "$exists":
where[field_name] = {**where.get(field_name, {}), param["operator"]: True}
elif param["operator"] in ["$in", "$nin", "$all"]:
where[field_name] = {
**where.get(field_name, {}),
param["operator"]: field_value.split(",") if isinstance(field_value, str) else field_value,
}
else:
where[field_name] = {**where.get(field_name, {}), param["operator"]: field_value}
url = f"{astra_url}?page-size={self.number_of_results}"
url += f"&where={json.dumps(where)}"
if self.projection_fields != "*":
url += f"&fields={urllib.parse.quote(self.projection_fields.replace(' ', ''))}"
@ -113,7 +246,9 @@ class AstraDBCQLToolComponent(LCToolComponent):
res = requests.request("GET", url=url, headers=headers, timeout=10)
if int(res.status_code) >= HTTPStatus.BAD_REQUEST:
return res.text
msg = f"Error on Astra DB CQL Tool {self.tool_name} request: {res.text}"
logger.error(msg)
raise ValueError(msg)
try:
res_data = res.json()
@ -124,18 +259,13 @@ class AstraDBCQLToolComponent(LCToolComponent):
def create_args_schema(self) -> dict[str, BaseModel]:
args: dict[str, tuple[Any, Field]] = {}
for key in self.partition_keys:
# Partition keys are mandatory is it doesn't have a static filter
if key not in self.static_filters:
args[key] = (str, Field(description=self.partition_keys[key]))
for key in self.clustering_keys:
# Partition keys are mandatory if has the exclamation mark and doesn't have a static filter
if key not in self.static_filters:
if key.startswith("!"): # Mandatory
args[key[1:]] = (str, Field(description=self.clustering_keys[key]))
else: # Optional
args[key] = (str | None, Field(description=self.clustering_keys[key], default=None))
for param in self.tools_params:
field_name = param["field_name"] if param["field_name"] else param["name"]
if field_name not in self.static_filters:
if param["mandatory"]:
args[param["name"]] = (str, Field(description=param["description"]))
else:
args[param["name"]] = (str | None, Field(description=param["description"], default=None))
model = create_model("ToolInput", **args, __base__=BaseModel)
return {"ToolInput": model}
@ -172,6 +302,13 @@ class AstraDBCQLToolComponent(LCToolComponent):
def run_model(self, **args) -> Data | list[Data]:
results = self.astra_rest(args)
data: list[Data] = [Data(data=doc) for doc in results]
data: list[Data] = []
if isinstance(results, list):
data = [Data(data=doc) for doc in results]
else:
self.status = results
return []
self.status = data
return results
return data