refactor: Update VertexAIEmbeddingsComponent to use new Inputs/Outputs format
This commit is contained in:
parent
5d21466525
commit
4a201d478c
1 changed files with 103 additions and 78 deletions
|
|
@ -1,76 +1,101 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.io import BoolInput, DictInput, FileInput, FloatInput, IntInput, Output, TextInput
|
||||
|
||||
|
||||
class VertexAIEmbeddingsComponent(CustomComponent):
|
||||
class VertexAIEmbeddingsComponent(LCModelComponent):
|
||||
display_name = "VertexAI Embeddings"
|
||||
description = "Generate embeddings using Google Cloud VertexAI models."
|
||||
icon = "VertexAI"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"credentials": {
|
||||
"display_name": "Credentials",
|
||||
"value": "",
|
||||
"file_types": [".json"],
|
||||
"field_type": "file",
|
||||
},
|
||||
"instance": {
|
||||
"display_name": "instance",
|
||||
"advanced": True,
|
||||
"field_type": "dict",
|
||||
},
|
||||
"location": {
|
||||
"display_name": "Location",
|
||||
"value": "us-central1",
|
||||
"advanced": True,
|
||||
},
|
||||
"max_output_tokens": {"display_name": "Max Output Tokens", "value": 128},
|
||||
"max_retries": {
|
||||
"display_name": "Max Retries",
|
||||
"value": 6,
|
||||
"advanced": True,
|
||||
},
|
||||
"model_name": {
|
||||
"display_name": "Model Name",
|
||||
"value": "textembedding-gecko",
|
||||
},
|
||||
"n": {"display_name": "N", "value": 1, "advanced": True},
|
||||
"project": {"display_name": "Project", "advanced": True},
|
||||
"request_parallelism": {
|
||||
"display_name": "Request Parallelism",
|
||||
"value": 5,
|
||||
"advanced": True,
|
||||
},
|
||||
"stop": {"display_name": "Stop", "advanced": True},
|
||||
"streaming": {
|
||||
"display_name": "Streaming",
|
||||
"value": False,
|
||||
"advanced": True,
|
||||
},
|
||||
"temperature": {"display_name": "Temperature", "value": 0.0},
|
||||
"top_k": {"display_name": "Top K", "value": 40, "advanced": True},
|
||||
"top_p": {"display_name": "Top P", "value": 0.95, "advanced": True},
|
||||
}
|
||||
inputs = [
|
||||
FileInput(
|
||||
name="credentials",
|
||||
display_name="Credentials",
|
||||
value="",
|
||||
file_types=["json"], # Removed the dot
|
||||
),
|
||||
DictInput(
|
||||
name="instance",
|
||||
display_name="Instance",
|
||||
advanced=True,
|
||||
),
|
||||
TextInput(
|
||||
name="location",
|
||||
display_name="Location",
|
||||
value="us-central1",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
name="max_output_tokens",
|
||||
display_name="Max Output Tokens",
|
||||
value=128,
|
||||
),
|
||||
IntInput(
|
||||
name="max_retries",
|
||||
display_name="Max Retries",
|
||||
value=6,
|
||||
advanced=True,
|
||||
),
|
||||
TextInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
value="textembedding-gecko",
|
||||
),
|
||||
IntInput(
|
||||
name="n",
|
||||
display_name="N",
|
||||
value=1,
|
||||
advanced=True,
|
||||
),
|
||||
TextInput(
|
||||
name="project",
|
||||
display_name="Project",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
name="request_parallelism",
|
||||
display_name="Request Parallelism",
|
||||
value=5,
|
||||
advanced=True,
|
||||
),
|
||||
TextInput(
|
||||
name="stop",
|
||||
display_name="Stop",
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="streaming",
|
||||
display_name="Streaming",
|
||||
value=False,
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.0,
|
||||
),
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K",
|
||||
value=40,
|
||||
advanced=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="top_p",
|
||||
display_name="Top P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
instance: Optional[str] = None,
|
||||
credentials: Optional[str] = None,
|
||||
location: str = "us-central1",
|
||||
max_output_tokens: int = 128,
|
||||
max_retries: int = 6,
|
||||
model_name: str = "textembedding-gecko",
|
||||
n: int = 1,
|
||||
project: Optional[str] = None,
|
||||
request_parallelism: int = 5,
|
||||
stop: Optional[List[str]] = None,
|
||||
streaming: bool = False,
|
||||
temperature: float = 0.0,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.95,
|
||||
) -> Embeddings:
|
||||
outputs = [
|
||||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
|
||||
]
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
try:
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
except ImportError:
|
||||
|
|
@ -79,18 +104,18 @@ class VertexAIEmbeddingsComponent(CustomComponent):
|
|||
)
|
||||
|
||||
return VertexAIEmbeddings(
|
||||
instance=instance,
|
||||
credentials=credentials,
|
||||
location=location,
|
||||
max_output_tokens=max_output_tokens,
|
||||
max_retries=max_retries,
|
||||
model_name=model_name,
|
||||
n=n,
|
||||
project=project,
|
||||
request_parallelism=request_parallelism,
|
||||
stop=stop,
|
||||
streaming=streaming,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
instance=self.instance,
|
||||
credentials=self.credentials,
|
||||
location=self.location,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
max_retries=self.max_retries,
|
||||
model_name=self.model_name,
|
||||
n=self.n,
|
||||
project=self.project,
|
||||
request_parallelism=self.request_parallelism,
|
||||
stop=self.stop,
|
||||
streaming=self.streaming,
|
||||
temperature=self.temperature,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue