🐛 fix(vector_store.py): build Chroma settings if any of the chroma_server_ params are present in params

 feat(vectorstores.py): add new fields for Chroma vector store configuration
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-04 16:56:36 -03:00
commit 35ec2e0867
2 changed files with 127 additions and 51 deletions

View file

@ -170,6 +170,29 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict):
def initialize_chroma(class_object: Type[Chroma], params: dict):
"""Initialize a ChromaDB object from the params"""
# chroma_server_host: str | None = None,
# chroma_server_http_port: str | None = None,
# chroma_server_ssl_enabled: bool | None = False,
# chroma_server_grpc_port: str | None = None,
# chroma_server_cors_allow_origins: List[str] = [],
# If any of the above params are in params, specially host and port,
# we need to build the Chroma settings
if ( # type: ignore
"chroma_server_host" in params
or "chroma_server_http_port" in params
or "chroma_server_ssl_enabled" in params
or "chroma_server_grpc_port" in params
or "chroma_server_cors_allow_origins" in params
):
import chromadb
settings_params = {
key: params[key]
for key, value_ in params.items()
if key.startswith("chroma_server_") and value_
}
chroma_settings = chromadb.config.Settings(**settings_params)
params["client_settings"] = chroma_settings
persist = params.pop("persist", False)
if not docs_in_params(params):
params.pop("documents", None)

View file

@ -4,6 +4,52 @@ from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
BASIC_FIELDS = [
"work_dir",
"collection_name",
"api_key",
"location",
"persist_directory",
"persist",
"weaviate_url",
"index_name",
"namespace",
"folder_path",
"table_name",
"query_name",
"supabase_url",
"supabase_service_key",
"mongodb_atlas_cluster_uri",
"collection_name",
"db_name",
]
ADVANCED_FIELDS = [
"n_dim",
"key",
"prefix",
"distance_func",
"content_payload_key",
"metadata_payload_key",
"timeout",
"host",
"path",
"url",
"port",
"https",
"prefer_grpc",
"grpc_port",
"pinecone_api_key",
"pinecone_env",
"client_kwargs",
"search_kwargs",
"chroma_server_host",
"chroma_server_http_port",
"chroma_server_ssl_enabled",
"chroma_server_grpc_port",
"chroma_server_cors_allow_origins",
]
class VectorStoreFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
extra_fields: List[TemplateField] = []
@ -45,16 +91,62 @@ class VectorStoreFrontendNode(FrontendNode):
elif self.template.type_name == "Chroma":
# New bool field for persist parameter
extra_field = TemplateField(
name="persist",
field_type="bool",
required=False,
show=True,
advanced=False,
value=False,
display_name="Persist",
)
extra_fields.append(extra_field)
chroma_fields = [
TemplateField(
name="persist",
field_type="bool",
required=False,
show=True,
advanced=False,
value=False,
display_name="Persist",
),
# chroma_server_grpc_port: str | None = None,
TemplateField(
name="chroma_server_host",
field_type="str",
required=False,
show=True,
advanced=True,
display_name="Chroma Server Host",
),
TemplateField(
name="chroma_server_http_port",
field_type="str",
required=False,
show=True,
advanced=True,
display_name="Chroma Server HTTP Port",
),
TemplateField(
name="chroma_server_ssl_enabled",
field_type="bool",
required=False,
show=True,
advanced=True,
value=False,
display_name="Chroma Server SSL Enabled",
),
TemplateField(
name="chroma_server_grpc_port",
field_type="str",
required=False,
show=True,
advanced=True,
display_name="Chroma Server GRPC Port",
),
TemplateField(
name="chroma_server_cors_allow_origins",
field_type="str",
required=False,
is_list=True,
show=True,
advanced=True,
display_name="Chroma Server CORS Allow Origins",
),
]
extra_fields.extend(chroma_fields)
elif self.template.type_name == "Pinecone":
# add pinecone_api_key and pinecone_env
extra_field = TemplateField(
@ -208,45 +300,6 @@ class VectorStoreFrontendNode(FrontendNode):
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
# Define common field attributes
basic_fields = [
"work_dir",
"collection_name",
"api_key",
"location",
"persist_directory",
"persist",
"weaviate_url",
"index_name",
"namespace",
"folder_path",
"table_name",
"query_name",
"supabase_url",
"supabase_service_key",
"mongodb_atlas_cluster_uri",
"collection_name",
"db_name",
]
advanced_fields = [
"n_dim",
"key",
"prefix",
"distance_func",
"content_payload_key",
"metadata_payload_key",
"timeout",
"host",
"path",
"url",
"port",
"https",
"prefer_grpc",
"grpc_port",
"pinecone_api_key",
"pinecone_env",
"client_kwargs",
"search_kwargs",
]
# Check and set field attributes
if field.name == "texts":
@ -269,7 +322,7 @@ class VectorStoreFrontendNode(FrontendNode):
field.display_name = "Embedding"
field.field_type = "Embeddings"
elif field.name in basic_fields:
elif field.name in BASIC_FIELDS:
field.show = True
field.advanced = False
if field.name == "api_key":
@ -279,7 +332,7 @@ class VectorStoreFrontendNode(FrontendNode):
field.value = ":memory:"
field.placeholder = ":memory:"
elif field.name in advanced_fields:
elif field.name in ADVANCED_FIELDS:
field.show = True
field.advanced = True
if "key" in field.name: