fix: updating components to access secret key with new method get_secret_value() (#4243)

* updating components to access secret key with new method

* [autofix.ci] apply automated fixes

* 📝 (text_embedder.py): Add error messages as constants for better error handling and readability
📝 (text_embedder.py): Improve error handling and validation for embedding model and text content
📝 (text_embedder.py): Ensure proper protocol for the base URL in the embedding model client
📝 (text_embedder.py): Validate the output of embeddings and handle exceptions with logging
📝 (text_embedder.py): Refactor code to encapsulate results in a Data object and update status accordingly
📝 (similarity.spec.ts): Add a delay before clicking the button to run the text output test for better synchronization

* uv fix

* [autofix.ci] apply automated fixes

* 📝 (text_embedder.py): refactor error messages to use inline strings for better readability and maintainability
🐛 (text_embedder.py): fix issue with extracting the first element from embeddings list to ensure correct data handling

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
Cristhian Zanforlin Lousa 2024-10-22 20:21:58 -03:00 committed by GitHub
commit 9b1c382641
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 70 additions and 40 deletions

View file

@ -91,7 +91,7 @@ class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
msg = "API Key is required for non-local inference endpoints"
raise ValueError(msg)
else:
api_key = SecretStr(self.api_key)
api_key = SecretStr(self.api_key).get_secret_value()
try:
return self.create_huggingface_embeddings(api_key, api_url, self.model_name)

View file

@ -46,7 +46,7 @@ class MistralAIEmbeddingsComponent(LCModelComponent):
msg = "Mistral API Key is required"
raise ValueError(msg)
api_key = SecretStr(self.mistral_api_key)
api_key = SecretStr(self.mistral_api_key).get_secret_value()
return MistralAIEmbeddings(
api_key=api_key,

View file

@ -1,3 +1,4 @@
import logging
from typing import TYPE_CHECKING
from langflow.custom import Component
@ -13,7 +14,6 @@ class TextEmbedderComponent(Component):
display_name: str = "Text Embedder"
description: str = "Generate embeddings for a given message using the specified embedding model."
icon = "binary"
inputs = [
HandleInput(
name="embedding_model",
@ -27,26 +27,55 @@ class TextEmbedderComponent(Component):
info="The message to generate embeddings for.",
),
]
outputs = [
Output(display_name="Embedding Data", name="embeddings", method="generate_embeddings"),
]
def generate_embeddings(self) -> Data:
embedding_model: Embeddings = self.embedding_model
message: Message = self.message
try:
embedding_model: Embeddings = self.embedding_model
message: Message = self.message
# Extract the text content from the message
text_content = message.text
# Validate embedding model
if not embedding_model:
msg = "Embedding model not provided"
raise ValueError(msg)
# Generate embeddings using the provided embedding model
embeddings = embedding_model.embed_documents([text_content])
# Extract the text content from the message
text_content = message.text if message and message.text else ""
if not text_content:
msg = "No text content found in message"
raise ValueError(msg)
# Assuming the embedding model returns a list of embeddings, we take the first one
embedding_vector = embeddings[0] if embeddings else []
# Check if the embedding model has the required attributes
if not hasattr(embedding_model, "client") or not embedding_model.client:
msg = "Embedding model client not properly initialized"
raise ValueError(msg)
# Ensure the base URL has proper protocol
if hasattr(embedding_model.client, "base_url"):
base_url = embedding_model.client.base_url
if not base_url.startswith(("http://", "https://")):
embedding_model.client.base_url = f"https://{base_url}"
# Generate embeddings using the provided embedding model
embeddings = embedding_model.embed_documents([text_content])
# Validate embeddings output
if not embeddings or not isinstance(embeddings, list):
msg = "Invalid embeddings generated"
raise ValueError(msg)
embedding_vector = embeddings[0]
except Exception as e:
logging.exception("Error generating embeddings")
# Return empty data with error status
error_data = Data(data={"text": "", "embeddings": [], "error": str(e)})
self.status = {"error": str(e)}
return error_data
# Create a Data object to encapsulate the results
result_data = Data(data={"text": text_content, "embeddings": embedding_vector})
self.status = {"text": text_content, "embeddings": embedding_vector}
return result_data

View file

@ -68,7 +68,7 @@ class AnthropicModelComponent(LCModelComponent):
try:
output = ChatAnthropic(
model=model,
anthropic_api_key=(SecretStr(anthropic_api_key) if anthropic_api_key else None),
anthropic_api_key=(SecretStr(anthropic_api_key).get_secret_value() if anthropic_api_key else None),
max_tokens_to_sample=max_tokens,
temperature=temperature,
anthropic_api_url=anthropic_api_url,

View file

@ -88,8 +88,8 @@ class QianfanChatEndpointComponent(LCModelComponent):
try:
output = QianfanChatEndpoint(
model=model,
qianfan_ak=SecretStr(qianfan_ak) if qianfan_ak else None,
qianfan_sk=SecretStr(qianfan_sk) if qianfan_sk else None,
qianfan_ak=SecretStr(qianfan_ak).get_secret_value() if qianfan_ak else None,
qianfan_sk=SecretStr(qianfan_sk).get_secret_value() if qianfan_sk else None,
top_p=top_p,
temperature=temperature,
penalty_score=penalty_score,

View file

@ -37,7 +37,7 @@ class CohereComponent(LCModelComponent):
cohere_api_key = self.cohere_api_key
temperature = self.temperature
api_key = SecretStr(cohere_api_key) if cohere_api_key else None
api_key = SecretStr(cohere_api_key).get_secret_value() if cohere_api_key else None
return ChatCohere(
temperature=temperature or 0.75,

View file

@ -80,5 +80,5 @@ class GoogleGenerativeAIComponent(LCModelComponent):
top_k=top_k or None,
top_p=top_p or None,
n=n or 1,
google_api_key=SecretStr(google_api_key),
google_api_key=SecretStr(google_api_key).get_secret_value(),
)

View file

@ -98,6 +98,6 @@ class GroqModel(LCModelComponent):
temperature=temperature,
base_url=groq_api_base,
n=n or 1,
api_key=SecretStr(groq_api_key),
api_key=SecretStr(groq_api_key).get_secret_value(),
streaming=stream,
)

View file

@ -77,7 +77,7 @@ class MistralAIModelComponent(LCModelComponent):
random_seed = self.random_seed
safe_mode = self.safe_mode
api_key = SecretStr(mistral_api_key) if mistral_api_key else None
api_key = SecretStr(mistral_api_key).get_secret_value() if mistral_api_key else None
return ChatMistralAI(
max_tokens=max_tokens or None,

View file

@ -95,7 +95,7 @@ class OpenAIModelComponent(LCModelComponent):
json_mode = bool(output_schema_dict) or self.json_mode
seed = self.seed
api_key = SecretStr(openai_api_key) if openai_api_key else None
api_key = SecretStr(openai_api_key).get_secret_value() if openai_api_key else None
output = ChatOpenAI(
max_tokens=max_tokens or None,
model_kwargs=model_kwargs,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -110,4 +110,4 @@ class AuthSettings(BaseSettings):
write_secret_to_file(secret_key_path, value)
logger.debug("Saved secret key")
return value if isinstance(value, SecretStr) else SecretStr(value)
return value if isinstance(value, SecretStr) else SecretStr(value).get_secret_value()

View file

@ -923,7 +923,6 @@
},
"node_modules/@clack/prompts/node_modules/is-unicode-supported": {
"version": "1.3.0",
"extraneous": true,
"inBundle": true,
"license": "MIT",
"engines": {

View file

@ -319,6 +319,8 @@ test("user must be able to check similarity between embedding texts", async ({
await textOutputInput.hover();
await page.mouse.up();
await page.waitForTimeout(3000);
await page.getByTestId("button_run_text output").click();
await page.waitForSelector("text=built successfully", { timeout: 30000 });