feat: bedrock access with aws access key (#3032)
* feat: bedrock access with aws access key * [autofix.ci] apply automated fixes * chore: ignore type import in AmazonBedrockEmbeddings.py and AmazonBedrockModel.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
7fa3d33d8a
commit
58b51d1c56
5 changed files with 71 additions and 60 deletions
|
|
@ -2,6 +2,7 @@ from langchain_community.embeddings import BedrockEmbeddings
|
|||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.inputs import SecretStrInput
|
||||
from langflow.io import DropdownInput, MessageTextInput, Output
|
||||
|
||||
|
||||
|
|
@ -19,18 +20,15 @@ class AmazonBedrockEmbeddingsComponent(LCModelComponent):
|
|||
options=["amazon.titan-embed-text-v1"],
|
||||
value="amazon.titan-embed-text-v1",
|
||||
),
|
||||
SecretStrInput(name="aws_access_key", display_name="Access Key"),
|
||||
SecretStrInput(name="aws_secret_key", display_name="Secret Key"),
|
||||
MessageTextInput(
|
||||
name="credentials_profile_name",
|
||||
display_name="Credentials Profile Name",
|
||||
advanced=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="endpoint_url",
|
||||
display_name="Bedrock Endpoint URL",
|
||||
),
|
||||
MessageTextInput(
|
||||
name="region_name",
|
||||
display_name="AWS Region",
|
||||
),
|
||||
MessageTextInput(name="region_name", display_name="Region Name", value="us-east-1"),
|
||||
MessageTextInput(name="endpoint_url", display_name=" Endpoint URL", advanced=True),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
|
|
@ -38,13 +36,34 @@ class AmazonBedrockEmbeddingsComponent(LCModelComponent):
|
|||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
try:
|
||||
output = BedrockEmbeddings(
|
||||
credentials_profile_name=self.credentials_profile_name,
|
||||
model_id=self.model_id,
|
||||
endpoint_url=self.endpoint_url,
|
||||
region_name=self.region_name,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Amazon Bedrock API.") from e
|
||||
if self.aws_access_key:
|
||||
import boto3 # type: ignore
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self.aws_access_key,
|
||||
aws_secret_access_key=self.aws_secret_key,
|
||||
)
|
||||
elif self.credentials_profile_name:
|
||||
import boto3
|
||||
|
||||
session = boto3.Session(profile_name=self.credentials_profile_name)
|
||||
else:
|
||||
import boto3
|
||||
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if self.endpoint_url:
|
||||
client_params["endpoint_url"] = self.endpoint_url
|
||||
if self.region_name:
|
||||
client_params["region_name"] = self.region_name
|
||||
|
||||
boto3_client = session.client("bedrock-runtime", **client_params)
|
||||
output = BedrockEmbeddings(
|
||||
credentials_profile_name=self.credentials_profile_name,
|
||||
client=boto3_client,
|
||||
model_id=self.model_id,
|
||||
endpoint_url=self.endpoint_url,
|
||||
region_name=self.region_name,
|
||||
) # type: ignore
|
||||
return output
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from langchain_aws import ChatBedrock
|
|||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.inputs import MessageTextInput
|
||||
from langflow.inputs import MessageTextInput, SecretStrInput
|
||||
from langflow.io import DictInput, DropdownInput
|
||||
|
||||
|
||||
|
|
@ -51,27 +51,46 @@ class AmazonBedrockComponent(LCModelComponent):
|
|||
],
|
||||
value="anthropic.claude-3-haiku-20240307-v1:0",
|
||||
),
|
||||
MessageTextInput(name="credentials_profile_name", display_name="Credentials Profile Name"),
|
||||
SecretStrInput(name="aws_access_key", display_name="Access Key"),
|
||||
SecretStrInput(name="aws_secret_key", display_name="Secret Key"),
|
||||
MessageTextInput(name="credentials_profile_name", display_name="Credentials Profile Name", advanced=True),
|
||||
MessageTextInput(name="region_name", display_name="Region Name", value="us-east-1"),
|
||||
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True, is_list=True),
|
||||
MessageTextInput(name="endpoint_url", display_name="Endpoint URL", advanced=True),
|
||||
]
|
||||
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
model_id = self.model_id
|
||||
credentials_profile_name = self.credentials_profile_name
|
||||
region_name = self.region_name
|
||||
model_kwargs = self.model_kwargs
|
||||
endpoint_url = self.endpoint_url
|
||||
stream = self.stream
|
||||
if self.aws_access_key:
|
||||
import boto3 # type: ignore
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self.aws_access_key,
|
||||
aws_secret_access_key=self.aws_secret_key,
|
||||
)
|
||||
elif self.credentials_profile_name:
|
||||
import boto3
|
||||
|
||||
session = boto3.Session(profile_name=self.credentials_profile_name)
|
||||
else:
|
||||
import boto3
|
||||
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if self.endpoint_url:
|
||||
client_params["endpoint_url"] = self.endpoint_url
|
||||
if self.region_name:
|
||||
client_params["region_name"] = self.region_name
|
||||
|
||||
boto3_client = session.client("bedrock-runtime", **client_params)
|
||||
try:
|
||||
output = ChatBedrock( # type: ignore
|
||||
credentials_profile_name=credentials_profile_name,
|
||||
model_id=model_id,
|
||||
region_name=region_name,
|
||||
model_kwargs=model_kwargs,
|
||||
endpoint_url=endpoint_url,
|
||||
streaming=stream,
|
||||
client=boto3_client,
|
||||
model_id=self.model_id,
|
||||
region_name=self.region_name,
|
||||
model_kwargs=self.model_kwargs,
|
||||
endpoint_url=self.endpoint_url,
|
||||
streaming=self.stream,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to AmazonBedrock API.") from e
|
||||
|
|
|
|||
|
|
@ -22,4 +22,6 @@ VARIABLES_TO_GET_FROM_ENVIRONMENT = [
|
|||
"VECTARA_CUSTOMER_ID",
|
||||
"VECTARA_CORPUS_ID",
|
||||
"VECTARA_API_KEY",
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -80,16 +80,6 @@ test("dropDownComponent", async ({ page }) => {
|
|||
expect(false).toBeTruthy();
|
||||
}
|
||||
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(),
|
||||
).toBeFalsy();
|
||||
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(),
|
||||
).toBeTruthy();
|
||||
|
||||
await page.locator('//*[@id="showregion_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showregion_name"]').isChecked(),
|
||||
|
|
@ -110,16 +100,6 @@ test("dropDownComponent", async ({ page }) => {
|
|||
await page.locator('//*[@id="showmodel_id"]').isChecked(),
|
||||
).toBeTruthy();
|
||||
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(),
|
||||
).toBeFalsy();
|
||||
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(),
|
||||
).toBeTruthy();
|
||||
|
||||
await page.locator('//*[@id="showregion_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showregion_name"]').isChecked(),
|
||||
|
|
|
|||
|
|
@ -88,10 +88,6 @@ test("KeypairListComponent", async ({ page }) => {
|
|||
await page.getByTestId("more-options-modal").click();
|
||||
await page.getByTestId("edit-button-modal").click();
|
||||
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(),
|
||||
).toBeFalsy();
|
||||
await page.getByText("Close").last().click();
|
||||
|
||||
const plusButtonLocator = page.locator('//*[@id="plusbtn0"]');
|
||||
|
|
@ -103,11 +99,6 @@ test("KeypairListComponent", async ({ page }) => {
|
|||
await page.getByTestId("more-options-modal").click();
|
||||
await page.getByTestId("edit-button-modal").click();
|
||||
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').click();
|
||||
expect(
|
||||
await page.locator('//*[@id="showcredentials_profile_name"]').isChecked(),
|
||||
).toBeTruthy();
|
||||
|
||||
await page.locator('//*[@id="editNodekeypair0"]').click();
|
||||
await page.locator('//*[@id="editNodekeypair0"]').fill("testtesttesttest");
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue