From d606a4dac3b1b250b2e4333c758f8f7edbeaf892 Mon Sep 17 00:00:00 2001 From: vinicius Mello <45274355+vmellos@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:41:42 -0400 Subject: [PATCH] feat: add huggingface endpoint retry (#3236) * feat: added retry when calling huggingface endpoint * chore: added default value to retry input * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../components/models/HuggingFaceModel.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/backend/base/langflow/components/models/HuggingFaceModel.py b/src/backend/base/langflow/components/models/HuggingFaceModel.py index 313d44001..069d63d18 100644 --- a/src/backend/base/langflow/components/models/HuggingFaceModel.py +++ b/src/backend/base/langflow/components/models/HuggingFaceModel.py @@ -1,9 +1,11 @@ +from tenacity import retry, stop_after_attempt, wait_fixed + from langchain_community.chat_models.huggingface import ChatHuggingFace from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel -from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput +from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput class HuggingFaceEndpointsComponent(LCModelComponent): @@ -26,8 +28,24 @@ class HuggingFaceEndpointsComponent(LCModelComponent): ), SecretStrInput(name="huggingfacehub_api_token", display_name="API token", password=True), DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True), + IntInput(name="retry_attempts", display_name="Retry Attempts", value=1), ] + def create_huggingface_endpoint(self, endpoint_url, task, huggingfacehub_api_token, model_kwargs): + @retry(stop=stop_after_attempt(self.retry_attempts), wait=wait_fixed(2)) + def _attempt_create(): + try: + return HuggingFaceEndpoint( # type: ignore + endpoint_url=endpoint_url, + task=task, + huggingfacehub_api_token=huggingfacehub_api_token, + model_kwargs=model_kwargs, + ) + except Exception as e: + raise ValueError("Could not connect to HuggingFace Endpoints API.") from e + + return _attempt_create() + def build_model(self) -> LanguageModel: # type: ignore[type-var] endpoint_url = self.endpoint_url task = self.task @@ -35,12 +53,7 @@ class HuggingFaceEndpointsComponent(LCModelComponent): model_kwargs = self.model_kwargs or {} try: - llm = HuggingFaceEndpoint( # type: ignore - endpoint_url=endpoint_url, - task=task, - huggingfacehub_api_token=huggingfacehub_api_token, - model_kwargs=model_kwargs, - ) + llm = self.create_huggingface_endpoint(endpoint_url, task, huggingfacehub_api_token, model_kwargs) except Exception as e: raise ValueError("Could not connect to HuggingFace Endpoints API.") from e