fix: allow dynamic updating when hosted anywhere (#5999)
* fix: allow dynamic updating when hosted anywhere * [autofix.ci] apply automated fixes * Continue speed improvements * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Fix autodetect with new collection * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
30c2fc159f
commit
c007a5fffc
2 changed files with 164 additions and 56 deletions
|
|
@ -132,6 +132,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
real_time_refresh=True,
|
||||
combobox=True,
|
||||
),
|
||||
StrInput(
|
||||
name="d_api_endpoint",
|
||||
display_name="Database API Endpoint",
|
||||
info="The API Endpoint for the Astra DB instance. Supercedes database selection.",
|
||||
advanced=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="collection_name",
|
||||
display_name="Collection",
|
||||
|
|
@ -194,6 +200,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
info="Optional dictionary of filters to apply to the search query.",
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="autodetect_collection",
|
||||
display_name="Autodetect Collection",
|
||||
info="Boolean flag to determine whether to autodetect the collection.",
|
||||
advanced=True,
|
||||
value=True,
|
||||
),
|
||||
StrInput(
|
||||
name="content_field",
|
||||
display_name="Content Field",
|
||||
|
|
@ -330,8 +343,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
cls,
|
||||
token: str,
|
||||
environment: str | None = None,
|
||||
api_endpoint: str | None = None,
|
||||
database_name: str | None = None,
|
||||
):
|
||||
# If the api_endpoint is set, return it
|
||||
if api_endpoint:
|
||||
return api_endpoint
|
||||
|
||||
# Check if the database_name is like a url
|
||||
if database_name and database_name.startswith("https://"):
|
||||
return database_name
|
||||
|
|
@ -343,10 +361,11 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
# Otherwise, get the URL from the database list
|
||||
return cls.get_database_list_static(token=token, environment=environment).get(database_name).get("api_endpoint")
|
||||
|
||||
def get_api_endpoint(self):
|
||||
def get_api_endpoint(self, *, api_endpoint: str | None = None):
|
||||
return self.get_api_endpoint_static(
|
||||
token=self.token,
|
||||
environment=self.environment,
|
||||
api_endpoint=api_endpoint or self.d_api_endpoint,
|
||||
database_name=self.api_endpoint,
|
||||
)
|
||||
|
||||
|
|
@ -358,12 +377,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
return None
|
||||
|
||||
def get_database_object(self):
|
||||
def get_database_object(self, api_endpoint: str | None = None):
|
||||
try:
|
||||
client = DataAPIClient(token=self.token, environment=self.environment)
|
||||
|
||||
return client.get_database(
|
||||
api_endpoint=self.get_api_endpoint(),
|
||||
api_endpoint=self.get_api_endpoint(api_endpoint=api_endpoint),
|
||||
token=self.token,
|
||||
keyspace=self.get_keyspace(),
|
||||
)
|
||||
|
|
@ -372,21 +391,6 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
return None
|
||||
|
||||
def collection_exists(self):
|
||||
try:
|
||||
client = DataAPIClient(token=self.token, environment=self.environment)
|
||||
database = client.get_database(
|
||||
api_endpoint=self.get_api_endpoint(),
|
||||
token=self.token,
|
||||
keyspace=self.get_keyspace(),
|
||||
)
|
||||
|
||||
return self.collection_name in list(database.list_collection_names(keyspace=self.get_keyspace()))
|
||||
except Exception as e: # noqa: BLE001
|
||||
self.log(f"Error getting collection status: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def collection_data(self, collection_name: str, database: Database | None = None):
|
||||
try:
|
||||
if not database:
|
||||
|
|
@ -436,15 +440,20 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
def _initialize_database_options(self):
|
||||
try:
|
||||
return [
|
||||
{"name": name, "collections": info["collections"]} for name, info in self.get_database_list().items()
|
||||
{
|
||||
"name": name,
|
||||
"collections": info["collections"],
|
||||
"api_endpoint": info["api_endpoint"],
|
||||
}
|
||||
for name, info in self.get_database_list().items()
|
||||
]
|
||||
except Exception as e: # noqa: BLE001
|
||||
self.log(f"Error fetching databases: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def _initialize_collection_options(self):
|
||||
database = self.get_database_object()
|
||||
def _initialize_collection_options(self, api_endpoint: str | None = None):
|
||||
database = self.get_database_object(api_endpoint=api_endpoint)
|
||||
if database is None:
|
||||
return []
|
||||
|
||||
|
|
@ -475,13 +484,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
return []
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
# Define variables for common database conditions a user may experience
|
||||
is_hosted = os.getenv("LANGFLOW_HOST") is not None
|
||||
no_databases = "options" not in build_config["api_endpoint"] or not build_config["api_endpoint"]["options"]
|
||||
no_api_endpoint = not build_config["api_endpoint"]["value"]
|
||||
# TODO: Remove special astra flags when overlays are out
|
||||
# TODO: Better targeting of this field
|
||||
dslf = os.getenv("AWS_EXECUTION_ENV") == "AWS_ECS_FARGATE"
|
||||
|
||||
# Refresh the database name options
|
||||
if not is_hosted and (field_name in ["token", "environment"] or (no_databases and no_api_endpoint)):
|
||||
if not dslf and (field_name in ["token", "environment"] or not build_config["api_endpoint"]["options"]):
|
||||
# Get the list of options we have based on the token provided
|
||||
database_options = self._initialize_database_options()
|
||||
|
||||
|
|
@ -491,26 +499,14 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
build_config["collection_name"]["value"] = ""
|
||||
|
||||
# Scenario #1: We have database options from the provided token
|
||||
if database_options:
|
||||
# Reset the selected database
|
||||
build_config["api_endpoint"]["name"] = "Database"
|
||||
build_config["api_endpoint"]["display_name"] = "Database"
|
||||
build_config["api_endpoint"]["value"] = ""
|
||||
build_config["api_endpoint"]["name"] = "Database"
|
||||
|
||||
# If we retrieved options based on the token, show the dropdown
|
||||
build_config["api_endpoint"]["options"] = [db["name"] for db in database_options]
|
||||
build_config["api_endpoint"]["options_metadata"] = [
|
||||
{k: v for k, v in db.items() if k not in ["name"]} for db in database_options
|
||||
]
|
||||
# Scenario #2: We have no options from the provided token
|
||||
else:
|
||||
# Fallback to an API Endpoint if we couldn't retrieve options
|
||||
build_config["api_endpoint"]["value"] = ""
|
||||
build_config["api_endpoint"]["name"] = "API Endpoint"
|
||||
build_config["api_endpoint"]["display_name"] = "Astra DB API Endpoint"
|
||||
|
||||
# If we didn't retrieve options based on the token, show the text input
|
||||
if "options" in build_config["api_endpoint"]:
|
||||
del build_config["api_endpoint"]["options"]
|
||||
# If we retrieved options based on the token, show the dropdown
|
||||
build_config["api_endpoint"]["options"] = [db["name"] for db in database_options]
|
||||
build_config["api_endpoint"]["options_metadata"] = [
|
||||
{k: v for k, v in db.items() if k not in ["name"]} for db in database_options
|
||||
]
|
||||
|
||||
# Get list of regions for a given cloud provider
|
||||
"""
|
||||
|
|
@ -530,8 +526,21 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
# Reset the selected collection
|
||||
build_config["collection_name"]["value"] = ""
|
||||
|
||||
# Set the underlying api endpoint value of the database
|
||||
if field_value in build_config["api_endpoint"]["options"]:
|
||||
index_of_name = build_config["api_endpoint"]["options"].index(field_value)
|
||||
build_config["d_api_endpoint"]["value"] = build_config["api_endpoint"]["options_metadata"][
|
||||
index_of_name
|
||||
]["api_endpoint"]
|
||||
else:
|
||||
build_config["d_api_endpoint"]["value"] = ""
|
||||
|
||||
# Reload the list of collections and metadata associated
|
||||
collection_options = self._initialize_collection_options()
|
||||
collection_options = self._initialize_collection_options(
|
||||
api_endpoint=build_config["d_api_endpoint"]["value"]
|
||||
)
|
||||
|
||||
# If we have collections, show the dropdown
|
||||
build_config["collection_name"]["options"] = [col["name"] for col in collection_options]
|
||||
build_config["collection_name"]["options_metadata"] = [
|
||||
{k: v for k, v in col.items() if k not in ["name"]} for col in collection_options
|
||||
|
|
@ -540,14 +549,32 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
# Hide embedding model option if opriona_metadata provider is not null
|
||||
if field_name == "collection_name" and field_value:
|
||||
# Set the options for collection name to be the field value if its a new collection
|
||||
if not is_hosted and field_value not in build_config["collection_name"]["options"]:
|
||||
if not dslf and field_value not in build_config["collection_name"]["options"]:
|
||||
# Add the new collection to the list of options
|
||||
build_config["collection_name"]["options"].append(field_value)
|
||||
build_config["collection_name"]["options_metadata"].append(
|
||||
{"records": 0, "provider": None, "icon": "", "model": None}
|
||||
)
|
||||
|
||||
# Ensure that autodetect collection is set to False
|
||||
build_config["autodetect_collection"]["value"] = False
|
||||
else:
|
||||
build_config["autodetect_collection"]["value"] = True
|
||||
|
||||
# Find location of the name in the options list
|
||||
index_of_name = build_config["collection_name"]["options"].index(field_value)
|
||||
|
||||
# Return if not found
|
||||
if index_of_name == -1:
|
||||
return build_config
|
||||
|
||||
# Check if the number of records is 0
|
||||
if build_config["collection_name"]["options_metadata"][index_of_name]["records"] == 0:
|
||||
build_config["autodetect_collection"]["value"] = False
|
||||
else:
|
||||
build_config["autodetect_collection"]["value"] = True
|
||||
|
||||
# Get the provider value of the selected collection
|
||||
value_of_provider = build_config["collection_name"]["options_metadata"][index_of_name]["provider"]
|
||||
|
||||
# If we were able to determine the Vectorize provider, set it accordingly
|
||||
|
|
@ -608,23 +635,29 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
else {}
|
||||
)
|
||||
|
||||
# Get the additional parameters
|
||||
additional_params = self.astradb_vectorstore_kwargs or {}
|
||||
|
||||
# Get Langflow version and platform information
|
||||
__version__ = get_version_info()["version"]
|
||||
langflow_prefix = ""
|
||||
if os.getenv("LANGFLOW_HOST") is not None:
|
||||
if os.getenv("AWS_EXECUTION_ENV") == "AWS_ECS_FARGATE": # TODO: More precise way of detecting
|
||||
langflow_prefix = "ds-"
|
||||
|
||||
# Get the database object
|
||||
database = self.get_database_object(api_endpoint=self.d_api_endpoint)
|
||||
autodetect = self.collection_name in database.list_collection_names() and self.autodetect_collection
|
||||
|
||||
# Bundle up the auto-detect parameters
|
||||
autodetect_params = {
|
||||
"autodetect_collection": self.collection_exists(), # TODO: May want to expose this option
|
||||
"autodetect_collection": autodetect,
|
||||
"content_field": (
|
||||
self.content_field
|
||||
if self.content_field and embedding_params
|
||||
else (
|
||||
"page_content"
|
||||
if embedding_params and self.collection_data(collection_name=self.collection_name) == 0
|
||||
if embedding_params
|
||||
and self.collection_data(collection_name=self.collection_name, database=database) == 0
|
||||
else None
|
||||
)
|
||||
),
|
||||
|
|
@ -636,8 +669,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
vector_store = AstraDBVectorStore(
|
||||
# Astra DB Authentication Parameters
|
||||
token=self.token,
|
||||
api_endpoint=self.get_api_endpoint(),
|
||||
namespace=self.get_keyspace(),
|
||||
api_endpoint=database.api_endpoint,
|
||||
namespace=database.keyspace,
|
||||
collection_name=self.collection_name,
|
||||
environment=self.environment,
|
||||
# Astra DB Usage Tracking Parameters
|
||||
|
|
@ -651,6 +684,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
msg = f"Error initializing AstraDBVectorStore: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# Add documents to the vector store
|
||||
self._add_documents_to_vector_store(vector_store)
|
||||
|
||||
return vector_store
|
||||
|
|
@ -667,8 +701,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
if documents and self.deletion_field:
|
||||
self.log(f"Deleting documents where {self.deletion_field}")
|
||||
try:
|
||||
database = self.get_database_object()
|
||||
collection = database.get_collection(self.collection_name, keyspace=self.get_keyspace())
|
||||
database = self.get_database_object(api_endpoint=self.d_api_endpoint)
|
||||
collection = database.get_collection(self.collection_name, keyspace=database.keyspace)
|
||||
delete_values = list({doc.metadata[self.deletion_field] for doc in documents})
|
||||
self.log(f"Deleting documents where {self.deletion_field} matches {delete_values}.")
|
||||
collection.delete_many({f"metadata.{self.deletion_field}": {"$in": delete_values}})
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue