🐛 fix(vector_store.py): build Chroma settings if any of the chroma_server_ params are present in params
✨ feat(vectorstores.py): add new fields for Chroma vector store configuration
This commit is contained in:
parent
520bbc35b0
commit
35ec2e0867
2 changed files with 127 additions and 51 deletions
|
|
@ -170,6 +170,29 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict):
|
|||
|
||||
def initialize_chroma(class_object: Type[Chroma], params: dict):
|
||||
"""Initialize a ChromaDB object from the params"""
|
||||
# chroma_server_host: str | None = None,
|
||||
# chroma_server_http_port: str | None = None,
|
||||
# chroma_server_ssl_enabled: bool | None = False,
|
||||
# chroma_server_grpc_port: str | None = None,
|
||||
# chroma_server_cors_allow_origins: List[str] = [],
|
||||
# If any of the above params are in params, specially host and port,
|
||||
# we need to build the Chroma settings
|
||||
if ( # type: ignore
|
||||
"chroma_server_host" in params
|
||||
or "chroma_server_http_port" in params
|
||||
or "chroma_server_ssl_enabled" in params
|
||||
or "chroma_server_grpc_port" in params
|
||||
or "chroma_server_cors_allow_origins" in params
|
||||
):
|
||||
import chromadb
|
||||
|
||||
settings_params = {
|
||||
key: params[key]
|
||||
for key, value_ in params.items()
|
||||
if key.startswith("chroma_server_") and value_
|
||||
}
|
||||
chroma_settings = chromadb.config.Settings(**settings_params)
|
||||
params["client_settings"] = chroma_settings
|
||||
persist = params.pop("persist", False)
|
||||
if not docs_in_params(params):
|
||||
params.pop("documents", None)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,52 @@ from langflow.template.field.base import TemplateField
|
|||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
BASIC_FIELDS = [
|
||||
"work_dir",
|
||||
"collection_name",
|
||||
"api_key",
|
||||
"location",
|
||||
"persist_directory",
|
||||
"persist",
|
||||
"weaviate_url",
|
||||
"index_name",
|
||||
"namespace",
|
||||
"folder_path",
|
||||
"table_name",
|
||||
"query_name",
|
||||
"supabase_url",
|
||||
"supabase_service_key",
|
||||
"mongodb_atlas_cluster_uri",
|
||||
"collection_name",
|
||||
"db_name",
|
||||
]
|
||||
ADVANCED_FIELDS = [
|
||||
"n_dim",
|
||||
"key",
|
||||
"prefix",
|
||||
"distance_func",
|
||||
"content_payload_key",
|
||||
"metadata_payload_key",
|
||||
"timeout",
|
||||
"host",
|
||||
"path",
|
||||
"url",
|
||||
"port",
|
||||
"https",
|
||||
"prefer_grpc",
|
||||
"grpc_port",
|
||||
"pinecone_api_key",
|
||||
"pinecone_env",
|
||||
"client_kwargs",
|
||||
"search_kwargs",
|
||||
"chroma_server_host",
|
||||
"chroma_server_http_port",
|
||||
"chroma_server_ssl_enabled",
|
||||
"chroma_server_grpc_port",
|
||||
"chroma_server_cors_allow_origins",
|
||||
]
|
||||
|
||||
|
||||
class VectorStoreFrontendNode(FrontendNode):
|
||||
def add_extra_fields(self) -> None:
|
||||
extra_fields: List[TemplateField] = []
|
||||
|
|
@ -45,16 +91,62 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
|
||||
elif self.template.type_name == "Chroma":
|
||||
# New bool field for persist parameter
|
||||
extra_field = TemplateField(
|
||||
name="persist",
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=False,
|
||||
value=False,
|
||||
display_name="Persist",
|
||||
)
|
||||
extra_fields.append(extra_field)
|
||||
chroma_fields = [
|
||||
TemplateField(
|
||||
name="persist",
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=False,
|
||||
value=False,
|
||||
display_name="Persist",
|
||||
),
|
||||
# chroma_server_grpc_port: str | None = None,
|
||||
TemplateField(
|
||||
name="chroma_server_host",
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=True,
|
||||
display_name="Chroma Server Host",
|
||||
),
|
||||
TemplateField(
|
||||
name="chroma_server_http_port",
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=True,
|
||||
display_name="Chroma Server HTTP Port",
|
||||
),
|
||||
TemplateField(
|
||||
name="chroma_server_ssl_enabled",
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=True,
|
||||
value=False,
|
||||
display_name="Chroma Server SSL Enabled",
|
||||
),
|
||||
TemplateField(
|
||||
name="chroma_server_grpc_port",
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=True,
|
||||
display_name="Chroma Server GRPC Port",
|
||||
),
|
||||
TemplateField(
|
||||
name="chroma_server_cors_allow_origins",
|
||||
field_type="str",
|
||||
required=False,
|
||||
is_list=True,
|
||||
show=True,
|
||||
advanced=True,
|
||||
display_name="Chroma Server CORS Allow Origins",
|
||||
),
|
||||
]
|
||||
|
||||
extra_fields.extend(chroma_fields)
|
||||
elif self.template.type_name == "Pinecone":
|
||||
# add pinecone_api_key and pinecone_env
|
||||
extra_field = TemplateField(
|
||||
|
|
@ -208,45 +300,6 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
# Define common field attributes
|
||||
basic_fields = [
|
||||
"work_dir",
|
||||
"collection_name",
|
||||
"api_key",
|
||||
"location",
|
||||
"persist_directory",
|
||||
"persist",
|
||||
"weaviate_url",
|
||||
"index_name",
|
||||
"namespace",
|
||||
"folder_path",
|
||||
"table_name",
|
||||
"query_name",
|
||||
"supabase_url",
|
||||
"supabase_service_key",
|
||||
"mongodb_atlas_cluster_uri",
|
||||
"collection_name",
|
||||
"db_name",
|
||||
]
|
||||
advanced_fields = [
|
||||
"n_dim",
|
||||
"key",
|
||||
"prefix",
|
||||
"distance_func",
|
||||
"content_payload_key",
|
||||
"metadata_payload_key",
|
||||
"timeout",
|
||||
"host",
|
||||
"path",
|
||||
"url",
|
||||
"port",
|
||||
"https",
|
||||
"prefer_grpc",
|
||||
"grpc_port",
|
||||
"pinecone_api_key",
|
||||
"pinecone_env",
|
||||
"client_kwargs",
|
||||
"search_kwargs",
|
||||
]
|
||||
|
||||
# Check and set field attributes
|
||||
if field.name == "texts":
|
||||
|
|
@ -269,7 +322,7 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
field.display_name = "Embedding"
|
||||
field.field_type = "Embeddings"
|
||||
|
||||
elif field.name in basic_fields:
|
||||
elif field.name in BASIC_FIELDS:
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if field.name == "api_key":
|
||||
|
|
@ -279,7 +332,7 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
field.value = ":memory:"
|
||||
field.placeholder = ":memory:"
|
||||
|
||||
elif field.name in advanced_fields:
|
||||
elif field.name in ADVANCED_FIELDS:
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
if "key" in field.name:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue