feat: add xAI integration (#6012)

* feat: add xAI integration

* fix: implement file_names_mapping fixture in test_xai.py aligning with test standards

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* fix: remove unused variable in test_xai

* refactor: update input types and variable naming

* [autofix.ci] apply automated fixes

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* [autofix.ci] apply automated fixes

* Update xai.py

* test: update test_xai to use MessageTextInput and base_url

* fix: add missing component_class parameter to test_build_model_error

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Christophe Bornet <cbornet@hotmail.com>
Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
Raphael Valdetaro 2025-02-19 11:28:46 -03:00 committed by GitHub
commit e970cdbca3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 387 additions and 0 deletions

View file

@ -19,6 +19,7 @@ from .openrouter import OpenRouterComponent
from .perplexity import PerplexityComponent
from .sambanova import SambaNovaComponent
from .vertexai import ChatVertexAIComponent
from .xai import XAIModelComponent
__all__ = [
"AIMLModelComponent",
@ -42,4 +43,5 @@ __all__ = [
"PerplexityComponent",
"QianfanChatEndpointComponent",
"SambaNovaComponent",
"XAIModelComponent",
]

View file

@ -0,0 +1,155 @@
import requests
from langchain_openai import ChatOpenAI
from pydantic.v1 import SecretStr
from typing_extensions import override
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.inputs import BoolInput, DictInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput
XAI_DEFAULT_MODELS = ["grok-2-latest"]
class XAIModelComponent(LCModelComponent):
display_name = "xAI"
description = "Generates text using xAI models like Grok."
icon = "xAI"
name = "xAIModel"
inputs = [
*LCModelComponent._base_inputs,
IntInput(
name="max_tokens",
display_name="Max Tokens",
advanced=True,
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
range_spec=RangeSpec(min=0, max=128000),
),
DictInput(
name="model_kwargs",
display_name="Model Kwargs",
advanced=True,
info="Additional keyword arguments to pass to the model.",
),
BoolInput(
name="json_mode",
display_name="JSON Mode",
advanced=True,
info="If True, it will output JSON regardless of passing a schema.",
),
DropdownInput(
name="model_name",
display_name="Model Name",
advanced=False,
options=XAI_DEFAULT_MODELS,
value=XAI_DEFAULT_MODELS[0],
refresh_button=True,
combobox=True,
info="The xAI model to use",
),
MessageTextInput(
name="base_url",
display_name="xAI API Base",
advanced=True,
info="The base URL of the xAI API. Defaults to https://api.x.ai/v1",
value="https://api.x.ai/v1",
),
SecretStrInput(
name="api_key",
display_name="xAI API Key",
info="The xAI API Key to use for the model.",
advanced=False,
value="XAI_API_KEY",
required=True,
),
SliderInput(
name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=2, step=0.01)
),
IntInput(
name="seed",
display_name="Seed",
info="The seed controls the reproducibility of the job.",
advanced=True,
value=1,
),
]
def get_models(self) -> list[str]:
"""Fetch available models from xAI API."""
if not self.api_key:
return XAI_DEFAULT_MODELS
base_url = self.base_url or "https://api.x.ai/v1"
url = f"{base_url}/language-models"
headers = {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
# Extract model IDs and any aliases
models = set()
for model in data.get("models", []):
models.add(model["id"])
models.update(model.get("aliases", []))
return sorted(models) if models else XAI_DEFAULT_MODELS
except requests.RequestException as e:
self.status = f"Error fetching models: {e}"
return XAI_DEFAULT_MODELS
@override
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
"""Update build configuration with fresh model list when key fields change."""
if field_name in {"api_key", "base_url", "model_name"}:
models = self.get_models()
build_config["model_name"]["options"] = models
return build_config
def build_model(self) -> LanguageModel:
api_key = self.api_key
temperature = self.temperature
model_name: str = self.model_name
max_tokens = self.max_tokens
model_kwargs = self.model_kwargs or {}
base_url = self.base_url or "https://api.x.ai/v1"
json_mode = self.json_mode
seed = self.seed
api_key = SecretStr(api_key).get_secret_value() if api_key else None
output = ChatOpenAI(
max_tokens=max_tokens or None,
model_kwargs=model_kwargs,
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature if temperature is not None else 0.1,
seed=seed,
)
if json_mode:
output = output.bind(response_format={"type": "json_object"})
return output
def _get_exception_message(self, e: Exception):
"""Get a message from an xAI exception.
Args:
e (Exception): The exception to get the message from.
Returns:
str: The message from the exception.
"""
try:
from openai import BadRequestError
except ImportError:
return None
if isinstance(e, BadRequestError):
message = e.body.get("message")
if message:
return message
return None

View file

@ -0,0 +1,198 @@
from unittest.mock import MagicMock, patch
import pytest
from langflow.components.models import XAIModelComponent
from langflow.custom import Component
from langflow.custom.utils import build_custom_component_template
from langflow.inputs import (
BoolInput,
DictInput,
DropdownInput,
IntInput,
MessageTextInput,
SecretStrInput,
SliderInput,
)
from tests.base import ComponentTestBaseWithoutClient
class TestXAIComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
def component_class(self):
return XAIModelComponent
@pytest.fixture
def default_kwargs(self):
return {
"temperature": 0.1,
"max_tokens": 50,
"api_key": "dummy-key",
"model_name": "grok-2-latest",
"model_kwargs": {},
"base_url": "https://api.x.ai/v1",
"seed": 42,
}
@pytest.fixture
def file_names_mapping(self):
return []
def test_initialization(self, component_class):
component = component_class()
assert component.display_name == "xAI"
assert component.description == "Generates text using xAI models like Grok."
assert component.icon == "xAI"
assert component.name == "xAIModel"
def test_template(self, default_kwargs):
component = XAIModelComponent(**default_kwargs)
comp = Component(_code=component._code)
frontend_node, _ = build_custom_component_template(comp)
assert isinstance(frontend_node, dict)
assert "template" in frontend_node
input_names = [inp["name"] for inp in frontend_node["template"].values() if isinstance(inp, dict)]
expected_inputs = [
"max_tokens",
"model_kwargs",
"json_mode",
"model_name",
"base_url",
"api_key",
"temperature",
"seed",
]
for input_name in expected_inputs:
assert input_name in input_names
def test_inputs(self):
component = XAIModelComponent()
inputs = component.inputs
expected_inputs = {
"max_tokens": IntInput,
"model_kwargs": DictInput,
"json_mode": BoolInput,
"model_name": DropdownInput,
"base_url": MessageTextInput,
"api_key": SecretStrInput,
"temperature": SliderInput,
"seed": IntInput,
}
for name, input_type in expected_inputs.items():
matching_inputs = [inp for inp in inputs if isinstance(inp, input_type) and inp.name == name]
assert matching_inputs, f"Missing or incorrect input: {name}"
if name == "model_name":
input_field = matching_inputs[0]
assert input_field.value == "grok-2-latest"
assert input_field.refresh_button is True
elif name == "temperature":
input_field = matching_inputs[0]
assert input_field.value == 0.1
assert input_field.range_spec.min == 0
assert input_field.range_spec.max == 2
def test_build_model(self, component_class, default_kwargs, mocker):
component = component_class(**default_kwargs)
component.temperature = 0.7
component.max_tokens = 100
component.api_key = "test-key"
component.model_name = "grok-2-latest"
component.model_kwargs = {}
component.base_url = "https://api.x.ai/v1"
component.seed = 1
mock_chat_openai = mocker.patch("langflow.components.models.xai.ChatOpenAI", return_value=MagicMock())
model = component.build_model()
mock_chat_openai.assert_called_once_with(
max_tokens=100,
model_kwargs={},
model="grok-2-latest",
base_url="https://api.x.ai/v1",
api_key="test-key",
temperature=0.7,
seed=1,
)
assert model == mock_chat_openai.return_value
def test_get_models(self):
component = XAIModelComponent()
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {
"models": [
{"id": "grok-2-latest", "aliases": ["grok-2"]},
{"id": "grok-1", "aliases": []},
]
}
mock_get.return_value = mock_response
component.api_key = "test-key"
models = component.get_models()
assert sorted(models) == ["grok-1", "grok-2", "grok-2-latest"]
mock_get.assert_called_once_with(
"https://api.x.ai/v1/language-models",
headers={
"Authorization": "Bearer test-key",
"Accept": "application/json",
},
timeout=10,
)
def test_get_models_no_api_key(self):
component = XAIModelComponent(api_key=None)
models = component.get_models()
assert models == ["grok-2-latest"]
def test_build_model_error(self, component_class, mocker):
from openai import BadRequestError
component = component_class()
component.api_key = "invalid-key"
component.model_name = "grok-2-latest"
component.temperature = 0.7
component.max_tokens = 100
component.model_kwargs = {}
component.base_url = "https://api.x.ai/v1"
component.seed = 1
mocker.patch(
"langflow.components.models.xai.ChatOpenAI",
side_effect=BadRequestError(
message="Invalid API key",
response=MagicMock(),
body={"message": "Invalid API key provided"},
),
)
with pytest.raises(BadRequestError) as exc_info:
component.build_model()
assert exc_info.value.body["message"] == "Invalid API key provided"
def test_json_mode(self, component_class, mocker):
component = component_class()
component.api_key = "test-key"
component.json_mode = True
component.temperature = 0.7
component.max_tokens = 100
component.model_name = "grok-2-latest"
component.model_kwargs = {}
component.base_url = "https://api.x.ai/v1"
component.seed = 1
mock_instance = MagicMock()
mock_bound_instance = MagicMock()
mock_instance.bind.return_value = mock_bound_instance
mocker.patch("langflow.components.models.xai.ChatOpenAI", return_value=mock_instance)
model = component.build_model()
mock_instance.bind.assert_called_once_with(response_format={"type": "json_object"})
assert model == mock_bound_instance
def test_update_build_config(self):
component = XAIModelComponent()
build_config = {"model_name": {"options": []}}
updated_config = component.update_build_config(build_config, "test-key", "api_key")
assert "model_name" in updated_config
updated_config = component.update_build_config(build_config, "grok-2-latest", "model_name")
assert "model_name" in updated_config

View file

@ -0,0 +1,10 @@
import { useDarkStore } from "@/stores/darkStore";
import React, { forwardRef } from "react";
import XAISVG from "./xAIIcon.jsx";
export const XAIIcon = forwardRef<SVGSVGElement, React.PropsWithChildren<{}>>(
(props, ref) => {
const isdark = useDarkStore((state) => state.dark).toString();
return <XAISVG ref={ref} isdark={isdark} {...props} />;
},
);

View file

@ -0,0 +1,19 @@
import { stringToBool } from "@/utils/utils";
const XAISVG = (props) => (
<svg
{...props}
fill={stringToBool(props.isdark) ? "#ffffff" : "#0A0A0A"}
fillRule="evenodd"
height="1em"
style={{ flex: "none", lineHeight: 1 }}
viewBox="0 0 24 24"
width="1em"
xmlns="http://www.w3.org/2000/svg"
>
<title>Grok</title>
<path d="M6.469 8.776L16.512 23h-4.464L2.005 8.776H6.47zm-.004 7.9l2.233 3.164L6.467 23H2l4.465-6.324zM22 2.582V23h-3.659V7.764L22 2.582zM22 1l-9.952 14.095-2.233-3.163L17.533 1H22z" />
</svg>
);
export default XAISVG;

View file

@ -0,0 +1 @@
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Grok</title><path d="M6.469 8.776L16.512 23h-4.464L2.005 8.776H6.47zm-.004 7.9l2.233 3.164L6.467 23H2l4.465-6.324zM22 2.582V23h-3.659V7.764L22 2.582zM22 1l-9.952 14.095-2.233-3.163L17.533 1H22z"></path></svg>

After

Width:  |  Height:  |  Size: 372 B

View file

@ -316,6 +316,7 @@ import SvgWolfram from "../icons/Wolfram/Wolfram";
import { HackerNewsIcon } from "../icons/hackerNews";
import { MistralIcon } from "../icons/mistral";
import { SupabaseIcon } from "../icons/supabase";
import { XAIIcon } from "../icons/xAI";
import { iconsType } from "../types/components";
export const BG_NOISE =
"url(data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADIAAAAyCAMAAAAp4XiDAAAAUVBMVEWFhYWDg4N3d3dtbW17e3t1dXWBgYGHh4d5eXlzc3OLi4ubm5uVlZWPj4+NjY19fX2JiYl/f39ra2uRkZGZmZlpaWmXl5dvb29xcXGTk5NnZ2c8TV1mAAAAG3RSTlNAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEAvEOwtAAAFVklEQVR4XpWWB67c2BUFb3g557T/hRo9/WUMZHlgr4Bg8Z4qQgQJlHI4A8SzFVrapvmTF9O7dmYRFZ60YiBhJRCgh1FYhiLAmdvX0CzTOpNE77ME0Zty/nWWzchDtiqrmQDeuv3powQ5ta2eN0FY0InkqDD73lT9c9lEzwUNqgFHs9VQce3TVClFCQrSTfOiYkVJQBmpbq2L6iZavPnAPcoU0dSw0SUTqz/GtrGuXfbyyBniKykOWQWGqwwMA7QiYAxi+IlPdqo+hYHnUt5ZPfnsHJyNiDtnpJyayNBkF6cWoYGAMY92U2hXHF/C1M8uP/ZtYdiuj26UdAdQQSXQErwSOMzt/XWRWAz5GuSBIkwG1H3FabJ2OsUOUhGC6tK4EMtJO0ttC6IBD3kM0ve0tJwMdSfjZo+EEISaeTr9P3wYrGjXqyC1krcKdhMpxEnt5JetoulscpyzhXN5FRpuPHvbeQaKxFAEB6EN+cYN6xD7RYGpXpNndMmZgM5Dcs3YSNFDHUo2LGfZuukSWyUYirJAdYbF3MfqEKmjM+I2EfhA94iG3L7uKrR+GdWD73ydlIB+6hgref1QTlmgmbM3/LeX5GI1Ux1RWpgxpLuZ2+I+IjzZ8wqE4nilvQdkUdfhzI5QDWy+kw5Wgg2pGpeEVeCCA7b85BO3F9DzxB3cdqvBzWcmzbyMiqhzuYqtHRVG2y4x+KOlnyqla8AoWWpuBoYRxzXrfKuILl6SfiWCbjxoZJUaCBj1CjH7GIaDbc9kqBY3W/Rgjda1iqQcOJu2WW+76pZC9QG7M00dffe9hNnseupFL53r8F7YHSwJWUKP2q+k7RdsxyOB11n0xtOvnW4irMMFNV4H0uqwS5ExsmP9AxbDTc9JwgneAT5vTiUSm1E7BSflSt3bfa1tv8Di3R8n3Af7MNWzs49hmauE2wP+ttrq+AsWpFG2awvsuOqbipWHgtuvuaAE+A1Z/7gC9hesnr+7wqCwG8c5yAg3AL1fm8T9AZtp/bbJGwl1pNrE7RuOX7PeMRUERVaPpEs+yqeoSmuOlokqw49pgomjLeh7icHNlG19yjs6XXOMedYm5xH2YxpV2tc0Ro2jJfxC50ApuxGob7lMsxfTbeUv07TyYxpeLucEH1gNd4IKH2LAg5TdVhlCafZvpskfncCfx8pOhJzd76bJWeYFnFciwcYfubRc12Ip/ppIhA1/mSZ/RxjFDrJC5xifFjJpY2Xl5zXdguFqYyTR1zSp1Y9p+tktDYYSNflcxI0iyO4TPBdlRcpeqjK/piF5bklq77VSEaA+z8qmJTFzIWiitbnzR794USKBUaT0NTEsVjZqLaFVqJoPN9ODG70IPbfBHKK+/q/AWR0tJzYHRULOa4MP+W/HfGadZUbfw177G7j/OGbIs8TahLyynl4X4RinF793Oz+BU0saXtUHrVBFT/DnA3ctNPoGbs4hRIjTok8i+algT1lTHi4SxFvONKNrgQFAq2/gFnWMXgwffgYMJpiKYkmW3tTg3ZQ9Jq+f8XN+A5eeUKHWvJWJ2sgJ1Sop+wwhqFVijqWaJhwtD8MNlSBeWNNWTa5Z5kPZw5+LbVT99wqTdx29lMUH4OIG/D86ruKEauBjvH5xy6um/Sfj7ei6UUVk4AIl3MyD4MSSTOFgSwsH/QJWaQ5as7ZcmgBZkzjjU1UrQ74ci1gWBCSGHtuV1H2mhSnO3Wp/3fEV5a+4wz//6qy8JxjZsmxxy5+4w9CDNJY09T072iKG0EnOS0arEYgXqYnXcYHwjTtUNAcMelOd4xpkoqiTYICWFq0JSiPfPDQdnt+4/wuqcXY47QILbgAAAABJRU5ErkJggg==)";
@ -689,6 +690,7 @@ export const nodeIconsLucide: iconsType = {
OpenAI: OpenAiIcon,
OpenRouter: OpenRouterIcon,
DeepSeek: DeepSeekIcon,
xAI: XAIIcon,
OpenAIEmbeddings: OpenAiIcon,
Pinecone: PineconeIcon,
Qdrant: QDrantIcon,