diff --git a/src/backend/base/langflow/components/embeddings/__init__.py b/src/backend/base/langflow/components/embeddings/__init__.py index f2cb48295..0daa25324 100644 --- a/src/backend/base/langflow/components/embeddings/__init__.py +++ b/src/backend/base/langflow/components/embeddings/__init__.py @@ -3,6 +3,7 @@ from .astra_vectorize import AstraVectorizeComponent from .azure_openai import AzureOpenAIEmbeddingsComponent from .cloudflare import CloudflareWorkersAIEmbeddingsComponent from .cohere import CohereEmbeddingsComponent +from .embedding_model import EmbeddingModelComponent from .google_generative_ai import GoogleGenerativeAIEmbeddingsComponent from .huggingface_inference_api import HuggingFaceInferenceAPIEmbeddingsComponent from .lmstudioembeddings import LMStudioEmbeddingsComponent @@ -21,6 +22,7 @@ __all__ = [ "AzureOpenAIEmbeddingsComponent", "CloudflareWorkersAIEmbeddingsComponent", "CohereEmbeddingsComponent", + "EmbeddingModelComponent", "EmbeddingSimilarityComponent", "GoogleGenerativeAIEmbeddingsComponent", "HuggingFaceInferenceAPIEmbeddingsComponent", diff --git a/src/backend/base/langflow/components/embeddings/embedding_model.py b/src/backend/base/langflow/components/embeddings/embedding_model.py new file mode 100644 index 000000000..bf668532f --- /dev/null +++ b/src/backend/base/langflow/components/embeddings/embedding_model.py @@ -0,0 +1,112 @@ +from typing import Any + +from langchain_openai import OpenAIEmbeddings + +from langflow.base.embeddings.model import LCEmbeddingsModel +from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES +from langflow.field_typing import Embeddings +from langflow.io import ( + BoolInput, + DictInput, + DropdownInput, + FloatInput, + IntInput, + MessageTextInput, + SecretStrInput, +) +from langflow.schema.dotdict import dotdict + + +class EmbeddingModelComponent(LCEmbeddingsModel): + display_name = "Embedding Model" + description = "Generate embeddings using a specified provider." + icon = "binary" + name = "EmbeddingModel" + category = "embeddings" + + inputs = [ + DropdownInput( + name="provider", + display_name="Model Provider", + options=["OpenAI"], + value="OpenAI", + info="Select the embedding model provider", + real_time_refresh=True, + options_metadata=[{"icon": "OpenAI"}], + ), + DropdownInput( + name="model", + display_name="Model Name", + options=OPENAI_EMBEDDING_MODEL_NAMES, + value=OPENAI_EMBEDDING_MODEL_NAMES[0], + info="Select the embedding model to use", + ), + SecretStrInput( + name="api_key", + display_name="OpenAI API Key", + info="Model Provider API key", + required=True, + show=True, + real_time_refresh=True, + ), + MessageTextInput( + name="api_base", + display_name="API Base URL", + info="Base URL for the API. Leave empty for default.", + advanced=True, + ), + IntInput( + name="dimensions", + display_name="Dimensions", + info="The number of dimensions the resulting output embeddings should have. Only supported by certain models.", + advanced=True, + ), + IntInput(name="chunk_size", display_name="Chunk Size", advanced=True, value=1000), + FloatInput(name="request_timeout", display_name="Request Timeout", advanced=True), + IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=3), + BoolInput(name="show_progress_bar", display_name="Show Progress Bar", advanced=True), + DictInput( + name="model_kwargs", + display_name="Model Kwargs", + advanced=True, + info="Additional keyword arguments to pass to the model.", + ), + ] + + def build_embeddings(self) -> Embeddings: + provider = self.provider + model = self.model + api_key = self.api_key + api_base = self.api_base + dimensions = self.dimensions + chunk_size = self.chunk_size + request_timeout = self.request_timeout + max_retries = self.max_retries + show_progress_bar = self.show_progress_bar + model_kwargs = self.model_kwargs or {} + + if provider == "OpenAI": + if not api_key: + msg = "OpenAI API key is required when using OpenAI provider" + raise ValueError(msg) + return OpenAIEmbeddings( + model=model, + dimensions=dimensions or None, + base_url=api_base or None, + api_key=api_key, + chunk_size=chunk_size, + max_retries=max_retries, + timeout=request_timeout or None, + show_progress_bar=show_progress_bar, + model_kwargs=model_kwargs, + ) + msg = f"Unknown provider: {provider}" + raise ValueError(msg) + + def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict: + if field_name == "provider" and field_value == "OpenAI": + build_config["model"]["options"] = OPENAI_EMBEDDING_MODEL_NAMES + build_config["model"]["value"] = OPENAI_EMBEDDING_MODEL_NAMES[0] + build_config["api_key"]["display_name"] = "OpenAI API Key" + build_config["api_base"]["display_name"] = "OpenAI API Base URL" + return build_config diff --git a/src/backend/tests/unit/components/embeddings/__init__.py b/src/backend/tests/unit/components/embeddings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/tests/unit/components/embeddings/test_embedding_model_component.py b/src/backend/tests/unit/components/embeddings/test_embedding_model_component.py new file mode 100644 index 000000000..308a00a2d --- /dev/null +++ b/src/backend/tests/unit/components/embeddings/test_embedding_model_component.py @@ -0,0 +1,89 @@ +from unittest.mock import MagicMock, patch + +import pytest +from langflow.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES +from langflow.components.embeddings.embedding_model import EmbeddingModelComponent + +from tests.base import ComponentTestBaseWithClient + + +@pytest.mark.usefixtures("client") +class TestEmbeddingModelComponent(ComponentTestBaseWithClient): + @pytest.fixture + def component_class(self): + return EmbeddingModelComponent + + @pytest.fixture + def default_kwargs(self): + return { + "provider": "OpenAI", + "model": "text-embedding-3-small", + "api_key": "test-api-key", + "chunk_size": 1000, + "max_retries": 3, + "show_progress_bar": False, + } + + @pytest.fixture + def file_names_mapping(self): + """Return the file names mapping for version-specific files.""" + + async def test_update_build_config_openai(self, component_class, default_kwargs): + component = component_class(**default_kwargs) + build_config = { + "model": {"options": [], "value": ""}, + "api_key": {"display_name": "API Key"}, + "api_base": {"display_name": "API Base URL"}, + } + updated_config = component.update_build_config(build_config, "OpenAI", "provider") + assert updated_config["model"]["options"] == OPENAI_EMBEDDING_MODEL_NAMES + assert updated_config["model"]["value"] == OPENAI_EMBEDDING_MODEL_NAMES[0] + assert updated_config["api_key"]["display_name"] == "OpenAI API Key" + assert updated_config["api_base"]["display_name"] == "OpenAI API Base URL" + + @patch("langflow.components.embeddings.embedding_model.OpenAIEmbeddings") + async def test_build_embeddings_openai(self, mock_openai_embeddings, component_class, default_kwargs): + # Setup mock + mock_instance = MagicMock() + mock_openai_embeddings.return_value = mock_instance + + # Create and configure the component + component = component_class(**default_kwargs) + component.provider = "OpenAI" + component.model = "text-embedding-3-small" + component.api_key = "test-key" + component.chunk_size = 1000 + component.max_retries = 3 + component.show_progress_bar = False + + # Build the embeddings + embeddings = component.build_embeddings() + + # Verify the OpenAIEmbeddings was called with the correct parameters + mock_openai_embeddings.assert_called_once_with( + model="text-embedding-3-small", + dimensions=None, + base_url=None, + api_key="test-key", + chunk_size=1000, + max_retries=3, + timeout=None, + show_progress_bar=False, + model_kwargs={}, + ) + assert embeddings == mock_instance + + async def test_build_embeddings_openai_missing_api_key(self, component_class, default_kwargs): + component = component_class(**default_kwargs) + component.provider = "OpenAI" + component.api_key = None + + with pytest.raises(ValueError, match="OpenAI API key is required when using OpenAI provider"): + component.build_embeddings() + + async def test_build_embeddings_unknown_provider(self, component_class, default_kwargs): + component = component_class(**default_kwargs) + component.provider = "Unknown" + + with pytest.raises(ValueError, match="Unknown provider: Unknown"): + component.build_embeddings()