From 35ec2e086709b5ebf4bb338ddb1a5d6aee88520e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 4 Aug 2023 16:56:36 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(vector=5Fstore.py):=20build?= =?UTF-8?q?=20Chroma=20settings=20if=20any=20of=20the=20chroma=5Fserver=5F?= =?UTF-8?q?=20params=20are=20present=20in=20params=20=E2=9C=A8=20feat(vect?= =?UTF-8?q?orstores.py):=20add=20new=20fields=20for=20Chroma=20vector=20st?= =?UTF-8?q?ore=20configuration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../interface/initialize/vector_store.py | 23 +++ .../template/frontend_node/vectorstores.py | 155 ++++++++++++------ 2 files changed, 127 insertions(+), 51 deletions(-) diff --git a/src/backend/langflow/interface/initialize/vector_store.py b/src/backend/langflow/interface/initialize/vector_store.py index d4bdb0155..c616d9b87 100644 --- a/src/backend/langflow/interface/initialize/vector_store.py +++ b/src/backend/langflow/interface/initialize/vector_store.py @@ -170,6 +170,29 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict): def initialize_chroma(class_object: Type[Chroma], params: dict): """Initialize a ChromaDB object from the params""" + # chroma_server_host: str | None = None, + # chroma_server_http_port: str | None = None, + # chroma_server_ssl_enabled: bool | None = False, + # chroma_server_grpc_port: str | None = None, + # chroma_server_cors_allow_origins: List[str] = [], + # If any of the above params are in params, specially host and port, + # we need to build the Chroma settings + if ( # type: ignore + "chroma_server_host" in params + or "chroma_server_http_port" in params + or "chroma_server_ssl_enabled" in params + or "chroma_server_grpc_port" in params + or "chroma_server_cors_allow_origins" in params + ): + import chromadb + + settings_params = { + key: params[key] + for key, value_ in params.items() + if key.startswith("chroma_server_") and value_ + } + chroma_settings = chromadb.config.Settings(**settings_params) + params["client_settings"] = chroma_settings persist = params.pop("persist", False) if not docs_in_params(params): params.pop("documents", None) diff --git a/src/backend/langflow/template/frontend_node/vectorstores.py b/src/backend/langflow/template/frontend_node/vectorstores.py index 53a840b80..23c293437 100644 --- a/src/backend/langflow/template/frontend_node/vectorstores.py +++ b/src/backend/langflow/template/frontend_node/vectorstores.py @@ -4,6 +4,52 @@ from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode +BASIC_FIELDS = [ + "work_dir", + "collection_name", + "api_key", + "location", + "persist_directory", + "persist", + "weaviate_url", + "index_name", + "namespace", + "folder_path", + "table_name", + "query_name", + "supabase_url", + "supabase_service_key", + "mongodb_atlas_cluster_uri", + "collection_name", + "db_name", +] +ADVANCED_FIELDS = [ + "n_dim", + "key", + "prefix", + "distance_func", + "content_payload_key", + "metadata_payload_key", + "timeout", + "host", + "path", + "url", + "port", + "https", + "prefer_grpc", + "grpc_port", + "pinecone_api_key", + "pinecone_env", + "client_kwargs", + "search_kwargs", + "chroma_server_host", + "chroma_server_http_port", + "chroma_server_ssl_enabled", + "chroma_server_grpc_port", + "chroma_server_cors_allow_origins", +] + + class VectorStoreFrontendNode(FrontendNode): def add_extra_fields(self) -> None: extra_fields: List[TemplateField] = [] @@ -45,16 +91,62 @@ class VectorStoreFrontendNode(FrontendNode): elif self.template.type_name == "Chroma": # New bool field for persist parameter - extra_field = TemplateField( - name="persist", - field_type="bool", - required=False, - show=True, - advanced=False, - value=False, - display_name="Persist", - ) - extra_fields.append(extra_field) + chroma_fields = [ + TemplateField( + name="persist", + field_type="bool", + required=False, + show=True, + advanced=False, + value=False, + display_name="Persist", + ), + # chroma_server_grpc_port: str | None = None, + TemplateField( + name="chroma_server_host", + field_type="str", + required=False, + show=True, + advanced=True, + display_name="Chroma Server Host", + ), + TemplateField( + name="chroma_server_http_port", + field_type="str", + required=False, + show=True, + advanced=True, + display_name="Chroma Server HTTP Port", + ), + TemplateField( + name="chroma_server_ssl_enabled", + field_type="bool", + required=False, + show=True, + advanced=True, + value=False, + display_name="Chroma Server SSL Enabled", + ), + TemplateField( + name="chroma_server_grpc_port", + field_type="str", + required=False, + show=True, + advanced=True, + display_name="Chroma Server GRPC Port", + ), + TemplateField( + name="chroma_server_cors_allow_origins", + field_type="str", + required=False, + is_list=True, + show=True, + advanced=True, + display_name="Chroma Server CORS Allow Origins", + ), + ] + + extra_fields.extend(chroma_fields) elif self.template.type_name == "Pinecone": # add pinecone_api_key and pinecone_env extra_field = TemplateField( @@ -208,45 +300,6 @@ class VectorStoreFrontendNode(FrontendNode): def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) # Define common field attributes - basic_fields = [ - "work_dir", - "collection_name", - "api_key", - "location", - "persist_directory", - "persist", - "weaviate_url", - "index_name", - "namespace", - "folder_path", - "table_name", - "query_name", - "supabase_url", - "supabase_service_key", - "mongodb_atlas_cluster_uri", - "collection_name", - "db_name", - ] - advanced_fields = [ - "n_dim", - "key", - "prefix", - "distance_func", - "content_payload_key", - "metadata_payload_key", - "timeout", - "host", - "path", - "url", - "port", - "https", - "prefer_grpc", - "grpc_port", - "pinecone_api_key", - "pinecone_env", - "client_kwargs", - "search_kwargs", - ] # Check and set field attributes if field.name == "texts": @@ -269,7 +322,7 @@ class VectorStoreFrontendNode(FrontendNode): field.display_name = "Embedding" field.field_type = "Embeddings" - elif field.name in basic_fields: + elif field.name in BASIC_FIELDS: field.show = True field.advanced = False if field.name == "api_key": @@ -279,7 +332,7 @@ class VectorStoreFrontendNode(FrontendNode): field.value = ":memory:" field.placeholder = ":memory:" - elif field.name in advanced_fields: + elif field.name in ADVANCED_FIELDS: field.show = True field.advanced = True if "key" in field.name: