From 17065dd083fdde890e5f1822f64bf7e81f600610 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Tue, 11 Jun 2024 23:01:46 -0300 Subject: [PATCH] refactor: Update GroqModelComponent to use BaseLanguageModel and langflow template --- .../langflow/components/models/GroqModel.py | 149 +++++++++--------- 1 file changed, 78 insertions(+), 71 deletions(-) diff --git a/src/backend/base/langflow/components/models/GroqModel.py b/src/backend/base/langflow/components/models/GroqModel.py index 91959e751..022aece47 100644 --- a/src/backend/base/langflow/components/models/GroqModel.py +++ b/src/backend/base/langflow/components/models/GroqModel.py @@ -1,95 +1,102 @@ from typing import Optional from langchain_groq import ChatGroq -from langflow.base.models.groq_constants import MODEL_NAMES from pydantic.v1 import SecretStr from langflow.base.constants import STREAM_INFO_TEXT +from langflow.base.models.groq_constants import MODEL_NAMES from langflow.base.models.model import LCModelComponent -from langflow.field_typing import Text +from langflow.field_typing import BaseLanguageModel, Text +from langflow.template import Input, Output -class GroqModel(LCModelComponent): +class GroqModelComponent(LCModelComponent): display_name: str = "Groq" description: str = "Generate text using Groq." icon = "Groq" - field_order = [ - "groq_api_key", - "model", - "max_output_tokens", - "temperature", - "top_k", - "top_p", - "n", - "input_value", - "system_message", - "stream", + inputs = [ + Input( + name="groq_api_key", + field_type=str, + display_name="Groq API Key", + info="API key for the Groq API.", + password=True, + ), + Input( + name="groq_api_base", + field_type=Optional[str], + display_name="Groq API Base", + advanced=True, + info="Base URL path for API requests, leave blank if not using a proxy or service emulator.", + ), + Input( + name="max_tokens", + field_type=Optional[int], + display_name="Max Output Tokens", + advanced=True, + info="The maximum number of tokens to generate.", + ), + Input( + name="temperature", + field_type=float, + display_name="Temperature", + info="Run inference with this temperature. Must be in the closed interval [0.0, 1.0].", + ), + Input( + name="n", + field_type=Optional[int], + display_name="N", + advanced=True, + info="Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.", + ), + Input( + name="model_name", + field_type=str, + display_name="Model", + info="The name of the model to use. Supported examples: gemini-pro", + options=MODEL_NAMES, + ), + Input(name="input_value", field_type=str, display_name="Input", input_types=["Text", "Record", "Prompt"]), + Input(name="stream", field_type=bool, display_name="Stream", advanced=True, info=STREAM_INFO_TEXT), + Input( + name="system_message", + field_type=Optional[str], + display_name="System Message", + advanced=True, + info="System message to pass to the model.", + ), + ] + outputs = [ + Output(display_name="Text", name="text_output", method="text_response"), + Output(display_name="Language Model", name="model_output", method="model_response"), ] - def build_config(self): - return { - "groq_api_key": { - "display_name": "Groq API Key", - "info": "API key for the Groq API.", - "password": True, - }, - "groq_api_base": { - "display_name": "Groq API Base", - "info": "Base URL path for API requests, leave blank if not using a proxy or service emulator.", - "advanced": True, - }, - "max_tokens": { - "display_name": "Max Output Tokens", - "info": "The maximum number of tokens to generate.", - "advanced": True, - }, - "temperature": { - "display_name": "Temperature", - "info": "Run inference with this temperature. Must by in the closed interval [0.0, 1.0].", - }, - "n": { - "display_name": "N", - "info": "Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.", - "advanced": True, - }, - "model_name": { - "display_name": "Model", - "info": "The name of the model to use. Supported examples: gemini-pro", - "options": MODEL_NAMES, - }, - "input_value": {"display_name": "Input", "info": "The input to the model."}, - "stream": { - "display_name": "Stream", - "info": STREAM_INFO_TEXT, - "advanced": True, - }, - "system_message": { - "display_name": "System Message", - "info": "System message to pass to the model.", - "advanced": True, - }, - } + def text_response(self) -> Text: + input_value = self.input_value + stream = self.stream + system_message = self.system_message + output = self.model_response() + result = self.get_chat_result(output, stream, input_value, system_message) + self.status = result + return result + + def model_response(self) -> BaseLanguageModel: + groq_api_key = self.groq_api_key + model_name = self.model_name + groq_api_base = self.groq_api_base or None + max_tokens = self.max_tokens + temperature = self.temperature + n = self.n or 1 + stream = self.stream - def build( - self, - groq_api_key: str, - model_name: str, - input_value: Text, - groq_api_base: Optional[str] = None, - max_tokens: Optional[int] = None, - temperature: float = 0.1, - n: Optional[int] = 1, - stream: bool = False, - system_message: Optional[str] = None, - ) -> Text: output = ChatGroq( model_name=model_name, max_tokens=max_tokens or None, # type: ignore temperature=temperature, groq_api_base=groq_api_base, - n=n or 1, + n=n, groq_api_key=SecretStr(groq_api_key), streaming=stream, ) - return self.get_chat_result(output, stream, input_value, system_message) + return output