Added MistralAI LLM and Embedding (#1865)

* Added Mistral AI LLM Model and added MistralAI embeddings.

* Added tracking to the new MistralAIEmbeddings.py and MistalModel.py

* Added MistralAI model specs and added the correct models list to the Mistral Model LLM component.

* Actually added the MistralAI model specs this time.
This commit is contained in:
h-arnold 2024-05-09 19:54:16 +01:00 committed by GitHub
commit 12d1a9a426
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 450 additions and 154 deletions

View file

@ -0,0 +1,69 @@
from typing import List, Optional
from pydantic.v1 import SecretStr
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langflow.interface.custom.custom_component import CustomComponent
from langflow.field_typing import Embeddings, NestedDict
class MistralAIEmbeddingsComponent(CustomComponent):
display_name = "MistralAI Embeddings"
description = "Generate embeddings using MistralAI models."
def build_config(self):
return {
"model": {
"display_name": "Model",
"advanced": False,
"options": ["mistral-embed"],
"value": "mistral-embed",
},
"mistral_api_key": {
"display_name": "Mistral API Key",
"password": True,
"advanced": False,
},
"max_concurrent_requests": {
"display_name": "Max Concurrent Requests",
"advanced": True,
"value": 64,
},
"max_retries": {
"display_name": "Max Retries",
"advanced": True,
"value": 5,
},
"timeout": {
"display_name": "Request Timeout",
"advanced": True,
"value": 120,
},
"endpoint": {
"display_name": "API Endpoint",
"advanced": True,
"value": "https://api.mistral.ai/v1/"
}
}
def build(
self,
mistral_api_key: str,
model: str = "mistral-embed",
max_concurrent_requests: int = 64,
max_retries: int = 5,
timeout: int = 120,
endpoint: str = "https://api.mistral.ai/v1/"
) -> Embeddings:
if mistral_api_key:
api_key = SecretStr(mistral_api_key)
else:
api_key = None
return MistralAIEmbeddings(
api_key=api_key,
model=model,
endpoint=endpoint,
max_concurrent_requests=max_concurrent_requests,
max_retries=max_retries,
timeout=timeout
)

View file

@ -0,0 +1,87 @@
from typing import Optional
from langchain_mistralai import ChatMistralAI
from pydantic.v1 import SecretStr
from langflow.custom import CustomComponent
from langflow.field_typing import BaseLanguageModel
class MistralAIModelComponent(CustomComponent):
display_name: str = "MistralAI"
description: str = "Generate text using MistralAI LLMs."
icon = "MistralAI"
field_order = [
"model",
"mistral_api_key",
"max_tokens",
"temperature",
"mistral_api_base",
]
def build_config(self):
return {
"model": {
"display_name": "Model Name",
"options": [
"open-mistral-7b",
"open-mixtral-8x7b",
"open-mixtral-8x22b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest"
],
"info": "Name of the model to use.",
"required": True,
"value": "open-mistral-7b",
},
"mistral_api_key": {
"display_name": "Mistral API Key",
"required": True,
"password": True,
"info": "Your Mistral API key.",
},
"max_tokens": {
"display_name": "Max Tokens",
"field_type": "int",
"advanced": True,
"value": 256,
},
"temperature": {
"display_name": "Temperature",
"field_type": "float",
"value": 0.1,
},
"mistral_api_base": {
"display_name": "Mistral API Base",
"advanced": True,
"info": "Endpoint of the Mistral API. Defaults to 'https://api.mistral.ai' if not specified.",
},
"code": {"show": False},
}
def build(
self,
model: str,
mistral_api_key: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
mistral_api_base: Optional[str] = None,
) -> BaseLanguageModel:
# Set default API endpoint if not provided
if not mistral_api_base:
mistral_api_base = "https://api.mistral.ai"
try:
output = ChatMistralAI(
model=model,
api_key=(SecretStr(mistral_api_key) if mistral_api_key else None),
max_tokens=max_tokens,
temperature=temperature,
base_url=mistral_api_base,
)
except Exception as e:
raise ValueError("Could not connect to Mistral API.") from e
return output

View file

@ -0,0 +1,111 @@
from typing import Optional
from langchain_mistralai import ChatMistralAI
from pydantic.v1 import SecretStr
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import NestedDict, Text
class MistralAIModelComponent(LCModelComponent):
display_name = "MistralAI"
description = "Generates text using MistralAI LLMs."
icon = "MistralAI"
field_order = [
"max_tokens",
"model_kwargs",
"model_name",
"mistral_api_base",
"mistral_api_key",
"temperature",
"input_value",
"system_message",
"stream",
]
def build_config(self):
return {
"input_value": {"display_name": "Input"},
"max_tokens": {
"display_name": "Max Tokens",
"advanced": True,
},
"model_kwargs": {
"display_name": "Model Kwargs",
"advanced": True,
},
"model_name": {
"display_name": "Model Name",
"advanced": False,
"options": [
"open-mistral-7b",
"open-mixtral-8x7b",
"open-mixtral-8x22b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest"
],
"value": "open-mistral-7b",
},
"mistral_api_base": {
"display_name": "Mistral API Base",
"advanced": True,
"info": (
"The base URL of the Mistral API. Defaults to https://api.mistral.ai.\n\n"
"You can change this to use other APIs like JinaChat, LocalAI and Prem."
),
},
"mistral_api_key": {
"display_name": "Mistral API Key",
"info": "The Mistral API Key to use for the Mistral model.",
"advanced": False,
"password": True,
},
"temperature": {
"display_name": "Temperature",
"advanced": False,
"value": 0.1,
},
"stream": {
"display_name": "Stream",
"info": STREAM_INFO_TEXT,
"advanced": True,
},
"system_message": {
"display_name": "System Message",
"info": "System message to pass to the model.",
"advanced": True,
},
}
def build(
self,
input_value: Text,
mistral_api_key: str,
temperature: float,
model_name: str,
max_tokens: Optional[int] = 256,
model_kwargs: NestedDict = {},
mistral_api_base: Optional[str] = None,
stream: bool = False,
system_message: Optional[str] = None,
) -> Text:
if not mistral_api_base:
mistral_api_base = "https://api.mistral.ai"
if mistral_api_key:
api_key = SecretStr(mistral_api_key)
else:
api_key = None
chat_model = ChatMistralAI(
max_tokens=max_tokens,
model_kwargs=model_kwargs,
model=model_name,
base_url=mistral_api_base,
api_key=api_key,
temperature=temperature,
)
return self.get_chat_result(chat_model, stream, input_value, system_message)