diff --git a/src/backend/base/langflow/components/vectorstores/Chroma.py b/src/backend/base/langflow/components/vectorstores/Chroma.py index e93257932..1db41a159 100644 --- a/src/backend/base/langflow/components/vectorstores/Chroma.py +++ b/src/backend/base/langflow/components/vectorstores/Chroma.py @@ -1,22 +1,20 @@ from copy import deepcopy -from typing import List +from typing import TYPE_CHECKING, List -import chromadb from chromadb.config import Settings -from langchain.vectorstores import Chroma -from langchain.schema import BaseRetriever - -from langflow.base.vectorstores.utils import chroma_collection_to_data -from langflow.custom import Component -from langflow.inputs import BoolInput, IntInput, StrInput, HandleInput, DropdownInput -from langflow.schema import Data -from langflow.template import Output -from langflow.field_typing import Embeddings -from langflow.helpers.data import docs_to_data - +from langchain_chroma.vectorstores import Chroma from loguru import logger -class ChromaVectorStoreComponent(Component): +from langflow.base.vectorstores.utils import chroma_collection_to_data +from langflow.components.vectorstores.base.model import LCVectorStoreComponent +from langflow.inputs import BoolInput, DropdownInput, HandleInput, IntInput, StrInput +from langflow.schema import Data + +if TYPE_CHECKING: + from langchain_chroma import Chroma + + +class ChromaVectorStoreComponent(LCVectorStoreComponent): """ Chroma Vector Store with search capabilities """ @@ -45,13 +43,9 @@ class ChromaVectorStoreComponent(Component): name="vector_store_inputs", display_name="Vector Store Inputs", input_types=["Document", "Data"], - is_list=True - ), - HandleInput( - name="embedding", - display_name="Embedding", - input_types=["Embeddings"] + is_list=True, ), + HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]), StrInput( name="chroma_server_cors_allow_origins", display_name="Server CORS Allow Origins", @@ -105,42 +99,25 @@ class ChromaVectorStoreComponent(Component): advanced=True, value=4, ), - ] - - outputs = [ - Output( - display_name="Vector Store", - name="vector_store", - method="build_vector_store", - ), - Output( - display_name="Base Retriever", - name="base_retriever", - method="build_base_retriever", - ), - Output( - display_name="Search Results", - name="search_results", - method="search_documents", + IntInput( + name="limit", + display_name="Limit", + advanced=True, + info="Limit the number of records to compare when Allow Duplicates is False.", ), ] - def build_vector_store(self) -> Chroma: - """ - Builds the Vector Store object. - """ - return self._build_chroma() - - def build_base_retriever(self) -> BaseRetriever: - """ - Builds the BaseRetriever object. - """ - return self._build_chroma() - - def _build_chroma(self) -> Chroma: + def build_vector_store(self) -> "Chroma": """ Builds the Chroma object. """ + try: + from chromadb import Client + from langchain_chroma import Chroma + except ImportError: + raise ImportError( + "Could not import Chroma integration package. " "Please install it with `pip install langchain-chroma`." + ) # Chroma settings chroma_settings = None client = None @@ -152,7 +129,7 @@ class ChromaVectorStoreComponent(Component): chroma_server_grpc_port=self.chroma_server_grpc_port or None, chroma_server_ssl_enabled=self.chroma_server_ssl_enabled, ) - client = chromadb.Client(settings=chroma_settings) + client = Client(settings=chroma_settings) # Check persist_directory and expand it if it is a relative path if self.persist_directory is not None: @@ -170,17 +147,17 @@ class ChromaVectorStoreComponent(Component): if self.add_to_vector_store: self._add_documents_to_vector_store(chroma) - self.status = chroma_collection_to_data(chroma.get()) + self.status = chroma_collection_to_data(chroma.get(self.limit)) return chroma - def _add_documents_to_vector_store(self, chroma: Chroma) -> None: + def _add_documents_to_vector_store(self, vector_store: "Chroma") -> None: """ Adds documents to the Vector Store. """ if self.allow_duplicates: stored_data = [] else: - stored_data = chroma_collection_to_data(chroma.get()) + stored_data = chroma_collection_to_data(vector_store.get(self.limit)) _stored_documents_without_id = [] for value in deepcopy(stored_data): del value.id @@ -196,7 +173,7 @@ class ChromaVectorStoreComponent(Component): if documents and self.embedding is not None: logger.debug(f"Adding {len(documents)} documents to the Vector Store.") - chroma.add_documents(documents) + vector_store.add_documents(documents) else: logger.debug("No documents to add to the Vector Store.") @@ -213,26 +190,7 @@ class ChromaVectorStoreComponent(Component): logger.debug(f"Search type: {self.search_type}") logger.debug(f"Number of results: {self.number_of_results}") - if self.search_input and isinstance(self.search_input, str) and self.search_input.strip(): - if self.search_type == "Similarity": - docs = vector_store.similarity_search( - query=self.search_input, - k=self.number_of_results, - ) - elif self.search_type == "MMR": - docs = vector_store.max_marginal_relevance_search( - query=self.search_input, - k=self.number_of_results, - ) - else: - raise ValueError(f"Invalid search type: {self.search_type}") - - logger.debug(f"Retrieved documents: {len(docs)}") - - data = docs_to_data(docs) - logger.debug(f"Converted documents to data: {len(data)}") - self.status = data - return data - else: - logger.debug("No search input provided. Skipping search.") - return [] + search_results = self.search_with_vector_store( + self.input_value, self.search_type, vector_store, k=self.number_of_results + ) + return search_results