fix: updating components to access secret key with new method get_secret_value() (#4243)
* updating components to access secret key with new method * [autofix.ci] apply automated fixes * 📝 (text_embedder.py): Add error messages as constants for better error handling and readability 📝 (text_embedder.py): Improve error handling and validation for embedding model and text content 📝 (text_embedder.py): Ensure proper protocol for the base URL in the embedding model client 📝 (text_embedder.py): Validate the output of embeddings and handle exceptions with logging 📝 (text_embedder.py): Refactor code to encapsulate results in a Data object and update status accordingly 📝 (similarity.spec.ts): Add a delay before clicking the button to run the text output test for better synchronization * uv fix * [autofix.ci] apply automated fixes * 📝 (text_embedder.py): refactor error messages to use inline strings for better readability and maintainability 🐛 (text_embedder.py): fix issue with extracting the first element from embeddings list to ensure correct data handling --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
parent
f44aca5b41
commit
9b1c382641
23 changed files with 70 additions and 40 deletions
|
|
@ -91,7 +91,7 @@ class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
|
|||
msg = "API Key is required for non-local inference endpoints"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
api_key = SecretStr(self.api_key)
|
||||
api_key = SecretStr(self.api_key).get_secret_value()
|
||||
|
||||
try:
|
||||
return self.create_huggingface_embeddings(api_key, api_url, self.model_name)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class MistralAIEmbeddingsComponent(LCModelComponent):
|
|||
msg = "Mistral API Key is required"
|
||||
raise ValueError(msg)
|
||||
|
||||
api_key = SecretStr(self.mistral_api_key)
|
||||
api_key = SecretStr(self.mistral_api_key).get_secret_value()
|
||||
|
||||
return MistralAIEmbeddings(
|
||||
api_key=api_key,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langflow.custom import Component
|
||||
|
|
@ -13,7 +14,6 @@ class TextEmbedderComponent(Component):
|
|||
display_name: str = "Text Embedder"
|
||||
description: str = "Generate embeddings for a given message using the specified embedding model."
|
||||
icon = "binary"
|
||||
|
||||
inputs = [
|
||||
HandleInput(
|
||||
name="embedding_model",
|
||||
|
|
@ -27,26 +27,55 @@ class TextEmbedderComponent(Component):
|
|||
info="The message to generate embeddings for.",
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Embedding Data", name="embeddings", method="generate_embeddings"),
|
||||
]
|
||||
|
||||
def generate_embeddings(self) -> Data:
|
||||
embedding_model: Embeddings = self.embedding_model
|
||||
message: Message = self.message
|
||||
try:
|
||||
embedding_model: Embeddings = self.embedding_model
|
||||
message: Message = self.message
|
||||
|
||||
# Extract the text content from the message
|
||||
text_content = message.text
|
||||
# Validate embedding model
|
||||
if not embedding_model:
|
||||
msg = "Embedding model not provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Generate embeddings using the provided embedding model
|
||||
embeddings = embedding_model.embed_documents([text_content])
|
||||
# Extract the text content from the message
|
||||
text_content = message.text if message and message.text else ""
|
||||
if not text_content:
|
||||
msg = "No text content found in message"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Assuming the embedding model returns a list of embeddings, we take the first one
|
||||
embedding_vector = embeddings[0] if embeddings else []
|
||||
# Check if the embedding model has the required attributes
|
||||
if not hasattr(embedding_model, "client") or not embedding_model.client:
|
||||
msg = "Embedding model client not properly initialized"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Ensure the base URL has proper protocol
|
||||
if hasattr(embedding_model.client, "base_url"):
|
||||
base_url = embedding_model.client.base_url
|
||||
if not base_url.startswith(("http://", "https://")):
|
||||
embedding_model.client.base_url = f"https://{base_url}"
|
||||
|
||||
# Generate embeddings using the provided embedding model
|
||||
embeddings = embedding_model.embed_documents([text_content])
|
||||
|
||||
# Validate embeddings output
|
||||
if not embeddings or not isinstance(embeddings, list):
|
||||
msg = "Invalid embeddings generated"
|
||||
raise ValueError(msg)
|
||||
|
||||
embedding_vector = embeddings[0]
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error generating embeddings")
|
||||
# Return empty data with error status
|
||||
error_data = Data(data={"text": "", "embeddings": [], "error": str(e)})
|
||||
self.status = {"error": str(e)}
|
||||
return error_data
|
||||
|
||||
# Create a Data object to encapsulate the results
|
||||
result_data = Data(data={"text": text_content, "embeddings": embedding_vector})
|
||||
|
||||
self.status = {"text": text_content, "embeddings": embedding_vector}
|
||||
return result_data
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class AnthropicModelComponent(LCModelComponent):
|
|||
try:
|
||||
output = ChatAnthropic(
|
||||
model=model,
|
||||
anthropic_api_key=(SecretStr(anthropic_api_key) if anthropic_api_key else None),
|
||||
anthropic_api_key=(SecretStr(anthropic_api_key).get_secret_value() if anthropic_api_key else None),
|
||||
max_tokens_to_sample=max_tokens,
|
||||
temperature=temperature,
|
||||
anthropic_api_url=anthropic_api_url,
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@ class QianfanChatEndpointComponent(LCModelComponent):
|
|||
try:
|
||||
output = QianfanChatEndpoint(
|
||||
model=model,
|
||||
qianfan_ak=SecretStr(qianfan_ak) if qianfan_ak else None,
|
||||
qianfan_sk=SecretStr(qianfan_sk) if qianfan_sk else None,
|
||||
qianfan_ak=SecretStr(qianfan_ak).get_secret_value() if qianfan_ak else None,
|
||||
qianfan_sk=SecretStr(qianfan_sk).get_secret_value() if qianfan_sk else None,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
penalty_score=penalty_score,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class CohereComponent(LCModelComponent):
|
|||
cohere_api_key = self.cohere_api_key
|
||||
temperature = self.temperature
|
||||
|
||||
api_key = SecretStr(cohere_api_key) if cohere_api_key else None
|
||||
api_key = SecretStr(cohere_api_key).get_secret_value() if cohere_api_key else None
|
||||
|
||||
return ChatCohere(
|
||||
temperature=temperature or 0.75,
|
||||
|
|
|
|||
|
|
@ -80,5 +80,5 @@ class GoogleGenerativeAIComponent(LCModelComponent):
|
|||
top_k=top_k or None,
|
||||
top_p=top_p or None,
|
||||
n=n or 1,
|
||||
google_api_key=SecretStr(google_api_key),
|
||||
google_api_key=SecretStr(google_api_key).get_secret_value(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -98,6 +98,6 @@ class GroqModel(LCModelComponent):
|
|||
temperature=temperature,
|
||||
base_url=groq_api_base,
|
||||
n=n or 1,
|
||||
api_key=SecretStr(groq_api_key),
|
||||
api_key=SecretStr(groq_api_key).get_secret_value(),
|
||||
streaming=stream,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class MistralAIModelComponent(LCModelComponent):
|
|||
random_seed = self.random_seed
|
||||
safe_mode = self.safe_mode
|
||||
|
||||
api_key = SecretStr(mistral_api_key) if mistral_api_key else None
|
||||
api_key = SecretStr(mistral_api_key).get_secret_value() if mistral_api_key else None
|
||||
|
||||
return ChatMistralAI(
|
||||
max_tokens=max_tokens or None,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
json_mode = bool(output_schema_dict) or self.json_mode
|
||||
seed = self.seed
|
||||
|
||||
api_key = SecretStr(openai_api_key) if openai_api_key else None
|
||||
api_key = SecretStr(openai_api_key).get_secret_value() if openai_api_key else None
|
||||
output = ChatOpenAI(
|
||||
max_tokens=max_tokens or None,
|
||||
model_kwargs=model_kwargs,
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -110,4 +110,4 @@ class AuthSettings(BaseSettings):
|
|||
write_secret_to_file(secret_key_path, value)
|
||||
logger.debug("Saved secret key")
|
||||
|
||||
return value if isinstance(value, SecretStr) else SecretStr(value)
|
||||
return value if isinstance(value, SecretStr) else SecretStr(value).get_secret_value()
|
||||
|
|
|
|||
1
src/frontend/package-lock.json
generated
1
src/frontend/package-lock.json
generated
|
|
@ -923,7 +923,6 @@
|
|||
},
|
||||
"node_modules/@clack/prompts/node_modules/is-unicode-supported": {
|
||||
"version": "1.3.0",
|
||||
"extraneous": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
|
|
|||
|
|
@ -319,6 +319,8 @@ test("user must be able to check similarity between embedding texts", async ({
|
|||
await textOutputInput.hover();
|
||||
await page.mouse.up();
|
||||
|
||||
await page.waitForTimeout(3000);
|
||||
|
||||
await page.getByTestId("button_run_text output").click();
|
||||
|
||||
await page.waitForSelector("text=built successfully", { timeout: 30000 });
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue