feat: Enhance Chroma Handling, Bug Fixes & add GoogleGenerativeAIEmbeddingsComponent (#3476)
* Update Chroma.py * Update utils.py * Create GoogleGenerativeAIEmbeddings.py * Update __init__.py * Update GoogleGenerativeAIEmbeddings.py * [autofix.ci] apply automated fixes * Update src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
4e15090927
commit
2896c72e99
4 changed files with 128 additions and 2 deletions
|
|
@ -17,7 +17,7 @@ def chroma_collection_to_data(collection_dict: dict):
|
|||
"id": collection_dict["ids"][i],
|
||||
"text": doc,
|
||||
}
|
||||
if "metadatas" in collection_dict:
|
||||
if ("metadatas" in collection_dict) and collection_dict["metadatas"][i]:
|
||||
for key, value in collection_dict["metadatas"][i].items():
|
||||
data_dict[key] = value
|
||||
data.append(Data(**data_dict))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
# from langflow.field_typing import Data
|
||||
from langflow.custom import Component
|
||||
from langflow.io import MessageTextInput, Output, SecretStrInput
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
from google.ai.generativelanguage_v1beta.types import (
|
||||
BatchEmbedContentsRequest,
|
||||
)
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_google_genai._common import (
|
||||
GoogleGenerativeAIError,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GoogleGenerativeAIEmbeddingsComponent(Component):
|
||||
display_name = "Custom Component"
|
||||
description = "Use as a template to create your own component."
|
||||
documentation: str = "http://docs.langflow.org/components/custom"
|
||||
icon = "custom_components"
|
||||
name = "CustomComponent"
|
||||
|
||||
inputs = [
|
||||
SecretStrInput(name="api_key", display_name="API Key"),
|
||||
MessageTextInput(name="model_name", display_name="Model Name", value="models/text-embedding-004"),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
|
||||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
if not self.api_key:
|
||||
raise ValueError("API Key is required")
|
||||
|
||||
class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs)
|
||||
|
||||
def embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
batch_size: int = 100,
|
||||
task_type: Optional[str] = None,
|
||||
titles: Optional[List[str]] = None,
|
||||
output_dimensionality: Optional[int] = 1536,
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings. Google Generative AI currently
|
||||
sets a max batch size of 100 strings.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of strings to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model
|
||||
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
|
||||
titles: An optional list of titles for texts provided.
|
||||
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
|
||||
output_dimensionality: Optional reduced dimension for the output embedding.
|
||||
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings: List[List[float]] = []
|
||||
batch_start_index = 0
|
||||
for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size):
|
||||
if titles:
|
||||
titles_batch = titles[batch_start_index : batch_start_index + len(batch)]
|
||||
batch_start_index += len(batch)
|
||||
else:
|
||||
titles_batch = [None] * len(batch) # type: ignore[list-item]
|
||||
|
||||
requests = [
|
||||
self._prepare_request(
|
||||
text=text,
|
||||
task_type=task_type,
|
||||
title=title,
|
||||
output_dimensionality=1536,
|
||||
)
|
||||
for text, title in zip(batch, titles_batch)
|
||||
]
|
||||
|
||||
try:
|
||||
result = self.client.batch_embed_contents(
|
||||
BatchEmbedContentsRequest(requests=requests, model=self.model)
|
||||
)
|
||||
except Exception as e:
|
||||
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
|
||||
embeddings.extend([list(np.pad(e.values, (0, 768), "constant")) for e in result.embeddings])
|
||||
return embeddings
|
||||
|
||||
def embed_query(
|
||||
self,
|
||||
text: str,
|
||||
task_type: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
output_dimensionality: Optional[int] = 1536,
|
||||
) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
|
||||
title: An optional title for the text.
|
||||
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
|
||||
output_dimensionality: Optional reduced dimension for the output embedding.
|
||||
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
task_type = self.task_type or "RETRIEVAL_QUERY"
|
||||
return self.embed_documents(
|
||||
[text],
|
||||
task_type=task_type,
|
||||
titles=[title] if title else None,
|
||||
output_dimensionality=1536,
|
||||
)[0]
|
||||
|
||||
return HotaGoogleGenerativeAIEmbeddings(model=self.model_name, google_api_key=self.api_key)
|
||||
|
|
@ -7,6 +7,7 @@ from .HuggingFaceInferenceAPIEmbeddings import HuggingFaceInferenceAPIEmbeddings
|
|||
from .OllamaEmbeddings import OllamaEmbeddingsComponent
|
||||
from .OpenAIEmbeddings import OpenAIEmbeddingsComponent
|
||||
from .VertexAIEmbeddings import VertexAIEmbeddingsComponent
|
||||
from .GoogleGenerativeAIEmbeddings import GoogleGenerativeAIEmbeddingsComponent
|
||||
|
||||
__all__ = [
|
||||
"AIMLEmbeddingsComponent",
|
||||
|
|
@ -18,4 +19,5 @@ __all__ = [
|
|||
"OllamaEmbeddingsComponent",
|
||||
"OpenAIEmbeddingsComponent",
|
||||
"VertexAIEmbeddingsComponent",
|
||||
"GoogleGenerativeAIEmbeddingsComponent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ class ChromaVectorStoreComponent(LCVectorStoreComponent):
|
|||
if self.allow_duplicates:
|
||||
stored_data = []
|
||||
else:
|
||||
stored_data = chroma_collection_to_data(vector_store.get(self.limit))
|
||||
stored_data = chroma_collection_to_data(vector_store.get(limit=self.limit))
|
||||
for value in deepcopy(stored_data):
|
||||
del value.id
|
||||
_stored_documents_without_id.append(value)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue