From 58b51d1c562ccb31b6177f0ac0ca2fa58aca0e77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Mon, 29 Jul 2024 23:18:21 +0200 Subject: [PATCH] feat: bedrock access with aws access key (#3032) * feat: bedrock access with aws access key * [autofix.ci] apply automated fixes * chore: ignore type import in AmazonBedrockEmbeddings.py and AmazonBedrockModel.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida --- .../embeddings/AmazonBedrockEmbeddings.py | 53 +++++++++++++------ .../components/models/AmazonBedrockModel.py | 47 +++++++++++----- .../langflow/services/settings/constants.py | 2 + .../end-to-end/dropdownComponent.spec.ts | 20 ------- .../end-to-end/keyPairListComponent.spec.ts | 9 ---- 5 files changed, 71 insertions(+), 60 deletions(-) diff --git a/src/backend/base/langflow/components/embeddings/AmazonBedrockEmbeddings.py b/src/backend/base/langflow/components/embeddings/AmazonBedrockEmbeddings.py index 5e91e801a..299073510 100644 --- a/src/backend/base/langflow/components/embeddings/AmazonBedrockEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/AmazonBedrockEmbeddings.py @@ -2,6 +2,7 @@ from langchain_community.embeddings import BedrockEmbeddings from langflow.base.models.model import LCModelComponent from langflow.field_typing import Embeddings +from langflow.inputs import SecretStrInput from langflow.io import DropdownInput, MessageTextInput, Output @@ -19,18 +20,15 @@ class AmazonBedrockEmbeddingsComponent(LCModelComponent): options=["amazon.titan-embed-text-v1"], value="amazon.titan-embed-text-v1", ), + SecretStrInput(name="aws_access_key", display_name="Access Key"), + SecretStrInput(name="aws_secret_key", display_name="Secret Key"), MessageTextInput( name="credentials_profile_name", display_name="Credentials Profile Name", + advanced=True, ), - MessageTextInput( - name="endpoint_url", - display_name="Bedrock Endpoint URL", - ), - MessageTextInput( - name="region_name", - display_name="AWS Region", - ), + MessageTextInput(name="region_name", display_name="Region Name", value="us-east-1"), + MessageTextInput(name="endpoint_url", display_name=" Endpoint URL", advanced=True), ] outputs = [ @@ -38,13 +36,34 @@ class AmazonBedrockEmbeddingsComponent(LCModelComponent): ] def build_embeddings(self) -> Embeddings: - try: - output = BedrockEmbeddings( - credentials_profile_name=self.credentials_profile_name, - model_id=self.model_id, - endpoint_url=self.endpoint_url, - region_name=self.region_name, - ) # type: ignore - except Exception as e: - raise ValueError("Could not connect to Amazon Bedrock API.") from e + if self.aws_access_key: + import boto3 # type: ignore + + session = boto3.Session( + aws_access_key_id=self.aws_access_key, + aws_secret_access_key=self.aws_secret_key, + ) + elif self.credentials_profile_name: + import boto3 + + session = boto3.Session(profile_name=self.credentials_profile_name) + else: + import boto3 + + session = boto3.Session() + + client_params = {} + if self.endpoint_url: + client_params["endpoint_url"] = self.endpoint_url + if self.region_name: + client_params["region_name"] = self.region_name + + boto3_client = session.client("bedrock-runtime", **client_params) + output = BedrockEmbeddings( + credentials_profile_name=self.credentials_profile_name, + client=boto3_client, + model_id=self.model_id, + endpoint_url=self.endpoint_url, + region_name=self.region_name, + ) # type: ignore return output diff --git a/src/backend/base/langflow/components/models/AmazonBedrockModel.py b/src/backend/base/langflow/components/models/AmazonBedrockModel.py index deee5798a..b34f7ce7f 100644 --- a/src/backend/base/langflow/components/models/AmazonBedrockModel.py +++ b/src/backend/base/langflow/components/models/AmazonBedrockModel.py @@ -2,7 +2,7 @@ from langchain_aws import ChatBedrock from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel -from langflow.inputs import MessageTextInput +from langflow.inputs import MessageTextInput, SecretStrInput from langflow.io import DictInput, DropdownInput @@ -51,27 +51,46 @@ class AmazonBedrockComponent(LCModelComponent): ], value="anthropic.claude-3-haiku-20240307-v1:0", ), - MessageTextInput(name="credentials_profile_name", display_name="Credentials Profile Name"), + SecretStrInput(name="aws_access_key", display_name="Access Key"), + SecretStrInput(name="aws_secret_key", display_name="Secret Key"), + MessageTextInput(name="credentials_profile_name", display_name="Credentials Profile Name", advanced=True), MessageTextInput(name="region_name", display_name="Region Name", value="us-east-1"), DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True, is_list=True), MessageTextInput(name="endpoint_url", display_name="Endpoint URL", advanced=True), ] def build_model(self) -> LanguageModel: # type: ignore[type-var] - model_id = self.model_id - credentials_profile_name = self.credentials_profile_name - region_name = self.region_name - model_kwargs = self.model_kwargs - endpoint_url = self.endpoint_url - stream = self.stream + if self.aws_access_key: + import boto3 # type: ignore + + session = boto3.Session( + aws_access_key_id=self.aws_access_key, + aws_secret_access_key=self.aws_secret_key, + ) + elif self.credentials_profile_name: + import boto3 + + session = boto3.Session(profile_name=self.credentials_profile_name) + else: + import boto3 + + session = boto3.Session() + + client_params = {} + if self.endpoint_url: + client_params["endpoint_url"] = self.endpoint_url + if self.region_name: + client_params["region_name"] = self.region_name + + boto3_client = session.client("bedrock-runtime", **client_params) try: output = ChatBedrock( # type: ignore - credentials_profile_name=credentials_profile_name, - model_id=model_id, - region_name=region_name, - model_kwargs=model_kwargs, - endpoint_url=endpoint_url, - streaming=stream, + client=boto3_client, + model_id=self.model_id, + region_name=self.region_name, + model_kwargs=self.model_kwargs, + endpoint_url=self.endpoint_url, + streaming=self.stream, ) except Exception as e: raise ValueError("Could not connect to AmazonBedrock API.") from e diff --git a/src/backend/base/langflow/services/settings/constants.py b/src/backend/base/langflow/services/settings/constants.py index 256030183..8f7a97cbe 100644 --- a/src/backend/base/langflow/services/settings/constants.py +++ b/src/backend/base/langflow/services/settings/constants.py @@ -22,4 +22,6 @@ VARIABLES_TO_GET_FROM_ENVIRONMENT = [ "VECTARA_CUSTOMER_ID", "VECTARA_CORPUS_ID", "VECTARA_API_KEY", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", ] diff --git a/src/frontend/tests/end-to-end/dropdownComponent.spec.ts b/src/frontend/tests/end-to-end/dropdownComponent.spec.ts index 8cf60c9d7..1243a339e 100644 --- a/src/frontend/tests/end-to-end/dropdownComponent.spec.ts +++ b/src/frontend/tests/end-to-end/dropdownComponent.spec.ts @@ -80,16 +80,6 @@ test("dropDownComponent", async ({ page }) => { expect(false).toBeTruthy(); } - await page.locator('//*[@id="showcredentials_profile_name"]').click(); - expect( - await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(), - ).toBeFalsy(); - - await page.locator('//*[@id="showcredentials_profile_name"]').click(); - expect( - await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(), - ).toBeTruthy(); - await page.locator('//*[@id="showregion_name"]').click(); expect( await page.locator('//*[@id="showregion_name"]').isChecked(), @@ -110,16 +100,6 @@ test("dropDownComponent", async ({ page }) => { await page.locator('//*[@id="showmodel_id"]').isChecked(), ).toBeTruthy(); - await page.locator('//*[@id="showcredentials_profile_name"]').click(); - expect( - await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(), - ).toBeFalsy(); - - await page.locator('//*[@id="showcredentials_profile_name"]').click(); - expect( - await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(), - ).toBeTruthy(); - await page.locator('//*[@id="showregion_name"]').click(); expect( await page.locator('//*[@id="showregion_name"]').isChecked(), diff --git a/src/frontend/tests/end-to-end/keyPairListComponent.spec.ts b/src/frontend/tests/end-to-end/keyPairListComponent.spec.ts index 43b0700dd..9f2f9e63a 100644 --- a/src/frontend/tests/end-to-end/keyPairListComponent.spec.ts +++ b/src/frontend/tests/end-to-end/keyPairListComponent.spec.ts @@ -88,10 +88,6 @@ test("KeypairListComponent", async ({ page }) => { await page.getByTestId("more-options-modal").click(); await page.getByTestId("edit-button-modal").click(); - await page.locator('//*[@id="showcredentials_profile_name"]').click(); - expect( - await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(), - ).toBeFalsy(); await page.getByText("Close").last().click(); const plusButtonLocator = page.locator('//*[@id="plusbtn0"]'); @@ -103,11 +99,6 @@ test("KeypairListComponent", async ({ page }) => { await page.getByTestId("more-options-modal").click(); await page.getByTestId("edit-button-modal").click(); - await page.locator('//*[@id="showcredentials_profile_name"]').click(); - expect( - await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(), - ).toBeTruthy(); - await page.locator('//*[@id="editNodekeypair0"]').click(); await page.locator('//*[@id="editNodekeypair0"]').fill("testtesttesttest");