From 7ca1d8459603de564ed56a84f405624b049ef67b Mon Sep 17 00:00:00 2001 From: Cezar Vasconcelos <97035956+vasconceloscezar@users.noreply.github.com> Date: Wed, 24 Jul 2024 12:56:20 -0300 Subject: [PATCH] feat: add method and refresh button to fetch Groq models (#2902) * feat: add method and refresh button to fetch Groq models * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../langflow/components/models/GroqModel.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/backend/base/langflow/components/models/GroqModel.py b/src/backend/base/langflow/components/models/GroqModel.py index 7bad1bcf1..2e6d2df1e 100644 --- a/src/backend/base/langflow/components/models/GroqModel.py +++ b/src/backend/base/langflow/components/models/GroqModel.py @@ -1,7 +1,8 @@ +import requests +from typing import List from langchain_groq import ChatGroq from pydantic.v1 import SecretStr -from langflow.base.models.groq_constants import MODEL_NAMES from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel from langflow.io import DropdownInput, FloatInput, IntInput, MessageTextInput, SecretStrInput @@ -24,6 +25,7 @@ class GroqModel(LCModelComponent): display_name="Groq API Base", info="Base URL path for API requests, leave blank if not using a proxy or service emulator.", advanced=True, + value="https://api.groq.com", ), IntInput( name="max_tokens", @@ -47,10 +49,33 @@ class GroqModel(LCModelComponent): name="model_name", display_name="Model", info="The name of the model to use.", - options=MODEL_NAMES, + options=[], + refresh_button=True, ), ] + def get_models(self) -> List[str]: + api_key = self.groq_api_key + base_url = self.groq_api_base or "https://api.groq.com" + url = f"{base_url}/openai/v1/models" + + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + model_list = response.json() + return [model["id"] for model in model_list.get("data", [])] + except requests.RequestException as e: + self.status = f"Error fetching models: {str(e)}" + return [] + + def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): + if field_name == "groq_api_key" or field_name == "groq_api_base" or field_name == "model_name": + models = self.get_models() + build_config["model_name"]["options"] = models + return build_config + def build_model(self) -> LanguageModel: # type: ignore[type-var] groq_api_key = self.groq_api_key model_name = self.model_name