diff --git a/src/backend/langflow/template/frontend_node/vectorstores.py b/src/backend/langflow/template/frontend_node/vectorstores.py index d04936a8b..0d6fb2467 100644 --- a/src/backend/langflow/template/frontend_node/vectorstores.py +++ b/src/backend/langflow/template/frontend_node/vectorstores.py @@ -6,6 +6,7 @@ from langflow.template.frontend_node.base import FrontendNode class VectorStoreFrontendNode(FrontendNode): def add_extra_fields(self) -> None: + extra_field = None if self.template.type_name == "Weaviate": extra_field = TemplateField( name="weaviate_url", @@ -18,6 +19,18 @@ class VectorStoreFrontendNode(FrontendNode): value="http://localhost:8080", ) + elif self.template.type_name == "Chroma": + # New bool field for persist parameter + extra_field = TemplateField( + name="persist", + field_type="bool", + required=False, + show=True, + advanced=False, + value=True, + display_name="Persist", + ) + if extra_field is not None: self.template.add_field(extra_field) def add_extra_base_classes(self) -> None: @@ -27,7 +40,14 @@ class VectorStoreFrontendNode(FrontendNode): def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) # Define common field attributes - basic_fields = ["work_dir", "collection_name", "api_key", "location"] + basic_fields = [ + "work_dir", + "collection_name", + "api_key", + "location", + "persist_directory", + "persist", + ] advanced_fields = [ "n_dim", "key", @@ -78,5 +98,3 @@ class VectorStoreFrontendNode(FrontendNode): field.advanced = True if "key" in field.name: field.password = False - # TODO: Weaviate requires weaviate_url to be passed as it is not part of - # the class or from_texts method. We need the add_extra_fields to fix this