feat: add xAI integration (#6012)
* feat: add xAI integration * fix: implement file_names_mapping fixture in test_xai.py aligning with test standards * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * fix: remove unused variable in test_xai * refactor: update input types and variable naming * [autofix.ci] apply automated fixes * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * Update src/backend/tests/unit/components/models/test_xai.py Co-authored-by: Christophe Bornet <cbornet@hotmail.com> * [autofix.ci] apply automated fixes * Update xai.py * test: update test_xai to use MessageTextInput and base_url * fix: add missing component_class parameter to test_build_model_error --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Christophe Bornet <cbornet@hotmail.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
parent
770c1b3528
commit
e970cdbca3
7 changed files with 387 additions and 0 deletions
|
|
@ -19,6 +19,7 @@ from .openrouter import OpenRouterComponent
|
|||
from .perplexity import PerplexityComponent
|
||||
from .sambanova import SambaNovaComponent
|
||||
from .vertexai import ChatVertexAIComponent
|
||||
from .xai import XAIModelComponent
|
||||
|
||||
__all__ = [
|
||||
"AIMLModelComponent",
|
||||
|
|
@ -42,4 +43,5 @@ __all__ = [
|
|||
"PerplexityComponent",
|
||||
"QianfanChatEndpointComponent",
|
||||
"SambaNovaComponent",
|
||||
"XAIModelComponent",
|
||||
]
|
||||
|
|
|
|||
155
src/backend/base/langflow/components/models/xai.py
Normal file
155
src/backend/base/langflow/components/models/xai.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
import requests
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic.v1 import SecretStr
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.inputs import BoolInput, DictInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput
|
||||
|
||||
XAI_DEFAULT_MODELS = ["grok-2-latest"]
|
||||
|
||||
|
||||
class XAIModelComponent(LCModelComponent):
|
||||
display_name = "xAI"
|
||||
description = "Generates text using xAI models like Grok."
|
||||
icon = "xAI"
|
||||
name = "xAIModel"
|
||||
|
||||
inputs = [
|
||||
*LCModelComponent._base_inputs,
|
||||
IntInput(
|
||||
name="max_tokens",
|
||||
display_name="Max Tokens",
|
||||
advanced=True,
|
||||
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
|
||||
range_spec=RangeSpec(min=0, max=128000),
|
||||
),
|
||||
DictInput(
|
||||
name="model_kwargs",
|
||||
display_name="Model Kwargs",
|
||||
advanced=True,
|
||||
info="Additional keyword arguments to pass to the model.",
|
||||
),
|
||||
BoolInput(
|
||||
name="json_mode",
|
||||
display_name="JSON Mode",
|
||||
advanced=True,
|
||||
info="If True, it will output JSON regardless of passing a schema.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
advanced=False,
|
||||
options=XAI_DEFAULT_MODELS,
|
||||
value=XAI_DEFAULT_MODELS[0],
|
||||
refresh_button=True,
|
||||
combobox=True,
|
||||
info="The xAI model to use",
|
||||
),
|
||||
MessageTextInput(
|
||||
name="base_url",
|
||||
display_name="xAI API Base",
|
||||
advanced=True,
|
||||
info="The base URL of the xAI API. Defaults to https://api.x.ai/v1",
|
||||
value="https://api.x.ai/v1",
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
display_name="xAI API Key",
|
||||
info="The xAI API Key to use for the model.",
|
||||
advanced=False,
|
||||
value="XAI_API_KEY",
|
||||
required=True,
|
||||
),
|
||||
SliderInput(
|
||||
name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=2, step=0.01)
|
||||
),
|
||||
IntInput(
|
||||
name="seed",
|
||||
display_name="Seed",
|
||||
info="The seed controls the reproducibility of the job.",
|
||||
advanced=True,
|
||||
value=1,
|
||||
),
|
||||
]
|
||||
|
||||
def get_models(self) -> list[str]:
|
||||
"""Fetch available models from xAI API."""
|
||||
if not self.api_key:
|
||||
return XAI_DEFAULT_MODELS
|
||||
|
||||
base_url = self.base_url or "https://api.x.ai/v1"
|
||||
url = f"{base_url}/language-models"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"}
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract model IDs and any aliases
|
||||
models = set()
|
||||
for model in data.get("models", []):
|
||||
models.add(model["id"])
|
||||
models.update(model.get("aliases", []))
|
||||
|
||||
return sorted(models) if models else XAI_DEFAULT_MODELS
|
||||
except requests.RequestException as e:
|
||||
self.status = f"Error fetching models: {e}"
|
||||
return XAI_DEFAULT_MODELS
|
||||
|
||||
@override
|
||||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
"""Update build configuration with fresh model list when key fields change."""
|
||||
if field_name in {"api_key", "base_url", "model_name"}:
|
||||
models = self.get_models()
|
||||
build_config["model_name"]["options"] = models
|
||||
return build_config
|
||||
|
||||
def build_model(self) -> LanguageModel:
|
||||
api_key = self.api_key
|
||||
temperature = self.temperature
|
||||
model_name: str = self.model_name
|
||||
max_tokens = self.max_tokens
|
||||
model_kwargs = self.model_kwargs or {}
|
||||
base_url = self.base_url or "https://api.x.ai/v1"
|
||||
json_mode = self.json_mode
|
||||
seed = self.seed
|
||||
|
||||
api_key = SecretStr(api_key).get_secret_value() if api_key else None
|
||||
|
||||
output = ChatOpenAI(
|
||||
max_tokens=max_tokens or None,
|
||||
model_kwargs=model_kwargs,
|
||||
model=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
temperature=temperature if temperature is not None else 0.1,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if json_mode:
|
||||
output = output.bind(response_format={"type": "json_object"})
|
||||
|
||||
return output
|
||||
|
||||
def _get_exception_message(self, e: Exception):
|
||||
"""Get a message from an xAI exception.
|
||||
|
||||
Args:
|
||||
e (Exception): The exception to get the message from.
|
||||
|
||||
Returns:
|
||||
str: The message from the exception.
|
||||
"""
|
||||
try:
|
||||
from openai import BadRequestError
|
||||
except ImportError:
|
||||
return None
|
||||
if isinstance(e, BadRequestError):
|
||||
message = e.body.get("message")
|
||||
if message:
|
||||
return message
|
||||
return None
|
||||
198
src/backend/tests/unit/components/models/test_xai.py
Normal file
198
src/backend/tests/unit/components/models/test_xai.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.components.models import XAIModelComponent
|
||||
from langflow.custom import Component
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.inputs import (
|
||||
BoolInput,
|
||||
DictInput,
|
||||
DropdownInput,
|
||||
IntInput,
|
||||
MessageTextInput,
|
||||
SecretStrInput,
|
||||
SliderInput,
|
||||
)
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
||||
class TestXAIComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return XAIModelComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
return {
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 50,
|
||||
"api_key": "dummy-key",
|
||||
"model_name": "grok-2-latest",
|
||||
"model_kwargs": {},
|
||||
"base_url": "https://api.x.ai/v1",
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
return []
|
||||
|
||||
def test_initialization(self, component_class):
|
||||
component = component_class()
|
||||
assert component.display_name == "xAI"
|
||||
assert component.description == "Generates text using xAI models like Grok."
|
||||
assert component.icon == "xAI"
|
||||
assert component.name == "xAIModel"
|
||||
|
||||
def test_template(self, default_kwargs):
|
||||
component = XAIModelComponent(**default_kwargs)
|
||||
comp = Component(_code=component._code)
|
||||
frontend_node, _ = build_custom_component_template(comp)
|
||||
assert isinstance(frontend_node, dict)
|
||||
assert "template" in frontend_node
|
||||
input_names = [inp["name"] for inp in frontend_node["template"].values() if isinstance(inp, dict)]
|
||||
expected_inputs = [
|
||||
"max_tokens",
|
||||
"model_kwargs",
|
||||
"json_mode",
|
||||
"model_name",
|
||||
"base_url",
|
||||
"api_key",
|
||||
"temperature",
|
||||
"seed",
|
||||
]
|
||||
for input_name in expected_inputs:
|
||||
assert input_name in input_names
|
||||
|
||||
def test_inputs(self):
|
||||
component = XAIModelComponent()
|
||||
inputs = component.inputs
|
||||
expected_inputs = {
|
||||
"max_tokens": IntInput,
|
||||
"model_kwargs": DictInput,
|
||||
"json_mode": BoolInput,
|
||||
"model_name": DropdownInput,
|
||||
"base_url": MessageTextInput,
|
||||
"api_key": SecretStrInput,
|
||||
"temperature": SliderInput,
|
||||
"seed": IntInput,
|
||||
}
|
||||
for name, input_type in expected_inputs.items():
|
||||
matching_inputs = [inp for inp in inputs if isinstance(inp, input_type) and inp.name == name]
|
||||
assert matching_inputs, f"Missing or incorrect input: {name}"
|
||||
if name == "model_name":
|
||||
input_field = matching_inputs[0]
|
||||
assert input_field.value == "grok-2-latest"
|
||||
assert input_field.refresh_button is True
|
||||
elif name == "temperature":
|
||||
input_field = matching_inputs[0]
|
||||
assert input_field.value == 0.1
|
||||
assert input_field.range_spec.min == 0
|
||||
assert input_field.range_spec.max == 2
|
||||
|
||||
def test_build_model(self, component_class, default_kwargs, mocker):
|
||||
component = component_class(**default_kwargs)
|
||||
component.temperature = 0.7
|
||||
component.max_tokens = 100
|
||||
component.api_key = "test-key"
|
||||
component.model_name = "grok-2-latest"
|
||||
component.model_kwargs = {}
|
||||
component.base_url = "https://api.x.ai/v1"
|
||||
component.seed = 1
|
||||
|
||||
mock_chat_openai = mocker.patch("langflow.components.models.xai.ChatOpenAI", return_value=MagicMock())
|
||||
model = component.build_model()
|
||||
mock_chat_openai.assert_called_once_with(
|
||||
max_tokens=100,
|
||||
model_kwargs={},
|
||||
model="grok-2-latest",
|
||||
base_url="https://api.x.ai/v1",
|
||||
api_key="test-key",
|
||||
temperature=0.7,
|
||||
seed=1,
|
||||
)
|
||||
assert model == mock_chat_openai.return_value
|
||||
|
||||
def test_get_models(self):
|
||||
component = XAIModelComponent()
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"models": [
|
||||
{"id": "grok-2-latest", "aliases": ["grok-2"]},
|
||||
{"id": "grok-1", "aliases": []},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
component.api_key = "test-key"
|
||||
models = component.get_models()
|
||||
assert sorted(models) == ["grok-1", "grok-2", "grok-2-latest"]
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.x.ai/v1/language-models",
|
||||
headers={
|
||||
"Authorization": "Bearer test-key",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
def test_get_models_no_api_key(self):
|
||||
component = XAIModelComponent(api_key=None)
|
||||
models = component.get_models()
|
||||
assert models == ["grok-2-latest"]
|
||||
|
||||
def test_build_model_error(self, component_class, mocker):
|
||||
from openai import BadRequestError
|
||||
|
||||
component = component_class()
|
||||
component.api_key = "invalid-key"
|
||||
component.model_name = "grok-2-latest"
|
||||
component.temperature = 0.7
|
||||
component.max_tokens = 100
|
||||
component.model_kwargs = {}
|
||||
component.base_url = "https://api.x.ai/v1"
|
||||
component.seed = 1
|
||||
|
||||
mocker.patch(
|
||||
"langflow.components.models.xai.ChatOpenAI",
|
||||
side_effect=BadRequestError(
|
||||
message="Invalid API key",
|
||||
response=MagicMock(),
|
||||
body={"message": "Invalid API key provided"},
|
||||
),
|
||||
)
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
component.build_model()
|
||||
assert exc_info.value.body["message"] == "Invalid API key provided"
|
||||
|
||||
def test_json_mode(self, component_class, mocker):
|
||||
component = component_class()
|
||||
component.api_key = "test-key"
|
||||
component.json_mode = True
|
||||
component.temperature = 0.7
|
||||
component.max_tokens = 100
|
||||
component.model_name = "grok-2-latest"
|
||||
component.model_kwargs = {}
|
||||
component.base_url = "https://api.x.ai/v1"
|
||||
component.seed = 1
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_bound_instance = MagicMock()
|
||||
mock_instance.bind.return_value = mock_bound_instance
|
||||
mocker.patch("langflow.components.models.xai.ChatOpenAI", return_value=mock_instance)
|
||||
|
||||
model = component.build_model()
|
||||
mock_instance.bind.assert_called_once_with(response_format={"type": "json_object"})
|
||||
assert model == mock_bound_instance
|
||||
|
||||
def test_update_build_config(self):
|
||||
component = XAIModelComponent()
|
||||
build_config = {"model_name": {"options": []}}
|
||||
|
||||
updated_config = component.update_build_config(build_config, "test-key", "api_key")
|
||||
assert "model_name" in updated_config
|
||||
|
||||
updated_config = component.update_build_config(build_config, "grok-2-latest", "model_name")
|
||||
assert "model_name" in updated_config
|
||||
Loading…
Add table
Add a link
Reference in a new issue