refactor: Cohere and Mistral Embeddings, new Inputs/Outputs format
This commit is contained in:
parent
d287711f1b
commit
5d21466525
2 changed files with 77 additions and 79 deletions
|
|
@ -1,38 +1,44 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_community.embeddings.cohere import CohereEmbeddings
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, Output, SecretStrInput, TextInput
|
||||
|
||||
|
||||
class CohereEmbeddingsComponent(CustomComponent):
|
||||
class CohereEmbeddingsComponent(LCModelComponent):
|
||||
display_name = "Cohere Embeddings"
|
||||
description = "Generate embeddings using Cohere models."
|
||||
icon = "Cohere"
|
||||
inputs = [
|
||||
SecretStrInput(name="cohere_api_key", display_name="Cohere API Key"),
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
advanced=True,
|
||||
options=[
|
||||
"embed-english-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-light-v2.0",
|
||||
],
|
||||
value="embed-english-v2.0",
|
||||
),
|
||||
TextInput(name="truncate", display_name="Truncate", advanced=True),
|
||||
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
|
||||
TextInput(name="user_agent", display_name="User Agent", advanced=True, value="langchain"),
|
||||
FloatInput(name="request_timeout", display_name="Request Timeout", advanced=True),
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"cohere_api_key": {"display_name": "Cohere API Key", "password": True},
|
||||
"model": {"display_name": "Model", "default": "embed-english-v2.0", "advanced": True},
|
||||
"truncate": {"display_name": "Truncate", "advanced": True},
|
||||
"max_retries": {"display_name": "Max Retries", "advanced": True},
|
||||
"user_agent": {"display_name": "User Agent", "advanced": True},
|
||||
"request_timeout": {"display_name": "Request Timeout", "advanced": True},
|
||||
}
|
||||
outputs = [
|
||||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
request_timeout: Optional[float] = None,
|
||||
cohere_api_key: str = "",
|
||||
max_retries: int = 3,
|
||||
model: str = "embed-english-v2.0",
|
||||
truncate: Optional[str] = None,
|
||||
user_agent: str = "langchain",
|
||||
) -> CohereEmbeddings:
|
||||
return CohereEmbeddings( # type: ignore
|
||||
max_retries=max_retries,
|
||||
user_agent=user_agent,
|
||||
request_timeout=request_timeout,
|
||||
cohere_api_key=cohere_api_key,
|
||||
model=model,
|
||||
truncate=truncate,
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
return CohereEmbeddings(
|
||||
cohere_api_key=self.cohere_api_key,
|
||||
model=self.model,
|
||||
truncate=self.truncate,
|
||||
max_retries=self.max_retries,
|
||||
user_agent=self.user_agent,
|
||||
request_timeout=self.request_timeout or None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,64 +1,56 @@
|
|||
from langchain_mistralai.embeddings import MistralAIEmbeddings
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.io import DropdownInput, IntInput, Output, SecretStrInput, TextInput
|
||||
|
||||
|
||||
class MistralAIEmbeddingsComponent(CustomComponent):
|
||||
class MistralAIEmbeddingsComponent(LCModelComponent):
|
||||
display_name = "MistralAI Embeddings"
|
||||
description = "Generate embeddings using MistralAI models."
|
||||
icon = "MistralAI"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
"display_name": "Model",
|
||||
"advanced": False,
|
||||
"options": ["mistral-embed"],
|
||||
"value": "mistral-embed",
|
||||
},
|
||||
"mistral_api_key": {
|
||||
"display_name": "Mistral API Key",
|
||||
"password": True,
|
||||
"advanced": False,
|
||||
},
|
||||
"max_concurrent_requests": {
|
||||
"display_name": "Max Concurrent Requests",
|
||||
"advanced": True,
|
||||
"value": 64,
|
||||
},
|
||||
"max_retries": {
|
||||
"display_name": "Max Retries",
|
||||
"advanced": True,
|
||||
"value": 5,
|
||||
},
|
||||
"timeout": {
|
||||
"display_name": "Request Timeout",
|
||||
"advanced": True,
|
||||
"value": 120,
|
||||
},
|
||||
"endpoint": {"display_name": "API Endpoint", "advanced": True, "value": "https://api.mistral.ai/v1/"},
|
||||
}
|
||||
inputs = [
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model",
|
||||
advanced=False,
|
||||
options=["mistral-embed"],
|
||||
value="mistral-embed",
|
||||
),
|
||||
SecretStrInput(name="mistral_api_key", display_name="Mistral API Key"),
|
||||
IntInput(
|
||||
name="max_concurrent_requests",
|
||||
display_name="Max Concurrent Requests",
|
||||
advanced=True,
|
||||
value=64,
|
||||
),
|
||||
IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=5),
|
||||
IntInput(name="timeout", display_name="Request Timeout", advanced=True, value=120),
|
||||
TextInput(
|
||||
name="endpoint",
|
||||
display_name="API Endpoint",
|
||||
advanced=True,
|
||||
value="https://api.mistral.ai/v1/",
|
||||
),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
mistral_api_key: str,
|
||||
model: str = "mistral-embed",
|
||||
max_concurrent_requests: int = 64,
|
||||
max_retries: int = 5,
|
||||
timeout: int = 120,
|
||||
endpoint: str = "https://api.mistral.ai/v1/",
|
||||
) -> Embeddings:
|
||||
if mistral_api_key:
|
||||
api_key = SecretStr(mistral_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
outputs = [
|
||||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
|
||||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
if not self.mistral_api_key:
|
||||
raise ValueError("Mistral API Key is required")
|
||||
|
||||
api_key = SecretStr(self.mistral_api_key)
|
||||
|
||||
return MistralAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
endpoint=endpoint,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
model=self.model,
|
||||
endpoint=self.endpoint,
|
||||
max_concurrent_requests=self.max_concurrent_requests,
|
||||
max_retries=self.max_retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue