Add HuggingFaceInferenceAPIEmbeddingsComponent class (#1431)

Adds a new HuggingFaceInferenceAPIEmbeddingsComponent component class

<img width="658" alt="image"
src="https://github.com/logspace-ai/langflow/assets/763757/175ff5ec-9ea4-4232-8009-63f022132c51">
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-15 10:45:42 -03:00 committed by GitHub
commit 620ef2b26a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 92 additions and 23 deletions

View file

@ -0,0 +1,42 @@
from langflow import CustomComponent
from typing import Optional, Dict
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
class HuggingFaceInferenceAPIEmbeddingsComponent(CustomComponent):
display_name = "HuggingFaceInferenceAPIEmbeddings"
description = "HuggingFace sentence_transformers embedding models, API version."
documentation = (
"https://github.com/huggingface/text-embeddings-inference"
)
def build_config(self):
return {
"api_key": {"display_name": "API Key", "password": True, "advanced": True},
"api_url": {"display_name": "API URL", "advanced": True},
"model_name": {"display_name": "Model Name"},
"cache_folder": {"display_name": "Cache Folder", "advanced": True},
"encode_kwargs": {"display_name": "Encode Kwargs", "advanced": True, "field_type": "dict"},
"model_kwargs": {"display_name": "Model Kwargs", "field_type": "dict", "advanced": True},
"multi_process": {"display_name": "Multi Process", "advanced": True},
}
def build(
self,
api_key: Optional[str] = "",
api_url: str = "http://localhost:8080",
model_name: str = "BAAI/bge-large-en-v1.5",
cache_folder: Optional[str] = None,
encode_kwargs: Optional[Dict] = {},
model_kwargs: Optional[Dict] = {},
multi_process: bool = False,
) -> HuggingFaceInferenceAPIEmbeddings:
return HuggingFaceInferenceAPIEmbeddings(
api_key=api_key,
api_url=api_url,
model_name=model_name,
cache_folder=cache_folder,
encode_kwargs=encode_kwargs,
model_kwargs=model_kwargs,
multi_process=multi_process,
)

View file

@ -15,7 +15,7 @@ class QdrantComponent(CustomComponent):
return {
"documents": {"display_name": "Documents"},
"embedding": {"display_name": "Embedding"},
"api_key": {"display_name": "API Key", "password": True},
"api_key": {"display_name": "API Key", "password": True, "advanced": True},
"collection_name": {"display_name": "Collection Name"},
"content_payload_key": {"display_name": "Content Payload Key", "advanced": True},
"distance_func": {"display_name": "Distance Function", "advanced": True},
@ -36,7 +36,7 @@ class QdrantComponent(CustomComponent):
def build(
self,
embedding: Embeddings,
documents: List[Document],
documents: Optional[Document] = None,
api_key: Optional[str] = None,
collection_name: Optional[str] = None,
content_payload_key: str = "page_content",
@ -44,7 +44,7 @@ class QdrantComponent(CustomComponent):
grpc_port: Optional[int] = 6334,
host: Optional[str] = None,
https: bool = False,
location: str = ":memory:",
location: Optional[str] = None,
metadata_payload_key: str = "metadata",
path: Optional[str] = None,
port: Optional[int] = 6333,
@ -54,23 +54,50 @@ class QdrantComponent(CustomComponent):
timeout: Optional[float] = None,
url: Optional[str] = None,
) -> Union[VectorStore, Qdrant, BaseRetriever]:
return Qdrant.from_documents(
documents=documents,
embedding=embedding,
api_key=api_key,
collection_name=collection_name,
content_payload_key=content_payload_key,
distance_func=distance_func,
grpc_port=grpc_port,
host=host,
https=https,
location=location,
metadata_payload_key=metadata_payload_key,
path=path,
port=port,
prefer_grpc=prefer_grpc,
prefix=prefix,
search_kwargs=search_kwargs,
timeout=timeout,
url=url,
)
if documents is None:
from qdrant_client import QdrantClient
client = QdrantClient(
location=location,
url=host,
port=port,
grpc_port=grpc_port,
https=https,
prefix=prefix,
timeout=timeout,
prefer_grpc=prefer_grpc,
metadata_payload_key=metadata_payload_key,
content_payload_key=content_payload_key,
api_key=api_key,
collection_name=collection_name,
host=host,
path=path,
)
vs = Qdrant(client=client,
collection_name=collection_name,
embeddings=embedding,
search_kwargs=search_kwargs,
distance_func=distance_func,
)
return vs
else:
vs = Qdrant.from_documents(
documents=documents,
embedding=embedding,
api_key=api_key,
collection_name=collection_name,
content_payload_key=content_payload_key,
distance_func=distance_func,
grpc_port=grpc_port,
host=host,
https=https,
location=location,
metadata_payload_key=metadata_payload_key,
path=path,
port=port,
prefer_grpc=prefer_grpc,
prefix=prefix,
search_kwargs=search_kwargs,
timeout=timeout,
url=url,
)
return vs