refactor: Update VertexAIEmbeddingsComponent to use new Inputs/Outputs format

This commit is contained in:
Cezar Vasconcelos 2024-06-19 21:04:46 +00:00
commit 4a201d478c

View file

@ -1,76 +1,101 @@
from typing import List, Optional
from langflow.custom import CustomComponent
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import BoolInput, DictInput, FileInput, FloatInput, IntInput, Output, TextInput
class VertexAIEmbeddingsComponent(CustomComponent):
class VertexAIEmbeddingsComponent(LCModelComponent):
display_name = "VertexAI Embeddings"
description = "Generate embeddings using Google Cloud VertexAI models."
icon = "VertexAI"
def build_config(self):
return {
"credentials": {
"display_name": "Credentials",
"value": "",
"file_types": [".json"],
"field_type": "file",
},
"instance": {
"display_name": "instance",
"advanced": True,
"field_type": "dict",
},
"location": {
"display_name": "Location",
"value": "us-central1",
"advanced": True,
},
"max_output_tokens": {"display_name": "Max Output Tokens", "value": 128},
"max_retries": {
"display_name": "Max Retries",
"value": 6,
"advanced": True,
},
"model_name": {
"display_name": "Model Name",
"value": "textembedding-gecko",
},
"n": {"display_name": "N", "value": 1, "advanced": True},
"project": {"display_name": "Project", "advanced": True},
"request_parallelism": {
"display_name": "Request Parallelism",
"value": 5,
"advanced": True,
},
"stop": {"display_name": "Stop", "advanced": True},
"streaming": {
"display_name": "Streaming",
"value": False,
"advanced": True,
},
"temperature": {"display_name": "Temperature", "value": 0.0},
"top_k": {"display_name": "Top K", "value": 40, "advanced": True},
"top_p": {"display_name": "Top P", "value": 0.95, "advanced": True},
}
inputs = [
FileInput(
name="credentials",
display_name="Credentials",
value="",
file_types=["json"], # Removed the dot
),
DictInput(
name="instance",
display_name="Instance",
advanced=True,
),
TextInput(
name="location",
display_name="Location",
value="us-central1",
advanced=True,
),
IntInput(
name="max_output_tokens",
display_name="Max Output Tokens",
value=128,
),
IntInput(
name="max_retries",
display_name="Max Retries",
value=6,
advanced=True,
),
TextInput(
name="model_name",
display_name="Model Name",
value="textembedding-gecko",
),
IntInput(
name="n",
display_name="N",
value=1,
advanced=True,
),
TextInput(
name="project",
display_name="Project",
advanced=True,
),
IntInput(
name="request_parallelism",
display_name="Request Parallelism",
value=5,
advanced=True,
),
TextInput(
name="stop",
display_name="Stop",
advanced=True,
),
BoolInput(
name="streaming",
display_name="Streaming",
value=False,
advanced=True,
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.0,
),
IntInput(
name="top_k",
display_name="Top K",
value=40,
advanced=True,
),
FloatInput(
name="top_p",
display_name="Top P",
value=0.95,
advanced=True,
),
]
def build(
self,
instance: Optional[str] = None,
credentials: Optional[str] = None,
location: str = "us-central1",
max_output_tokens: int = 128,
max_retries: int = 6,
model_name: str = "textembedding-gecko",
n: int = 1,
project: Optional[str] = None,
request_parallelism: int = 5,
stop: Optional[List[str]] = None,
streaming: bool = False,
temperature: float = 0.0,
top_k: int = 40,
top_p: float = 0.95,
) -> Embeddings:
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
try:
from langchain_google_vertexai import VertexAIEmbeddings
except ImportError:
@ -79,18 +104,18 @@ class VertexAIEmbeddingsComponent(CustomComponent):
)
return VertexAIEmbeddings(
instance=instance,
credentials=credentials,
location=location,
max_output_tokens=max_output_tokens,
max_retries=max_retries,
model_name=model_name,
n=n,
project=project,
request_parallelism=request_parallelism,
stop=stop,
streaming=streaming,
temperature=temperature,
top_k=top_k,
top_p=top_p,
instance=self.instance,
credentials=self.credentials,
location=self.location,
max_output_tokens=self.max_output_tokens,
max_retries=self.max_retries,
model_name=self.model_name,
n=self.n,
project=self.project,
request_parallelism=self.request_parallelism,
stop=self.stop,
streaming=self.streaming,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
)