diff --git a/src/backend/base/langflow/components/embeddings/CohereEmbeddings.py b/src/backend/base/langflow/components/embeddings/CohereEmbeddings.py index 23d855f40..f5cb7a3b4 100644 --- a/src/backend/base/langflow/components/embeddings/CohereEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/CohereEmbeddings.py @@ -1,38 +1,44 @@ -from typing import Optional - from langchain_community.embeddings.cohere import CohereEmbeddings -from langflow.custom import CustomComponent +from langflow.base.models.model import LCModelComponent +from langflow.field_typing import Embeddings +from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, Output, SecretStrInput, TextInput -class CohereEmbeddingsComponent(CustomComponent): +class CohereEmbeddingsComponent(LCModelComponent): display_name = "Cohere Embeddings" description = "Generate embeddings using Cohere models." + icon = "Cohere" + inputs = [ + SecretStrInput(name="cohere_api_key", display_name="Cohere API Key"), + DropdownInput( + name="model", + display_name="Model", + advanced=True, + options=[ + "embed-english-v2.0", + "embed-multilingual-v2.0", + "embed-english-light-v2.0", + "embed-multilingual-light-v2.0", + ], + value="embed-english-v2.0", + ), + TextInput(name="truncate", display_name="Truncate", advanced=True), + IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True), + TextInput(name="user_agent", display_name="User Agent", advanced=True, value="langchain"), + FloatInput(name="request_timeout", display_name="Request Timeout", advanced=True), + ] - def build_config(self): - return { - "cohere_api_key": {"display_name": "Cohere API Key", "password": True}, - "model": {"display_name": "Model", "default": "embed-english-v2.0", "advanced": True}, - "truncate": {"display_name": "Truncate", "advanced": True}, - "max_retries": {"display_name": "Max Retries", "advanced": True}, - "user_agent": {"display_name": "User Agent", "advanced": True}, - "request_timeout": {"display_name": "Request Timeout", "advanced": True}, - } + outputs = [ + Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), + ] - def build( - self, - request_timeout: Optional[float] = None, - cohere_api_key: str = "", - max_retries: int = 3, - model: str = "embed-english-v2.0", - truncate: Optional[str] = None, - user_agent: str = "langchain", - ) -> CohereEmbeddings: - return CohereEmbeddings( # type: ignore - max_retries=max_retries, - user_agent=user_agent, - request_timeout=request_timeout, - cohere_api_key=cohere_api_key, - model=model, - truncate=truncate, + def build_embeddings(self) -> Embeddings: + return CohereEmbeddings( + cohere_api_key=self.cohere_api_key, + model=self.model, + truncate=self.truncate, + max_retries=self.max_retries, + user_agent=self.user_agent, + request_timeout=self.request_timeout or None, ) diff --git a/src/backend/base/langflow/components/embeddings/MistalAIEmbeddings.py b/src/backend/base/langflow/components/embeddings/MistalAIEmbeddings.py index d24c9fb30..145aed76d 100644 --- a/src/backend/base/langflow/components/embeddings/MistalAIEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/MistalAIEmbeddings.py @@ -1,64 +1,56 @@ from langchain_mistralai.embeddings import MistralAIEmbeddings from pydantic.v1 import SecretStr -from langflow.custom import CustomComponent +from langflow.base.models.model import LCModelComponent from langflow.field_typing import Embeddings +from langflow.io import DropdownInput, IntInput, Output, SecretStrInput, TextInput -class MistralAIEmbeddingsComponent(CustomComponent): +class MistralAIEmbeddingsComponent(LCModelComponent): display_name = "MistralAI Embeddings" description = "Generate embeddings using MistralAI models." + icon = "MistralAI" - def build_config(self): - return { - "model": { - "display_name": "Model", - "advanced": False, - "options": ["mistral-embed"], - "value": "mistral-embed", - }, - "mistral_api_key": { - "display_name": "Mistral API Key", - "password": True, - "advanced": False, - }, - "max_concurrent_requests": { - "display_name": "Max Concurrent Requests", - "advanced": True, - "value": 64, - }, - "max_retries": { - "display_name": "Max Retries", - "advanced": True, - "value": 5, - }, - "timeout": { - "display_name": "Request Timeout", - "advanced": True, - "value": 120, - }, - "endpoint": {"display_name": "API Endpoint", "advanced": True, "value": "https://api.mistral.ai/v1/"}, - } + inputs = [ + DropdownInput( + name="model", + display_name="Model", + advanced=False, + options=["mistral-embed"], + value="mistral-embed", + ), + SecretStrInput(name="mistral_api_key", display_name="Mistral API Key"), + IntInput( + name="max_concurrent_requests", + display_name="Max Concurrent Requests", + advanced=True, + value=64, + ), + IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=5), + IntInput(name="timeout", display_name="Request Timeout", advanced=True, value=120), + TextInput( + name="endpoint", + display_name="API Endpoint", + advanced=True, + value="https://api.mistral.ai/v1/", + ), + ] - def build( - self, - mistral_api_key: str, - model: str = "mistral-embed", - max_concurrent_requests: int = 64, - max_retries: int = 5, - timeout: int = 120, - endpoint: str = "https://api.mistral.ai/v1/", - ) -> Embeddings: - if mistral_api_key: - api_key = SecretStr(mistral_api_key) - else: - api_key = None + outputs = [ + Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), + ] + + def build_embeddings(self) -> Embeddings: + if not self.mistral_api_key: + raise ValueError("Mistral API Key is required") + + api_key = SecretStr(self.mistral_api_key) return MistralAIEmbeddings( api_key=api_key, - model=model, - endpoint=endpoint, - max_concurrent_requests=max_concurrent_requests, - max_retries=max_retries, - timeout=timeout, + model=self.model, + endpoint=self.endpoint, + max_concurrent_requests=self.max_concurrent_requests, + max_retries=self.max_retries, + timeout=self.timeout, )