From a16df4b2bf06e7eb224f829f90042f07616f8430 Mon Sep 17 00:00:00 2001 From: Thiago Araujo Date: Sat, 29 Mar 2025 05:57:33 -0300 Subject: [PATCH] feat: Add watsonx embedding component (#7292) * Add draft watsonx component * feat: improved logic for embedding * Add small changes to the ibm embedding component * Fix icon and logger mode --------- Co-authored-by: Thiago Araujo Co-authored-by: Giovanni-Galatro Co-authored-by: Jordan Frazier <122494242+jordanrfrazier@users.noreply.github.com> Co-authored-by: galatro --- .../components/embeddings/__init__.py | 2 + .../langflow/components/embeddings/watsonx.py | 136 ++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 src/backend/base/langflow/components/embeddings/watsonx.py diff --git a/src/backend/base/langflow/components/embeddings/__init__.py b/src/backend/base/langflow/components/embeddings/__init__.py index e0ab27f16..fd8213e2d 100644 --- a/src/backend/base/langflow/components/embeddings/__init__.py +++ b/src/backend/base/langflow/components/embeddings/__init__.py @@ -14,6 +14,7 @@ from .openai import OpenAIEmbeddingsComponent from .similarity import EmbeddingSimilarityComponent from .text_embedder import TextEmbedderComponent from .vertexai import VertexAIEmbeddingsComponent +from .watsonx import WatsonxEmbeddingsComponent __all__ = [ "AIMLEmbeddingsComponent", @@ -32,4 +33,5 @@ __all__ = [ "OpenAIEmbeddingsComponent", "TextEmbedderComponent", "VertexAIEmbeddingsComponent", + "WatsonxEmbeddingsComponent", ] diff --git a/src/backend/base/langflow/components/embeddings/watsonx.py b/src/backend/base/langflow/components/embeddings/watsonx.py new file mode 100644 index 000000000..9685fef78 --- /dev/null +++ b/src/backend/base/langflow/components/embeddings/watsonx.py @@ -0,0 +1,136 @@ +import logging +from typing import Any + +import requests +from ibm_watsonx_ai import APIClient, Credentials +from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames +from langchain_ibm import WatsonxEmbeddings +from pydantic.v1 import SecretStr + +from langflow.base.embeddings.model import LCEmbeddingsModel +from langflow.field_typing import Embeddings +from langflow.io import BoolInput, DropdownInput, IntInput, SecretStrInput, StrInput +from langflow.schema.dotdict import dotdict + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class WatsonxEmbeddingsComponent(LCEmbeddingsModel): + display_name = "IBM watsonx.ai Embeddings" + description = "Generate embeddings using IBM watsonx.ai models." + icon = "WatsonxAI" + name = "WatsonxEmbeddingsComponent" + + # models present in all the regions + _default_models = [ + "sentence-transformers/all-minilm-l12-v2", + "ibm/slate-125m-english-rtrvr-v2", + "ibm/slate-30m-english-rtrvr-v2", + "intfloat/multilingual-e5-large", + ] + + inputs = [ + DropdownInput( + name="url", + display_name="watsonx API Endpoint", + info="The base URL of the API.", + value=None, + options=[ + "https://us-south.ml.cloud.ibm.com", + "https://eu-de.ml.cloud.ibm.com", + "https://eu-gb.ml.cloud.ibm.com", + "https://au-syd.ml.cloud.ibm.com", + "https://jp-tok.ml.cloud.ibm.com", + "https://ca-tor.ml.cloud.ibm.com", + ], + real_time_refresh=True, + ), + StrInput( + name="project_id", + display_name="watsonx project id", + ), + SecretStrInput( + name="api_key", + display_name="API Key", + info="The API Key to use for the model.", + required=True, + ), + DropdownInput( + name="model_name", + display_name="Model Name", + options=[], + value=None, + dynamic=True, + required=True, + ), + IntInput( + name="truncate_input_tokens", + display_name="Truncate Input Tokens", + advanced=True, + value=200, + ), + BoolInput( + name="input_text", + display_name="Include the original text in the output", + value=True, + advanced=True, + ), + ] + + @staticmethod + def fetch_models(base_url: str) -> list[str]: + """Fetch available models from the watsonx.ai API.""" + try: + endpoint = f"{base_url}/ml/v1/foundation_model_specs" + params = { + "version": "2024-09-16", + "filters": "function_embedding,!lifecycle_withdrawn:and", + } + response = requests.get(endpoint, params=params, timeout=10) + response.raise_for_status() + data = response.json() + models = [model["model_id"] for model in data.get("resources", [])] + return sorted(models) + except Exception: + logger.exception("Error fetching models") + return WatsonxEmbeddingsComponent._default_models + + def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): + """Update model options when URL or API key changes.""" + logger.debug( + "Updating build config. Field name: %s, Field value: %s", + field_name, + field_value, + ) + + if field_name == "url" and field_value: + try: + models = self.fetch_models(base_url=build_config.url.value) + build_config.model_name.options = models + if build_config.model_name.value: + build_config.model_name.value = models[0] + info_message = f"Updated model options: {len(models)} models found in {build_config.url.value}" + logger.info(info_message) + except Exception: + logger.exception("Error updating model options.") + + def build_embeddings(self) -> Embeddings: + credentials = Credentials( + api_key=SecretStr(self.api_key).get_secret_value(), + url=self.url, + ) + + api_client = APIClient(credentials) + + params = { + EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: self.truncate_input_tokens, + EmbedTextParamsMetaNames.RETURN_OPTIONS: {"input_text": self.input_text}, + } + + return WatsonxEmbeddings( + model_id=self.model_name, + params=params, + watsonx_client=api_client, + project_id=self.project_id, + )