diff --git a/src/backend/base/langflow/components/models/__init__.py b/src/backend/base/langflow/components/models/__init__.py index 1d214e939..b986b6f16 100644 --- a/src/backend/base/langflow/components/models/__init__.py +++ b/src/backend/base/langflow/components/models/__init__.py @@ -19,6 +19,7 @@ from .openrouter import OpenRouterComponent from .perplexity import PerplexityComponent from .sambanova import SambaNovaComponent from .vertexai import ChatVertexAIComponent +from .xai import XAIModelComponent __all__ = [ "AIMLModelComponent", @@ -42,4 +43,5 @@ __all__ = [ "PerplexityComponent", "QianfanChatEndpointComponent", "SambaNovaComponent", + "XAIModelComponent", ] diff --git a/src/backend/base/langflow/components/models/xai.py b/src/backend/base/langflow/components/models/xai.py new file mode 100644 index 000000000..bb6b04bee --- /dev/null +++ b/src/backend/base/langflow/components/models/xai.py @@ -0,0 +1,155 @@ +import requests +from langchain_openai import ChatOpenAI +from pydantic.v1 import SecretStr +from typing_extensions import override + +from langflow.base.models.model import LCModelComponent +from langflow.field_typing import LanguageModel +from langflow.field_typing.range_spec import RangeSpec +from langflow.inputs import BoolInput, DictInput, DropdownInput, IntInput, MessageTextInput, SecretStrInput, SliderInput + +XAI_DEFAULT_MODELS = ["grok-2-latest"] + + +class XAIModelComponent(LCModelComponent): + display_name = "xAI" + description = "Generates text using xAI models like Grok." + icon = "xAI" + name = "xAIModel" + + inputs = [ + *LCModelComponent._base_inputs, + IntInput( + name="max_tokens", + display_name="Max Tokens", + advanced=True, + info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.", + range_spec=RangeSpec(min=0, max=128000), + ), + DictInput( + name="model_kwargs", + display_name="Model Kwargs", + advanced=True, + info="Additional keyword arguments to pass to the model.", + ), + BoolInput( + name="json_mode", + display_name="JSON Mode", + advanced=True, + info="If True, it will output JSON regardless of passing a schema.", + ), + DropdownInput( + name="model_name", + display_name="Model Name", + advanced=False, + options=XAI_DEFAULT_MODELS, + value=XAI_DEFAULT_MODELS[0], + refresh_button=True, + combobox=True, + info="The xAI model to use", + ), + MessageTextInput( + name="base_url", + display_name="xAI API Base", + advanced=True, + info="The base URL of the xAI API. Defaults to https://api.x.ai/v1", + value="https://api.x.ai/v1", + ), + SecretStrInput( + name="api_key", + display_name="xAI API Key", + info="The xAI API Key to use for the model.", + advanced=False, + value="XAI_API_KEY", + required=True, + ), + SliderInput( + name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=2, step=0.01) + ), + IntInput( + name="seed", + display_name="Seed", + info="The seed controls the reproducibility of the job.", + advanced=True, + value=1, + ), + ] + + def get_models(self) -> list[str]: + """Fetch available models from xAI API.""" + if not self.api_key: + return XAI_DEFAULT_MODELS + + base_url = self.base_url or "https://api.x.ai/v1" + url = f"{base_url}/language-models" + headers = {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"} + + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + data = response.json() + + # Extract model IDs and any aliases + models = set() + for model in data.get("models", []): + models.add(model["id"]) + models.update(model.get("aliases", [])) + + return sorted(models) if models else XAI_DEFAULT_MODELS + except requests.RequestException as e: + self.status = f"Error fetching models: {e}" + return XAI_DEFAULT_MODELS + + @override + def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): + """Update build configuration with fresh model list when key fields change.""" + if field_name in {"api_key", "base_url", "model_name"}: + models = self.get_models() + build_config["model_name"]["options"] = models + return build_config + + def build_model(self) -> LanguageModel: + api_key = self.api_key + temperature = self.temperature + model_name: str = self.model_name + max_tokens = self.max_tokens + model_kwargs = self.model_kwargs or {} + base_url = self.base_url or "https://api.x.ai/v1" + json_mode = self.json_mode + seed = self.seed + + api_key = SecretStr(api_key).get_secret_value() if api_key else None + + output = ChatOpenAI( + max_tokens=max_tokens or None, + model_kwargs=model_kwargs, + model=model_name, + base_url=base_url, + api_key=api_key, + temperature=temperature if temperature is not None else 0.1, + seed=seed, + ) + + if json_mode: + output = output.bind(response_format={"type": "json_object"}) + + return output + + def _get_exception_message(self, e: Exception): + """Get a message from an xAI exception. + + Args: + e (Exception): The exception to get the message from. + + Returns: + str: The message from the exception. + """ + try: + from openai import BadRequestError + except ImportError: + return None + if isinstance(e, BadRequestError): + message = e.body.get("message") + if message: + return message + return None diff --git a/src/backend/tests/unit/components/models/test_xai.py b/src/backend/tests/unit/components/models/test_xai.py new file mode 100644 index 000000000..7ba969bcc --- /dev/null +++ b/src/backend/tests/unit/components/models/test_xai.py @@ -0,0 +1,198 @@ +from unittest.mock import MagicMock, patch + +import pytest +from langflow.components.models import XAIModelComponent +from langflow.custom import Component +from langflow.custom.utils import build_custom_component_template +from langflow.inputs import ( + BoolInput, + DictInput, + DropdownInput, + IntInput, + MessageTextInput, + SecretStrInput, + SliderInput, +) + +from tests.base import ComponentTestBaseWithoutClient + + +class TestXAIComponent(ComponentTestBaseWithoutClient): + @pytest.fixture + def component_class(self): + return XAIModelComponent + + @pytest.fixture + def default_kwargs(self): + return { + "temperature": 0.1, + "max_tokens": 50, + "api_key": "dummy-key", + "model_name": "grok-2-latest", + "model_kwargs": {}, + "base_url": "https://api.x.ai/v1", + "seed": 42, + } + + @pytest.fixture + def file_names_mapping(self): + return [] + + def test_initialization(self, component_class): + component = component_class() + assert component.display_name == "xAI" + assert component.description == "Generates text using xAI models like Grok." + assert component.icon == "xAI" + assert component.name == "xAIModel" + + def test_template(self, default_kwargs): + component = XAIModelComponent(**default_kwargs) + comp = Component(_code=component._code) + frontend_node, _ = build_custom_component_template(comp) + assert isinstance(frontend_node, dict) + assert "template" in frontend_node + input_names = [inp["name"] for inp in frontend_node["template"].values() if isinstance(inp, dict)] + expected_inputs = [ + "max_tokens", + "model_kwargs", + "json_mode", + "model_name", + "base_url", + "api_key", + "temperature", + "seed", + ] + for input_name in expected_inputs: + assert input_name in input_names + + def test_inputs(self): + component = XAIModelComponent() + inputs = component.inputs + expected_inputs = { + "max_tokens": IntInput, + "model_kwargs": DictInput, + "json_mode": BoolInput, + "model_name": DropdownInput, + "base_url": MessageTextInput, + "api_key": SecretStrInput, + "temperature": SliderInput, + "seed": IntInput, + } + for name, input_type in expected_inputs.items(): + matching_inputs = [inp for inp in inputs if isinstance(inp, input_type) and inp.name == name] + assert matching_inputs, f"Missing or incorrect input: {name}" + if name == "model_name": + input_field = matching_inputs[0] + assert input_field.value == "grok-2-latest" + assert input_field.refresh_button is True + elif name == "temperature": + input_field = matching_inputs[0] + assert input_field.value == 0.1 + assert input_field.range_spec.min == 0 + assert input_field.range_spec.max == 2 + + def test_build_model(self, component_class, default_kwargs, mocker): + component = component_class(**default_kwargs) + component.temperature = 0.7 + component.max_tokens = 100 + component.api_key = "test-key" + component.model_name = "grok-2-latest" + component.model_kwargs = {} + component.base_url = "https://api.x.ai/v1" + component.seed = 1 + + mock_chat_openai = mocker.patch("langflow.components.models.xai.ChatOpenAI", return_value=MagicMock()) + model = component.build_model() + mock_chat_openai.assert_called_once_with( + max_tokens=100, + model_kwargs={}, + model="grok-2-latest", + base_url="https://api.x.ai/v1", + api_key="test-key", + temperature=0.7, + seed=1, + ) + assert model == mock_chat_openai.return_value + + def test_get_models(self): + component = XAIModelComponent() + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.json.return_value = { + "models": [ + {"id": "grok-2-latest", "aliases": ["grok-2"]}, + {"id": "grok-1", "aliases": []}, + ] + } + mock_get.return_value = mock_response + + component.api_key = "test-key" + models = component.get_models() + assert sorted(models) == ["grok-1", "grok-2", "grok-2-latest"] + mock_get.assert_called_once_with( + "https://api.x.ai/v1/language-models", + headers={ + "Authorization": "Bearer test-key", + "Accept": "application/json", + }, + timeout=10, + ) + + def test_get_models_no_api_key(self): + component = XAIModelComponent(api_key=None) + models = component.get_models() + assert models == ["grok-2-latest"] + + def test_build_model_error(self, component_class, mocker): + from openai import BadRequestError + + component = component_class() + component.api_key = "invalid-key" + component.model_name = "grok-2-latest" + component.temperature = 0.7 + component.max_tokens = 100 + component.model_kwargs = {} + component.base_url = "https://api.x.ai/v1" + component.seed = 1 + + mocker.patch( + "langflow.components.models.xai.ChatOpenAI", + side_effect=BadRequestError( + message="Invalid API key", + response=MagicMock(), + body={"message": "Invalid API key provided"}, + ), + ) + with pytest.raises(BadRequestError) as exc_info: + component.build_model() + assert exc_info.value.body["message"] == "Invalid API key provided" + + def test_json_mode(self, component_class, mocker): + component = component_class() + component.api_key = "test-key" + component.json_mode = True + component.temperature = 0.7 + component.max_tokens = 100 + component.model_name = "grok-2-latest" + component.model_kwargs = {} + component.base_url = "https://api.x.ai/v1" + component.seed = 1 + + mock_instance = MagicMock() + mock_bound_instance = MagicMock() + mock_instance.bind.return_value = mock_bound_instance + mocker.patch("langflow.components.models.xai.ChatOpenAI", return_value=mock_instance) + + model = component.build_model() + mock_instance.bind.assert_called_once_with(response_format={"type": "json_object"}) + assert model == mock_bound_instance + + def test_update_build_config(self): + component = XAIModelComponent() + build_config = {"model_name": {"options": []}} + + updated_config = component.update_build_config(build_config, "test-key", "api_key") + assert "model_name" in updated_config + + updated_config = component.update_build_config(build_config, "grok-2-latest", "model_name") + assert "model_name" in updated_config diff --git a/src/frontend/src/icons/xAI/index.tsx b/src/frontend/src/icons/xAI/index.tsx new file mode 100644 index 000000000..219e33b94 --- /dev/null +++ b/src/frontend/src/icons/xAI/index.tsx @@ -0,0 +1,10 @@ +import { useDarkStore } from "@/stores/darkStore"; +import React, { forwardRef } from "react"; +import XAISVG from "./xAIIcon.jsx"; + +export const XAIIcon = forwardRef>( + (props, ref) => { + const isdark = useDarkStore((state) => state.dark).toString(); + return ; + }, +); diff --git a/src/frontend/src/icons/xAI/xAIIcon.jsx b/src/frontend/src/icons/xAI/xAIIcon.jsx new file mode 100644 index 000000000..24d398d16 --- /dev/null +++ b/src/frontend/src/icons/xAI/xAIIcon.jsx @@ -0,0 +1,19 @@ +import { stringToBool } from "@/utils/utils"; + +const XAISVG = (props) => ( + + Grok + + +); + +export default XAISVG; diff --git a/src/frontend/src/icons/xAI/xai.svg b/src/frontend/src/icons/xAI/xai.svg new file mode 100644 index 000000000..536e71390 --- /dev/null +++ b/src/frontend/src/icons/xAI/xai.svg @@ -0,0 +1 @@ +Grok \ No newline at end of file diff --git a/src/frontend/src/utils/styleUtils.ts b/src/frontend/src/utils/styleUtils.ts index 9783a6b2f..586c61f0d 100644 --- a/src/frontend/src/utils/styleUtils.ts +++ b/src/frontend/src/utils/styleUtils.ts @@ -316,6 +316,7 @@ import SvgWolfram from "../icons/Wolfram/Wolfram"; import { HackerNewsIcon } from "../icons/hackerNews"; import { MistralIcon } from "../icons/mistral"; import { SupabaseIcon } from "../icons/supabase"; +import { XAIIcon } from "../icons/xAI"; import { iconsType } from "../types/components"; export const BG_NOISE = "url(data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADIAAAAyCAMAAAAp4XiDAAAAUVBMVEWFhYWDg4N3d3dtbW17e3t1dXWBgYGHh4d5eXlzc3OLi4ubm5uVlZWPj4+NjY19fX2JiYl/f39ra2uRkZGZmZlpaWmXl5dvb29xcXGTk5NnZ2c8TV1mAAAAG3RSTlNAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEAvEOwtAAAFVklEQVR4XpWWB67c2BUFb3g557T/hRo9/WUMZHlgr4Bg8Z4qQgQJlHI4A8SzFVrapvmTF9O7dmYRFZ60YiBhJRCgh1FYhiLAmdvX0CzTOpNE77ME0Zty/nWWzchDtiqrmQDeuv3powQ5ta2eN0FY0InkqDD73lT9c9lEzwUNqgFHs9VQce3TVClFCQrSTfOiYkVJQBmpbq2L6iZavPnAPcoU0dSw0SUTqz/GtrGuXfbyyBniKykOWQWGqwwMA7QiYAxi+IlPdqo+hYHnUt5ZPfnsHJyNiDtnpJyayNBkF6cWoYGAMY92U2hXHF/C1M8uP/ZtYdiuj26UdAdQQSXQErwSOMzt/XWRWAz5GuSBIkwG1H3FabJ2OsUOUhGC6tK4EMtJO0ttC6IBD3kM0ve0tJwMdSfjZo+EEISaeTr9P3wYrGjXqyC1krcKdhMpxEnt5JetoulscpyzhXN5FRpuPHvbeQaKxFAEB6EN+cYN6xD7RYGpXpNndMmZgM5Dcs3YSNFDHUo2LGfZuukSWyUYirJAdYbF3MfqEKmjM+I2EfhA94iG3L7uKrR+GdWD73ydlIB+6hgref1QTlmgmbM3/LeX5GI1Ux1RWpgxpLuZ2+I+IjzZ8wqE4nilvQdkUdfhzI5QDWy+kw5Wgg2pGpeEVeCCA7b85BO3F9DzxB3cdqvBzWcmzbyMiqhzuYqtHRVG2y4x+KOlnyqla8AoWWpuBoYRxzXrfKuILl6SfiWCbjxoZJUaCBj1CjH7GIaDbc9kqBY3W/Rgjda1iqQcOJu2WW+76pZC9QG7M00dffe9hNnseupFL53r8F7YHSwJWUKP2q+k7RdsxyOB11n0xtOvnW4irMMFNV4H0uqwS5ExsmP9AxbDTc9JwgneAT5vTiUSm1E7BSflSt3bfa1tv8Di3R8n3Af7MNWzs49hmauE2wP+ttrq+AsWpFG2awvsuOqbipWHgtuvuaAE+A1Z/7gC9hesnr+7wqCwG8c5yAg3AL1fm8T9AZtp/bbJGwl1pNrE7RuOX7PeMRUERVaPpEs+yqeoSmuOlokqw49pgomjLeh7icHNlG19yjs6XXOMedYm5xH2YxpV2tc0Ro2jJfxC50ApuxGob7lMsxfTbeUv07TyYxpeLucEH1gNd4IKH2LAg5TdVhlCafZvpskfncCfx8pOhJzd76bJWeYFnFciwcYfubRc12Ip/ppIhA1/mSZ/RxjFDrJC5xifFjJpY2Xl5zXdguFqYyTR1zSp1Y9p+tktDYYSNflcxI0iyO4TPBdlRcpeqjK/piF5bklq77VSEaA+z8qmJTFzIWiitbnzR794USKBUaT0NTEsVjZqLaFVqJoPN9ODG70IPbfBHKK+/q/AWR0tJzYHRULOa4MP+W/HfGadZUbfw177G7j/OGbIs8TahLyynl4X4RinF793Oz+BU0saXtUHrVBFT/DnA3ctNPoGbs4hRIjTok8i+algT1lTHi4SxFvONKNrgQFAq2/gFnWMXgwffgYMJpiKYkmW3tTg3ZQ9Jq+f8XN+A5eeUKHWvJWJ2sgJ1Sop+wwhqFVijqWaJhwtD8MNlSBeWNNWTa5Z5kPZw5+LbVT99wqTdx29lMUH4OIG/D86ruKEauBjvH5xy6um/Sfj7ei6UUVk4AIl3MyD4MSSTOFgSwsH/QJWaQ5as7ZcmgBZkzjjU1UrQ74ci1gWBCSGHtuV1H2mhSnO3Wp/3fEV5a+4wz//6qy8JxjZsmxxy5+4w9CDNJY09T072iKG0EnOS0arEYgXqYnXcYHwjTtUNAcMelOd4xpkoqiTYICWFq0JSiPfPDQdnt+4/wuqcXY47QILbgAAAABJRU5ErkJggg==)"; @@ -689,6 +690,7 @@ export const nodeIconsLucide: iconsType = { OpenAI: OpenAiIcon, OpenRouter: OpenRouterIcon, DeepSeek: DeepSeekIcon, + xAI: XAIIcon, OpenAIEmbeddings: OpenAiIcon, Pinecone: PineconeIcon, Qdrant: QDrantIcon,