📝 (Chroma.py): Update imports and class inheritance

🐛 (Chroma.py): Fix method calls and variable names to align with changes in imports and class structure
🔧 (Chroma.py): Refactor code to handle exceptions and imports more gracefully
This commit is contained in:
ogabrielluiz 2024-06-17 14:24:25 -03:00
commit f3cb8c81b0

View file

@ -1,22 +1,20 @@
from copy import deepcopy
from typing import List
from typing import TYPE_CHECKING, List
import chromadb
from chromadb.config import Settings
from langchain.vectorstores import Chroma
from langchain.schema import BaseRetriever
from langflow.base.vectorstores.utils import chroma_collection_to_data
from langflow.custom import Component
from langflow.inputs import BoolInput, IntInput, StrInput, HandleInput, DropdownInput
from langflow.schema import Data
from langflow.template import Output
from langflow.field_typing import Embeddings
from langflow.helpers.data import docs_to_data
from langchain_chroma.vectorstores import Chroma
from loguru import logger
class ChromaVectorStoreComponent(Component):
from langflow.base.vectorstores.utils import chroma_collection_to_data
from langflow.components.vectorstores.base.model import LCVectorStoreComponent
from langflow.inputs import BoolInput, DropdownInput, HandleInput, IntInput, StrInput
from langflow.schema import Data
if TYPE_CHECKING:
from langchain_chroma import Chroma
class ChromaVectorStoreComponent(LCVectorStoreComponent):
"""
Chroma Vector Store with search capabilities
"""
@ -45,13 +43,9 @@ class ChromaVectorStoreComponent(Component):
name="vector_store_inputs",
display_name="Vector Store Inputs",
input_types=["Document", "Data"],
is_list=True
),
HandleInput(
name="embedding",
display_name="Embedding",
input_types=["Embeddings"]
is_list=True,
),
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]),
StrInput(
name="chroma_server_cors_allow_origins",
display_name="Server CORS Allow Origins",
@ -105,42 +99,25 @@ class ChromaVectorStoreComponent(Component):
advanced=True,
value=4,
),
]
outputs = [
Output(
display_name="Vector Store",
name="vector_store",
method="build_vector_store",
),
Output(
display_name="Base Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
IntInput(
name="limit",
display_name="Limit",
advanced=True,
info="Limit the number of records to compare when Allow Duplicates is False.",
),
]
def build_vector_store(self) -> Chroma:
"""
Builds the Vector Store object.
"""
return self._build_chroma()
def build_base_retriever(self) -> BaseRetriever:
"""
Builds the BaseRetriever object.
"""
return self._build_chroma()
def _build_chroma(self) -> Chroma:
def build_vector_store(self) -> "Chroma":
"""
Builds the Chroma object.
"""
try:
from chromadb import Client
from langchain_chroma import Chroma
except ImportError:
raise ImportError(
"Could not import Chroma integration package. " "Please install it with `pip install langchain-chroma`."
)
# Chroma settings
chroma_settings = None
client = None
@ -152,7 +129,7 @@ class ChromaVectorStoreComponent(Component):
chroma_server_grpc_port=self.chroma_server_grpc_port or None,
chroma_server_ssl_enabled=self.chroma_server_ssl_enabled,
)
client = chromadb.Client(settings=chroma_settings)
client = Client(settings=chroma_settings)
# Check persist_directory and expand it if it is a relative path
if self.persist_directory is not None:
@ -170,17 +147,17 @@ class ChromaVectorStoreComponent(Component):
if self.add_to_vector_store:
self._add_documents_to_vector_store(chroma)
self.status = chroma_collection_to_data(chroma.get())
self.status = chroma_collection_to_data(chroma.get(self.limit))
return chroma
def _add_documents_to_vector_store(self, chroma: Chroma) -> None:
def _add_documents_to_vector_store(self, vector_store: "Chroma") -> None:
"""
Adds documents to the Vector Store.
"""
if self.allow_duplicates:
stored_data = []
else:
stored_data = chroma_collection_to_data(chroma.get())
stored_data = chroma_collection_to_data(vector_store.get(self.limit))
_stored_documents_without_id = []
for value in deepcopy(stored_data):
del value.id
@ -196,7 +173,7 @@ class ChromaVectorStoreComponent(Component):
if documents and self.embedding is not None:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
chroma.add_documents(documents)
vector_store.add_documents(documents)
else:
logger.debug("No documents to add to the Vector Store.")
@ -213,26 +190,7 @@ class ChromaVectorStoreComponent(Component):
logger.debug(f"Search type: {self.search_type}")
logger.debug(f"Number of results: {self.number_of_results}")
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
if self.search_type == "Similarity":
docs = vector_store.similarity_search(
query=self.search_input,
k=self.number_of_results,
)
elif self.search_type == "MMR":
docs = vector_store.max_marginal_relevance_search(
query=self.search_input,
k=self.number_of_results,
)
else:
raise ValueError(f"Invalid search type: {self.search_type}")
logger.debug(f"Retrieved documents: {len(docs)}")
data = docs_to_data(docs)
logger.debug(f"Converted documents to data: {len(data)}")
self.status = data
return data
else:
logger.debug("No search input provided. Skipping search.")
return []
search_results = self.search_with_vector_store(
self.input_value, self.search_type, vector_store, k=self.number_of_results
)
return search_results