feat: Enhance AstraDB tool component with vector search and metadata filter (#6887)
* feat: Enhance AstraDB tool component with advanced configuration and semantic search * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Format and Lint * Format and Lint * refactor: Improve AstraDB tool component with code cleanup and documentation * [autofix.ci] apply automated fixes * Lint & Format * [autofix.ci] apply automated fixes * Add search_query description input * Format backend * [autofix.ci] apply automated fixes * Error message on Astra DB CQL Tool * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Format backend * Enhance AstraDB CQL Tool Component with new tools_params input and update filtering logic. Deprecate partition and clustering keys inputs. Introduce attribute_name for improved field mapping. * Add 'is_date' parameter to AstraDBToolComponent for date filtering and update filter logic to handle date values. * Revert "Format backend" This reverts commit 0f12efbd817d82087bc9b48af809e0384b1eb160. * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * format backend * [autofix.ci] apply automated fixes * Implement timestamp parsing in AstraDB components and update filtering logic to utilize the new method. Rename 'is_date' to 'is_timestamp' for clarity in parameter definitions. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
parent
8c2f6addf2
commit
4527c473be
2 changed files with 443 additions and 49 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from astrapy import Collection, DataAPIClient, Database
|
||||
|
|
@ -6,13 +7,15 @@ from langchain.pydantic_v1 import BaseModel, Field, create_model
|
|||
from langchain_core.tools import StructuredTool, Tool
|
||||
|
||||
from langflow.base.langchain_utilities.model import LCToolComponent
|
||||
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.io import BoolInput, DictInput, HandleInput, IntInput, SecretStrInput, StrInput, TableInput
|
||||
from langflow.logging import logger
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.table import EditMode
|
||||
|
||||
|
||||
class AstraDBToolComponent(LCToolComponent):
|
||||
display_name: str = "Astra DB Tool"
|
||||
description: str = "Create a tool to get transactional data from DataStax Astra DB Collection"
|
||||
description: str = "Tool to run hybrid vector and metadata search on DataStax Astra DB Collection"
|
||||
documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-tool"
|
||||
icon: str = "AstraDB"
|
||||
|
||||
|
|
@ -20,19 +23,19 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
StrInput(
|
||||
name="tool_name",
|
||||
display_name="Tool Name",
|
||||
info="The name of the tool.",
|
||||
info="The name of the tool to be passed to the LLM.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
name="tool_description",
|
||||
display_name="Tool Description",
|
||||
info="The description of the tool.",
|
||||
info="Describe the tool to LLM. Add any information that can help the LLM to use the tool.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
name="namespace",
|
||||
display_name="Namespace Name",
|
||||
info="The name of the namespace within Astra where the collection is be stored.",
|
||||
name="keyspace",
|
||||
display_name="Keyspace Name",
|
||||
info="The name of the keyspace within Astra where the collection is stored.",
|
||||
value="default_keyspace",
|
||||
advanced=True,
|
||||
),
|
||||
|
|
@ -59,16 +62,90 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
StrInput(
|
||||
name="projection_attributes",
|
||||
display_name="Projection Attributes",
|
||||
info="Attributes to return separated by comma.",
|
||||
info="Attributes to be returned by the tool separated by comma.",
|
||||
required=True,
|
||||
value="*",
|
||||
advanced=True,
|
||||
),
|
||||
TableInput(
|
||||
name="tools_params_v2",
|
||||
display_name="Tools Parameters",
|
||||
info="Define the structure for the tool parameters. Describe the parameters "
|
||||
"in a way the LLM can understand how to use them.",
|
||||
required=False,
|
||||
table_schema=[
|
||||
{
|
||||
"name": "name",
|
||||
"display_name": "Name",
|
||||
"type": "str",
|
||||
"description": "Specify the name of the output field/parameter for the model.",
|
||||
"default": "field",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
},
|
||||
{
|
||||
"name": "attribute_name",
|
||||
"display_name": "Attribute Name",
|
||||
"type": "str",
|
||||
"description": "Specify the attribute name to be filtered on the collection. "
|
||||
"Leave empty if the attribute name is the same as the name of the field.",
|
||||
"default": "",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"display_name": "Description",
|
||||
"type": "str",
|
||||
"description": "Describe the purpose of the output field.",
|
||||
"default": "description of field",
|
||||
"edit_mode": EditMode.POPOVER,
|
||||
},
|
||||
{
|
||||
"name": "metadata",
|
||||
"display_name": "Is Metadata",
|
||||
"type": "boolean",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
"description": ("Indicate if the field is included in the metadata field."),
|
||||
"options": ["True", "False"],
|
||||
"default": "False",
|
||||
},
|
||||
{
|
||||
"name": "mandatory",
|
||||
"display_name": "Is Mandatory",
|
||||
"type": "boolean",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
"description": ("Indicate if the field is mandatory."),
|
||||
"options": ["True", "False"],
|
||||
"default": "False",
|
||||
},
|
||||
{
|
||||
"name": "is_timestamp",
|
||||
"display_name": "Is Timestamp",
|
||||
"type": "boolean",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
"description": ("Indicate if the field is a timestamp."),
|
||||
"options": ["True", "False"],
|
||||
"default": "False",
|
||||
},
|
||||
{
|
||||
"name": "operator",
|
||||
"display_name": "Operator",
|
||||
"type": "str",
|
||||
"description": "Set the operator for the field. "
|
||||
"https://docs.datastax.com/en/astra-db-serverless/api-reference/documents.html#operators",
|
||||
"default": "$eq",
|
||||
"options": ["$gt", "$gte", "$lt", "$lte", "$eq", "$ne", "$in", "$nin", "$exists", "$all", "$size"],
|
||||
"edit_mode": EditMode.INLINE,
|
||||
},
|
||||
],
|
||||
value=[],
|
||||
),
|
||||
DictInput(
|
||||
name="tool_params",
|
||||
info="Attributes to filter and description to the model. Add ! for mandatory (e.g: !customerId)",
|
||||
info="DEPRECATED: Attributes to filter and description to the model. "
|
||||
"Add ! for mandatory (e.g: !customerId)",
|
||||
display_name="Tool params",
|
||||
is_list=True,
|
||||
advanced=True,
|
||||
),
|
||||
DictInput(
|
||||
name="static_filters",
|
||||
|
|
@ -84,6 +161,29 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
advanced=True,
|
||||
value=5,
|
||||
),
|
||||
BoolInput(
|
||||
name="use_search_query",
|
||||
display_name="Semantic Search",
|
||||
info="When this parameter is activated, the search query parameter will be used to search the collection.",
|
||||
advanced=False,
|
||||
value=False,
|
||||
),
|
||||
BoolInput(
|
||||
name="use_vectorize",
|
||||
display_name="Use Astra DB Vectorize",
|
||||
info="When this parameter is activated, Astra DB Vectorize method will be used to generate the embeddings.",
|
||||
advanced=False,
|
||||
value=False,
|
||||
),
|
||||
HandleInput(name="embedding", display_name="Embedding Model", input_types=["Embeddings"]),
|
||||
StrInput(
|
||||
name="semantic_search_instruction",
|
||||
display_name="Semantic Search Instruction",
|
||||
info="The instruction to use for the semantic search.",
|
||||
required=True,
|
||||
value="Search query to find relevant documents.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
_cached_client: DataAPIClient | None = None
|
||||
|
|
@ -94,12 +194,22 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
if self._cached_collection:
|
||||
return self._cached_collection
|
||||
|
||||
cached_client = DataAPIClient(self.token)
|
||||
cached_db = cached_client.get_database(self.api_endpoint, namespace=self.namespace)
|
||||
self._cached_collection = cached_db.get_collection(self.collection_name)
|
||||
return self._cached_collection
|
||||
try:
|
||||
cached_client = DataAPIClient(self.token)
|
||||
cached_db = cached_client.get_database(self.api_endpoint, keyspace=self.keyspace)
|
||||
self._cached_collection = cached_db.get_collection(self.collection_name)
|
||||
except Exception as e:
|
||||
msg = f"Error building collection: {e}"
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
return self._cached_collection
|
||||
|
||||
def create_args_schema(self) -> dict[str, BaseModel]:
|
||||
"""DEPRECATED: This method is deprecated. Please use create_args_schema_v2 instead.
|
||||
|
||||
It is keep only for backward compatibility.
|
||||
"""
|
||||
logger.warning("This is the old way to define the tool parameters. Please use the new way.")
|
||||
args: dict[str, tuple[Any, Field] | list[str]] = {}
|
||||
|
||||
for key in self.tool_params:
|
||||
|
|
@ -108,6 +218,31 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
else: # Optional
|
||||
args[key] = (str | None, Field(description=self.tool_params[key], default=None))
|
||||
|
||||
if self.use_search_query:
|
||||
args["search_query"] = (
|
||||
str | None,
|
||||
Field(description="Search query to find relevant documents.", default=None),
|
||||
)
|
||||
|
||||
model = create_model("ToolInput", **args, __base__=BaseModel)
|
||||
return {"ToolInput": model}
|
||||
|
||||
def create_args_schema_v2(self) -> dict[str, BaseModel]:
|
||||
"""Create the tool input schema using the new tool parameters configuration."""
|
||||
args: dict[str, tuple[Any, Field] | list[str]] = {}
|
||||
|
||||
for tool_param in self.tools_params_v2:
|
||||
if tool_param["mandatory"]:
|
||||
args[tool_param["name"]] = (str, Field(description=tool_param["description"]))
|
||||
else:
|
||||
args[tool_param["name"]] = (str | None, Field(description=tool_param["description"], default=None))
|
||||
|
||||
if self.use_search_query:
|
||||
args["search_query"] = (
|
||||
str,
|
||||
Field(description=self.semantic_search_instruction),
|
||||
)
|
||||
|
||||
model = create_model("ToolInput", **args, __base__=BaseModel)
|
||||
return {"ToolInput": model}
|
||||
|
||||
|
|
@ -117,7 +252,7 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
Returns:
|
||||
Tool: The built Astra DB tool.
|
||||
"""
|
||||
schema_dict = self.create_args_schema()
|
||||
schema_dict = self.create_args_schema() if len(self.tool_params.keys()) > 0 else self.create_args_schema_v2()
|
||||
|
||||
tool = StructuredTool.from_function(
|
||||
name=self.tool_name,
|
||||
|
|
@ -130,10 +265,18 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
|
||||
return tool
|
||||
|
||||
def projection_args(self, input_str: str) -> dict:
|
||||
def projection_args(self, input_str: str) -> dict | None:
|
||||
"""Build the projection arguments for the AstraDB query."""
|
||||
elements = input_str.split(",")
|
||||
result = {}
|
||||
|
||||
if elements == ["*"]:
|
||||
return None
|
||||
|
||||
# Force the projection to exclude the $vector field as it is not required by the tool
|
||||
result["$vector"] = False
|
||||
|
||||
# Fields with ! as prefix should be removed from the projection
|
||||
for element in elements:
|
||||
if element.startswith("!"):
|
||||
result[element[1:]] = False
|
||||
|
|
@ -142,13 +285,127 @@ class AstraDBToolComponent(LCToolComponent):
|
|||
|
||||
return result
|
||||
|
||||
def parse_timestamp(self, timestamp_str: str) -> datetime:
|
||||
"""Parse a timestamp string into Astra DB REST API format.
|
||||
|
||||
Args:
|
||||
timestamp_str (str): Input timestamp string
|
||||
|
||||
Returns:
|
||||
datetime: Datetime object
|
||||
|
||||
Raises:
|
||||
ValueError: If the timestamp cannot be parsed
|
||||
"""
|
||||
# Common datetime formats to try
|
||||
formats = [
|
||||
"%Y-%m-%d", # 2024-03-21
|
||||
"%Y-%m-%dT%H:%M:%S", # 2024-03-21T15:30:00
|
||||
"%Y-%m-%dT%H:%M:%S%z", # 2024-03-21T15:30:00+0000
|
||||
"%Y-%m-%d %H:%M:%S", # 2024-03-21 15:30:00
|
||||
"%d/%m/%Y", # 21/03/2024
|
||||
"%Y/%m/%d", # 2024/03/21
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
# Parse the date string
|
||||
date_obj = datetime.strptime(timestamp_str, fmt).astimezone()
|
||||
|
||||
# If the parsed date has no timezone info, assume UTC
|
||||
if date_obj.tzinfo is None:
|
||||
date_obj = date_obj.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Convert to UTC and format
|
||||
return date_obj.astimezone(timezone.utc)
|
||||
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
msg = f"Could not parse date: {timestamp_str}"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
def build_filter(self, args: dict, filter_settings: list) -> dict:
|
||||
"""Build filter dictionary for AstraDB query.
|
||||
|
||||
Args:
|
||||
args: Dictionary of arguments from the tool
|
||||
filter_settings: List of filter settings from tools_params_v2
|
||||
Returns:
|
||||
Dictionary containing the filter conditions
|
||||
"""
|
||||
filters = {**self.static_filters}
|
||||
|
||||
for key, value in args.items():
|
||||
# Skip search_query as it's handled separately
|
||||
if key == "search_query":
|
||||
continue
|
||||
|
||||
filter_setting = next((x for x in filter_settings if x["name"] == key), None)
|
||||
if filter_setting and value is not None:
|
||||
field_name = filter_setting["attribute_name"] if filter_setting["attribute_name"] else key
|
||||
filter_key = field_name if not filter_setting["metadata"] else f"metadata.{field_name}"
|
||||
if filter_setting["operator"] == "$exists":
|
||||
filters[filter_key] = {**filters.get(filter_key, {}), filter_setting["operator"]: True}
|
||||
elif filter_setting["operator"] in ["$in", "$nin", "$all"]:
|
||||
filters[filter_key] = {
|
||||
**filters.get(filter_key, {}),
|
||||
filter_setting["operator"]: value.split(",") if isinstance(value, str) else value,
|
||||
}
|
||||
elif filter_setting["is_timestamp"] == True: # noqa: E712
|
||||
try:
|
||||
filters[filter_key] = {
|
||||
**filters.get(filter_key, {}),
|
||||
filter_setting["operator"]: self.parse_timestamp(value),
|
||||
}
|
||||
except ValueError as e:
|
||||
msg = f"Error parsing timestamp: {e} - Use the prompt to specify the date in the correct format"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
filters[filter_key] = {**filters.get(filter_key, {}), filter_setting["operator"]: value}
|
||||
return filters
|
||||
|
||||
def run_model(self, **args) -> Data | list[Data]:
|
||||
"""Run the query to get the data from the AstraDB collection."""
|
||||
collection = self._build_collection()
|
||||
results = collection.find(
|
||||
({**args, **self.static_filters}),
|
||||
projection=self.projection_args(self.projection_attributes),
|
||||
limit=self.number_of_results,
|
||||
)
|
||||
sort = {}
|
||||
|
||||
# Build filters using the new method
|
||||
filters = self.build_filter(args, self.tools_params_v2)
|
||||
|
||||
# Build the vector search on
|
||||
if self.use_search_query and args["search_query"] is not None and args["search_query"] != "":
|
||||
if self.use_vectorize:
|
||||
sort["$vectorize"] = args["search_query"]
|
||||
else:
|
||||
if self.embedding is None:
|
||||
msg = "Embedding model is not set. Please set the embedding model or use Astra DB Vectorize."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
embedding_query = self.embedding.embed_query(args["search_query"])
|
||||
sort["$vector"] = embedding_query
|
||||
del args["search_query"]
|
||||
|
||||
find_options = {
|
||||
"filter": filters,
|
||||
"limit": self.number_of_results,
|
||||
"sort": sort,
|
||||
}
|
||||
|
||||
projection = self.projection_args(self.projection_attributes)
|
||||
if projection and len(projection) > 0:
|
||||
find_options["projection"] = projection
|
||||
|
||||
try:
|
||||
results = collection.find(**find_options)
|
||||
except Exception as e:
|
||||
msg = f"Error on Astra DB Tool {self.tool_name} request: {e}"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
logger.info(f"Tool {self.tool_name} executed`")
|
||||
|
||||
data: list[Data] = [Data(data=doc) for doc in results]
|
||||
self.status = data
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
import urllib
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -7,8 +9,10 @@ from langchain.pydantic_v1 import BaseModel, Field, create_model
|
|||
from langchain_core.tools import StructuredTool, Tool
|
||||
|
||||
from langflow.base.langchain_utilities.model import LCToolComponent
|
||||
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput
|
||||
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput, TableInput
|
||||
from langflow.logging import logger
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.table import EditMode
|
||||
|
||||
|
||||
class AstraDBCQLToolComponent(LCToolComponent):
|
||||
|
|
@ -61,18 +65,85 @@ class AstraDBCQLToolComponent(LCToolComponent):
|
|||
value="*",
|
||||
advanced=True,
|
||||
),
|
||||
TableInput(
|
||||
name="tools_params",
|
||||
display_name="Tools Parameters",
|
||||
info="Define the structure for the tool parameters. Describe the parameters "
|
||||
"in a way the LLM can understand how to use them. Add the parameters "
|
||||
"respecting the table schema (Partition Keys, Clustering Keys and Indexed Fields).",
|
||||
required=False,
|
||||
table_schema=[
|
||||
{
|
||||
"name": "name",
|
||||
"display_name": "Name",
|
||||
"type": "str",
|
||||
"description": "Name of the field/parameter to be used by the model.",
|
||||
"default": "field",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
},
|
||||
{
|
||||
"name": "field_name",
|
||||
"display_name": "Field Name",
|
||||
"type": "str",
|
||||
"description": "Specify the column name to be filtered on the table. "
|
||||
"Leave empty if the attribute name is the same as the name of the field.",
|
||||
"default": "",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"display_name": "Description",
|
||||
"type": "str",
|
||||
"description": "Describe the purpose of the parameter.",
|
||||
"default": "description of tool parameter",
|
||||
"edit_mode": EditMode.POPOVER,
|
||||
},
|
||||
{
|
||||
"name": "mandatory",
|
||||
"display_name": "Is Mandatory",
|
||||
"type": "boolean",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
"description": ("Indicate if the field is mandatory."),
|
||||
"options": ["True", "False"],
|
||||
"default": "False",
|
||||
},
|
||||
{
|
||||
"name": "is_timestamp",
|
||||
"display_name": "Is Timestamp",
|
||||
"type": "boolean",
|
||||
"edit_mode": EditMode.INLINE,
|
||||
"description": ("Indicate if the field is a timestamp."),
|
||||
"options": ["True", "False"],
|
||||
"default": "False",
|
||||
},
|
||||
{
|
||||
"name": "operator",
|
||||
"display_name": "Operator",
|
||||
"type": "str",
|
||||
"description": "Set the operator for the field. "
|
||||
"https://docs.datastax.com/en/astra-db-serverless/api-reference/documents.html#operators",
|
||||
"default": "$eq",
|
||||
"options": ["$gt", "$gte", "$lt", "$lte", "$eq", "$ne", "$in", "$nin", "$exists", "$all", "$size"],
|
||||
"edit_mode": EditMode.INLINE,
|
||||
},
|
||||
],
|
||||
value=[],
|
||||
),
|
||||
DictInput(
|
||||
name="partition_keys",
|
||||
display_name="Partition Keys",
|
||||
display_name="DEPRECATED: Partition Keys",
|
||||
is_list=True,
|
||||
info="Field name and description to the model",
|
||||
required=True,
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
DictInput(
|
||||
name="clustering_keys",
|
||||
display_name="Clustering Keys",
|
||||
display_name="DEPRECATED: Clustering Keys",
|
||||
is_list=True,
|
||||
info="Field name and description to the model",
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
DictInput(
|
||||
name="static_filters",
|
||||
|
|
@ -90,22 +161,84 @@ class AstraDBCQLToolComponent(LCToolComponent):
|
|||
),
|
||||
]
|
||||
|
||||
def parse_timestamp(self, timestamp_str: str) -> str:
|
||||
"""Parse a timestamp string into Astra DB REST API format.
|
||||
|
||||
Args:
|
||||
timestamp_str (str): Input timestamp string
|
||||
|
||||
Returns:
|
||||
str: Formatted timestamp string in YYYY-MM-DDTHH:MI:SS.000Z format
|
||||
|
||||
Raises:
|
||||
ValueError: If the timestamp cannot be parsed
|
||||
"""
|
||||
# Common datetime formats to try
|
||||
formats = [
|
||||
"%Y-%m-%d", # 2024-03-21
|
||||
"%Y-%m-%dT%H:%M:%S", # 2024-03-21T15:30:00
|
||||
"%Y-%m-%dT%H:%M:%S%z", # 2024-03-21T15:30:00+0000
|
||||
"%Y-%m-%d %H:%M:%S", # 2024-03-21 15:30:00
|
||||
"%d/%m/%Y", # 21/03/2024
|
||||
"%Y/%m/%d", # 2024/03/21
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
# Parse the date string
|
||||
date_obj = datetime.strptime(timestamp_str, fmt).astimezone()
|
||||
|
||||
# If the parsed date has no timezone info, assume UTC
|
||||
if date_obj.tzinfo is None:
|
||||
date_obj = date_obj.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Convert to UTC and format
|
||||
utc_date = date_obj.astimezone(timezone.utc)
|
||||
return utc_date.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
msg = f"Could not parse date: {timestamp_str}"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
def astra_rest(self, args):
|
||||
headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}"}
|
||||
astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/"
|
||||
key = []
|
||||
where = {}
|
||||
|
||||
# Partition keys are mandatory
|
||||
key = [self.partition_keys[k] for k in self.partition_keys]
|
||||
for param in self.tools_params:
|
||||
field_name = param["field_name"] if param["field_name"] else param["name"]
|
||||
field_value = None
|
||||
|
||||
# Clustering keys are optional
|
||||
for k in self.clustering_keys:
|
||||
if k in args:
|
||||
key.append(args[k])
|
||||
elif self.static_filters[k] is not None:
|
||||
key.append(self.static_filters[k])
|
||||
if field_name in self.static_filters:
|
||||
field_value = self.static_filters[field_name]
|
||||
elif param["name"] in args:
|
||||
field_value = args[param["name"]]
|
||||
|
||||
url = f"{astra_url}{'/'.join(key)}?page-size={self.number_of_results}"
|
||||
if field_value is None:
|
||||
continue
|
||||
|
||||
if param["is_timestamp"] == True: # noqa: E712
|
||||
try:
|
||||
field_value = self.parse_timestamp(field_value)
|
||||
except ValueError as e:
|
||||
msg = f"Error parsing timestamp: {e} - Use the prompt to specify the date in the correct format"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if param["operator"] == "$exists":
|
||||
where[field_name] = {**where.get(field_name, {}), param["operator"]: True}
|
||||
elif param["operator"] in ["$in", "$nin", "$all"]:
|
||||
where[field_name] = {
|
||||
**where.get(field_name, {}),
|
||||
param["operator"]: field_value.split(",") if isinstance(field_value, str) else field_value,
|
||||
}
|
||||
else:
|
||||
where[field_name] = {**where.get(field_name, {}), param["operator"]: field_value}
|
||||
|
||||
url = f"{astra_url}?page-size={self.number_of_results}"
|
||||
url += f"&where={json.dumps(where)}"
|
||||
|
||||
if self.projection_fields != "*":
|
||||
url += f"&fields={urllib.parse.quote(self.projection_fields.replace(' ', ''))}"
|
||||
|
|
@ -113,7 +246,9 @@ class AstraDBCQLToolComponent(LCToolComponent):
|
|||
res = requests.request("GET", url=url, headers=headers, timeout=10)
|
||||
|
||||
if int(res.status_code) >= HTTPStatus.BAD_REQUEST:
|
||||
return res.text
|
||||
msg = f"Error on Astra DB CQL Tool {self.tool_name} request: {res.text}"
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
res_data = res.json()
|
||||
|
|
@ -124,18 +259,13 @@ class AstraDBCQLToolComponent(LCToolComponent):
|
|||
def create_args_schema(self) -> dict[str, BaseModel]:
|
||||
args: dict[str, tuple[Any, Field]] = {}
|
||||
|
||||
for key in self.partition_keys:
|
||||
# Partition keys are mandatory is it doesn't have a static filter
|
||||
if key not in self.static_filters:
|
||||
args[key] = (str, Field(description=self.partition_keys[key]))
|
||||
|
||||
for key in self.clustering_keys:
|
||||
# Partition keys are mandatory if has the exclamation mark and doesn't have a static filter
|
||||
if key not in self.static_filters:
|
||||
if key.startswith("!"): # Mandatory
|
||||
args[key[1:]] = (str, Field(description=self.clustering_keys[key]))
|
||||
else: # Optional
|
||||
args[key] = (str | None, Field(description=self.clustering_keys[key], default=None))
|
||||
for param in self.tools_params:
|
||||
field_name = param["field_name"] if param["field_name"] else param["name"]
|
||||
if field_name not in self.static_filters:
|
||||
if param["mandatory"]:
|
||||
args[param["name"]] = (str, Field(description=param["description"]))
|
||||
else:
|
||||
args[param["name"]] = (str | None, Field(description=param["description"], default=None))
|
||||
|
||||
model = create_model("ToolInput", **args, __base__=BaseModel)
|
||||
return {"ToolInput": model}
|
||||
|
|
@ -172,6 +302,13 @@ class AstraDBCQLToolComponent(LCToolComponent):
|
|||
|
||||
def run_model(self, **args) -> Data | list[Data]:
|
||||
results = self.astra_rest(args)
|
||||
data: list[Data] = [Data(data=doc) for doc in results]
|
||||
data: list[Data] = []
|
||||
|
||||
if isinstance(results, list):
|
||||
data = [Data(data=doc) for doc in results]
|
||||
else:
|
||||
self.status = results
|
||||
return []
|
||||
|
||||
self.status = data
|
||||
return results
|
||||
return data
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue