refactor: Cohere and Mistral Embeddings, new Inputs/Outputs format

This commit is contained in:
Cezar Vasconcelos 2024-06-19 20:59:12 +00:00
commit 5d21466525
2 changed files with 77 additions and 79 deletions

View file

@ -1,38 +1,44 @@
from typing import Optional
from langchain_community.embeddings.cohere import CohereEmbeddings
from langflow.custom import CustomComponent
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, Output, SecretStrInput, TextInput
class CohereEmbeddingsComponent(CustomComponent):
class CohereEmbeddingsComponent(LCModelComponent):
display_name = "Cohere Embeddings"
description = "Generate embeddings using Cohere models."
icon = "Cohere"
inputs = [
SecretStrInput(name="cohere_api_key", display_name="Cohere API Key"),
DropdownInput(
name="model",
display_name="Model",
advanced=True,
options=[
"embed-english-v2.0",
"embed-multilingual-v2.0",
"embed-english-light-v2.0",
"embed-multilingual-light-v2.0",
],
value="embed-english-v2.0",
),
TextInput(name="truncate", display_name="Truncate", advanced=True),
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
TextInput(name="user_agent", display_name="User Agent", advanced=True, value="langchain"),
FloatInput(name="request_timeout", display_name="Request Timeout", advanced=True),
]
def build_config(self):
return {
"cohere_api_key": {"display_name": "Cohere API Key", "password": True},
"model": {"display_name": "Model", "default": "embed-english-v2.0", "advanced": True},
"truncate": {"display_name": "Truncate", "advanced": True},
"max_retries": {"display_name": "Max Retries", "advanced": True},
"user_agent": {"display_name": "User Agent", "advanced": True},
"request_timeout": {"display_name": "Request Timeout", "advanced": True},
}
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build(
self,
request_timeout: Optional[float] = None,
cohere_api_key: str = "",
max_retries: int = 3,
model: str = "embed-english-v2.0",
truncate: Optional[str] = None,
user_agent: str = "langchain",
) -> CohereEmbeddings:
return CohereEmbeddings( # type: ignore
max_retries=max_retries,
user_agent=user_agent,
request_timeout=request_timeout,
cohere_api_key=cohere_api_key,
model=model,
truncate=truncate,
def build_embeddings(self) -> Embeddings:
return CohereEmbeddings(
cohere_api_key=self.cohere_api_key,
model=self.model,
truncate=self.truncate,
max_retries=self.max_retries,
user_agent=self.user_agent,
request_timeout=self.request_timeout or None,
)

View file

@ -1,64 +1,56 @@
from langchain_mistralai.embeddings import MistralAIEmbeddings
from pydantic.v1 import SecretStr
from langflow.custom import CustomComponent
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import DropdownInput, IntInput, Output, SecretStrInput, TextInput
class MistralAIEmbeddingsComponent(CustomComponent):
class MistralAIEmbeddingsComponent(LCModelComponent):
display_name = "MistralAI Embeddings"
description = "Generate embeddings using MistralAI models."
icon = "MistralAI"
def build_config(self):
return {
"model": {
"display_name": "Model",
"advanced": False,
"options": ["mistral-embed"],
"value": "mistral-embed",
},
"mistral_api_key": {
"display_name": "Mistral API Key",
"password": True,
"advanced": False,
},
"max_concurrent_requests": {
"display_name": "Max Concurrent Requests",
"advanced": True,
"value": 64,
},
"max_retries": {
"display_name": "Max Retries",
"advanced": True,
"value": 5,
},
"timeout": {
"display_name": "Request Timeout",
"advanced": True,
"value": 120,
},
"endpoint": {"display_name": "API Endpoint", "advanced": True, "value": "https://api.mistral.ai/v1/"},
}
inputs = [
DropdownInput(
name="model",
display_name="Model",
advanced=False,
options=["mistral-embed"],
value="mistral-embed",
),
SecretStrInput(name="mistral_api_key", display_name="Mistral API Key"),
IntInput(
name="max_concurrent_requests",
display_name="Max Concurrent Requests",
advanced=True,
value=64,
),
IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=5),
IntInput(name="timeout", display_name="Request Timeout", advanced=True, value=120),
TextInput(
name="endpoint",
display_name="API Endpoint",
advanced=True,
value="https://api.mistral.ai/v1/",
),
]
def build(
self,
mistral_api_key: str,
model: str = "mistral-embed",
max_concurrent_requests: int = 64,
max_retries: int = 5,
timeout: int = 120,
endpoint: str = "https://api.mistral.ai/v1/",
) -> Embeddings:
if mistral_api_key:
api_key = SecretStr(mistral_api_key)
else:
api_key = None
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
if not self.mistral_api_key:
raise ValueError("Mistral API Key is required")
api_key = SecretStr(self.mistral_api_key)
return MistralAIEmbeddings(
api_key=api_key,
model=model,
endpoint=endpoint,
max_concurrent_requests=max_concurrent_requests,
max_retries=max_retries,
timeout=timeout,
model=self.model,
endpoint=self.endpoint,
max_concurrent_requests=self.max_concurrent_requests,
max_retries=self.max_retries,
timeout=self.timeout,
)