refactor: Update HuggingFaceEmbeddingsComponent and HuggingFaceInferenceAPIEmbeddingsComponent to use new Inputs/Outputs format

This commit is contained in:
Cezar Vasconcelos 2024-06-19 21:12:07 +00:00
commit f93bdf8cd8
2 changed files with 50 additions and 51 deletions

View file

@ -2,10 +2,12 @@ from typing import Dict, Optional
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langflow.custom import CustomComponent
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import BoolInput, DictInput, TextInput, Output
class HuggingFaceEmbeddingsComponent(CustomComponent):
class HuggingFaceEmbeddingsComponent(LCModelComponent):
display_name = "Hugging Face Embeddings"
description = "Generate embeddings using HuggingFace models."
documentation = (
@ -13,27 +15,23 @@ class HuggingFaceEmbeddingsComponent(CustomComponent):
)
icon = "HuggingFace"
def build_config(self):
return {
"cache_folder": {"display_name": "Cache Folder", "advanced": True},
"encode_kwargs": {"display_name": "Encode Kwargs", "advanced": True, "field_type": "dict"},
"model_kwargs": {"display_name": "Model Kwargs", "field_type": "dict", "advanced": True},
"model_name": {"display_name": "Model Name"},
"multi_process": {"display_name": "Multi Process", "advanced": True},
}
inputs = [
TextInput(name="cache_folder", display_name="Cache Folder", advanced=True),
DictInput(name="encode_kwargs", display_name="Encode Kwargs", advanced=True),
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
TextInput(name="model_name", display_name="Model Name", value="sentence-transformers/all-mpnet-base-v2"),
BoolInput(name="multi_process", display_name="Multi Process", advanced=True),
]
def build(
self,
cache_folder: Optional[str] = None,
encode_kwargs: Optional[Dict] = {},
model_kwargs: Optional[Dict] = {},
model_name: str = "sentence-transformers/all-mpnet-base-v2",
multi_process: bool = False,
) -> HuggingFaceEmbeddings:
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
return HuggingFaceEmbeddings(
cache_folder=cache_folder,
encode_kwargs=encode_kwargs,
model_kwargs=model_kwargs,
model_name=model_name,
multi_process=multi_process,
cache_folder=self.cache_folder,
encode_kwargs=self.encode_kwargs,
model_kwargs=self.model_kwargs,
model_name=self.model_name,
multi_process=self.multi_process,
)

View file

@ -3,42 +3,43 @@ from typing import Dict, Optional
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
from pydantic.v1.types import SecretStr
from langflow.custom import CustomComponent
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import BoolInput, DictInput, FloatInput, Output, SecretStrInput, TextInput
class HuggingFaceInferenceAPIEmbeddingsComponent(CustomComponent):
class HuggingFaceInferenceAPIEmbeddingsComponent(LCModelComponent):
display_name = "Hugging Face API Embeddings"
description = "Generate embeddings using Hugging Face Inference API models."
documentation = "https://github.com/huggingface/text-embeddings-inference"
icon = "HuggingFace"
def build_config(self):
return {
"api_key": {"display_name": "API Key", "password": True, "advanced": True},
"api_url": {"display_name": "API URL", "advanced": True},
"model_name": {"display_name": "Model Name"},
"cache_folder": {"display_name": "Cache Folder", "advanced": True},
"encode_kwargs": {"display_name": "Encode Kwargs", "advanced": True, "field_type": "dict"},
"model_kwargs": {"display_name": "Model Kwargs", "field_type": "dict", "advanced": True},
"multi_process": {"display_name": "Multi Process", "advanced": True},
}
inputs = [
SecretStrInput(name="api_key", display_name="API Key", advanced=True),
TextInput(name="api_url", display_name="API URL", advanced=True, value="http://localhost:8080"),
TextInput(name="model_name", display_name="Model Name", value="BAAI/bge-large-en-v1.5"),
TextInput(name="cache_folder", display_name="Cache Folder", advanced=True),
DictInput(name="encode_kwargs", display_name="Encode Kwargs", advanced=True),
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
BoolInput(name="multi_process", display_name="Multi Process", advanced=True),
]
def build(
self,
api_key: Optional[str] = "",
api_url: str = "http://localhost:8080",
model_name: str = "BAAI/bge-large-en-v1.5",
cache_folder: Optional[str] = None,
encode_kwargs: Optional[Dict] = {},
model_kwargs: Optional[Dict] = {},
multi_process: bool = False,
) -> HuggingFaceInferenceAPIEmbeddings:
if api_key:
secret_api_key = SecretStr(api_key)
else:
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
if not self.api_key:
raise ValueError("API Key is required")
api_key = SecretStr(self.api_key)
return HuggingFaceInferenceAPIEmbeddings(
api_key=secret_api_key,
api_url=api_url,
model_name=model_name,
api_key=api_key,
api_url=self.api_url,
model_name=self.model_name,
cache_folder=self.cache_folder,
encode_kwargs=self.encode_kwargs,
model_kwargs=self.model_kwargs,
multi_process=self.multi_process,
)