From 4527c473beae8f75b131ef0fceda92fafb604d82 Mon Sep 17 00:00:00 2001 From: Samuel Matioli <101875785+smatiolids@users.noreply.github.com> Date: Fri, 21 Mar 2025 15:30:11 -0300 Subject: [PATCH] feat: Enhance AstraDB tool component with vector search and metadata filter (#6887) * feat: Enhance AstraDB tool component with advanced configuration and semantic search * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Format and Lint * Format and Lint * refactor: Improve AstraDB tool component with code cleanup and documentation * [autofix.ci] apply automated fixes * Lint & Format * [autofix.ci] apply automated fixes * Add search_query description input * Format backend * [autofix.ci] apply automated fixes * Error message on Astra DB CQL Tool * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Format backend * Enhance AstraDB CQL Tool Component with new tools_params input and update filtering logic. Deprecate partition and clustering keys inputs. Introduce attribute_name for improved field mapping. * Add 'is_date' parameter to AstraDBToolComponent for date filtering and update filter logic to handle date values. * Revert "Format backend" This reverts commit 0f12efbd817d82087bc9b48af809e0384b1eb160. * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * format backend * [autofix.ci] apply automated fixes * Implement timestamp parsing in AstraDB components and update filtering logic to utilize the new method. Rename 'is_date' to 'is_timestamp' for clarity in parameter definitions. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose Co-authored-by: Eric Hare --- .../base/langflow/components/tools/astradb.py | 297 ++++++++++++++++-- .../langflow/components/tools/astradb_cql.py | 195 ++++++++++-- 2 files changed, 443 insertions(+), 49 deletions(-) diff --git a/src/backend/base/langflow/components/tools/astradb.py b/src/backend/base/langflow/components/tools/astradb.py index 0376f2d64..157cdeda5 100644 --- a/src/backend/base/langflow/components/tools/astradb.py +++ b/src/backend/base/langflow/components/tools/astradb.py @@ -1,4 +1,5 @@ import os +from datetime import datetime, timezone from typing import Any from astrapy import Collection, DataAPIClient, Database @@ -6,13 +7,15 @@ from langchain.pydantic_v1 import BaseModel, Field, create_model from langchain_core.tools import StructuredTool, Tool from langflow.base.langchain_utilities.model import LCToolComponent -from langflow.io import DictInput, IntInput, SecretStrInput, StrInput +from langflow.io import BoolInput, DictInput, HandleInput, IntInput, SecretStrInput, StrInput, TableInput +from langflow.logging import logger from langflow.schema import Data +from langflow.schema.table import EditMode class AstraDBToolComponent(LCToolComponent): display_name: str = "Astra DB Tool" - description: str = "Create a tool to get transactional data from DataStax Astra DB Collection" + description: str = "Tool to run hybrid vector and metadata search on DataStax Astra DB Collection" documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-tool" icon: str = "AstraDB" @@ -20,19 +23,19 @@ class AstraDBToolComponent(LCToolComponent): StrInput( name="tool_name", display_name="Tool Name", - info="The name of the tool.", + info="The name of the tool to be passed to the LLM.", required=True, ), StrInput( name="tool_description", display_name="Tool Description", - info="The description of the tool.", + info="Describe the tool to LLM. Add any information that can help the LLM to use the tool.", required=True, ), StrInput( - name="namespace", - display_name="Namespace Name", - info="The name of the namespace within Astra where the collection is be stored.", + name="keyspace", + display_name="Keyspace Name", + info="The name of the keyspace within Astra where the collection is stored.", value="default_keyspace", advanced=True, ), @@ -59,16 +62,90 @@ class AstraDBToolComponent(LCToolComponent): StrInput( name="projection_attributes", display_name="Projection Attributes", - info="Attributes to return separated by comma.", + info="Attributes to be returned by the tool separated by comma.", required=True, value="*", advanced=True, ), + TableInput( + name="tools_params_v2", + display_name="Tools Parameters", + info="Define the structure for the tool parameters. Describe the parameters " + "in a way the LLM can understand how to use them.", + required=False, + table_schema=[ + { + "name": "name", + "display_name": "Name", + "type": "str", + "description": "Specify the name of the output field/parameter for the model.", + "default": "field", + "edit_mode": EditMode.INLINE, + }, + { + "name": "attribute_name", + "display_name": "Attribute Name", + "type": "str", + "description": "Specify the attribute name to be filtered on the collection. " + "Leave empty if the attribute name is the same as the name of the field.", + "default": "", + "edit_mode": EditMode.INLINE, + }, + { + "name": "description", + "display_name": "Description", + "type": "str", + "description": "Describe the purpose of the output field.", + "default": "description of field", + "edit_mode": EditMode.POPOVER, + }, + { + "name": "metadata", + "display_name": "Is Metadata", + "type": "boolean", + "edit_mode": EditMode.INLINE, + "description": ("Indicate if the field is included in the metadata field."), + "options": ["True", "False"], + "default": "False", + }, + { + "name": "mandatory", + "display_name": "Is Mandatory", + "type": "boolean", + "edit_mode": EditMode.INLINE, + "description": ("Indicate if the field is mandatory."), + "options": ["True", "False"], + "default": "False", + }, + { + "name": "is_timestamp", + "display_name": "Is Timestamp", + "type": "boolean", + "edit_mode": EditMode.INLINE, + "description": ("Indicate if the field is a timestamp."), + "options": ["True", "False"], + "default": "False", + }, + { + "name": "operator", + "display_name": "Operator", + "type": "str", + "description": "Set the operator for the field. " + "https://docs.datastax.com/en/astra-db-serverless/api-reference/documents.html#operators", + "default": "$eq", + "options": ["$gt", "$gte", "$lt", "$lte", "$eq", "$ne", "$in", "$nin", "$exists", "$all", "$size"], + "edit_mode": EditMode.INLINE, + }, + ], + value=[], + ), DictInput( name="tool_params", - info="Attributes to filter and description to the model. Add ! for mandatory (e.g: !customerId)", + info="DEPRECATED: Attributes to filter and description to the model. " + "Add ! for mandatory (e.g: !customerId)", display_name="Tool params", is_list=True, + advanced=True, ), DictInput( name="static_filters", @@ -84,6 +161,29 @@ class AstraDBToolComponent(LCToolComponent): advanced=True, value=5, ), + BoolInput( + name="use_search_query", + display_name="Semantic Search", + info="When this parameter is activated, the search query parameter will be used to search the collection.", + advanced=False, + value=False, + ), + BoolInput( + name="use_vectorize", + display_name="Use Astra DB Vectorize", + info="When this parameter is activated, Astra DB Vectorize method will be used to generate the embeddings.", + advanced=False, + value=False, + ), + HandleInput(name="embedding", display_name="Embedding Model", input_types=["Embeddings"]), + StrInput( + name="semantic_search_instruction", + display_name="Semantic Search Instruction", + info="The instruction to use for the semantic search.", + required=True, + value="Search query to find relevant documents.", + advanced=True, + ), ] _cached_client: DataAPIClient | None = None @@ -94,12 +194,22 @@ class AstraDBToolComponent(LCToolComponent): if self._cached_collection: return self._cached_collection - cached_client = DataAPIClient(self.token) - cached_db = cached_client.get_database(self.api_endpoint, namespace=self.namespace) - self._cached_collection = cached_db.get_collection(self.collection_name) - return self._cached_collection + try: + cached_client = DataAPIClient(self.token) + cached_db = cached_client.get_database(self.api_endpoint, keyspace=self.keyspace) + self._cached_collection = cached_db.get_collection(self.collection_name) + except Exception as e: + msg = f"Error building collection: {e}" + raise ValueError(msg) from e + else: + return self._cached_collection def create_args_schema(self) -> dict[str, BaseModel]: + """DEPRECATED: This method is deprecated. Please use create_args_schema_v2 instead. + + It is keep only for backward compatibility. + """ + logger.warning("This is the old way to define the tool parameters. Please use the new way.") args: dict[str, tuple[Any, Field] | list[str]] = {} for key in self.tool_params: @@ -108,6 +218,31 @@ class AstraDBToolComponent(LCToolComponent): else: # Optional args[key] = (str | None, Field(description=self.tool_params[key], default=None)) + if self.use_search_query: + args["search_query"] = ( + str | None, + Field(description="Search query to find relevant documents.", default=None), + ) + + model = create_model("ToolInput", **args, __base__=BaseModel) + return {"ToolInput": model} + + def create_args_schema_v2(self) -> dict[str, BaseModel]: + """Create the tool input schema using the new tool parameters configuration.""" + args: dict[str, tuple[Any, Field] | list[str]] = {} + + for tool_param in self.tools_params_v2: + if tool_param["mandatory"]: + args[tool_param["name"]] = (str, Field(description=tool_param["description"])) + else: + args[tool_param["name"]] = (str | None, Field(description=tool_param["description"], default=None)) + + if self.use_search_query: + args["search_query"] = ( + str, + Field(description=self.semantic_search_instruction), + ) + model = create_model("ToolInput", **args, __base__=BaseModel) return {"ToolInput": model} @@ -117,7 +252,7 @@ class AstraDBToolComponent(LCToolComponent): Returns: Tool: The built Astra DB tool. """ - schema_dict = self.create_args_schema() + schema_dict = self.create_args_schema() if len(self.tool_params.keys()) > 0 else self.create_args_schema_v2() tool = StructuredTool.from_function( name=self.tool_name, @@ -130,10 +265,18 @@ class AstraDBToolComponent(LCToolComponent): return tool - def projection_args(self, input_str: str) -> dict: + def projection_args(self, input_str: str) -> dict | None: + """Build the projection arguments for the AstraDB query.""" elements = input_str.split(",") result = {} + if elements == ["*"]: + return None + + # Force the projection to exclude the $vector field as it is not required by the tool + result["$vector"] = False + + # Fields with ! as prefix should be removed from the projection for element in elements: if element.startswith("!"): result[element[1:]] = False @@ -142,13 +285,127 @@ class AstraDBToolComponent(LCToolComponent): return result + def parse_timestamp(self, timestamp_str: str) -> datetime: + """Parse a timestamp string into Astra DB REST API format. + + Args: + timestamp_str (str): Input timestamp string + + Returns: + datetime: Datetime object + + Raises: + ValueError: If the timestamp cannot be parsed + """ + # Common datetime formats to try + formats = [ + "%Y-%m-%d", # 2024-03-21 + "%Y-%m-%dT%H:%M:%S", # 2024-03-21T15:30:00 + "%Y-%m-%dT%H:%M:%S%z", # 2024-03-21T15:30:00+0000 + "%Y-%m-%d %H:%M:%S", # 2024-03-21 15:30:00 + "%d/%m/%Y", # 21/03/2024 + "%Y/%m/%d", # 2024/03/21 + ] + + for fmt in formats: + try: + # Parse the date string + date_obj = datetime.strptime(timestamp_str, fmt).astimezone() + + # If the parsed date has no timezone info, assume UTC + if date_obj.tzinfo is None: + date_obj = date_obj.replace(tzinfo=timezone.utc) + + # Convert to UTC and format + return date_obj.astimezone(timezone.utc) + + except ValueError: + continue + + msg = f"Could not parse date: {timestamp_str}" + logger.error(msg) + raise ValueError(msg) + + def build_filter(self, args: dict, filter_settings: list) -> dict: + """Build filter dictionary for AstraDB query. + + Args: + args: Dictionary of arguments from the tool + filter_settings: List of filter settings from tools_params_v2 + Returns: + Dictionary containing the filter conditions + """ + filters = {**self.static_filters} + + for key, value in args.items(): + # Skip search_query as it's handled separately + if key == "search_query": + continue + + filter_setting = next((x for x in filter_settings if x["name"] == key), None) + if filter_setting and value is not None: + field_name = filter_setting["attribute_name"] if filter_setting["attribute_name"] else key + filter_key = field_name if not filter_setting["metadata"] else f"metadata.{field_name}" + if filter_setting["operator"] == "$exists": + filters[filter_key] = {**filters.get(filter_key, {}), filter_setting["operator"]: True} + elif filter_setting["operator"] in ["$in", "$nin", "$all"]: + filters[filter_key] = { + **filters.get(filter_key, {}), + filter_setting["operator"]: value.split(",") if isinstance(value, str) else value, + } + elif filter_setting["is_timestamp"] == True: # noqa: E712 + try: + filters[filter_key] = { + **filters.get(filter_key, {}), + filter_setting["operator"]: self.parse_timestamp(value), + } + except ValueError as e: + msg = f"Error parsing timestamp: {e} - Use the prompt to specify the date in the correct format" + logger.error(msg) + raise ValueError(msg) from e + else: + filters[filter_key] = {**filters.get(filter_key, {}), filter_setting["operator"]: value} + return filters + def run_model(self, **args) -> Data | list[Data]: + """Run the query to get the data from the AstraDB collection.""" collection = self._build_collection() - results = collection.find( - ({**args, **self.static_filters}), - projection=self.projection_args(self.projection_attributes), - limit=self.number_of_results, - ) + sort = {} + + # Build filters using the new method + filters = self.build_filter(args, self.tools_params_v2) + + # Build the vector search on + if self.use_search_query and args["search_query"] is not None and args["search_query"] != "": + if self.use_vectorize: + sort["$vectorize"] = args["search_query"] + else: + if self.embedding is None: + msg = "Embedding model is not set. Please set the embedding model or use Astra DB Vectorize." + logger.error(msg) + raise ValueError(msg) + embedding_query = self.embedding.embed_query(args["search_query"]) + sort["$vector"] = embedding_query + del args["search_query"] + + find_options = { + "filter": filters, + "limit": self.number_of_results, + "sort": sort, + } + + projection = self.projection_args(self.projection_attributes) + if projection and len(projection) > 0: + find_options["projection"] = projection + + try: + results = collection.find(**find_options) + except Exception as e: + msg = f"Error on Astra DB Tool {self.tool_name} request: {e}" + logger.error(msg) + raise ValueError(msg) from e + + logger.info(f"Tool {self.tool_name} executed`") data: list[Data] = [Data(data=doc) for doc in results] self.status = data diff --git a/src/backend/base/langflow/components/tools/astradb_cql.py b/src/backend/base/langflow/components/tools/astradb_cql.py index 652f4db52..1408a7b9c 100644 --- a/src/backend/base/langflow/components/tools/astradb_cql.py +++ b/src/backend/base/langflow/components/tools/astradb_cql.py @@ -1,4 +1,6 @@ +import json import urllib +from datetime import datetime, timezone from http import HTTPStatus from typing import Any @@ -7,8 +9,10 @@ from langchain.pydantic_v1 import BaseModel, Field, create_model from langchain_core.tools import StructuredTool, Tool from langflow.base.langchain_utilities.model import LCToolComponent -from langflow.io import DictInput, IntInput, SecretStrInput, StrInput +from langflow.io import DictInput, IntInput, SecretStrInput, StrInput, TableInput +from langflow.logging import logger from langflow.schema import Data +from langflow.schema.table import EditMode class AstraDBCQLToolComponent(LCToolComponent): @@ -61,18 +65,85 @@ class AstraDBCQLToolComponent(LCToolComponent): value="*", advanced=True, ), + TableInput( + name="tools_params", + display_name="Tools Parameters", + info="Define the structure for the tool parameters. Describe the parameters " + "in a way the LLM can understand how to use them. Add the parameters " + "respecting the table schema (Partition Keys, Clustering Keys and Indexed Fields).", + required=False, + table_schema=[ + { + "name": "name", + "display_name": "Name", + "type": "str", + "description": "Name of the field/parameter to be used by the model.", + "default": "field", + "edit_mode": EditMode.INLINE, + }, + { + "name": "field_name", + "display_name": "Field Name", + "type": "str", + "description": "Specify the column name to be filtered on the table. " + "Leave empty if the attribute name is the same as the name of the field.", + "default": "", + "edit_mode": EditMode.INLINE, + }, + { + "name": "description", + "display_name": "Description", + "type": "str", + "description": "Describe the purpose of the parameter.", + "default": "description of tool parameter", + "edit_mode": EditMode.POPOVER, + }, + { + "name": "mandatory", + "display_name": "Is Mandatory", + "type": "boolean", + "edit_mode": EditMode.INLINE, + "description": ("Indicate if the field is mandatory."), + "options": ["True", "False"], + "default": "False", + }, + { + "name": "is_timestamp", + "display_name": "Is Timestamp", + "type": "boolean", + "edit_mode": EditMode.INLINE, + "description": ("Indicate if the field is a timestamp."), + "options": ["True", "False"], + "default": "False", + }, + { + "name": "operator", + "display_name": "Operator", + "type": "str", + "description": "Set the operator for the field. " + "https://docs.datastax.com/en/astra-db-serverless/api-reference/documents.html#operators", + "default": "$eq", + "options": ["$gt", "$gte", "$lt", "$lte", "$eq", "$ne", "$in", "$nin", "$exists", "$all", "$size"], + "edit_mode": EditMode.INLINE, + }, + ], + value=[], + ), DictInput( name="partition_keys", - display_name="Partition Keys", + display_name="DEPRECATED: Partition Keys", is_list=True, info="Field name and description to the model", - required=True, + required=False, + advanced=True, ), DictInput( name="clustering_keys", - display_name="Clustering Keys", + display_name="DEPRECATED: Clustering Keys", is_list=True, info="Field name and description to the model", + required=False, + advanced=True, ), DictInput( name="static_filters", @@ -90,22 +161,84 @@ class AstraDBCQLToolComponent(LCToolComponent): ), ] + def parse_timestamp(self, timestamp_str: str) -> str: + """Parse a timestamp string into Astra DB REST API format. + + Args: + timestamp_str (str): Input timestamp string + + Returns: + str: Formatted timestamp string in YYYY-MM-DDTHH:MI:SS.000Z format + + Raises: + ValueError: If the timestamp cannot be parsed + """ + # Common datetime formats to try + formats = [ + "%Y-%m-%d", # 2024-03-21 + "%Y-%m-%dT%H:%M:%S", # 2024-03-21T15:30:00 + "%Y-%m-%dT%H:%M:%S%z", # 2024-03-21T15:30:00+0000 + "%Y-%m-%d %H:%M:%S", # 2024-03-21 15:30:00 + "%d/%m/%Y", # 21/03/2024 + "%Y/%m/%d", # 2024/03/21 + ] + + for fmt in formats: + try: + # Parse the date string + date_obj = datetime.strptime(timestamp_str, fmt).astimezone() + + # If the parsed date has no timezone info, assume UTC + if date_obj.tzinfo is None: + date_obj = date_obj.replace(tzinfo=timezone.utc) + + # Convert to UTC and format + utc_date = date_obj.astimezone(timezone.utc) + return utc_date.strftime("%Y-%m-%dT%H:%M:%S.000Z") + except ValueError: + continue + + msg = f"Could not parse date: {timestamp_str}" + logger.error(msg) + raise ValueError(msg) + def astra_rest(self, args): headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}"} astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/" - key = [] + where = {} - # Partition keys are mandatory - key = [self.partition_keys[k] for k in self.partition_keys] + for param in self.tools_params: + field_name = param["field_name"] if param["field_name"] else param["name"] + field_value = None - # Clustering keys are optional - for k in self.clustering_keys: - if k in args: - key.append(args[k]) - elif self.static_filters[k] is not None: - key.append(self.static_filters[k]) + if field_name in self.static_filters: + field_value = self.static_filters[field_name] + elif param["name"] in args: + field_value = args[param["name"]] - url = f"{astra_url}{'/'.join(key)}?page-size={self.number_of_results}" + if field_value is None: + continue + + if param["is_timestamp"] == True: # noqa: E712 + try: + field_value = self.parse_timestamp(field_value) + except ValueError as e: + msg = f"Error parsing timestamp: {e} - Use the prompt to specify the date in the correct format" + logger.error(msg) + raise ValueError(msg) from e + + if param["operator"] == "$exists": + where[field_name] = {**where.get(field_name, {}), param["operator"]: True} + elif param["operator"] in ["$in", "$nin", "$all"]: + where[field_name] = { + **where.get(field_name, {}), + param["operator"]: field_value.split(",") if isinstance(field_value, str) else field_value, + } + else: + where[field_name] = {**where.get(field_name, {}), param["operator"]: field_value} + + url = f"{astra_url}?page-size={self.number_of_results}" + url += f"&where={json.dumps(where)}" if self.projection_fields != "*": url += f"&fields={urllib.parse.quote(self.projection_fields.replace(' ', ''))}" @@ -113,7 +246,9 @@ class AstraDBCQLToolComponent(LCToolComponent): res = requests.request("GET", url=url, headers=headers, timeout=10) if int(res.status_code) >= HTTPStatus.BAD_REQUEST: - return res.text + msg = f"Error on Astra DB CQL Tool {self.tool_name} request: {res.text}" + logger.error(msg) + raise ValueError(msg) try: res_data = res.json() @@ -124,18 +259,13 @@ class AstraDBCQLToolComponent(LCToolComponent): def create_args_schema(self) -> dict[str, BaseModel]: args: dict[str, tuple[Any, Field]] = {} - for key in self.partition_keys: - # Partition keys are mandatory is it doesn't have a static filter - if key not in self.static_filters: - args[key] = (str, Field(description=self.partition_keys[key])) - - for key in self.clustering_keys: - # Partition keys are mandatory if has the exclamation mark and doesn't have a static filter - if key not in self.static_filters: - if key.startswith("!"): # Mandatory - args[key[1:]] = (str, Field(description=self.clustering_keys[key])) - else: # Optional - args[key] = (str | None, Field(description=self.clustering_keys[key], default=None)) + for param in self.tools_params: + field_name = param["field_name"] if param["field_name"] else param["name"] + if field_name not in self.static_filters: + if param["mandatory"]: + args[param["name"]] = (str, Field(description=param["description"])) + else: + args[param["name"]] = (str | None, Field(description=param["description"], default=None)) model = create_model("ToolInput", **args, __base__=BaseModel) return {"ToolInput": model} @@ -172,6 +302,13 @@ class AstraDBCQLToolComponent(LCToolComponent): def run_model(self, **args) -> Data | list[Data]: results = self.astra_rest(args) - data: list[Data] = [Data(data=doc) for doc in results] + data: list[Data] = [] + + if isinstance(results, list): + data = [Data(data=doc) for doc in results] + else: + self.status = results + return [] + self.status = data - return results + return data