feat(model_providers): Support deepseek for Azure AI Foundry (#13267)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-02-06 12:45:48 +08:00 committed by GitHub
commit 87763fc234
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 24 additions and 58 deletions

View file

@ -1,9 +1,9 @@
import logging
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Any, Optional, Union
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import StreamingChatCompletionsUpdate
from azure.ai.inference.models import StreamingChatCompletionsUpdate, SystemMessage, UserMessage
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import (
ClientAuthenticationError,
@ -60,10 +60,10 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[Sequence[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -82,8 +82,8 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
"""
if not self.client:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
endpoint = str(credentials.get("endpoint"))
api_key = str(credentials.get("api_key"))
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
@ -94,6 +94,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
"temperature": model_parameters.get("temperature", 0),
"top_p": model_parameters.get("top_p", 1),
"stream": stream,
"model": model,
}
if stop:
@ -255,10 +256,16 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
endpoint = str(credentials.get("endpoint"))
api_key = str(credentials.get("api_key"))
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
client.get_model_info()
client.complete(
messages=[
SystemMessage(content="I say 'ping', you say 'pong'"),
UserMessage(content="ping"),
],
model=model,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))