📝 (Chroma.py): Update imports and class inheritance
🐛 (Chroma.py): Fix method calls and variable names to align with changes in imports and class structure 🔧 (Chroma.py): Refactor code to handle exceptions and imports more gracefully
This commit is contained in:
parent
7818e55146
commit
f3cb8c81b0
1 changed files with 36 additions and 78 deletions
|
|
@ -1,22 +1,20 @@
|
|||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from langflow.base.vectorstores.utils import chroma_collection_to_data
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs import BoolInput, IntInput, StrInput, HandleInput, DropdownInput
|
||||
from langflow.schema import Data
|
||||
from langflow.template import Output
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.helpers.data import docs_to_data
|
||||
|
||||
from langchain_chroma.vectorstores import Chroma
|
||||
from loguru import logger
|
||||
|
||||
class ChromaVectorStoreComponent(Component):
|
||||
from langflow.base.vectorstores.utils import chroma_collection_to_data
|
||||
from langflow.components.vectorstores.base.model import LCVectorStoreComponent
|
||||
from langflow.inputs import BoolInput, DropdownInput, HandleInput, IntInput, StrInput
|
||||
from langflow.schema import Data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_chroma import Chroma
|
||||
|
||||
|
||||
class ChromaVectorStoreComponent(LCVectorStoreComponent):
|
||||
"""
|
||||
Chroma Vector Store with search capabilities
|
||||
"""
|
||||
|
|
@ -45,13 +43,9 @@ class ChromaVectorStoreComponent(Component):
|
|||
name="vector_store_inputs",
|
||||
display_name="Vector Store Inputs",
|
||||
input_types=["Document", "Data"],
|
||||
is_list=True
|
||||
),
|
||||
HandleInput(
|
||||
name="embedding",
|
||||
display_name="Embedding",
|
||||
input_types=["Embeddings"]
|
||||
is_list=True,
|
||||
),
|
||||
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]),
|
||||
StrInput(
|
||||
name="chroma_server_cors_allow_origins",
|
||||
display_name="Server CORS Allow Origins",
|
||||
|
|
@ -105,42 +99,25 @@ class ChromaVectorStoreComponent(Component):
|
|||
advanced=True,
|
||||
value=4,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(
|
||||
display_name="Vector Store",
|
||||
name="vector_store",
|
||||
method="build_vector_store",
|
||||
),
|
||||
Output(
|
||||
display_name="Base Retriever",
|
||||
name="base_retriever",
|
||||
method="build_base_retriever",
|
||||
),
|
||||
Output(
|
||||
display_name="Search Results",
|
||||
name="search_results",
|
||||
method="search_documents",
|
||||
IntInput(
|
||||
name="limit",
|
||||
display_name="Limit",
|
||||
advanced=True,
|
||||
info="Limit the number of records to compare when Allow Duplicates is False.",
|
||||
),
|
||||
]
|
||||
|
||||
def build_vector_store(self) -> Chroma:
|
||||
"""
|
||||
Builds the Vector Store object.
|
||||
"""
|
||||
return self._build_chroma()
|
||||
|
||||
def build_base_retriever(self) -> BaseRetriever:
|
||||
"""
|
||||
Builds the BaseRetriever object.
|
||||
"""
|
||||
return self._build_chroma()
|
||||
|
||||
def _build_chroma(self) -> Chroma:
|
||||
def build_vector_store(self) -> "Chroma":
|
||||
"""
|
||||
Builds the Chroma object.
|
||||
"""
|
||||
try:
|
||||
from chromadb import Client
|
||||
from langchain_chroma import Chroma
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Chroma integration package. " "Please install it with `pip install langchain-chroma`."
|
||||
)
|
||||
# Chroma settings
|
||||
chroma_settings = None
|
||||
client = None
|
||||
|
|
@ -152,7 +129,7 @@ class ChromaVectorStoreComponent(Component):
|
|||
chroma_server_grpc_port=self.chroma_server_grpc_port or None,
|
||||
chroma_server_ssl_enabled=self.chroma_server_ssl_enabled,
|
||||
)
|
||||
client = chromadb.Client(settings=chroma_settings)
|
||||
client = Client(settings=chroma_settings)
|
||||
|
||||
# Check persist_directory and expand it if it is a relative path
|
||||
if self.persist_directory is not None:
|
||||
|
|
@ -170,17 +147,17 @@ class ChromaVectorStoreComponent(Component):
|
|||
if self.add_to_vector_store:
|
||||
self._add_documents_to_vector_store(chroma)
|
||||
|
||||
self.status = chroma_collection_to_data(chroma.get())
|
||||
self.status = chroma_collection_to_data(chroma.get(self.limit))
|
||||
return chroma
|
||||
|
||||
def _add_documents_to_vector_store(self, chroma: Chroma) -> None:
|
||||
def _add_documents_to_vector_store(self, vector_store: "Chroma") -> None:
|
||||
"""
|
||||
Adds documents to the Vector Store.
|
||||
"""
|
||||
if self.allow_duplicates:
|
||||
stored_data = []
|
||||
else:
|
||||
stored_data = chroma_collection_to_data(chroma.get())
|
||||
stored_data = chroma_collection_to_data(vector_store.get(self.limit))
|
||||
_stored_documents_without_id = []
|
||||
for value in deepcopy(stored_data):
|
||||
del value.id
|
||||
|
|
@ -196,7 +173,7 @@ class ChromaVectorStoreComponent(Component):
|
|||
|
||||
if documents and self.embedding is not None:
|
||||
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
|
||||
chroma.add_documents(documents)
|
||||
vector_store.add_documents(documents)
|
||||
else:
|
||||
logger.debug("No documents to add to the Vector Store.")
|
||||
|
||||
|
|
@ -213,26 +190,7 @@ class ChromaVectorStoreComponent(Component):
|
|||
logger.debug(f"Search type: {self.search_type}")
|
||||
logger.debug(f"Number of results: {self.number_of_results}")
|
||||
|
||||
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
|
||||
if self.search_type == "Similarity":
|
||||
docs = vector_store.similarity_search(
|
||||
query=self.search_input,
|
||||
k=self.number_of_results,
|
||||
)
|
||||
elif self.search_type == "MMR":
|
||||
docs = vector_store.max_marginal_relevance_search(
|
||||
query=self.search_input,
|
||||
k=self.number_of_results,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid search type: {self.search_type}")
|
||||
|
||||
logger.debug(f"Retrieved documents: {len(docs)}")
|
||||
|
||||
data = docs_to_data(docs)
|
||||
logger.debug(f"Converted documents to data: {len(data)}")
|
||||
self.status = data
|
||||
return data
|
||||
else:
|
||||
logger.debug("No search input provided. Skipping search.")
|
||||
return []
|
||||
search_results = self.search_with_vector_store(
|
||||
self.input_value, self.search_type, vector_store, k=self.number_of_results
|
||||
)
|
||||
return search_results
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue