chore: Update persist_directory parameter to handle None value in ChromaSearch and Chroma components (#2157)
* chore: Update persist_directory parameter to handle None value in ChromaSearch and Chroma components
* 🐛 (test_endpoints.py): fix assertion to check for correct key name in output results for chat and any input types
This commit is contained in:
parent
84df4fd8e4
commit
34b6153fed
3 changed files with 5 additions and 6 deletions
|
|
@ -3,7 +3,6 @@ from typing import List, Optional
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from langchain_chroma import Chroma
|
||||
|
||||
from langflow.components.vectorstores.base.model import LCVectorStoreComponent
|
||||
from langflow.field_typing import Embeddings, Text
|
||||
from langflow.schema import Record
|
||||
|
|
@ -104,10 +103,11 @@ class ChromaSearchComponent(LCVectorStoreComponent):
|
|||
client = chromadb.HttpClient(settings=chroma_settings)
|
||||
if index_directory:
|
||||
index_directory = self.resolve_path(index_directory)
|
||||
|
||||
vector_store = Chroma(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
persist_directory=index_directory,
|
||||
persist_directory=index_directory or None,
|
||||
client=client,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from langchain_chroma import Chroma
|
|||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langflow.base.vectorstores.utils import chroma_collection_to_records
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.schema import Record
|
||||
|
|
@ -107,7 +106,7 @@ class ChromaComponent(CustomComponent):
|
|||
index_directory = self.resolve_path(index_directory)
|
||||
|
||||
chroma = Chroma(
|
||||
persist_directory=index_directory,
|
||||
persist_directory=index_directory or None,
|
||||
client=client,
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
|
|
|
|||
|
|
@ -597,7 +597,7 @@ def test_successful_run_with_input_type_chat(client, starter_project, created_ap
|
|||
chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")]
|
||||
assert len(chat_input_outputs) == 1
|
||||
# Now we check if the input_value is correct
|
||||
assert all([output.get("results").get("result") == "value1" for output in chat_input_outputs]), chat_input_outputs
|
||||
assert all([output.get("results").get("text") == "value1" for output in chat_input_outputs]), chat_input_outputs
|
||||
|
||||
|
||||
def test_successful_run_with_input_type_any(client, starter_project, created_api_key):
|
||||
|
|
@ -631,7 +631,7 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api
|
|||
]
|
||||
assert len(any_input_outputs) == 1
|
||||
# Now we check if the input_value is correct
|
||||
assert all([output.get("results").get("result") == "value1" for output in any_input_outputs]), any_input_outputs
|
||||
assert all([output.get("results").get("text") == "value1" for output in any_input_outputs]), any_input_outputs
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue