refactor: Update GroqModelComponent to use BaseLanguageModel and langflow template

This commit is contained in:
ogabrielluiz 2024-06-11 23:01:46 -03:00
commit 17065dd083

View file

@ -1,95 +1,102 @@
from typing import Optional
from langchain_groq import ChatGroq
from langflow.base.models.groq_constants import MODEL_NAMES
from pydantic.v1 import SecretStr
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.groq_constants import MODEL_NAMES
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Text
from langflow.field_typing import BaseLanguageModel, Text
from langflow.template import Input, Output
class GroqModel(LCModelComponent):
class GroqModelComponent(LCModelComponent):
display_name: str = "Groq"
description: str = "Generate text using Groq."
icon = "Groq"
field_order = [
"groq_api_key",
"model",
"max_output_tokens",
"temperature",
"top_k",
"top_p",
"n",
"input_value",
"system_message",
"stream",
inputs = [
Input(
name="groq_api_key",
field_type=str,
display_name="Groq API Key",
info="API key for the Groq API.",
password=True,
),
Input(
name="groq_api_base",
field_type=Optional[str],
display_name="Groq API Base",
advanced=True,
info="Base URL path for API requests, leave blank if not using a proxy or service emulator.",
),
Input(
name="max_tokens",
field_type=Optional[int],
display_name="Max Output Tokens",
advanced=True,
info="The maximum number of tokens to generate.",
),
Input(
name="temperature",
field_type=float,
display_name="Temperature",
info="Run inference with this temperature. Must be in the closed interval [0.0, 1.0].",
),
Input(
name="n",
field_type=Optional[int],
display_name="N",
advanced=True,
info="Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.",
),
Input(
name="model_name",
field_type=str,
display_name="Model",
info="The name of the model to use. Supported examples: gemini-pro",
options=MODEL_NAMES,
),
Input(name="input_value", field_type=str, display_name="Input", input_types=["Text", "Record", "Prompt"]),
Input(name="stream", field_type=bool, display_name="Stream", advanced=True, info=STREAM_INFO_TEXT),
Input(
name="system_message",
field_type=Optional[str],
display_name="System Message",
advanced=True,
info="System message to pass to the model.",
),
]
outputs = [
Output(display_name="Text", name="text_output", method="text_response"),
Output(display_name="Language Model", name="model_output", method="model_response"),
]
def build_config(self):
return {
"groq_api_key": {
"display_name": "Groq API Key",
"info": "API key for the Groq API.",
"password": True,
},
"groq_api_base": {
"display_name": "Groq API Base",
"info": "Base URL path for API requests, leave blank if not using a proxy or service emulator.",
"advanced": True,
},
"max_tokens": {
"display_name": "Max Output Tokens",
"info": "The maximum number of tokens to generate.",
"advanced": True,
},
"temperature": {
"display_name": "Temperature",
"info": "Run inference with this temperature. Must by in the closed interval [0.0, 1.0].",
},
"n": {
"display_name": "N",
"info": "Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.",
"advanced": True,
},
"model_name": {
"display_name": "Model",
"info": "The name of the model to use. Supported examples: gemini-pro",
"options": MODEL_NAMES,
},
"input_value": {"display_name": "Input", "info": "The input to the model."},
"stream": {
"display_name": "Stream",
"info": STREAM_INFO_TEXT,
"advanced": True,
},
"system_message": {
"display_name": "System Message",
"info": "System message to pass to the model.",
"advanced": True,
},
}
def text_response(self) -> Text:
input_value = self.input_value
stream = self.stream
system_message = self.system_message
output = self.model_response()
result = self.get_chat_result(output, stream, input_value, system_message)
self.status = result
return result
def model_response(self) -> BaseLanguageModel:
groq_api_key = self.groq_api_key
model_name = self.model_name
groq_api_base = self.groq_api_base or None
max_tokens = self.max_tokens
temperature = self.temperature
n = self.n or 1
stream = self.stream
def build(
self,
groq_api_key: str,
model_name: str,
input_value: Text,
groq_api_base: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: float = 0.1,
n: Optional[int] = 1,
stream: bool = False,
system_message: Optional[str] = None,
) -> Text:
output = ChatGroq(
model_name=model_name,
max_tokens=max_tokens or None, # type: ignore
temperature=temperature,
groq_api_base=groq_api_base,
n=n or 1,
n=n,
groq_api_key=SecretStr(groq_api_key),
streaming=stream,
)
return self.get_chat_result(output, stream, input_value, system_message)
return output