feat: add EmbeddingModelComponent for generating embeddings (#7204)
* feat: add EmbeddingModelComponent for generating embeddings - Introduced a new EmbeddingModelComponent to generate embeddings using specified providers, starting with OpenAI. - Updated the __init__.py file to include the new component in the exports. - The component includes various input fields for configuration, such as provider selection, model name, API key, and additional parameters. * update the tests. --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
parent
0ba8f6559c
commit
eff2a30489
4 changed files with 203 additions and 0 deletions
|
|
@ -3,6 +3,7 @@ from .astra_vectorize import AstraVectorizeComponent
|
|||
from .azure_openai import AzureOpenAIEmbeddingsComponent
|
||||
from .cloudflare import CloudflareWorkersAIEmbeddingsComponent
|
||||
from .cohere import CohereEmbeddingsComponent
|
||||
from .embedding_model import EmbeddingModelComponent
|
||||
from .google_generative_ai import GoogleGenerativeAIEmbeddingsComponent
|
||||
from .huggingface_inference_api import HuggingFaceInferenceAPIEmbeddingsComponent
|
||||
from .lmstudioembeddings import LMStudioEmbeddingsComponent
|
||||
|
|
@ -21,6 +22,7 @@ __all__ = [
|
|||
"AzureOpenAIEmbeddingsComponent",
|
||||
"CloudflareWorkersAIEmbeddingsComponent",
|
||||
"CohereEmbeddingsComponent",
|
||||
"EmbeddingModelComponent",
|
||||
"EmbeddingSimilarityComponent",
|
||||
"GoogleGenerativeAIEmbeddingsComponent",
|
||||
"HuggingFaceInferenceAPIEmbeddingsComponent",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
from typing import Any
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
from langflow.base.embeddings.model import LCEmbeddingsModel
|
||||
from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.io import (
|
||||
BoolInput,
|
||||
DictInput,
|
||||
DropdownInput,
|
||||
FloatInput,
|
||||
IntInput,
|
||||
MessageTextInput,
|
||||
SecretStrInput,
|
||||
)
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
class EmbeddingModelComponent(LCEmbeddingsModel):
|
||||
display_name = "Embedding Model"
|
||||
description = "Generate embeddings using a specified provider."
|
||||
icon = "binary"
|
||||
name = "EmbeddingModel"
|
||||
category = "embeddings"
|
||||
|
||||
inputs = [
|
||||
DropdownInput(
|
||||
name="provider",
|
||||
display_name="Model Provider",
|
||||
options=["OpenAI"],
|
||||
value="OpenAI",
|
||||
info="Select the embedding model provider",
|
||||
real_time_refresh=True,
|
||||
options_metadata=[{"icon": "OpenAI"}],
|
||||
),
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model Name",
|
||||
options=OPENAI_EMBEDDING_MODEL_NAMES,
|
||||
value=OPENAI_EMBEDDING_MODEL_NAMES[0],
|
||||
info="Select the embedding model to use",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
display_name="OpenAI API Key",
|
||||
info="Model Provider API key",
|
||||
required=True,
|
||||
show=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="api_base",
|
||||
display_name="API Base URL",
|
||||
info="Base URL for the API. Leave empty for default.",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
name="dimensions",
|
||||
display_name="Dimensions",
|
||||
info="The number of dimensions the resulting output embeddings should have. Only supported by certain models.",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(name="chunk_size", display_name="Chunk Size", advanced=True, value=1000),
|
||||
FloatInput(name="request_timeout", display_name="Request Timeout", advanced=True),
|
||||
IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=3),
|
||||
BoolInput(name="show_progress_bar", display_name="Show Progress Bar", advanced=True),
|
||||
DictInput(
|
||||
name="model_kwargs",
|
||||
display_name="Model Kwargs",
|
||||
advanced=True,
|
||||
info="Additional keyword arguments to pass to the model.",
|
||||
),
|
||||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
provider = self.provider
|
||||
model = self.model
|
||||
api_key = self.api_key
|
||||
api_base = self.api_base
|
||||
dimensions = self.dimensions
|
||||
chunk_size = self.chunk_size
|
||||
request_timeout = self.request_timeout
|
||||
max_retries = self.max_retries
|
||||
show_progress_bar = self.show_progress_bar
|
||||
model_kwargs = self.model_kwargs or {}
|
||||
|
||||
if provider == "OpenAI":
|
||||
if not api_key:
|
||||
msg = "OpenAI API key is required when using OpenAI provider"
|
||||
raise ValueError(msg)
|
||||
return OpenAIEmbeddings(
|
||||
model=model,
|
||||
dimensions=dimensions or None,
|
||||
base_url=api_base or None,
|
||||
api_key=api_key,
|
||||
chunk_size=chunk_size,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout or None,
|
||||
show_progress_bar=show_progress_bar,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
msg = f"Unknown provider: {provider}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
|
||||
if field_name == "provider" and field_value == "OpenAI":
|
||||
build_config["model"]["options"] = OPENAI_EMBEDDING_MODEL_NAMES
|
||||
build_config["model"]["value"] = OPENAI_EMBEDDING_MODEL_NAMES[0]
|
||||
build_config["api_key"]["display_name"] = "OpenAI API Key"
|
||||
build_config["api_base"]["display_name"] = "OpenAI API Base URL"
|
||||
return build_config
|
||||
0
src/backend/tests/unit/components/embeddings/__init__.py
Normal file
0
src/backend/tests/unit/components/embeddings/__init__.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES
|
||||
from langflow.components.embeddings.embedding_model import EmbeddingModelComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithClient
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
class TestEmbeddingModelComponent(ComponentTestBaseWithClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return EmbeddingModelComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
return {
|
||||
"provider": "OpenAI",
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"chunk_size": 1000,
|
||||
"max_retries": 3,
|
||||
"show_progress_bar": False,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
"""Return the file names mapping for version-specific files."""
|
||||
|
||||
async def test_update_build_config_openai(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {
|
||||
"model": {"options": [], "value": ""},
|
||||
"api_key": {"display_name": "API Key"},
|
||||
"api_base": {"display_name": "API Base URL"},
|
||||
}
|
||||
updated_config = component.update_build_config(build_config, "OpenAI", "provider")
|
||||
assert updated_config["model"]["options"] == OPENAI_EMBEDDING_MODEL_NAMES
|
||||
assert updated_config["model"]["value"] == OPENAI_EMBEDDING_MODEL_NAMES[0]
|
||||
assert updated_config["api_key"]["display_name"] == "OpenAI API Key"
|
||||
assert updated_config["api_base"]["display_name"] == "OpenAI API Base URL"
|
||||
|
||||
@patch("langflow.components.embeddings.embedding_model.OpenAIEmbeddings")
|
||||
async def test_build_embeddings_openai(self, mock_openai_embeddings, component_class, default_kwargs):
|
||||
# Setup mock
|
||||
mock_instance = MagicMock()
|
||||
mock_openai_embeddings.return_value = mock_instance
|
||||
|
||||
# Create and configure the component
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "OpenAI"
|
||||
component.model = "text-embedding-3-small"
|
||||
component.api_key = "test-key"
|
||||
component.chunk_size = 1000
|
||||
component.max_retries = 3
|
||||
component.show_progress_bar = False
|
||||
|
||||
# Build the embeddings
|
||||
embeddings = component.build_embeddings()
|
||||
|
||||
# Verify the OpenAIEmbeddings was called with the correct parameters
|
||||
mock_openai_embeddings.assert_called_once_with(
|
||||
model="text-embedding-3-small",
|
||||
dimensions=None,
|
||||
base_url=None,
|
||||
api_key="test-key",
|
||||
chunk_size=1000,
|
||||
max_retries=3,
|
||||
timeout=None,
|
||||
show_progress_bar=False,
|
||||
model_kwargs={},
|
||||
)
|
||||
assert embeddings == mock_instance
|
||||
|
||||
async def test_build_embeddings_openai_missing_api_key(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "OpenAI"
|
||||
component.api_key = None
|
||||
|
||||
with pytest.raises(ValueError, match="OpenAI API key is required when using OpenAI provider"):
|
||||
component.build_embeddings()
|
||||
|
||||
async def test_build_embeddings_unknown_provider(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Unknown"
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider: Unknown"):
|
||||
component.build_embeddings()
|
||||
Loading…
Add table
Add a link
Reference in a new issue