feat: add xAI integration (#6012)

* feat: add xAI integration

* fix: implement file_names_mapping fixture in test_xai.py aligning with test standards

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* fix: remove unused variable in test_xai

* refactor: update input types and variable naming

* [autofix.ci] apply automated fixes

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* Update src/backend/tests/unit/components/models/test_xai.py

Co-authored-by: Christophe Bornet <cbornet@hotmail.com>

* [autofix.ci] apply automated fixes

* Update xai.py

* test: update test_xai to use MessageTextInput and base_url

* fix: add missing component_class parameter to test_build_model_error

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Christophe Bornet <cbornet@hotmail.com>
Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
Raphael Valdetaro 2025-02-19 11:28:46 -03:00 committed by GitHub
commit e970cdbca3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 387 additions and 0 deletions

View file

@ -19,6 +19,7 @@ from .openrouter import OpenRouterComponent
from .perplexity import PerplexityComponent
from .sambanova import SambaNovaComponent
from .vertexai import ChatVertexAIComponent
from .xai import XAIModelComponent
__all__ = [
"AIMLModelComponent",
@ -42,4 +43,5 @@ __all__ = [
"PerplexityComponent",
"QianfanChatEndpointComponent",
"SambaNovaComponent",
"XAIModelComponent",
]

View file

@ -0,0 +1,155 @@
import requests
from langchain_openai import ChatOpenAI
from pydantic.v1 import SecretStr
from typing_extensions import override
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.inputs import BoolInput, DictInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput
XAI_DEFAULT_MODELS = ["grok-2-latest"]
class XAIModelComponent(LCModelComponent):
display_name = "xAI"
description = "Generates text using xAI models like Grok."
icon = "xAI"
name = "xAIModel"
inputs = [
*LCModelComponent._base_inputs,
IntInput(
name="max_tokens",
display_name="Max Tokens",
advanced=True,
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
range_spec=RangeSpec(min=0, max=128000),
),
DictInput(
name="model_kwargs",
display_name="Model Kwargs",
advanced=True,
info="Additional keyword arguments to pass to the model.",
),
BoolInput(
name="json_mode",
display_name="JSON Mode",
advanced=True,
info="If True, it will output JSON regardless of passing a schema.",
),
DropdownInput(
name="model_name",
display_name="Model Name",
advanced=False,
options=XAI_DEFAULT_MODELS,
value=XAI_DEFAULT_MODELS[0],
refresh_button=True,
combobox=True,
info="The xAI model to use",
),
MessageTextInput(
name="base_url",
display_name="xAI API Base",
advanced=True,
info="The base URL of the xAI API. Defaults to https://api.x.ai/v1",
value="https://api.x.ai/v1",
),
SecretStrInput(
name="api_key",
display_name="xAI API Key",
info="The xAI API Key to use for the model.",
advanced=False,
value="XAI_API_KEY",
required=True,
),
SliderInput(
name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=2, step=0.01)
),
IntInput(
name="seed",
display_name="Seed",
info="The seed controls the reproducibility of the job.",
advanced=True,
value=1,
),
]
def get_models(self) -> list[str]:
"""Fetch available models from xAI API."""
if not self.api_key:
return XAI_DEFAULT_MODELS
base_url = self.base_url or "https://api.x.ai/v1"
url = f"{base_url}/language-models"
headers = {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
# Extract model IDs and any aliases
models = set()
for model in data.get("models", []):
models.add(model["id"])
models.update(model.get("aliases", []))
return sorted(models) if models else XAI_DEFAULT_MODELS
except requests.RequestException as e:
self.status = f"Error fetching models: {e}"
return XAI_DEFAULT_MODELS
@override
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
"""Update build configuration with fresh model list when key fields change."""
if field_name in {"api_key", "base_url", "model_name"}:
models = self.get_models()
build_config["model_name"]["options"] = models
return build_config
def build_model(self) -> LanguageModel:
api_key = self.api_key
temperature = self.temperature
model_name: str = self.model_name
max_tokens = self.max_tokens
model_kwargs = self.model_kwargs or {}
base_url = self.base_url or "https://api.x.ai/v1"
json_mode = self.json_mode
seed = self.seed
api_key = SecretStr(api_key).get_secret_value() if api_key else None
output = ChatOpenAI(
max_tokens=max_tokens or None,
model_kwargs=model_kwargs,
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature if temperature is not None else 0.1,
seed=seed,
)
if json_mode:
output = output.bind(response_format={"type": "json_object"})
return output
def _get_exception_message(self, e: Exception):
"""Get a message from an xAI exception.
Args:
e (Exception): The exception to get the message from.
Returns:
str: The message from the exception.
"""
try:
from openai import BadRequestError
except ImportError:
return None
if isinstance(e, BadRequestError):
message = e.body.get("message")
if message:
return message
return None