feat: Enhance HuggingFaceEndpointsComponent with additional parameters (#3846)
* Update HuggingFaceInferenceAPIEmbeddings.py update to use inference api from hugging face * Enhance HuggingFaceEndpointsComponent with additional parameters - Add FloatInput for top_p, typical_p, temperature, and repetition_penalty - Update create_huggingface_endpoint and build_model methods to include new parameters - Set default values and info for new inputs ToDo Need to update the Package from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint since its depreciated. * Updated HuggingFaceModel Solving Lint Error Updated HuggingFaceModel Solving Lint Error * Update HuggingFaceModel.py Added Inference Endpoint as an input from user to support custom inference endpoints * Update HuggingFaceModel.py paper references removed
This commit is contained in:
parent
aa2578370b
commit
1caba1cd7c
2 changed files with 113 additions and 20 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from urllib.parse import urlparse
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
import requests
|
||||
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
|
||||
|
|
@ -27,7 +28,7 @@ class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
|
|||
name="inference_endpoint",
|
||||
display_name="Inference Endpoint",
|
||||
required=True,
|
||||
value="http://localhost:8080",
|
||||
value="https://api-inference.huggingface.co/models/",
|
||||
info="Custom inference endpoint URL.",
|
||||
),
|
||||
MessageTextInput(
|
||||
|
|
@ -61,24 +62,32 @@ class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
|
|||
# returning True to solve linting error
|
||||
return True
|
||||
|
||||
def get_api_url(self) -> str:
|
||||
if "huggingface" in self.inference_endpoint.lower():
|
||||
return f"{self.inference_endpoint}{self.model_name}"
|
||||
else:
|
||||
return self.inference_endpoint
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
||||
def create_huggingface_embeddings(
|
||||
self, api_key: SecretStr, api_url: str, model_name: str
|
||||
) -> HuggingFaceInferenceAPIEmbeddings:
|
||||
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=api_url, model_name=model_name)
|
||||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
if not self.inference_endpoint:
|
||||
raise ValueError("Inference endpoint is required")
|
||||
api_url = self.get_api_url()
|
||||
|
||||
self.validate_inference_endpoint(self.inference_endpoint)
|
||||
is_local_url = api_url.startswith(("http://localhost", "http://127.0.0.1"))
|
||||
|
||||
# Check if the inference endpoint is local
|
||||
is_local_url = self.inference_endpoint.startswith(("http://localhost", "http://127.0.0.1"))
|
||||
|
||||
# Use a dummy key for local URLs if no key is provided.
|
||||
# Refer https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceInferenceAPIEmbeddings.html
|
||||
if not self.api_key and is_local_url:
|
||||
self.validate_inference_endpoint(api_url)
|
||||
api_key = SecretStr("DummyAPIKeyForLocalDeployment")
|
||||
elif not self.api_key:
|
||||
raise ValueError("API Key is required for non-local inference endpoints")
|
||||
else:
|
||||
api_key = SecretStr(self.api_key)
|
||||
|
||||
return HuggingFaceInferenceAPIEmbeddings(
|
||||
api_key=api_key, api_url=self.inference_endpoint, model_name=self.model_name
|
||||
)
|
||||
try:
|
||||
return self.create_huggingface_embeddings(api_key, api_url, self.model_name)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to HuggingFace Inference API.") from e
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
|
||||
# TODO: langchain_community.llms.huggingface_endpoint is depreciated. Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput
|
||||
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput, FloatInput
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class HuggingFaceEndpointsComponent(LCModelComponent):
|
||||
|
|
@ -18,22 +20,81 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
display_name="Model ID",
|
||||
value="openai-community/gpt2",
|
||||
),
|
||||
StrInput(
|
||||
name="inference_endpoint",
|
||||
display_name="Inference Endpoint",
|
||||
value="https://api-inference.huggingface.co/models/",
|
||||
info="Custom inference endpoint URL.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="task",
|
||||
display_name="Task",
|
||||
options=["text2text-generation", "text-generation", "summarization", "translation"],
|
||||
value="text-generation",
|
||||
advanced=True,
|
||||
info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.",
|
||||
),
|
||||
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True),
|
||||
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
|
||||
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True),
|
||||
IntInput(
|
||||
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
|
||||
),
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K",
|
||||
advanced=True,
|
||||
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
|
||||
),
|
||||
FloatInput(
|
||||
name="top_p",
|
||||
display_name="Top P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
info="If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation",
|
||||
),
|
||||
FloatInput(
|
||||
name="typical_p",
|
||||
display_name="Typical P",
|
||||
value=0.95,
|
||||
advanced=True,
|
||||
info="Typical Decoding mass.",
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.8,
|
||||
advanced=True,
|
||||
info="The value used to module the logits distribution",
|
||||
),
|
||||
FloatInput(
|
||||
name="repetition_penalty",
|
||||
display_name="Repetition Penalty",
|
||||
advanced=True,
|
||||
info="The parameter for repetition penalty. 1.0 means no penalty.",
|
||||
),
|
||||
]
|
||||
|
||||
def get_api_url(self) -> str:
|
||||
if "huggingface" in self.inference_endpoint.lower():
|
||||
return f"{self.inference_endpoint}{self.model_id}"
|
||||
else:
|
||||
return self.inference_endpoint
|
||||
|
||||
def create_huggingface_endpoint(
|
||||
self, model_id: str, task: str, huggingfacehub_api_token: str, model_kwargs: dict
|
||||
self,
|
||||
model_id: str,
|
||||
task: Optional[str],
|
||||
huggingfacehub_api_token: Optional[str],
|
||||
model_kwargs: Dict[str, Any],
|
||||
max_new_tokens: int,
|
||||
top_k: Optional[int],
|
||||
top_p: float,
|
||||
typical_p: Optional[float],
|
||||
temperature: Optional[float],
|
||||
repetition_penalty: Optional[float],
|
||||
) -> HuggingFaceEndpoint:
|
||||
retry_attempts = self.retry_attempts # Access the retry attempts input
|
||||
endpoint_url = f"https://api-inference.huggingface.co/models/{model_id}"
|
||||
retry_attempts = self.retry_attempts
|
||||
endpoint_url = self.get_api_url()
|
||||
|
||||
@retry(stop=stop_after_attempt(retry_attempts), wait=wait_fixed(2))
|
||||
def _attempt_create():
|
||||
|
|
@ -42,18 +103,41 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
|
|||
task=task,
|
||||
huggingfacehub_api_token=huggingfacehub_api_token,
|
||||
model_kwargs=model_kwargs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
typical_p=self.typical_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
return _attempt_create()
|
||||
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
def build_model(self) -> LanguageModel:
|
||||
model_id = self.model_id
|
||||
task = self.task
|
||||
task = self.task or None
|
||||
huggingfacehub_api_token = self.huggingfacehub_api_token
|
||||
model_kwargs = self.model_kwargs or {}
|
||||
max_new_tokens = self.max_new_tokens
|
||||
top_k = self.top_k or None
|
||||
top_p = self.top_p
|
||||
typical_p = self.typical_p or None
|
||||
temperature = self.temperature or 0.8
|
||||
repetition_penalty = self.repetition_penalty or None
|
||||
|
||||
try:
|
||||
llm = self.create_huggingface_endpoint(model_id, task, huggingfacehub_api_token, model_kwargs)
|
||||
llm = self.create_huggingface_endpoint(
|
||||
model_id=model_id,
|
||||
task=task,
|
||||
huggingfacehub_api_token=huggingfacehub_api_token,
|
||||
model_kwargs=model_kwargs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
typical_p=typical_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue