feat: add huggingface endpoint retry (#3236)

* feat: added retry when calling huggingface endpoint

* chore: added default value to retry input

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
vinicius Mello 2024-08-08 16:41:42 -04:00 committed by GitHub
commit d606a4dac3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,9 +1,11 @@
from tenacity import retry, stop_after_attempt, wait_fixed
from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput
class HuggingFaceEndpointsComponent(LCModelComponent):
@ -26,8 +28,24 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
),
SecretStrInput(name="huggingfacehub_api_token", display_name="API token", password=True),
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1),
]
def create_huggingface_endpoint(self, endpoint_url, task, huggingfacehub_api_token, model_kwargs):
@retry(stop=stop_after_attempt(self.retry_attempts), wait=wait_fixed(2))
def _attempt_create():
try:
return HuggingFaceEndpoint( # type: ignore
endpoint_url=endpoint_url,
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e
return _attempt_create()
def build_model(self) -> LanguageModel: # type: ignore[type-var]
endpoint_url = self.endpoint_url
task = self.task
@ -35,12 +53,7 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
model_kwargs = self.model_kwargs or {}
try:
llm = HuggingFaceEndpoint( # type: ignore
endpoint_url=endpoint_url,
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
)
llm = self.create_huggingface_endpoint(endpoint_url, task, huggingfacehub_api_token, model_kwargs)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e