refactor: Update AmazonBedrockEmbeddingsComponent to use new Inputs/Outputs format

This commit is contained in:
Cezar Vasconcelos 2024-06-19 21:09:54 +00:00
commit a180589f79

View file

@ -1,42 +1,50 @@
from typing import Optional
from langchain_community.embeddings import BedrockEmbeddings
from langchain_core.embeddings import Embeddings
from langflow.custom import CustomComponent
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import DropdownInput, Output, SecretStrInput, TextInput
class AmazonBedrockEmeddingsComponent(CustomComponent):
class AmazonBedrockEmbeddingsComponent(LCModelComponent):
display_name: str = "Amazon Bedrock Embeddings"
description: str = "Generate embeddings using Amazon Bedrock models."
documentation = "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/bedrock"
icon = "Amazon"
def build_config(self):
return {
"model_id": {
"display_name": "Model Id",
"options": ["amazon.titan-embed-text-v1"],
},
"credentials_profile_name": {"display_name": "Credentials Profile Name"},
"endpoint_url": {"display_name": "Bedrock Endpoint URL"},
"region_name": {"display_name": "AWS Region"},
"code": {"show": False},
}
inputs = [
DropdownInput(
name="model_id",
display_name="Model Id",
options=["amazon.titan-embed-text-v1"],
value="amazon.titan-embed-text-v1",
),
TextInput(
name="credentials_profile_name",
display_name="Credentials Profile Name",
),
TextInput(
name="endpoint_url",
display_name="Bedrock Endpoint URL",
),
TextInput(
name="region_name",
display_name="AWS Region",
),
]
def build(
self,
model_id: str = "amazon.titan-embed-text-v1",
credentials_profile_name: Optional[str] = None,
endpoint_url: Optional[str] = None,
region_name: Optional[str] = None,
) -> Embeddings:
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
try:
output = BedrockEmbeddings(
credentials_profile_name=credentials_profile_name,
model_id=model_id,
endpoint_url=endpoint_url,
region_name=region_name,
credentials_profile_name=self.credentials_profile_name,
model_id=self.model_id,
endpoint_url=self.endpoint_url,
region_name=self.region_name,
) # type: ignore
except Exception as e:
raise ValueError("Could not connect to AmazonBedrock API.") from e
raise ValueError("Could not connect to Amazon Bedrock API.") from e
return output