From 4a201d478c29df0dae8f5755838f0ba75f359937 Mon Sep 17 00:00:00 2001 From: Cezar Vasconcelos Date: Wed, 19 Jun 2024 21:04:46 +0000 Subject: [PATCH] refactor: Update VertexAIEmbeddingsComponent to use new Inputs/Outputs format --- .../embeddings/VertexAIEmbeddings.py | 181 ++++++++++-------- 1 file changed, 103 insertions(+), 78 deletions(-) diff --git a/src/backend/base/langflow/components/embeddings/VertexAIEmbeddings.py b/src/backend/base/langflow/components/embeddings/VertexAIEmbeddings.py index 1a3eedc7c..30f1eda02 100644 --- a/src/backend/base/langflow/components/embeddings/VertexAIEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/VertexAIEmbeddings.py @@ -1,76 +1,101 @@ from typing import List, Optional -from langflow.custom import CustomComponent +from langflow.base.models.model import LCModelComponent from langflow.field_typing import Embeddings +from langflow.io import BoolInput, DictInput, FileInput, FloatInput, IntInput, Output, TextInput -class VertexAIEmbeddingsComponent(CustomComponent): +class VertexAIEmbeddingsComponent(LCModelComponent): display_name = "VertexAI Embeddings" description = "Generate embeddings using Google Cloud VertexAI models." + icon = "VertexAI" - def build_config(self): - return { - "credentials": { - "display_name": "Credentials", - "value": "", - "file_types": [".json"], - "field_type": "file", - }, - "instance": { - "display_name": "instance", - "advanced": True, - "field_type": "dict", - }, - "location": { - "display_name": "Location", - "value": "us-central1", - "advanced": True, - }, - "max_output_tokens": {"display_name": "Max Output Tokens", "value": 128}, - "max_retries": { - "display_name": "Max Retries", - "value": 6, - "advanced": True, - }, - "model_name": { - "display_name": "Model Name", - "value": "textembedding-gecko", - }, - "n": {"display_name": "N", "value": 1, "advanced": True}, - "project": {"display_name": "Project", "advanced": True}, - "request_parallelism": { - "display_name": "Request Parallelism", - "value": 5, - "advanced": True, - }, - "stop": {"display_name": "Stop", "advanced": True}, - "streaming": { - "display_name": "Streaming", - "value": False, - "advanced": True, - }, - "temperature": {"display_name": "Temperature", "value": 0.0}, - "top_k": {"display_name": "Top K", "value": 40, "advanced": True}, - "top_p": {"display_name": "Top P", "value": 0.95, "advanced": True}, - } + inputs = [ + FileInput( + name="credentials", + display_name="Credentials", + value="", + file_types=["json"], # Removed the dot + ), + DictInput( + name="instance", + display_name="Instance", + advanced=True, + ), + TextInput( + name="location", + display_name="Location", + value="us-central1", + advanced=True, + ), + IntInput( + name="max_output_tokens", + display_name="Max Output Tokens", + value=128, + ), + IntInput( + name="max_retries", + display_name="Max Retries", + value=6, + advanced=True, + ), + TextInput( + name="model_name", + display_name="Model Name", + value="textembedding-gecko", + ), + IntInput( + name="n", + display_name="N", + value=1, + advanced=True, + ), + TextInput( + name="project", + display_name="Project", + advanced=True, + ), + IntInput( + name="request_parallelism", + display_name="Request Parallelism", + value=5, + advanced=True, + ), + TextInput( + name="stop", + display_name="Stop", + advanced=True, + ), + BoolInput( + name="streaming", + display_name="Streaming", + value=False, + advanced=True, + ), + FloatInput( + name="temperature", + display_name="Temperature", + value=0.0, + ), + IntInput( + name="top_k", + display_name="Top K", + value=40, + advanced=True, + ), + FloatInput( + name="top_p", + display_name="Top P", + value=0.95, + advanced=True, + ), + ] - def build( - self, - instance: Optional[str] = None, - credentials: Optional[str] = None, - location: str = "us-central1", - max_output_tokens: int = 128, - max_retries: int = 6, - model_name: str = "textembedding-gecko", - n: int = 1, - project: Optional[str] = None, - request_parallelism: int = 5, - stop: Optional[List[str]] = None, - streaming: bool = False, - temperature: float = 0.0, - top_k: int = 40, - top_p: float = 0.95, - ) -> Embeddings: + outputs = [ + Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), + ] + + def build_embeddings(self) -> Embeddings: try: from langchain_google_vertexai import VertexAIEmbeddings except ImportError: @@ -79,18 +104,18 @@ class VertexAIEmbeddingsComponent(CustomComponent): ) return VertexAIEmbeddings( - instance=instance, - credentials=credentials, - location=location, - max_output_tokens=max_output_tokens, - max_retries=max_retries, - model_name=model_name, - n=n, - project=project, - request_parallelism=request_parallelism, - stop=stop, - streaming=streaming, - temperature=temperature, - top_k=top_k, - top_p=top_p, + instance=self.instance, + credentials=self.credentials, + location=self.location, + max_output_tokens=self.max_output_tokens, + max_retries=self.max_retries, + model_name=self.model_name, + n=self.n, + project=self.project, + request_parallelism=self.request_parallelism, + stop=self.stop, + streaming=self.streaming, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, )