Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost 2024-01-02 23:42:00 +08:00 committed by GitHub
commit d069c668f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
807 changed files with 171310 additions and 23806 deletions

View file

@ -9,11 +9,12 @@ AZURE_OPENAI_API_KEY=
ANTHROPIC_API_KEY=
# Replicate API Key
REPLICATE_API_TOKEN=
REPLICATE_API_KEY=
# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_ENDPOINT_URL=
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=
# Minimax Credentials
@ -44,7 +45,10 @@ CHATGLM_API_BASE=
# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
XINFERENCE_GENERATION_MODEL_UID=
XINFERENCE_CHAT_MODEL_UID=
XINFERENCE_EMBEDDINGS_MODEL_UID=
XINFERENCE_RERANK_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=
@ -56,4 +60,7 @@ LOCALAI_SERVER_URL=
COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
JINA_API_KEY=
# Mock Switch
MOCK_SWITCH=false

View file

@ -0,0 +1,68 @@
import anthropic
from anthropic import Anthropic
from anthropic.resources.completions import Completions
from anthropic.types import completion_create_params, Completion
from anthropic._types import NOT_GIVEN, NotGiven, Headers, Query, Body
from _pytest.monkeypatch import MonkeyPatch
from typing import List, Union, Literal, Any, Generator
from time import sleep
import pytest
import os
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
class MockAnthropicClass(object):
@staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Completion:
return Completion(
completion='hello, I\'m a chatbot from anthropic',
model=model,
stop_reason='stop_sequence'
)
@staticmethod
def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]:
full_response_text = "hello, I'm a chatbot from anthropic"
for i in range(0, len(full_response_text) + 1):
sleep(0.1)
if i == len(full_response_text):
yield Completion(
completion='',
model=model,
stop_reason='stop_sequence'
)
else:
yield Completion(
completion=full_response_text[i],
model=model,
stop_reason=''
)
def mocked_anthropic(self: Completions, *,
max_tokens_to_sample: int,
model: Union[str, Literal["claude-2.1", "claude-instant-1"]],
prompt: str,
stream: Literal[True],
**kwargs: Any
) -> Union[Completion, Generator[Completion, None, None]]:
if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError('Invalid API key')
if stream:
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
else:
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
@pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic)
yield
if MOCK:
monkeypatch.undo()

View file

