fix: add config options and error handling to MistralAI component (#6131)
* ✨ (mistral.py): Add new input parameters to MistralAIModelComponent for better customization and control over the Mistral model configuration ♻️ (mistral.py): Refactor build_model method to improve readability and maintainability by using try-except block for error handling and updating parameter names for better clarity * [autofix.ci] apply automated fixes * ♻️ (mistral.py): refactor MistralAIModelComponent class to improve code readability by formatting the IntInput and BoolInput sections for better organization and clarity. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f7db8eecf2
commit
f4715407b8
1 changed files with 62 additions and 34 deletions
|
|
@ -46,41 +46,69 @@ class MistralAIModelComponent(LCModelComponent):
|
|||
display_name="Mistral API Key",
|
||||
info="The Mistral API Key to use for the Mistral model.",
|
||||
advanced=False,
|
||||
required=True,
|
||||
value="MISTRAL_API_KEY",
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
advanced=False,
|
||||
value=0.5,
|
||||
),
|
||||
IntInput(
|
||||
name="max_retries",
|
||||
display_name="Max Retries",
|
||||
advanced=True,
|
||||
value=5,
|
||||
),
|
||||
IntInput(
|
||||
name="timeout",
|
||||
display_name="Timeout",
|
||||
advanced=True,
|
||||
value=60,
|
||||
),
|
||||
IntInput(
|
||||
name="max_concurrent_requests",
|
||||
display_name="Max Concurrent Requests",
|
||||
advanced=True,
|
||||
value=3,
|
||||
),
|
||||
FloatInput(
|
||||
name="top_p",
|
||||
display_name="Top P",
|
||||
advanced=True,
|
||||
value=1,
|
||||
),
|
||||
IntInput(
|
||||
name="random_seed",
|
||||
display_name="Random Seed",
|
||||
value=1,
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="safe_mode",
|
||||
display_name="Safe Mode",
|
||||
advanced=True,
|
||||
value=False,
|
||||
),
|
||||
FloatInput(name="temperature", display_name="Temperature", advanced=False, value=0.5),
|
||||
IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=5),
|
||||
IntInput(name="timeout", display_name="Timeout", advanced=True, value=60),
|
||||
IntInput(name="max_concurrent_requests", display_name="Max Concurrent Requests", advanced=True, value=3),
|
||||
FloatInput(name="top_p", display_name="Top P", advanced=True, value=1),
|
||||
IntInput(name="random_seed", display_name="Random Seed", value=1, advanced=True),
|
||||
BoolInput(name="safe_mode", display_name="Safe Mode", advanced=True),
|
||||
]
|
||||
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
mistral_api_key = self.api_key
|
||||
temperature = self.temperature
|
||||
model_name = self.model_name
|
||||
max_tokens = self.max_tokens
|
||||
mistral_api_base = self.mistral_api_base or "https://api.mistral.ai/v1"
|
||||
max_retries = self.max_retries
|
||||
timeout = self.timeout
|
||||
max_concurrent_requests = self.max_concurrent_requests
|
||||
top_p = self.top_p
|
||||
random_seed = self.random_seed
|
||||
safe_mode = self.safe_mode
|
||||
|
||||
api_key = SecretStr(mistral_api_key).get_secret_value() if mistral_api_key else None
|
||||
|
||||
return ChatMistralAI(
|
||||
max_tokens=max_tokens or None,
|
||||
model_name=model_name,
|
||||
endpoint=mistral_api_base,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
top_p=top_p,
|
||||
random_seed=random_seed,
|
||||
safe_mode=safe_mode,
|
||||
)
|
||||
try:
|
||||
return ChatMistralAI(
|
||||
model_name=self.model_name,
|
||||
mistral_api_key=SecretStr(self.api_key).get_secret_value() if self.api_key else None,
|
||||
endpoint=self.mistral_api_base or "https://api.mistral.ai/v1",
|
||||
max_tokens=self.max_tokens or None,
|
||||
temperature=self.temperature,
|
||||
max_retries=self.max_retries,
|
||||
timeout=self.timeout,
|
||||
max_concurrent_requests=self.max_concurrent_requests,
|
||||
top_p=self.top_p,
|
||||
random_seed=self.random_seed,
|
||||
safe_mode=self.safe_mode,
|
||||
streaming=self.stream,
|
||||
)
|
||||
except Exception as e:
|
||||
msg = "Could not connect to MistralAI API."
|
||||
raise ValueError(msg) from e
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue