Track caller versions in Astra DB, fix dynamic inputs (#5016)

This commit is contained in:
Eric Hare 2024-12-03 06:16:09 -08:00 committed by GitHub
commit 0dc37bb98e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 65 deletions

View file

@ -21,6 +21,7 @@ from langflow.io import (
StrInput,
)
from langflow.schema import Data
from langflow.utils.version import get_version_info
class AstraDBVectorStoreComponent(LCVectorStoreComponent):
@ -98,6 +99,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
value="ASTRA_DB_APPLICATION_TOKEN",
required=True,
advanced=os.getenv("ASTRA_ENHANCED", "false").lower() == "true",
real_time_refresh=True,
),
SecretStrInput(
name="api_endpoint",
@ -105,14 +107,15 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
info="API endpoint URL for the Astra DB service.",
value="ASTRA_DB_API_ENDPOINT",
required=True,
real_time_refresh=True,
),
DropdownInput(
name="collection_name",
display_name="Collection",
info="The name of the collection within Astra DB where the vectors will be stored.",
required=True,
real_time_refresh=True,
refresh_button=True,
real_time_refresh=True,
options=["+ Create new collection"],
value="+ Create new collection",
),
@ -318,6 +321,13 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
return self.VECTORIZE_PROVIDERS_MAPPING
def get_collection_choice(self):
collection_name = self.collection_name
if collection_name == "+ Create new collection":
return self.collection_name_new
return collection_name
def get_collection_options(self):
client = DataAPIClient(token=self.token)
@ -326,7 +336,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
token=self.token,
)
collection = database.get_collection(self.collection_name)
collection = database.get_collection(self.get_collection_choice())
# Only get the options if the collection exists
try:
@ -342,23 +352,19 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Refresh the collection name options
build_config["collection_name"]["options"] = self._initialize_collection_options()
# If the collection name is set to "+ Create new collection", show the advanced options
# If the collection name is set to "+ Create new collection", show embedding choice
if field_name == "collection_name" and field_value == "+ Create new collection":
build_config["embedding_choice"]["advanced"] = False
build_config["embedding_choice"]["value"] = "Embedding Model"
build_config["embedding_model"]["advanced"] = False
build_config["collection_name_new"]["advanced"] = False
build_config["collection_name_new"]["required"] = True
new_parameter = HandleInput(
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()
self.insert_in_dict(build_config, "collection_name_new", {"embedding_model": new_parameter})
# But if it's not, hide embedding choice
elif field_name == "collection_name" and field_value != "+ Create new collection":
build_config["embedding_choice"]["advanced"] = True
build_config["collection_name_new"]["advanced"] = True
build_config["collection_name_new"]["required"] = False
build_config["collection_name_new"]["value"] = ""
@ -366,61 +372,33 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Get the collection options
collection_options = self.get_collection_options()
# If the collection options are available, show the advanced options
# If the collection options are available (DB exists), show the advanced options
if collection_options:
build_config["embedding_choice"]["advanced"] = True
if collection_options.service:
for input_field in [
"embedding_provider",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
build_config[input_field]["advanced"] = False
self.del_fields(
build_config,
[
"embedding_provider",
"model",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
],
)
build_config["embedding_model"]["advanced"] = True
build_config["embedding_provider"]["advanced"] = True
build_config["embedding_choice"]["value"] = "Astra Vectorize"
build_config["embedding_provider"]["value"] = collection_options.service.provider
build_config["model"]["value"] = collection_options.service.model_name
build_config["z_01_model_parameters"]["value"] = collection_options.service.parameters
if collection_options.service.authentication:
build_config["z_02_api_key_name"]["value"] = collection_options.service.authentication.get(
"providerKey"
)
build_config["z_03_provider_api_key"]["value"] = collection_options.service.authentication.get(
"apiKey"
)
build_config["z_04_authentication"]["value"] = collection_options.service.authentication
else:
for input_field in [
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
build_config[input_field]["advanced"] = True
build_config["embedding_model"]["advanced"] = False
build_config["embedding_provider"]["advanced"] = False
build_config["embedding_choice"]["value"] = "Embedding Model"
new_parameter = HandleInput(
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()
self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter})
elif field_name == "embedding_choice":
if field_value == "Astra Vectorize":
self.del_fields(build_config, ["embedding_model"])
build_config["embedding_model"]["advanced"] = True
# Update the providers mapping
vectorize_providers = self.update_providers_mapping()
@ -436,6 +414,8 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
self.insert_in_dict(build_config, "embedding_choice", {"embedding_provider": new_parameter})
else:
build_config["embedding_model"]["advanced"] = False
self.del_fields(
build_config,
[
@ -448,15 +428,6 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
],
)
new_parameter = HandleInput(
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()
self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter})
elif field_name == "embedding_provider":
self.del_fields(
build_config,
@ -615,7 +586,9 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
# Ensure collection_options and its nested attributes are handled safely
authentication = getattr(self, "z_04_authentication", {}) or (
collection_options.service.authentication if collection_options and collection_options.service else {}
collection_options.service.authentication
if collection_options and collection_options.service and collection_options.service.authentication
else {}
)
# Build the vectorize options dictionary
@ -663,12 +636,18 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
}
# Get Langflow version and platform information
__version__ = get_version_info()["version"]
langflow_prefix = ""
if os.getenv("ASTRA_ENHANCED", "false").lower() == "true":
langflow_prefix = "ds-"
try:
vector_store = AstraDBVectorStore(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.keyspace or None,
collection_name=getattr(self, "collection_name_new", None) or self.collection_name,
collection_name=self.get_collection_choice(),
autodetect_collection=autodetect,
environment=(
parse_api_endpoint(getattr(self, "api_endpoint", None)).environment
@ -687,6 +666,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
collection_indexing_policy=orjson.dumps(self.collection_indexing_policy)
if self.collection_indexing_policy
else None,
ext_callers=[(f"{langflow_prefix}langflow", __version__)],
**embedding_dict,
)
except Exception as e:

File diff suppressed because one or more lines are too long