Added MistralAI LLM and Embedding (#1865)
* Added Mistral AI LLM Model and added MistralAI embeddings. * Added tracking to the new MistralAIEmbeddings.py and MistalModel.py * Added MistralAI model specs and added the correct models list to the Mistral Model LLM component. * Actually added the MistralAI model specs this time.
This commit is contained in:
parent
561765c74c
commit
12d1a9a426
5 changed files with 450 additions and 154 deletions
|
|
@ -0,0 +1,69 @@
|
|||
from typing import List, Optional
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langchain_mistralai.embeddings import MistralAIEmbeddings
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.field_typing import Embeddings, NestedDict
|
||||
|
||||
class MistralAIEmbeddingsComponent(CustomComponent):
|
||||
display_name = "MistralAI Embeddings"
|
||||
description = "Generate embeddings using MistralAI models."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
"display_name": "Model",
|
||||
"advanced": False,
|
||||
"options": ["mistral-embed"],
|
||||
"value": "mistral-embed",
|
||||
},
|
||||
"mistral_api_key": {
|
||||
"display_name": "Mistral API Key",
|
||||
"password": True,
|
||||
"advanced": False,
|
||||
},
|
||||
"max_concurrent_requests": {
|
||||
"display_name": "Max Concurrent Requests",
|
||||
"advanced": True,
|
||||
"value": 64,
|
||||
},
|
||||
"max_retries": {
|
||||
"display_name": "Max Retries",
|
||||
"advanced": True,
|
||||
"value": 5,
|
||||
},
|
||||
"timeout": {
|
||||
"display_name": "Request Timeout",
|
||||
"advanced": True,
|
||||
"value": 120,
|
||||
},
|
||||
"endpoint": {
|
||||
"display_name": "API Endpoint",
|
||||
"advanced": True,
|
||||
"value": "https://api.mistral.ai/v1/"
|
||||
}
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
mistral_api_key: str,
|
||||
model: str = "mistral-embed",
|
||||
max_concurrent_requests: int = 64,
|
||||
max_retries: int = 5,
|
||||
timeout: int = 120,
|
||||
endpoint: str = "https://api.mistral.ai/v1/"
|
||||
) -> Embeddings:
|
||||
if mistral_api_key:
|
||||
api_key = SecretStr(mistral_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
|
||||
return MistralAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
endpoint=endpoint,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel
|
||||
|
||||
|
||||
class MistralAIModelComponent(CustomComponent):
|
||||
display_name: str = "MistralAI"
|
||||
description: str = "Generate text using MistralAI LLMs."
|
||||
icon = "MistralAI"
|
||||
|
||||
field_order = [
|
||||
"model",
|
||||
"mistral_api_key",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"mistral_api_base",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"open-mixtral-8x22b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest"
|
||||
],
|
||||
"info": "Name of the model to use.",
|
||||
"required": True,
|
||||
"value": "open-mistral-7b",
|
||||
},
|
||||
"mistral_api_key": {
|
||||
"display_name": "Mistral API Key",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "Your Mistral API key.",
|
||||
},
|
||||
"max_tokens": {
|
||||
"display_name": "Max Tokens",
|
||||
"field_type": "int",
|
||||
"advanced": True,
|
||||
"value": 256,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"field_type": "float",
|
||||
"value": 0.1,
|
||||
},
|
||||
"mistral_api_base": {
|
||||
"display_name": "Mistral API Base",
|
||||
"advanced": True,
|
||||
"info": "Endpoint of the Mistral API. Defaults to 'https://api.mistral.ai' if not specified.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
model: str,
|
||||
mistral_api_key: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
mistral_api_base: Optional[str] = None,
|
||||
) -> BaseLanguageModel:
|
||||
# Set default API endpoint if not provided
|
||||
if not mistral_api_base:
|
||||
mistral_api_base = "https://api.mistral.ai"
|
||||
|
||||
try:
|
||||
output = ChatMistralAI(
|
||||
model=model,
|
||||
api_key=(SecretStr(mistral_api_key) if mistral_api_key else None),
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
base_url=mistral_api_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Mistral API.") from e
|
||||
|
||||
return output
|
||||
111
src/backend/base/langflow/components/models/MistralModel.py
Normal file
111
src/backend/base/langflow/components/models/MistralModel.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import NestedDict, Text
|
||||
|
||||
|
||||
class MistralAIModelComponent(LCModelComponent):
|
||||
display_name = "MistralAI"
|
||||
description = "Generates text using MistralAI LLMs."
|
||||
icon = "MistralAI"
|
||||
|
||||
field_order = [
|
||||
"max_tokens",
|
||||
"model_kwargs",
|
||||
"model_name",
|
||||
"mistral_api_base",
|
||||
"mistral_api_key",
|
||||
"temperature",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"input_value": {"display_name": "Input"},
|
||||
"max_tokens": {
|
||||
"display_name": "Max Tokens",
|
||||
"advanced": True,
|
||||
},
|
||||
"model_kwargs": {
|
||||
"display_name": "Model Kwargs",
|
||||
"advanced": True,
|
||||
},
|
||||
"model_name": {
|
||||
"display_name": "Model Name",
|
||||
"advanced": False,
|
||||
"options": [
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"open-mixtral-8x22b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest"
|
||||
],
|
||||
"value": "open-mistral-7b",
|
||||
},
|
||||
"mistral_api_base": {
|
||||
"display_name": "Mistral API Base",
|
||||
"advanced": True,
|
||||
"info": (
|
||||
"The base URL of the Mistral API. Defaults to https://api.mistral.ai.\n\n"
|
||||
"You can change this to use other APIs like JinaChat, LocalAI and Prem."
|
||||
),
|
||||
},
|
||||
"mistral_api_key": {
|
||||
"display_name": "Mistral API Key",
|
||||
"info": "The Mistral API Key to use for the Mistral model.",
|
||||
"advanced": False,
|
||||
"password": True,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"advanced": False,
|
||||
"value": 0.1,
|
||||
},
|
||||
"stream": {
|
||||
"display_name": "Stream",
|
||||
"info": STREAM_INFO_TEXT,
|
||||
"advanced": True,
|
||||
},
|
||||
"system_message": {
|
||||
"display_name": "System Message",
|
||||
"info": "System message to pass to the model.",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Text,
|
||||
mistral_api_key: str,
|
||||
temperature: float,
|
||||
model_name: str,
|
||||
max_tokens: Optional[int] = 256,
|
||||
model_kwargs: NestedDict = {},
|
||||
mistral_api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
system_message: Optional[str] = None,
|
||||
) -> Text:
|
||||
if not mistral_api_base:
|
||||
mistral_api_base = "https://api.mistral.ai"
|
||||
if mistral_api_key:
|
||||
api_key = SecretStr(mistral_api_key)
|
||||
else:
|
||||
api_key = None
|
||||
|
||||
chat_model = ChatMistralAI(
|
||||
max_tokens=max_tokens,
|
||||
model_kwargs=model_kwargs,
|
||||
model=model_name,
|
||||
base_url=mistral_api_base,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
return self.get_chat_result(chat_model, stream, input_value, system_message)
|
||||
Loading…
Add table
Add a link
Reference in a new issue