feat: update Cohere embedding model to use langchain_cohere, added support to dynamically load latest embedding models, improved error handling (#6034)
* update cohere model * Update src/backend/base/langflow/components/embeddings/cohere.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
5f63ca0d24
commit
e89edc3c1c
2 changed files with 47 additions and 12 deletions
|
|
@ -109,7 +109,7 @@ dependencies = [
|
|||
"uv>=0.5.7",
|
||||
"ag2>=0.1.0",
|
||||
"scrapegraph-py>=1.10.2",
|
||||
"pydantic-ai>=0.0.19"
|
||||
"pydantic-ai>=0.0.19",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
from langchain_community.embeddings.cohere import CohereEmbeddings
|
||||
from typing import Any
|
||||
|
||||
import cohere
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.io import DropdownInput, FloatInput, IntInput, MessageTextInput, Output, SecretStrInput
|
||||
|
||||
HTTP_STATUS_OK = 200
|
||||
|
||||
|
||||
class CohereEmbeddingsComponent(LCModelComponent):
|
||||
display_name = "Cohere Embeddings"
|
||||
|
|
@ -12,9 +17,9 @@ class CohereEmbeddingsComponent(LCModelComponent):
|
|||
name = "CohereEmbeddings"
|
||||
|
||||
inputs = [
|
||||
SecretStrInput(name="cohere_api_key", display_name="Cohere API Key", required=True),
|
||||
SecretStrInput(name="api_key", display_name="Cohere API Key", required=True, real_time_refresh=True),
|
||||
DropdownInput(
|
||||
name="model",
|
||||
name="model_name",
|
||||
display_name="Model",
|
||||
advanced=False,
|
||||
options=[
|
||||
|
|
@ -24,6 +29,8 @@ class CohereEmbeddingsComponent(LCModelComponent):
|
|||
"embed-multilingual-light-v2.0",
|
||||
],
|
||||
value="embed-english-v2.0",
|
||||
refresh_button=True,
|
||||
combobox=True,
|
||||
),
|
||||
MessageTextInput(name="truncate", display_name="Truncate", advanced=True),
|
||||
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
|
||||
|
|
@ -36,11 +43,39 @@ class CohereEmbeddingsComponent(LCModelComponent):
|
|||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
return CohereEmbeddings(
|
||||
cohere_api_key=self.cohere_api_key,
|
||||
model=self.model,
|
||||
truncate=self.truncate,
|
||||
max_retries=self.max_retries,
|
||||
user_agent=self.user_agent,
|
||||
request_timeout=self.request_timeout or None,
|
||||
)
|
||||
data = None
|
||||
try:
|
||||
data = CohereEmbeddings(
|
||||
cohere_api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
truncate=self.truncate,
|
||||
max_retries=self.max_retries,
|
||||
user_agent=self.user_agent,
|
||||
request_timeout=self.request_timeout or None,
|
||||
)
|
||||
except Exception as e:
|
||||
msg = (
|
||||
"Unable to create Cohere Embeddings. ",
|
||||
"Please verify the API key and model parameters, and try again.",
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
# added status if not the return data would be serialised to create the status
|
||||
return data
|
||||
|
||||
def get_model(self):
|
||||
try:
|
||||
co = cohere.ClientV2(self.api_key)
|
||||
response = co.models.list(endpoint="embed")
|
||||
models = response.models
|
||||
return [model.name for model in models]
|
||||
except Exception as e:
|
||||
msg = f"Failed to fetch Cohere models. Error: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name in {"model_name", "api_key"}:
|
||||
if build_config.get("api_key", {}).get("value", None):
|
||||
build_config["model_name"]["options"] = self.get_model()
|
||||
else:
|
||||
build_config["model_name"]["options"] = field_value
|
||||
return build_config
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue