refactor: Update ChromaComponent build method to allow duplicates in the Vector Store

This commit is contained in:
ogabrielluiz 2024-06-10 14:14:18 -03:00
commit 24e8da5086

View file

@ -1,3 +1,4 @@
from copy import deepcopy
from typing import List, Optional, Union
import chromadb
@ -6,6 +7,7 @@ from langchain_chroma import Chroma
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langflow.base.vectorstores.utils import chroma_collection_to_records
from langflow.custom import CustomComponent
from langflow.schema import Record
@ -48,6 +50,11 @@ class ChromaComponent(CustomComponent):
"display_name": "Server SSL Enabled",
"advanced": True,
},
"allow_duplicates": {
"display_name": "Allow Duplicates",
"advanced": True,
"info": "If false, will not add documents that are already in the Vector Store.",
},
}
def build(
@ -61,6 +68,7 @@ class ChromaComponent(CustomComponent):
chroma_server_host: Optional[str] = None,
chroma_server_http_port: Optional[int] = None,
chroma_server_grpc_port: Optional[int] = None,
allow_duplicates: bool = False,
) -> Union[VectorStore, BaseRetriever]:
"""
Builds the Vector Store or BaseRetriever object.
@ -75,6 +83,7 @@ class ChromaComponent(CustomComponent):
- chroma_server_host (Optional[str]): The host for the Chroma server.
- chroma_server_http_port (Optional[int]): The HTTP port for the Chroma server.
- chroma_server_grpc_port (Optional[int]): The gRPC port for the Chroma server.
- allow_duplicates (bool): Whether to allow duplicates in the Vector Store.
Returns:
- Union[VectorStore, BaseRetriever]: The Vector Store or BaseRetriever object.
@ -93,35 +102,34 @@ class ChromaComponent(CustomComponent):
)
client = chromadb.HttpClient(settings=chroma_settings)
# If documents, then we need to create a Chroma instance using .from_documents
# Check index_directory and expand it if it is a relative path
if index_directory is not None:
index_directory = self.resolve_path(index_directory)
chroma = Chroma(
persist_directory=index_directory,
client=client,
embedding_function=embedding,
collection_name=collection_name,
)
if allow_duplicates:
stored_records = []
else:
stored_records = chroma_collection_to_records(chroma.get())
_stored_documents_without_id = []
for record in deepcopy(stored_records):
del record.id
_stored_documents_without_id.append(record)
documents = []
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
if _input not in _stored_documents_without_id:
documents.append(_input.to_lc_document())
else:
documents.append(_input)
if documents is not None and embedding is not None:
if len(documents) == 0:
raise ValueError("If documents are provided, there must be at least one document.")
chroma = Chroma.from_documents(
documents=documents, # type: ignore
persist_directory=index_directory,
collection_name=collection_name,
embedding=embedding,
client=client,
)
else:
chroma = Chroma(
persist_directory=index_directory,
client=client,
embedding_function=embedding,
)
raise ValueError("Inputs must be a Record objects.")
store = chroma.get()
self.status = chroma_collection_to_records(store)
if documents and embedding is not None:
chroma.add_documents(documents)
self.status = stored_records
return chroma