chore: Update persist_directory parameter to handle None value in ChromaSearch and Chroma components (#2157)

* chore: Update persist_directory parameter to handle None value in ChromaSearch and Chroma components

* 🐛 (test_endpoints.py): fix assertion to check for correct key name in output results for chat and any input types
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-13 06:52:32 -07:00 committed by GitHub
commit 34b6153fed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 5 additions and 6 deletions

View file

@ -3,7 +3,6 @@ from typing import List, Optional
import chromadb
from chromadb.config import Settings
from langchain_chroma import Chroma
from langflow.components.vectorstores.base.model import LCVectorStoreComponent
from langflow.field_typing import Embeddings, Text
from langflow.schema import Record
@ -104,10 +103,11 @@ class ChromaSearchComponent(LCVectorStoreComponent):
client = chromadb.HttpClient(settings=chroma_settings)
if index_directory:
index_directory = self.resolve_path(index_directory)
vector_store = Chroma(
embedding_function=embedding,
collection_name=collection_name,
persist_directory=index_directory,
persist_directory=index_directory or None,
client=client,
)

View file

@ -7,7 +7,6 @@ from langchain_chroma import Chroma
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langflow.base.vectorstores.utils import chroma_collection_to_records
from langflow.custom import CustomComponent
from langflow.schema import Record
@ -107,7 +106,7 @@ class ChromaComponent(CustomComponent):
index_directory = self.resolve_path(index_directory)
chroma = Chroma(
persist_directory=index_directory,
persist_directory=index_directory or None,
client=client,
embedding_function=embedding,
collection_name=collection_name,

View file

@ -597,7 +597,7 @@ def test_successful_run_with_input_type_chat(client, starter_project, created_ap
chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")]
assert len(chat_input_outputs) == 1
# Now we check if the input_value is correct
assert all([output.get("results").get("result") == "value1" for output in chat_input_outputs]), chat_input_outputs
assert all([output.get("results").get("text") == "value1" for output in chat_input_outputs]), chat_input_outputs
def test_successful_run_with_input_type_any(client, starter_project, created_api_key):
@ -631,7 +631,7 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api
]
assert len(any_input_outputs) == 1
# Now we check if the input_value is correct
assert all([output.get("results").get("result") == "value1" for output in any_input_outputs]), any_input_outputs
assert all([output.get("results").get("text") == "value1" for output in any_input_outputs]), any_input_outputs
@pytest.mark.api_key_required