fix: allow dynamic updating when hosted anywhere (#5999)

* fix: allow dynamic updating when hosted anywhere

* [autofix.ci] apply automated fixes

* Continue speed improvements

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Fix autodetect with new collection

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Eric Hare 2025-01-30 04:38:56 -08:00 committed by GitHub
commit c007a5fffc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 164 additions and 56 deletions

View file

@ -132,6 +132,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
real_time_refresh=True,
combobox=True,
),
StrInput(
name="d_api_endpoint",
display_name="Database API Endpoint",
info="The API Endpoint for the Astra DB instance. Supercedes database selection.",
advanced=True,
),
DropdownInput(
name="collection_name",
display_name="Collection",
@ -194,6 +200,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
),
BoolInput(
name="autodetect_collection",
display_name="Autodetect Collection",
info="Boolean flag to determine whether to autodetect the collection.",
advanced=True,
value=True,
),
StrInput(
name="content_field",
display_name="Content Field",
@ -330,8 +343,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
cls,
token: str,
environment: str | None = None,
api_endpoint: str | None = None,
database_name: str | None = None,
):
# If the api_endpoint is set, return it
if api_endpoint:
return api_endpoint
# Check if the database_name is like a url
if database_name and database_name.startswith("https://"):
return database_name
@ -343,10 +361,11 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Otherwise, get the URL from the database list
return cls.get_database_list_static(token=token, environment=environment).get(database_name).get("api_endpoint")
def get_api_endpoint(self):
def get_api_endpoint(self, *, api_endpoint: str | None = None):
return self.get_api_endpoint_static(
token=self.token,
environment=self.environment,
api_endpoint=api_endpoint or self.d_api_endpoint,
database_name=self.api_endpoint,
)
@ -358,12 +377,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
return None
def get_database_object(self):
def get_database_object(self, api_endpoint: str | None = None):
try:
client = DataAPIClient(token=self.token, environment=self.environment)
return client.get_database(
api_endpoint=self.get_api_endpoint(),
api_endpoint=self.get_api_endpoint(api_endpoint=api_endpoint),
token=self.token,
keyspace=self.get_keyspace(),
)
@ -372,21 +391,6 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
return None
def collection_exists(self):
try:
client = DataAPIClient(token=self.token, environment=self.environment)
database = client.get_database(
api_endpoint=self.get_api_endpoint(),
token=self.token,
keyspace=self.get_keyspace(),
)
return self.collection_name in list(database.list_collection_names(keyspace=self.get_keyspace()))
except Exception as e: # noqa: BLE001
self.log(f"Error getting collection status: {e}")
return False
def collection_data(self, collection_name: str, database: Database | None = None):
try:
if not database:
@ -436,15 +440,20 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
def _initialize_database_options(self):
try:
return [
{"name": name, "collections": info["collections"]} for name, info in self.get_database_list().items()
{
"name": name,
"collections": info["collections"],
"api_endpoint": info["api_endpoint"],
}
for name, info in self.get_database_list().items()
]
except Exception as e: # noqa: BLE001
self.log(f"Error fetching databases: {e}")
return []
def _initialize_collection_options(self):
database = self.get_database_object()
def _initialize_collection_options(self, api_endpoint: str | None = None):
database = self.get_database_object(api_endpoint=api_endpoint)
if database is None:
return []
@ -475,13 +484,12 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
return []
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
# Define variables for common database conditions a user may experience
is_hosted = os.getenv("LANGFLOW_HOST") is not None
no_databases = "options" not in build_config["api_endpoint"] or not build_config["api_endpoint"]["options"]
no_api_endpoint = not build_config["api_endpoint"]["value"]
# TODO: Remove special astra flags when overlays are out
# TODO: Better targeting of this field
dslf = os.getenv("AWS_EXECUTION_ENV") == "AWS_ECS_FARGATE"
# Refresh the database name options
if not is_hosted and (field_name in ["token", "environment"] or (no_databases and no_api_endpoint)):
if not dslf and (field_name in ["token", "environment"] or not build_config["api_endpoint"]["options"]):
# Get the list of options we have based on the token provided
database_options = self._initialize_database_options()
@ -491,26 +499,14 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
build_config["collection_name"]["value"] = ""
# Scenario #1: We have database options from the provided token
if database_options:
# Reset the selected database
build_config["api_endpoint"]["name"] = "Database"
build_config["api_endpoint"]["display_name"] = "Database"
build_config["api_endpoint"]["value"] = ""
build_config["api_endpoint"]["name"] = "Database"
# If we retrieved options based on the token, show the dropdown
build_config["api_endpoint"]["options"] = [db["name"] for db in database_options]
build_config["api_endpoint"]["options_metadata"] = [
{k: v for k, v in db.items() if k not in ["name"]} for db in database_options
]
# Scenario #2: We have no options from the provided token
else:
# Fallback to an API Endpoint if we couldn't retrieve options
build_config["api_endpoint"]["value"] = ""
build_config["api_endpoint"]["name"] = "API Endpoint"
build_config["api_endpoint"]["display_name"] = "Astra DB API Endpoint"
# If we didn't retrieve options based on the token, show the text input
if "options" in build_config["api_endpoint"]:
del build_config["api_endpoint"]["options"]
# If we retrieved options based on the token, show the dropdown
build_config["api_endpoint"]["options"] = [db["name"] for db in database_options]
build_config["api_endpoint"]["options_metadata"] = [
{k: v for k, v in db.items() if k not in ["name"]} for db in database_options
]
# Get list of regions for a given cloud provider
"""
@ -530,8 +526,21 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Reset the selected collection
build_config["collection_name"]["value"] = ""
# Set the underlying api endpoint value of the database
if field_value in build_config["api_endpoint"]["options"]:
index_of_name = build_config["api_endpoint"]["options"].index(field_value)
build_config["d_api_endpoint"]["value"] = build_config["api_endpoint"]["options_metadata"][
index_of_name
]["api_endpoint"]
else:
build_config["d_api_endpoint"]["value"] = ""
# Reload the list of collections and metadata associated
collection_options = self._initialize_collection_options()
collection_options = self._initialize_collection_options(
api_endpoint=build_config["d_api_endpoint"]["value"]
)
# If we have collections, show the dropdown
build_config["collection_name"]["options"] = [col["name"] for col in collection_options]
build_config["collection_name"]["options_metadata"] = [
{k: v for k, v in col.items() if k not in ["name"]} for col in collection_options
@ -540,14 +549,32 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Hide embedding model option if opriona_metadata provider is not null
if field_name == "collection_name" and field_value:
# Set the options for collection name to be the field value if its a new collection
if not is_hosted and field_value not in build_config["collection_name"]["options"]:
if not dslf and field_value not in build_config["collection_name"]["options"]:
# Add the new collection to the list of options
build_config["collection_name"]["options"].append(field_value)
build_config["collection_name"]["options_metadata"].append(
{"records": 0, "provider": None, "icon": "", "model": None}
)
# Ensure that autodetect collection is set to False
build_config["autodetect_collection"]["value"] = False
else:
build_config["autodetect_collection"]["value"] = True
# Find location of the name in the options list
index_of_name = build_config["collection_name"]["options"].index(field_value)
# Return if not found
if index_of_name == -1:
return build_config
# Check if the number of records is 0
if build_config["collection_name"]["options_metadata"][index_of_name]["records"] == 0:
build_config["autodetect_collection"]["value"] = False
else:
build_config["autodetect_collection"]["value"] = True
# Get the provider value of the selected collection
value_of_provider = build_config["collection_name"]["options_metadata"][index_of_name]["provider"]
# If we were able to determine the Vectorize provider, set it accordingly
@ -608,23 +635,29 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
else {}
)
# Get the additional parameters
additional_params = self.astradb_vectorstore_kwargs or {}
# Get Langflow version and platform information
__version__ = get_version_info()["version"]
langflow_prefix = ""
if os.getenv("LANGFLOW_HOST") is not None:
if os.getenv("AWS_EXECUTION_ENV") == "AWS_ECS_FARGATE": # TODO: More precise way of detecting
langflow_prefix = "ds-"
# Get the database object
database = self.get_database_object(api_endpoint=self.d_api_endpoint)
autodetect = self.collection_name in database.list_collection_names() and self.autodetect_collection
# Bundle up the auto-detect parameters
autodetect_params = {
"autodetect_collection": self.collection_exists(), # TODO: May want to expose this option
"autodetect_collection": autodetect,
"content_field": (
self.content_field
if self.content_field and embedding_params
else (
"page_content"
if embedding_params and self.collection_data(collection_name=self.collection_name) == 0
if embedding_params
and self.collection_data(collection_name=self.collection_name, database=database) == 0
else None
)
),
@ -636,8 +669,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
vector_store = AstraDBVectorStore(
# Astra DB Authentication Parameters
token=self.token,
api_endpoint=self.get_api_endpoint(),
namespace=self.get_keyspace(),
api_endpoint=database.api_endpoint,
namespace=database.keyspace,
collection_name=self.collection_name,
environment=self.environment,
# Astra DB Usage Tracking Parameters
@ -651,6 +684,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
msg = f"Error initializing AstraDBVectorStore: {e}"
raise ValueError(msg) from e
# Add documents to the vector store
self._add_documents_to_vector_store(vector_store)
return vector_store
@ -667,8 +701,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
if documents and self.deletion_field:
self.log(f"Deleting documents where {self.deletion_field}")
try:
database = self.get_database_object()
collection = database.get_collection(self.collection_name, keyspace=self.get_keyspace())
database = self.get_database_object(api_endpoint=self.d_api_endpoint)
collection = database.get_collection(self.collection_name, keyspace=database.keyspace)
delete_values = list({doc.metadata[self.deletion_field] for doc in documents})
self.log(f"Deleting documents where {self.deletion_field} matches {delete_values}.")
collection.delete_many({f"metadata.{self.deletion_field}": {"$in": delete_values}})

File diff suppressed because one or more lines are too long