diff --git a/src/backend/base/langflow/components/vectorsearch/ChromaSearch.py b/src/backend/base/langflow/components/vectorsearch/ChromaSearch.py index 228e100e4..3d4687522 100644 --- a/src/backend/base/langflow/components/vectorsearch/ChromaSearch.py +++ b/src/backend/base/langflow/components/vectorsearch/ChromaSearch.py @@ -3,7 +3,6 @@ from typing import List, Optional import chromadb from chromadb.config import Settings from langchain_chroma import Chroma - from langflow.components.vectorstores.base.model import LCVectorStoreComponent from langflow.field_typing import Embeddings, Text from langflow.schema import Record @@ -104,10 +103,11 @@ class ChromaSearchComponent(LCVectorStoreComponent): client = chromadb.HttpClient(settings=chroma_settings) if index_directory: index_directory = self.resolve_path(index_directory) + vector_store = Chroma( embedding_function=embedding, collection_name=collection_name, - persist_directory=index_directory, + persist_directory=index_directory or None, client=client, ) diff --git a/src/backend/base/langflow/components/vectorstores/Chroma.py b/src/backend/base/langflow/components/vectorstores/Chroma.py index 6001b119c..f7080d2fc 100644 --- a/src/backend/base/langflow/components/vectorstores/Chroma.py +++ b/src/backend/base/langflow/components/vectorstores/Chroma.py @@ -7,7 +7,6 @@ from langchain_chroma import Chroma from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore - from langflow.base.vectorstores.utils import chroma_collection_to_records from langflow.custom import CustomComponent from langflow.schema import Record @@ -107,7 +106,7 @@ class ChromaComponent(CustomComponent): index_directory = self.resolve_path(index_directory) chroma = Chroma( - persist_directory=index_directory, + persist_directory=index_directory or None, client=client, embedding_function=embedding, collection_name=collection_name, diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index d7c8a4ef2..1f1dcd5df 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -597,7 +597,7 @@ def test_successful_run_with_input_type_chat(client, starter_project, created_ap chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")] assert len(chat_input_outputs) == 1 # Now we check if the input_value is correct - assert all([output.get("results").get("result") == "value1" for output in chat_input_outputs]), chat_input_outputs + assert all([output.get("results").get("text") == "value1" for output in chat_input_outputs]), chat_input_outputs def test_successful_run_with_input_type_any(client, starter_project, created_api_key): @@ -631,7 +631,7 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api ] assert len(any_input_outputs) == 1 # Now we check if the input_value is correct - assert all([output.get("results").get("result") == "value1" for output in any_input_outputs]), any_input_outputs + assert all([output.get("results").get("text") == "value1" for output in any_input_outputs]), any_input_outputs @pytest.mark.api_key_required