@ -0,0 +1,127 @@
from google.generativeai import GenerativeModel
from google.generativeai.types import GenerateContentResponse
from google.generativeai.types.generation_types import BaseGenerateContentResponse
import google.generativeai.types.generation_types as generation_config_types
import google.generativeai.types.content_types as content_types
import google.generativeai.types.safety_types as safety_types
from google.generativeai.client import _ClientManager, configure
from google.ai import generativelanguage as glm
from typing import Generator, List
from _pytest.monkeypatch import MonkeyPatch
import pytest
current_api_key = ''
class MockGoogleResponseClass(object):
_done = False
def __iter__(self):
full_response_text = 'it\'s google!'
for i in range(0, len(full_response_text) + 1, 1):
if i == len(full_response_text):
self._done = True
yield GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
)
else:
yield GenerateContentResponse(
done=False,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
)
class MockGoogleResponseCandidateClass(object):
finish_reason = 'stop'
class MockGoogleClass(object):
@staticmethod
def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
)
@staticmethod
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
return MockGoogleResponseClass()
def generate_content(self: GenerativeModel,
contents: content_types.ContentsType,
*,
generation_config: generation_config_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
stream: bool = False,
**kwargs,
) -> GenerateContentResponse:
global current_api_key
if len(current_api_key) < 16:
raise Exception('Invalid API key')
if stream:
return MockGoogleClass.generate_content_stream()
return MockGoogleClass.generate_content_sync()
@property
def generative_response_text(self) -> str:
return 'it\'s google!'
@property
def generative_response_candidates(self) -> List[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
# Attempt to configure using defaults.
if not self.client_config:
configure()
client_options = self.client_config.get("client_options", None)
if client_options:
current_api_key = client_options.api_key
def nop(self, *args, **kwargs):
pass
original_init = cls.__init__
cls.__init__ = nop
client: glm.GenerativeServiceClient = cls(**self.client_config)
cls.__init__ = original_init
if not self.default_metadata:
return client
@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
yield
monkeypatch.undo()

View file

@ -0,0 +1,21 @@
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
from huggingface_hub import InferenceClient
from _pytest.monkeypatch import MonkeyPatch
from typing import List, Dict, Any
import pytest
import os
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
yield
if MOCK:
monkeypatch.undo()

View file

@ -0,0 +1,54 @@
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Details, StreamDetails, Token
from huggingface_hub.utils import BadRequestError
from typing import Literal, Optional, List, Generator, Union, Any
from _pytest.monkeypatch import MonkeyPatch
import re
class MockHuggingfaceChatClass(object):
@staticmethod
def generate_create_sync(model: str) -> TextGenerationResponse:
response = TextGenerationResponse(
generated_text="You can call me Miku Miku o~e~o~",
details=Details(
finish_reason="length",
generated_tokens=6,
tokens=[
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
]
)
)
return response
@staticmethod
def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
full_text = "You can call me Miku Miku o~e~o~"
for i in range(0, len(full_text)):
response = TextGenerationStreamResponse(
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
)
response.generated_text = full_text[i]
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
yield response
def text_generation(self: InferenceClient, prompt: str, *,
stream: Literal[False] = ...,
model: Optional[str] = None,
**kwargs: Any
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
# check if key is valid
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
raise BadRequestError('Invalid API key')
if model is None:
raise BadRequestError('Invalid model')
if stream:
return MockHuggingfaceChatClass.generate_create_stream(model)
return MockHuggingfaceChatClass.generate_create_sync(model)

View file

@ -0,0 +1,63 @@
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
from openai.resources.completions import Completions
from openai.resources.chat import Completions as ChatCompletions
from openai.resources.models import Models
from openai.resources.moderations import Moderations
from openai.resources.audio.transcriptions import Transcriptions
from openai.resources.embeddings import Embeddings
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from typing import Literal, Callable, List
import os
import pytest
def mock_openai(monkeypatch: MonkeyPatch, methods: List[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
"""
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "completion" in methods:
monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)
if "chat" in methods:
monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)
if "remote" in methods:
monkeypatch.setattr(Models, "list", MockModelClass.list)
if "moderation" in methods:
monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
if "speech2text" in methods:
monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)
if "text_embedding" in methods:
monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_openai_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
if MOCK:
unpatch = mock_openai(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View file

@ -0,0 +1,235 @@
from openai import OpenAI
from openai.types import Completion as CompletionMessage
from openai._types import NotGiven, NOT_GIVEN
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, \
ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall,\
Choice, ChoiceDelta, ChoiceDeltaToolCallFunction
from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice, ChatCompletion as _ChatCompletion
from openai.types.chat.chat_completion_message import FunctionCall, ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.completion_usage import CompletionUsage
from openai.resources.chat.completions import Completions
from openai import AzureOpenAI
import openai.types.chat.completion_create_params as completion_create_params
# import monkeypatch
from typing import List, Any, Generator, Union, Optional, Literal
from time import time, sleep
from json import dumps, loads
from core.model_runtime.errors.invoke import InvokeAuthorizationError
import re
class MockChatClass(object):
@staticmethod
def generate_function_call(
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
) -> Optional[FunctionCall]:
if not functions or len(functions) == 0:
return None
function: completion_create_params.Function = functions[0]
function_name = function['name']
function_description = function['description']
function_parameters = function['parameters']
function_parameters_type = function_parameters['type']
if function_parameters_type != 'object':
return None
function_parameters_properties = function_parameters['properties']
function_parameters_required = function_parameters['required']
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter['type']
if parameter_type == 'string':
if 'enum' in parameter:
if len(parameter['enum']) == 0:
continue
parameters[parameter_name] = parameter['enum'][0]
else:
parameters[parameter_name] = 'kawaii'
elif parameter_type == 'integer':
parameters[parameter_name] = 114514
elif parameter_type == 'number':
parameters[parameter_name] = 1919810.0
elif parameter_type == 'boolean':
parameters[parameter_name] = True
return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod
def generate_tool_calls(
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> Optional[List[ChatCompletionMessageToolCall]]:
list_tool_calls = []
if not tools or len(tools) == 0:
return None
tool: ChatCompletionToolParam = tools[0]
if tools['type'] != 'function':
return None
function = tool['function']
function_call = MockChatClass.generate_function_call(functions=[function])
if function_call is None:
return None
list_tool_calls.append(ChatCompletionMessageToolCall(
id='sakurajima-mai',
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type='function'
))
return list_tool_calls
@staticmethod
def mocked_openai_chat_create_sync(
model: str,
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> CompletionMessage:
tool_calls = []
function_call = MockChatClass.generate_function_call(functions=functions)
if not function_call:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
sleep(1)
return _ChatCompletion(
id='cmpl-3QJQa5jXJ5Z5X',
choices=[
_ChatCompletionChoice(
finish_reason='content_filter',
index=0,
message=ChatCompletionMessage(
content='elaina',
role='assistant',
function_call=function_call,
tool_calls=tool_calls
)
)
],
created=int(time()),
model=model,
object='chat.completion',
system_fingerprint='',
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
)
@staticmethod
def mocked_openai_chat_create_stream(
model: str,
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> Generator[ChatCompletionChunk, None, None]:
tool_calls = []
function_call = MockChatClass.generate_function_call(functions=functions)
if not function_call:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
sleep(0.1)
if i == len(full_text):
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
choices=[
Choice(
delta=ChoiceDelta(
content='',
function_call=ChoiceDeltaFunctionCall(
name=function_call.name,
arguments=function_call.arguments,
) if function_call else None,
role='assistant',
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id='misaka-mikoto',
function=ChoiceDeltaToolCallFunction(
name=tool_calls[0].function.name,
arguments=tool_calls[0].function.arguments,
),
type='function'
)
] if tool_calls and len(tool_calls) > 0 else None
),
finish_reason='function_call',
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
total_tokens=19,
),
)
else:
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
choices=[
Choice(
delta=ChoiceDelta(
content=full_text[i],
role='assistant',
),
finish_reason='content_filter',
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
)
def chat_create(self: Completions, *,
messages: List[ChatCompletionMessageParam],
model: Union[str,Literal[
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
],
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
**kwargs: Any,
):
openai_models = [
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
]
azure_openai_models = [
"gpt35", "gpt-4v", "gpt-35-turbo"
]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
if stream:
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)

View file

@ -0,0 +1,121 @@
from openai import BadRequestError, OpenAI, AzureOpenAI
from openai.types import Completion as CompletionMessage
from openai._types import NotGiven, NOT_GIVEN
from openai.types.completion import CompletionChoice
from openai.types.completion_usage import CompletionUsage
from openai.resources.completions import Completions
# import monkeypatch
from typing import List, Any, Generator, Union, Optional, Literal
from time import time, sleep
from core.model_runtime.errors.invoke import InvokeAuthorizationError
import re
class MockCompletionsClass(object):
@staticmethod
def mocked_openai_completion_create_sync(
model: str
) -> CompletionMessage:
sleep(1)
return CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
created=int(time()),
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text="mock",
index=0,
logprobs=None,
finish_reason="stop",
)
],
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
)
@staticmethod
def mocked_openai_completion_create_stream(
model: str
) -> Generator[CompletionMessage, None, None]:
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
sleep(0.1)
if i == len(full_text):
yield CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
created=int(time()),
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text="",
index=0,
logprobs=None,
finish_reason="stop",
)
],
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
total_tokens=19,
),
)
else:
yield CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
created=int(time()),
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text=full_text[i],
index=0,
logprobs=None,
finish_reason="content_filter"
)
],
)
def completion_create(self: Completions, *, model: Union[
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
"text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001",
"text-ada-001"],
],
prompt: Union[str, List[str], List[int], List[List[int]], None],
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
**kwargs: Any
):
openai_models = [
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
]
azure_openai_models = [
"gpt-35-turbo-instruct"
]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
if not prompt:
raise BadRequestError('Invalid prompt')
if stream:
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,67 @@
from openai.resources.moderations import Moderations
from openai.types import ModerationCreateResponse
from openai.types.moderation import Moderation, Categories, CategoryScores
from openai._types import NotGiven, NOT_GIVEN
from typing import Union, List, Literal, Any
from core.model_runtime.errors.invoke import InvokeAuthorizationError
import re
class MockModerationClass(object):
def moderation_create(self: Moderations,*,
input: Union[str, List[str]],
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
**kwargs: Any
) -> ModerationCreateResponse:
if isinstance(input, str):
input = [input]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
for text in input:
result = []
if 'kill' in text:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
}
moderation_categories_scores = {
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
}
result.append(Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
else:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
}
moderation_categories_scores = {
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
}
result.append(Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
return ModerationCreateResponse(
id='shiroii kuloko',
model=model,
results=result
)

View file

@ -0,0 +1,22 @@
from openai.resources.models import Models
from openai.types.model import Model
from typing import List
from time import time
class MockModelClass(object):
"""
mock class for openai.models.Models
"""
def list(
self,
**kwargs,
) -> List[Model]:
return [
Model(
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
created=int(time()),
object='model',
owned_by='organization:org-123',
)
]

View file

@ -0,0 +1,30 @@
from openai.resources.audio.transcriptions import Transcriptions
from openai._types import NotGiven, NOT_GIVEN, FileTypes
from openai.types.audio.transcription import Transcription
from typing import Union, List, Literal, Any
from core.model_runtime.errors.invoke import InvokeAuthorizationError
import re
class MockSpeech2TextClass(object):
def speech2text_create(self: Transcriptions,
*,
file: FileTypes,
model: Union[str, Literal["whisper-1"]],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
**kwargs: Any
) -> Transcription:
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
return Transcription(
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
)

View file

@ -0,0 +1,142 @@
from xinference_client.client.restful.restful_client import Client, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatglmCppChatModelHandle, \
RESTfulEmbeddingModelHandle, RESTfulRerankModelHandle
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
from requests.sessions import Session
from requests import Response
from requests.exceptions import ConnectionError
from typing import Union, List
from _pytest.monkeypatch import MonkeyPatch
import pytest
import os
import re
class MockXinferenceClass(object):
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
raise RuntimeError('404 Not Found')
if 'generate' == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
if 'chat' == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
if 'embedding' == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
if 'rerank' == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
raise RuntimeError('404 Not Found')
def get(self: Session, url: str, **kwargs):
if '/v1/models/' in url:
response = Response()
# get model uid
model_uid = url.split('/')[-1]
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404
raise ConnectionError('404 Not Found')
# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404
raise ConnectionError('404 Not Found')
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'rerank':
raise RuntimeError('404 Not Found')
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
raise RuntimeError('404 Not Found')
if top_n is None:
top_n = 1
return {
'results': [
{
'index': i,
'document': doc,
'relevance_score': 0.9
}
for i, doc in enumerate(documents[:top_n])
]
}
def create_embedding(
self: RESTfulGenerateModelHandle,
input: Union[str, List[str]],
**kwargs
) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'embedding':
raise RuntimeError('404 Not Found')
if isinstance(input, str):
input = [input]
ipt_len = len(input)
embedding = Embedding(
object="list",
model=self._model_uid,
data=[
EmbeddingData(
index=i,
object="embedding",
embedding=[1919.810 for _ in range(768)]
)
for i in range(ipt_len)
],
usage=EmbeddingUsage(
prompt_tokens=ipt_len,
total_tokens=ipt_len
)
)
return embedding
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
yield
if MOCK:
monkeypatch.undo()

View file

@ -0,0 +1,116 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
def test_validate_credentials(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='claude-instant-1',
credentials={
'anthropic_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='claude-instant-1',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
)
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
def test_invoke_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model='claude-instant-1',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'max_tokens_to_sample': 10
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
def test_invoke_stream_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model='claude-instant-1',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens_to_sample': 100
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = AnthropicLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='claude-instant-1',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 18

View file

@ -0,0 +1,23 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
def test_validate_provider_credentials(setup_anthropic_mock):
provider = AnthropicProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
)

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,71 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': 'invalid_key',
'base_model_name': 'text-embedding-ada-002'
}
)
model.validate_credentials(
model='embedding',
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
'base_model_name': 'text-embedding-ada-002'
}
)
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
def test_invoke_model(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
result = model.invoke(
model='embedding',
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
'base_model_name': 'text-embedding-ada-002'
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = AzureOpenAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embedding',
credentials={
'base_model_name': 'text-embedding-ada-002'
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,190 @@
import os
import pytest
from typing import Generator
from time import sleep
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLarguageModel
def test_predefined_models():
model = BaichuanLarguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = BaichuanLarguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='baichuan2-turbo',
credentials={
'api_key': 'invalid_key',
'secret_key': 'invalid_key'
}
)
model.validate_credentials(
model='baichuan2-turbo',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
}
)
def test_invoke_model():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_with_system_message():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='请记住你是Kasumi。'
),
UserPromptMessage(
content='现在告诉我你是谁?'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'with_search_enhance': True,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
total_message = ''
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
total_message += chunk.delta.message.content
assert '' not in total_message
def test_get_num_tokens():
sleep(3)
model = BaichuanLarguageModel()
response = model.get_num_tokens(
model='baichuan2-turbo',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
)
assert isinstance(response, int)
assert response == 9

View file

@ -0,0 +1,23 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider
def test_validate_provider_credentials():
provider = BaichuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha'
}
)
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY')
}
)

View file

@ -0,0 +1,61 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel
def test_validate_credentials():
model = BaichuanTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='baichuan-text-embedding',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='baichuan-text-embedding',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY')
}
)
def test_invoke_model():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model='baichuan-text-embedding',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = BaichuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='baichuan-text-embedding',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,287 @@
import os
import pytest
from typing import Generator
from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \
SystemPromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
def test_predefined_models():
model = ChatGLMLargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chatglm2-6b',
credentials={
'api_base': 'invalid_key'
}
)
model.validate_credentials(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
}
)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_stream_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_stream_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm3-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
SystemPromptMessage(
content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
),
UserPromptMessage(
content='波士顿天气如何?'
)
],
model_parameters={
'temperature': 0,
'top_p': 1.0,
},
stop=['you'],
user='abc-123',
stream=True,
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": [
"location"
]
}
)
]
)
assert isinstance(response, Generator)
call: LLMResultChunk = None
chunks = []
for chunk in response:
chunks.append(chunk)
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
call = chunk
break
assert call is not None
assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_invoke_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm3-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
UserPromptMessage(
content='What is the weather like in San Francisco?'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user='abc-123',
stream=False,
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
},
"required": [
"location"
]
}
)
]
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
assert response.message.tool_calls[0].function.name == 'get_current_weather'
def test_get_num_tokens():
model = ChatGLMLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
},
"required": [
"location"
]
}
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21

View file

@ -0,0 +1,25 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = ChatGLMProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_base': 'hahahaha'
}
)
provider.validate_provider_credentials(
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
}
)

View file

@ -0,0 +1,21 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.cohere import CohereProvider
def test_validate_provider_credentials():
provider = CohereProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)

View file

@ -0,0 +1,51 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel
def test_validate_credentials():
model = CohereRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='rerank-english-v2.0',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='rerank-english-v2.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
def test_invoke_model():
model = CohereRerankModel()
result = model.invoke(
model='rerank-english-v2.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
"is the capital of the United States. It is a federal district. The President of the USA and many major "
"national government offices are in the territory. This makes it the political center of the United "
"States of America."
],
score_threshold=0.8
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 1
assert result.docs[0].score >= 0.8

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,23 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.google import GoogleProvider
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_validate_provider_credentials(setup_google_mock):
provider = GoogleProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
}
)

View file

@ -0,0 +1,304 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='fake-model',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
)
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
}
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
)
model.validate_credentials(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
)
model.validate_credentials(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = HuggingfaceHubLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 7

View file

@ -0,0 +1,120 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import \
HuggingfaceHubTextEmbeddingModel
def test_hosted_inference_api_validate_credentials():
model = HuggingfaceHubTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='facebook/bart-base',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key',
}
)
model.validate_credentials(
model='facebook/bart-base',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
}
)
def test_hosted_inference_api_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='facebook/bart-base',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
},
texts=[
"hello",
"world"
]
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_inference_endpoints_validate_credentials():
model = HuggingfaceHubTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
)
model.validate_credentials(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
)
def test_inference_endpoints_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
},
texts=[
"hello",
"world"
]
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 0
def test_get_num_tokens():
model = HuggingfaceHubTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,23 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.jina.jina import JinaProvider
def test_validate_provider_credentials():
provider = JinaProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha'
}
)
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('JINA_API_KEY')
}
)

View file

@ -0,0 +1,63 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel
def test_validate_credentials():
model = JinaTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': os.environ.get('JINA_API_KEY')
}
)
def test_invoke_model():
model = JinaTextEmbeddingModel()
result = model.invoke(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': os.environ.get('JINA_API_KEY'),
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = JinaTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': os.environ.get('JINA_API_KEY'),
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 6

View file

@ -0,0 +1,4 @@
"""
LocalAI Embedding Interface is temporarily unavaliable due to
we could not find a way to test it for now.
"""

View file

@ -0,0 +1,213 @@
import os
import pytest
from typing import Generator
from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \
SystemPromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ParameterRule
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel
def test_validate_credentials_for_chat_model():
model = LocalAILarguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chinese-llama-2-7b',
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
)
model.validate_credentials(
model='chinese-llama-2-7b',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
)
def test_invoke_completion_model():
model = LocalAILarguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
},
prompt_messages=[
UserPromptMessage(
content='ping'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=[],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_chat_model():
model = LocalAILarguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
},
prompt_messages=[
UserPromptMessage(
content='ping'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=[],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_completion_model():
model = LocalAILarguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_stream_chat_model():
model = LocalAILarguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = LocalAILarguageModel()
num_tokens = model.get_num_tokens(
model='????',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
},
"required": [
"location"
]
}
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='????',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 10

View file

@ -0,0 +1,64 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel
def test_validate_credentials():
model = MinimaxTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embo-01',
credentials={
'minimax_api_key': 'invalid_key',
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
)
model.validate_credentials(
model='embo-01',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
)
def test_invoke_model():
model = MinimaxTextEmbeddingModel()
result = model.invoke(
model='embo-01',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 16
def test_get_num_tokens():
model = MinimaxTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embo-01',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,158 @@
import os
import pytest
from typing import Generator
from time import sleep
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel
def test_predefined_models():
model = MinimaxLargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = MinimaxLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='abab5.5-chat',
credentials={
'minimax_api_key': 'invalid_key',
'minimax_group_id': 'invalid_key'
}
)
model.validate_credentials(
model='abab5.5-chat',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
)
def test_invoke_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5-chat',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5.5-chat',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5.5-chat',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
total_message = ''
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
assert '参考资料' in total_message
def test_get_num_tokens():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.get_num_tokens(
model='abab5.5-chat',
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
)
assert isinstance(response, int)
assert response == 30

View file

@ -0,0 +1,25 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider
def test_validate_provider_credentials():
provider = MinimaxProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'minimax_api_key': 'hahahaha',
'minimax_group_id': '123',
}
)
provider.validate_provider_credentials(
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
}
)

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,55 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAIModerationModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text-moderation-stable',
credentials={
'openai_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='text-moderation-stable',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAIModerationModel()
result = model.invoke(
model='text-moderation-stable',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
},
text="hello",
user="abc-123"
)
assert isinstance(result, bool)
assert result is False
result = model.invoke(
model='text-moderation-stable',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
},
text="i will kill you",
user="abc-123"
)
assert isinstance(result, bool)
assert result is True

View file

@ -0,0 +1,23 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai.openai import OpenAIProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = OpenAIProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)

View file

@ -0,0 +1,56 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai.speech2text.speech2text import OpenAISpeech2TextModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAISpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='whisper-1',
credentials={
'openai_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='whisper-1',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAISpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
file = audio_file
result = model.invoke(
model='whisper-1',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
},
file=file,
user="abc-123"
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'

View file

@ -0,0 +1,67 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'openai_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAITextEmbeddingModel()
result = model.invoke(
model='text-embedding-ada-002',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
'openai_api_base': 'https://api.openai.com'
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = OpenAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='text-embedding-ada-002',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
'openai_api_base': 'https://api.openai.com'
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,181 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
SystemPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
"""
Using Together.ai's OpenAI-compatible API as testing endpoint
"""
def test_validate_credentials():
model = OAIAPICompatLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': 'invalid_key',
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
'mode': 'chat'
}
)
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
'mode': 'chat'
}
)
def test_invoke_model():
model = OAIAPICompatLargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/completions',
'mode': 'completion'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = OAIAPICompatLargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
'mode': 'chat'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
# using OpenAI's ChatGPT-3.5 as testing endpoint
def test_invoke_chat_model_with_tools():
model = OAIAPICompatLargeLanguageModel()
result = model.invoke(
model='gpt-3.5-turbo',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/chat/completions',
'mode': 'chat'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content="what's the weather today in London?",
)
],
tools=[
PromptMessageTool(
name='get_weather',
description='Determine weather in my location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
]
}
},
"required": [
"location"
]
}
),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 1024
},
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)
assert len(result.message.tool_calls) > 0
def test_get_num_tokens():
model = OAIAPICompatLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/chat/completions'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 21

View file

@ -0,0 +1,79 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import OAICompatEmbeddingModel
"""
Using OpenAI's API as testing endpoint
"""
def test_validate_credentials():
model = OAICompatEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'api_key': 'invalid_key',
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
}
)
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
}
)
def test_invoke_model():
model = OAICompatEmbeddingModel()
result = model.invoke(
model='text-embedding-ada-002',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = OAICompatEmbeddingModel()
num_tokens = model.get_num_tokens(
model='text-embedding-ada-002',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,61 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openllm.text_embedding.text_embedding import OpenLLMTextEmbeddingModel
def test_validate_credentials():
model = OpenLLMTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='NOT IMPORTANT',
credentials={
'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'),
}
)
model.validate_credentials(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
}
)
def test_invoke_model():
model = OpenLLMTextEmbeddingModel()
result = model.invoke(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0
def test_get_num_tokens():
model = OpenLLMTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,104 @@
import os
import pytest
from typing import Generator
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openllm.llm.llm import OpenLLMLargeLanguageModel
def test_validate_credentials_for_chat_model():
model = OpenLLMLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='NOT IMPORTANT',
credentials={
'server_url': 'invalid_key',
}
)
model.validate_credentials(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
}
)
def test_invoke_model():
model = OpenLLMLargeLanguageModel()
response = model.invoke(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
model = OpenLLMLargeLanguageModel()
response = model.invoke(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = OpenLLMLargeLanguageModel()
response = model.get_num_tokens(
model='NOT IMPORTANT',
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
)
assert isinstance(response, int)
assert response == 3

View file

@ -0,0 +1,119 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel
def test_validate_credentials():
model = ReplicateLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='meta/llama-2-13b-chat',
credentials={
'replicate_api_token': 'invalid_key',
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
}
)
model.validate_credentials(
model='meta/llama-2-13b-chat',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
}
)
def test_invoke_model():
model = ReplicateLargeLanguageModel()
response = model.invoke(
model='meta/llama-2-13b-chat',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = ReplicateLargeLanguageModel()
response = model.invoke(
model='mistralai/mixtral-8x7b-instruct-v0.1',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
def test_get_num_tokens():
model = ReplicateLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 14

View file

@ -0,0 +1,151 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel
def test_validate_credentials_one():
model = ReplicateEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='replicate/all-mpnet-base-v2',
credentials={
'replicate_api_token': 'invalid_key',
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}
)
model.validate_credentials(
model='replicate/all-mpnet-base-v2',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}
)
def test_validate_credentials_two():
model = ReplicateEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='nateraw/bge-large-en-v1.5',
credentials={
'replicate_api_token': 'invalid_key',
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
}
)
model.validate_credentials(
model='nateraw/bge-large-en-v1.5',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
}
)
def test_invoke_model_one():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='nateraw/bge-large-en-v1.5',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_invoke_model_two():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='andreasjansson/clip-features',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a'
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_invoke_model_three():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='replicate/all-mpnet-base-v2',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_invoke_model_four():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='nateraw/jina-embeddings-v2-base-en',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = ReplicateEmbeddingModel()
num_tokens = model.get_num_tokens(
model='nateraw/jina-embeddings-v2-base-en',
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,114 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel
def test_validate_credentials():
model = SparkLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='spark-1.5',
credentials={
'app_id': 'invalid_key'
}
)
model.validate_credentials(
model='spark-1.5',
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
}
)
def test_invoke_model():
model = SparkLargeLanguageModel()
response = model.invoke(
model='spark-1.5',
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = SparkLargeLanguageModel()
response = model.invoke(
model='spark-1.5',
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = SparkLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='spark-1.5',
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 14

View file

@ -0,0 +1,23 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.spark.spark import SparkProvider
def test_validate_provider_credentials():
provider = SparkProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
}
)

View file

@ -0,0 +1,82 @@
import logging
import os
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderConfig, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory, ModelProviderExtension
logger = logging.getLogger(__name__)
def test_get_providers():
factory = ModelProviderFactory()
providers = factory.get_providers()
for provider in providers:
logger.debug(provider)
assert len(providers) >= 1
assert isinstance(providers[0], ProviderEntity)
def test_get_models():
factory = ModelProviderFactory()
providers = factory.get_models(
model_type=ModelType.LLM,
provider_configs=[
ProviderConfig(
provider='openai',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
]
)
logger.debug(providers)
assert len(providers) >= 1
assert isinstance(providers[0], SimpleProviderEntity)
# all provider models type equals to ModelType.LLM
for provider in providers:
for provider_model in provider.models:
assert provider_model.model_type == ModelType.LLM
providers = factory.get_models(
provider='openai',
provider_configs=[
ProviderConfig(
provider='openai',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
]
)
assert len(providers) == 1
assert isinstance(providers[0], SimpleProviderEntity)
assert providers[0].provider == 'openai'
def test_provider_credentials_validate():
factory = ModelProviderFactory()
factory.provider_credentials_validate(
provider='openai',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
def test__get_model_provider_map():
factory = ModelProviderFactory()
model_providers = factory._get_model_provider_map()
for name, model_provider in model_providers.items():
logger.debug(name)
logger.debug(model_provider.provider_instance)
assert len(model_providers) >= 1
assert isinstance(model_providers['openai'], ModelProviderExtension)

View file

@ -0,0 +1,107 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.tongyi.llm.llm import TongyiLargeLanguageModel
def test_validate_credentials():
model = TongyiLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='qwen-turbo',
credentials={
'dashscope_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
}
)
def test_invoke_model():
model = TongyiLargeLanguageModel()
response = model.invoke(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = TongyiLargeLanguageModel()
response = model.invoke(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = TongyiLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 12

View file

@ -0,0 +1,21 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.tongyi.tongyi import TongyiProvider
def test_validate_provider_credentials():
provider = TongyiProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
}
)

View file

@ -0,0 +1,271 @@
import os
import pytest
from typing import Generator
from time import sleep
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel
def test_predefined_models():
model = ErnieBotLarguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = ErnieBotLarguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='ernie-bot',
credentials={
'api_key': 'invalid_key',
'secret_key': 'invalid_key'
}
)
model.validate_credentials(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
}
)
def test_invoke_model_ernie_bot():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_ernie_bot_turbo():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot-turbo',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_ernie_8k():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot-8k',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_ernie_bot_4():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot-4',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_model_with_system():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='你是Kasumi'
),
UserPromptMessage(
content='你是谁?'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert 'kasumi' in response.message.content.lower()
def test_invoke_with_search():
sleep(3)
model = ErnieBotLarguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'disable_search': True,
},
stop=[],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
total_message = ''
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
print(chunk.delta.message.content)
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
# there should be 对不起、我不能、不支持……
assert ('' in total_message or '抱歉' in total_message or '无法' in total_message)
def test_get_num_tokens():
sleep(3)
model = ErnieBotLarguageModel()
response = model.get_num_tokens(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
)
assert isinstance(response, int)
assert response == 10

View file

@ -0,0 +1,25 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.wenxin.wenxin import WenxinProvider
def test_validate_provider_credentials():
provider = WenxinProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha',
'secret_key': 'hahahaha'
}
)
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
}
)

View file

@ -0,0 +1,68 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock, MOCK
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
def test_validate_credentials(setup_xinference_mock):
model = XinferenceTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-base-en',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
}
)
model.validate_credentials(
model='bge-base-en',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
}
)
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
def test_invoke_model(setup_xinference_mock):
model = XinferenceTextEmbeddingModel()
result = model.invoke(
model='bge-base-en',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0
def test_get_num_tokens():
model = XinferenceTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='bge-base-en',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -0,0 +1,392 @@
import os
import pytest
from typing import Generator
from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \
SystemPromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='ChatGLM3',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
}
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='aaaaa',
credentials={
'server_url': '',
'model_uid': ''
}
)
model.validate_credentials(
model='ChatGLM3',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
}
)
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='ChatGLM3',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='ChatGLM3',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
"""
Funtion calling of xinference does not support stream mode currently
"""
# def test_invoke_stream_chat_model_with_functions():
# model = XinferenceAILargeLanguageModel()
# response = model.invoke(
# model='ChatGLM3-6b',
# credentials={
# 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
# 'model_type': 'text-generation',
# 'model_name': 'ChatGLM3',
# 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
# },
# prompt_messages=[
# SystemPromptMessage(
# content='你是一个天气机器人,可以通过调用函数来获取天气信息',
# ),
# UserPromptMessage(
# content='波士顿天气如何?'
# )
# ],
# model_parameters={
# 'temperature': 0,
# 'top_p': 1.0,
# },
# stop=['you'],
# user='abc-123',
# stream=True,
# tools=[
# PromptMessageTool(
# name='get_current_weather',
# description='Get the current weather in a given location',
# parameters={
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": [
# "location"
# ]
# }
# )
# ]
# )
# assert isinstance(response, Generator)
# call: LLMResultChunk = None
# chunks = []
# for chunk in response:
# chunks.append(chunk)
# assert isinstance(chunk, LLMResultChunk)
# assert isinstance(chunk.delta, LLMResultChunkDelta)
# assert isinstance(chunk.delta.message, AssistantPromptMessage)
# assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
# if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
# call = chunk
# break
# assert call is not None
# assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
# def test_invoke_chat_model_with_functions():
# model = XinferenceAILargeLanguageModel()
# response = model.invoke(
# model='ChatGLM3-6b',
# credentials={
# 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
# 'model_type': 'text-generation',
# 'model_name': 'ChatGLM3',
# 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
# },
# prompt_messages=[
# UserPromptMessage(
# content='What is the weather like in San Francisco?'
# )
# ],
# model_parameters={
# 'temperature': 0.7,
# 'top_p': 1.0,
# },
# stop=['you'],
# user='abc-123',
# stream=False,
# tools=[
# PromptMessageTool(
# name='get_current_weather',
# description='Get the current weather in a given location',
# parameters={
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": [
# "c",
# "f"
# ]
# }
# },
# "required": [
# "location"
# ]
# }
# )
# ]
# )
# assert isinstance(response, LLMResult)
# assert len(response.message.content) > 0
# assert response.usage.total_tokens > 0
# assert response.message.tool_calls[0].function.name == 'get_current_weather'
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='alapaca',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
}
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='alapaca',
credentials={
'server_url': '',
'model_uid': ''
}
)
model.validate_credentials(
model='alapaca',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
}
)
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='alapaca',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
},
prompt_messages=[
UserPromptMessage(
content='the United States is'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
user="abc-123",
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='alapaca',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
},
prompt_messages=[
UserPromptMessage(
content='the United States is'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = XinferenceAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='ChatGLM3',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
},
"required": [
"location"
]
}
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='ChatGLM3',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21

View file

@ -0,0 +1,53 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock, MOCK
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
def test_validate_credentials(setup_xinference_mock):
model = XinferenceRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': 'awdawdaw',
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
}
)
model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
}
)
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
def test_invoke_model(setup_xinference_mock):
model = XinferenceRerankModel()
result = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
],
score_threshold=0.8
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.8

View file

@ -0,0 +1,106 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
def test_validate_credentials():
model = ZhipuAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chatglm_turbo',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)
def test_invoke_model():
model = ZhipuAILargeLanguageModel()
response = model.invoke(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.9,
'top_p': 0.7
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = ZhipuAILargeLanguageModel()
response = model.invoke(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.9,
'top_p': 0.7
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = ZhipuAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 14

View file

@ -0,0 +1,20 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.zhipuai.zhipuai import ZhipuaiProvider
def test_validate_provider_credentials():
provider = ZhipuaiProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)

View file

@ -0,0 +1,63 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel
def test_validate_credentials():
model = ZhipuAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text_embedding',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(
model='text_embedding',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)
def test_invoke_model():
model = ZhipuAITextEmbeddingModel()
result = model.invoke(
model='text_embedding',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = ZhipuAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='text_embedding',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2

View file

@ -1,57 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='azure_openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_azure_openai_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
openai_provider = AzureOpenAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='azure_openai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'openai_api_base': valid_openai_api_base,
'openai_api_key': valid_openai_api_key,
'base_model_name': model_name
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return AzureOpenAIEmbedding(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt, mocker):
embedding_model = get_mock_azure_openai_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View file

@ -1,136 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel
DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2'
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='huggingface_hub',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker):
valid_api_key = os.environ['HUGGINGFACE_API_KEY']
endpoint_url = os.environ['HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL']
model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key,
'task_type': 'feature-extraction'
}
if huggingfacehub_api_type == 'inference_endpoints':
credentials['huggingfacehub_endpoint_url'] = endpoint_url
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='huggingface_hub',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query',
return_value=mock_query)
return HuggingfaceEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 384
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
, return_value=bytes(json.dumps([[1, 2, 3], [4, 5, 6]]), 'utf-8'))
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 3
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_documents_two(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
, return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]],[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8'))
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 3
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 384
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
, return_value=bytes(json.dumps([[1, 2, 3]]), 'utf-8'))
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 3
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_query_two(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
, return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8'))
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 3

View file

@ -1,42 +0,0 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
from core.model_providers.providers.jina_provider import JinaProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='jina',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'jina-embeddings-v2-small-en'
valid_api_key = os.environ['JINA_API_KEY']
provider = JinaProvider(provider=get_mock_provider(valid_api_key))
return JinaEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 512

View file

@ -1,61 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
server_url = os.environ['LOCALAI_SERVER_URL']
model_provider = LocalAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return LocalAIEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)

View file

@ -1,44 +0,0 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_group_id, valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='minimax',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'minimax_group_id': valid_group_id,
'minimax_api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'embo-01'
valid_api_key = os.environ['MINIMAX_API_KEY']
valid_group_id = os.environ['MINIMAX_GROUP_ID']
provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key))
return MinimaxEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View file

@ -1,40 +0,0 @@
import json
import os
from unittest.mock import patch
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_embedding_model():
model_name = 'text-embedding-ada-002'
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIEmbedding(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_openai_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View file

@ -1,63 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.openllm_provider import OpenLLMProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openllm',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'facebook/opt-125m'
server_url = os.environ['OPENLLM_SERVER_URL']
model_provider = OpenLLMProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='openllm',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return OpenLLMEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) > 0

View file

@ -1,64 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='replicate',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'replicate/all-mpnet-base-v2'
valid_api_key = os.environ['REPLICATE_API_TOKEN']
model_provider = ReplicateProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='replicate',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'replicate_api_token': valid_api_key,
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return ReplicateEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 768
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 768

View file

@ -1,65 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.xinference_provider import XinferenceProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='xinference',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'vicuna-v1.3'
server_url = os.environ['XINFERENCE_SERVER_URL']
model_uid = os.environ['XINFERENCE_MODEL_UID']
model_provider = XinferenceProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='xinference',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
'model_uid': model_uid
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return XinferenceEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 4096
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 4096

View file

@ -1,50 +0,0 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='zhipuai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'text_embedding'
valid_api_key = os.environ['ZHIPUAI_API_KEY']
provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
return ZhipuAIEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1024
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_doc_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_documents(['test', 'test2'])
assert isinstance(rst, list)
assert len(rst[0]) == 1024

View file

@ -1,62 +0,0 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.providers.anthropic_provider import AnthropicProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='anthropic',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'anthropic_api_key': valid_api_key}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
valid_api_key = os.environ['ANTHROPIC_API_KEY']
model_provider = AnthropicProvider(provider=get_mock_provider(valid_api_key))
return AnthropicModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('claude-2')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View file

@ -1,87 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
import pytest
from langchain.schema import ChatGeneration, AIMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='azure_openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_azure_openai_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
provider = AzureOpenAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='azure_openai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'openai_api_base': valid_openai_api_base,
'openai_api_key': valid_openai_api_key,
'base_model_name': model_name
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return AzureOpenAIModel(
model_provider=provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('text-davinci-003', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
rst = openai_model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View file

@ -1,81 +0,0 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.baichuan_provider import BaichuanProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key, valid_secret_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='baichuan',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key,
'secret_key': valid_secret_key,
}),
is_valid=True,
)
def get_mock_model(model_name: str, streaming: bool = False):
model_kwargs = ModelKwargs(
temperature=0.01,
)
valid_api_key = os.environ['BAICHUAN_API_KEY']
valid_secret_key = os.environ['BAICHUAN_SECRET_KEY']
model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
return BaichuanModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('baichuan2-53b')
messages = [
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
)
assert len(rst.content) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_stream_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('baichuan2-53b', streaming=True)
messages = [
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
)
assert len(rst.content) > 0

View file

@ -1,127 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from langchain.schema import Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='huggingface_hub',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, huggingfacehub_api_type, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['HUGGINGFACE_API_KEY']
endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL']
model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key
}
if huggingfacehub_api_type == 'inference_endpoints':
credentials['huggingfacehub_endpoint_url'] = endpoint_url
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='huggingface_hub',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return HuggingfaceHubModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('huggingface_hub.hf_api.ModelInfo')
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
model = get_mock_model(
'tiiuae/falcon-40b',
'hosted_inference_api',
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
@patch('huggingface_hub.hf_api.ModelInfo')
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
model = get_mock_model(
'',
'inference_endpoints',
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model(
'google/flan-t5-base',
'hosted_inference_api',
mocker
)
rst = model.run(
[PromptMessage(content='Human: Are you Really Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_inference_endpoints_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model(
'',
'inference_endpoints',
mocker
)
rst = model.run(
[PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
)
assert len(rst.content) > 0

View file

@ -1,68 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.localai_provider import LocalAIProvider
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider(server_url):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({}),
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
server_url = os.environ['LOCALAI_SERVER_URL']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
return LocalAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View file

@ -1,65 +0,0 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_group_id, valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='minimax',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'minimax_group_id': valid_group_id,
'minimax_api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['MINIMAX_API_KEY']
valid_group_id = os.environ['MINIMAX_GROUP_ID']
model_provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key))
return MinimaxModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('abab5.5-chat')
rst = model.run(
[PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View file

@ -1,111 +0,0 @@
import json
import os
from unittest.mock import patch
from langchain.schema import Generation, ChatGeneration, AIMessage
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
model_name = model_name
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_vision_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-4-vision-preview')
messages = [
PromptMessage(content='Whats in first image?', files=[
ImageMessageFile(
data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
])
]
rst = openai_model.get_num_tokens(messages)
assert rst == 77
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_openai_model('gpt-3.5-turbo')
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
rst = openai_model.run(
messages,
stop=['\nHuman:'],
)
assert (len(rst.content) > 0)
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_vision_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_openai_model('gpt-4-vision-preview')
messages = [
PromptMessage(content='Whats in first image?', files=[
ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
])
]
rst = openai_model.run(
messages,
)
assert len(rst.content) > 0

View file

@ -1,72 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.openllm_provider import OpenLLMProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openllm',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
server_url = os.environ['OPENLLM_SERVER_URL']
model_provider = OpenLLMProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='openllm',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'server_url': server_url
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return OpenLLMModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('facebook/opt-125m', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('facebook/opt-125m', mocker)
messages = [PromptMessage(content='Human: who are you? \nAnswer: ')]
rst = model.run(
messages
)
assert len(rst.content) > 0

View file

@ -1,75 +0,0 @@
import json
import os
from unittest.mock import patch, MagicMock
from langchain.schema import Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='replicate',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, model_version, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['REPLICATE_API_TOKEN']
model_provider = ReplicateProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='replicate',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'replicate_api_token': valid_api_key,
'model_version': model_version
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return ReplicateModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 7
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
rst = model.run(
messages
)
assert len(rst.content) > 0

View file

@ -1,70 +0,0 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.minimax_provider import MinimaxProvider
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_app_id, valid_api_key, valid_api_secret):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='spark',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'app_id': valid_app_id,
'api_key': valid_api_key,
'api_secret': valid_api_secret,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_app_id = os.environ['SPARK_APP_ID']
valid_api_key = os.environ['SPARK_API_KEY']
valid_api_secret = os.environ['SPARK_API_SECRET']
model_provider = SparkProvider(provider=get_mock_provider(valid_app_id, valid_api_key, valid_api_secret))
return SparkModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('spark')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0

Some files were not shown because too many files have changed in this diff Show more