feat: add huggingface endpoint retry (#3236)
* feat: added retry when calling huggingface endpoint * chore: added default value to retry input * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e42b6bdb94
commit
d606a4dac3
1 changed files with 20 additions and 7 deletions
|
|
@ -1,9 +1,11 @@
|
|||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from langchain_community.chat_models.huggingface import ChatHuggingFace
|
||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput
|
||||
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput
|
||||
|
||||
|
||||
class HuggingFaceEndpointsComponent(LCModelComponent):
|
||||
|
|
@ -26,8 +28,24 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
),
|
||||
SecretStrInput(name="huggingfacehub_api_token", display_name="API token", password=True),
|
||||
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
|
||||
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1),
|
||||
]
|
||||
|
||||
def create_huggingface_endpoint(self, endpoint_url, task, huggingfacehub_api_token, model_kwargs):
|
||||
@retry(stop=stop_after_attempt(self.retry_attempts), wait=wait_fixed(2))
|
||||
def _attempt_create():
|
||||
try:
|
||||
return HuggingFaceEndpoint( # type: ignore
|
||||
endpoint_url=endpoint_url,
|
||||
task=task,
|
||||
huggingfacehub_api_token=huggingfacehub_api_token,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e
|
||||
|
||||
return _attempt_create()
|
||||
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
endpoint_url = self.endpoint_url
|
||||
task = self.task
|
||||
|
|
@ -35,12 +53,7 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
model_kwargs = self.model_kwargs or {}
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint( # type: ignore
|
||||
endpoint_url=endpoint_url,
|
||||
task=task,
|
||||
huggingfacehub_api_token=huggingfacehub_api_token,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
llm = self.create_huggingface_endpoint(endpoint_url, task, huggingfacehub_api_token, model_kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue