feat(message): support sequencing of multiple streamable models (#8434)

* feat: update OpenAI model parameters handling for reasoning models

* feat: extend input_value type in LCModelComponent to support AsyncIterator and Iterator

* refactor: remove assert_streaming_sequence method and related checks from Graph class

* feat: add consume_iterator method to Message class for handling iterators

* test: add unit tests for OpenAIModelComponent functionality and integration

* feat: update OpenAIModelComponent to include temperature and seed parameters in build_model method

* feat: rename consume_iterator method to consume_iterator_in_text and update its implementation for handling text

* feat: add is_connected_to_chat_output method to Component class for improved message handling

* feat: refactor LCModelComponent methods to support asynchronous message handling and improve chat output integration

* refactor: remove consume_iterator_in_text method from Message class and clean up LCModelComponent input handling

* fix: update import paths for input components in multiple starter project JSON files

* fix: enhance error message formatting in ErrorMessage class to handle additional exception attributes

* refactor: remove validate_stream calls from generate_flow_events and Graph class to streamline flow processing

* fix: handle asyncio.CancelledError in aadd_messagetables to ensure proper session rollback and retry logic

* refactor: streamline message handling in LCModelComponent by replacing async invocation with synchronous calls and updating message text handling

* refactor: enhance message handling in LCModelComponent by introducing lf_message for improved return value management and updating properties for consistency

* feat: add _build_source method to Component class for enhanced source handling and flexibility in source object management

* feat: enhance LCModelComponent by adding _handle_stream method for improved streaming response handling and refactoring chat output integration

* feat: update MemoryComponent to enhance message retrieval and storage functionality, including new sender type handling and output options for text and dataframe formats

* test: refactor LanguageModelComponent tests to use ComponentTestBaseWithoutClient and add tests for Google model creation and error handling

* test: add fixtures for API keys and implement live API tests for OpenAI, Anthropic, and Google models

* fix: reorder JSON properties for consistency in starter projects

* Updated JSON files for various starter projects to ensure consistent ordering of properties, specifically moving "type" to follow "selected_output" for better readability and maintainability.
* Affected files: Basic Prompt Chaining.json, Blog Writer.json, Financial Report Parser.json, Hybrid Search RAG.json, SEO Keyword Generator.json.

* refactor: simplify input_value type in LCModelComponent

* Updated the input_value parameter in LCModelComponent to remove AsyncIterator and Iterator types, streamlining the input options to only str and Message for improved clarity and maintainability.
* This change enhances the documentation and understanding of the expected input types for the component.

* fix: clarify comment for handling source in Component class

* refactor: remove unnecessary mocking in OpenAI model integration tests
This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-06-25 14:14:08 -03:00 committed by GitHub
commit 633b1e582a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 431 additions and 75 deletions

View file

@ -214,7 +214,6 @@ async def generate_flow_events(
async with session_scope() as fresh_session:
graph = await create_graph(fresh_session, flow_id_str, flow_name)
graph.validate_stream()
first_layer = sort_vertices(graph)
for vertex_id in first_layer:

View file

@ -14,6 +14,7 @@ from langflow.field_typing import LanguageModel
from langflow.inputs.inputs import BoolInput, InputTypes, MessageInput, MultilineInput
from langflow.schema.message import Message
from langflow.template.field.base import Output
from langflow.utils.constants import MESSAGE_SENDER_AI
# Enabled detailed thinking for NVIDIA reasoning models.
#
@ -82,12 +83,12 @@ class LCModelComponent(Component):
msg = f"Method '{method_name}' must be defined."
raise ValueError(msg)
def text_response(self) -> Message:
async def text_response(self) -> Message:
input_value = self.input_value
stream = self.stream
system_message = self.system_message
output = self.build_model()
result = self.get_chat_result(
result = await self.get_chat_result(
runnable=output, stream=stream, input_value=input_value, system_message=system_message
)
self.status = result
@ -167,7 +168,7 @@ class LCModelComponent(Component):
status_message = f"Response: {message.content}" # type: ignore[assignment]
return status_message
def get_chat_result(
async def get_chat_result(
self,
*,
runnable: LanguageModel,
@ -178,14 +179,14 @@ class LCModelComponent(Component):
if getattr(self, "detailed_thinking", False):
system_message = DETAILED_THINKING_PREFIX + (system_message or "")
return self._get_chat_result(
return await self._get_chat_result(
runnable=runnable,
stream=stream,
input_value=input_value,
system_message=system_message,
)
def _get_chat_result(
async def _get_chat_result(
self,
*,
runnable: LanguageModel,
@ -193,11 +194,29 @@ class LCModelComponent(Component):
input_value: str | Message,
system_message: str | None = None,
) -> Message:
"""Get chat result from a language model.
This method handles the core logic of getting a response from a language model,
including handling different input types, streaming, and error handling.
Args:
runnable (LanguageModel): The language model to use for generating responses
stream (bool): Whether to stream the response
input_value (str | Message): The input to send to the model
system_message (str | None, optional): System message to prepend. Defaults to None.
Returns:
The model response, either as a Message object or raw content
Raises:
ValueError: If the input message is empty or if there's an error during model invocation
"""
messages: list[BaseMessage] = []
if not input_value and not system_message:
msg = "The message you want to send to the model is empty."
raise ValueError(msg)
system_message_added = False
message = None
if input_value:
if isinstance(input_value, Message):
with warnings.catch_warnings():
@ -219,6 +238,7 @@ class LCModelComponent(Component):
if system_message and not system_message_added:
messages.insert(0, SystemMessage(content=system_message))
inputs: list | dict = messages or {}
lf_message = None
try:
# TODO: Depreciated Feature to be removed in upcoming release
if hasattr(self, "output_parser") and self.output_parser is not None:
@ -232,9 +252,10 @@ class LCModelComponent(Component):
}
)
if stream:
return runnable.stream(inputs)
message = runnable.invoke(inputs)
result = message.content if hasattr(message, "content") else message
lf_message, result = await self._handle_stream(runnable, inputs)
else:
message = runnable.invoke(inputs)
result = message.content if hasattr(message, "content") else message
if isinstance(message, AIMessage):
status_message = self.build_status_message(message)
self.status = status_message
@ -247,8 +268,41 @@ class LCModelComponent(Component):
if message := self._get_exception_message(e):
raise ValueError(message) from e
raise
return lf_message or Message(text=result)
return Message(text=result)
async def _handle_stream(self, runnable, inputs):
"""Handle streaming responses from the language model.
Args:
runnable: The language model configured for streaming
inputs: The inputs to send to the model
Returns:
tuple: (Message object if connected to chat output, model result)
"""
lf_message = None
if self.is_connected_to_chat_output():
# Add a Message
if hasattr(self, "graph"):
session_id = self.graph.session_id
elif hasattr(self, "_session_id"):
session_id = self._session_id
else:
session_id = None
model_message = Message(
text=runnable.stream(inputs),
sender=MESSAGE_SENDER_AI,
sender_name="AI",
properties={"icon": self.icon, "state": "partial"},
session_id=session_id,
)
model_message.properties.source = self._build_source(self._id, self.display_name, self)
lf_message = await self.send_message(model_message)
result = lf_message.text
else:
message = runnable.invoke(inputs)
result = message.content if hasattr(message, "content") else message
return lf_message, result
@abstractmethod
def build_model(self) -> LanguageModel: # type: ignore[type-var]

View file

@ -103,17 +103,15 @@ class OpenAIModelComponent(LCModelComponent):
"max_tokens": self.max_tokens or None,
"model_kwargs": self.model_kwargs or {},
"base_url": self.openai_api_base or "https://api.openai.com/v1",
"seed": self.seed,
"max_retries": self.max_retries,
"timeout": self.timeout,
"temperature": self.temperature if self.temperature is not None else 0.1,
}
logger.info(f"Model name: {self.model_name}")
if self.model_name in OPENAI_REASONING_MODEL_NAMES:
logger.info("Getting reasoning model parameters")
parameters.pop("temperature")
parameters.pop("seed")
if self.model_name not in OPENAI_REASONING_MODEL_NAMES:
parameters["temperature"] = self.temperature if self.temperature is not None else 0.1
parameters["seed"] = self.seed
output = ChatOpenAI(**parameters)
if self.json_mode:
output = output.bind(response_format={"type": "json_object"})

View file

@ -160,6 +160,22 @@ class Component(CustomComponent):
self._set_output_types(list(self._outputs_map.values()))
self.set_class_code()
def _build_source(self, id_: str | None, display_name: str | None, source: str | None) -> Source:
source_dict = {}
if id_:
source_dict["id"] = id_
if display_name:
source_dict["display_name"] = display_name
if source:
# Handle case where source is a ChatOpenAI and other models objects
if hasattr(source, "model_name"):
source_dict["source"] = source.model_name
elif hasattr(source, "model"):
source_dict["source"] = str(source.model)
else:
source_dict["source"] = str(source)
return Source(**source_dict)
def get_incoming_edge_by_target_param(self, target_param: str) -> str | None:
"""Get the source vertex ID for an incoming edge that targets a specific parameter.
@ -1354,12 +1370,15 @@ class Component(CustomComponent):
)
)
def is_connected_to_chat_output(self) -> bool:
return has_chat_output(self.graph.get_vertex_neighbors(self._vertex))
def _should_skip_message(self, message: Message) -> bool:
"""Check if the message should be skipped based on vertex configuration and message type."""
return (
self._vertex is not None
and not (self._vertex.is_output or self._vertex.is_input)
and not has_chat_output(self.graph.get_vertex_neighbors(self._vertex))
and not self.is_connected_to_chat_output()
and not isinstance(message, ErrorMessage)
)

View file

@ -1218,8 +1218,6 @@ class Graph:
if vertex.id in self.cycle_vertices:
self.run_manager.add_to_cycle_vertices(vertex.id)
self.assert_streaming_sequence()
def _get_edges_as_list_of_tuples(self) -> list[tuple[str, str]]:
"""Returns the edges of the graph as a list of tuples.
@ -1940,24 +1938,11 @@ class Graph:
vertex_instance.set_top_level(self.top_level_vertices)
return vertex_instance
def assert_streaming_sequence(self) -> None:
for i in self.edges:
source = self.get_vertex(i.source_id)
if "stream" in source.params and source.params["stream"] is True:
target = self.get_vertex(i.target_id)
if target.vertex_type != "ChatOutput":
msg = (
"Error: A 'streaming' vertex cannot be followed by a non-'chat output' vertex."
"Disable streaming to run the flow."
)
raise Exception(msg) # noqa: TRY002
def prepare(self, stop_component_id: str | None = None, start_component_id: str | None = None):
self.initialize()
if stop_component_id and start_component_id:
msg = "You can only provide one of stop_component_id or start_component_id"
raise ValueError(msg)
self.validate_stream()
if stop_component_id or start_component_id:
try:

File diff suppressed because one or more lines are too long

View file

@ -153,15 +153,16 @@ async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
try:
for message in messages:
session.add(message)
try:
for message in messages:
session.add(message)
await session.commit()
# This is a hack.
# We are doing this because build_public_tmp causes the CancelledError to be raised
# while build_flow does not.
except asyncio.CancelledError:
await session.commit()
await session.rollback()
return await aadd_messagetables(messages, session)
for message in messages:
await session.refresh(message)
except asyncio.CancelledError as e:

View file

@ -403,12 +403,18 @@ class ErrorMessage(Message):
"""Format the error reason without markdown."""
if hasattr(exception, "body") and isinstance(exception.body, dict) and "message" in exception.body:
reason = f"{exception.body.get('message')}\n"
elif hasattr(exception, "_message"):
reason = f"{exception._message()}\n" if callable(exception._message) else f"{exception._message}\n"
elif hasattr(exception, "code"):
reason = f"Code: {exception.code}\n"
elif hasattr(exception, "args") and exception.args:
reason = f"{exception.args[0]}\n"
elif isinstance(exception, ValidationError):
reason = f"{exception!s}\n"
elif hasattr(exception, "detail"):
reason = f"{exception.detail}\n"
elif hasattr(exception, "message"):
reason = f"{exception.message}\n"
else:
reason = "An unknown error occurred.\n"
return reason

View file

@ -0,0 +1,203 @@
import os
from unittest.mock import MagicMock, patch
import pytest
from langchain_openai import ChatOpenAI
from langflow.components.languagemodels.openai_chat_model import OpenAIModelComponent
from tests.base import ComponentTestBaseWithoutClient
class TestOpenAIModelComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
def component_class(self):
return OpenAIModelComponent
@pytest.fixture
def default_kwargs(self):
return {
"max_tokens": 1000,
"model_kwargs": {},
"json_mode": False,
"model_name": "gpt-4.1-nano",
"openai_api_base": "https://api.openai.com/v1",
"api_key": "test-api-key",
"temperature": 0.1,
"seed": 1,
"max_retries": 5,
"timeout": 700,
}
@pytest.fixture
def file_names_mapping(self):
# Provide an empty list or the actual mapping if versioned files exist
return []
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
async def test_build_model(self, mock_chat_openai, component_class, default_kwargs):
mock_instance = MagicMock()
mock_chat_openai.return_value = mock_instance
component = component_class(**default_kwargs)
model = component.build_model()
mock_chat_openai.assert_called_once_with(
api_key="test-api-key",
model_name="gpt-4.1-nano",
max_tokens=1000,
model_kwargs={},
base_url="https://api.openai.com/v1",
seed=1,
max_retries=5,
timeout=700,
temperature=0.1,
)
assert model == mock_instance
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
async def test_build_model_reasoning_model(self, mock_chat_openai, component_class, default_kwargs):
mock_instance = MagicMock()
mock_chat_openai.return_value = mock_instance
default_kwargs["model_name"] = "o1"
component = component_class(**default_kwargs)
model = component.build_model()
# For reasoning models, temperature and seed should be excluded
mock_chat_openai.assert_called_once_with(
api_key="test-api-key",
model_name="o1",
max_tokens=1000,
model_kwargs={},
base_url="https://api.openai.com/v1",
max_retries=5,
timeout=700,
)
assert model == mock_instance
# Verify that temperature and seed are not in the parameters
args, kwargs = mock_chat_openai.call_args
assert "temperature" not in kwargs
assert "seed" not in kwargs
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
async def test_build_model_with_json_mode(self, mock_chat_openai, component_class, default_kwargs):
mock_instance = MagicMock()
mock_bound_instance = MagicMock()
mock_instance.bind.return_value = mock_bound_instance
mock_chat_openai.return_value = mock_instance
default_kwargs["json_mode"] = True
component = component_class(**default_kwargs)
model = component.build_model()
mock_chat_openai.assert_called_once()
mock_instance.bind.assert_called_once_with(response_format={"type": "json_object"})
assert model == mock_bound_instance
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
async def test_build_model_no_api_key(self, mock_chat_openai, component_class, default_kwargs):
mock_instance = MagicMock()
mock_chat_openai.return_value = mock_instance
default_kwargs["api_key"] = None
component = component_class(**default_kwargs)
component.build_model()
# When api_key is None, it should be passed as None to ChatOpenAI
args, kwargs = mock_chat_openai.call_args
assert kwargs["api_key"] is None
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
async def test_build_model_max_tokens_zero(self, mock_chat_openai, component_class, default_kwargs):
mock_instance = MagicMock()
mock_chat_openai.return_value = mock_instance
default_kwargs["max_tokens"] = 0
component = component_class(**default_kwargs)
component.build_model()
# When max_tokens is 0, it should be passed as None to ChatOpenAI
args, kwargs = mock_chat_openai.call_args
assert kwargs["max_tokens"] is None
async def test_get_exception_message_bad_request_error(self, component_class, default_kwargs):
component_class(**default_kwargs)
# Create a mock BadRequestError with a body attribute
mock_error = MagicMock()
mock_error.body = {"message": "test error message"}
# Test the method directly by patching the import
with patch("openai.BadRequestError", mock_error.__class__):
# Manually call isinstance to avoid mocking it
if hasattr(mock_error, "body"):
message = mock_error.body.get("message")
assert message == "test error message"
async def test_get_exception_message_no_openai_import(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
# Test when openai module is not available
with patch.dict("sys.modules", {"openai": None}), patch("builtins.__import__", side_effect=ImportError):
message = component._get_exception_message(Exception("test"))
assert message is None
async def test_get_exception_message_other_exception(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
# Create a regular exception (not BadRequestError)
regular_exception = ValueError("test error")
# Create a simple mock for BadRequestError that the exception won't match
class MockBadRequestError:
pass
with patch("openai.BadRequestError", MockBadRequestError):
message = component._get_exception_message(regular_exception)
assert message is None
async def test_update_build_config_reasoning_model(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
build_config = {
"temperature": {"show": True},
"seed": {"show": True},
}
# Test with reasoning model
updated_config = component.update_build_config(build_config, "o1", "model_name")
assert updated_config["temperature"]["show"] is False
assert updated_config["seed"]["show"] is False
# Test with regular model
updated_config = component.update_build_config(build_config, "gpt-4", "model_name")
assert updated_config["temperature"]["show"] is True
assert updated_config["seed"]["show"] is True
def test_build_model_integration(self):
component = OpenAIModelComponent()
component.api_key = os.getenv("OPENAI_API_KEY")
component.model_name = "gpt-4.1-nano"
component.temperature = 0.2
component.max_tokens = 1000
component.seed = 42
component.max_retries = 3
component.timeout = 600
component.openai_api_base = "https://api.openai.com/v1"
model = component.build_model()
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-4.1-nano"
assert model.openai_api_base == "https://api.openai.com/v1"
def test_build_model_integration_reasoning(self):
component = OpenAIModelComponent()
component.api_key = os.getenv("OPENAI_API_KEY")
component.model_name = "o1"
component.temperature = 0.2 # This should be ignored for reasoning models
component.max_tokens = 1000
component.seed = 42 # This should be ignored for reasoning models
component.max_retries = 3
component.timeout = 600
component.openai_api_base = "https://api.openai.com/v1"
model = component.build_model()
assert isinstance(model, ChatOpenAI)
assert model.model_name == "o1"
assert model.openai_api_base == "https://api.openai.com/v1"

View file

@ -1,15 +1,18 @@
from unittest.mock import MagicMock, patch
import os
import pytest
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langflow.base.models.anthropic_constants import ANTHROPIC_MODELS
from langflow.base.models.google_generative_ai_constants import GOOGLE_GENERATIVE_AI_MODELS
from langflow.base.models.openai_constants import OPENAI_MODEL_NAMES
from langflow.components.models.language_model import LanguageModelComponent
from tests.base import ComponentTestBaseWithClient
from tests.base import ComponentTestBaseWithoutClient
@pytest.mark.usefixtures("client")
class TestLanguageModelComponent(ComponentTestBaseWithClient):
class TestLanguageModelComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
def component_class(self):
return LanguageModelComponent
@ -29,6 +32,32 @@ class TestLanguageModelComponent(ComponentTestBaseWithClient):
@pytest.fixture
def file_names_mapping(self):
"""Return the file names mapping for version-specific files."""
# No version-specific files for this component
return []
@pytest.fixture
def openai_api_key(self):
"""Fixture to get OpenAI API key from environment variable."""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY environment variable not set")
return api_key
@pytest.fixture
def anthropic_api_key(self):
"""Fixture to get Anthropic API key from environment variable."""
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY environment variable not set")
return api_key
@pytest.fixture
def google_api_key(self):
"""Fixture to get Google API key from environment variable."""
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
pytest.skip("GOOGLE_API_KEY environment variable not set")
return api_key
async def test_update_build_config_openai(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
@ -52,57 +81,69 @@ class TestLanguageModelComponent(ComponentTestBaseWithClient):
assert updated_config["model_name"]["value"] == ANTHROPIC_MODELS[0]
assert updated_config["api_key"]["display_name"] == "Anthropic API Key"
@patch("langflow.components.models.language_model.ChatOpenAI")
async def test_build_model_openai(self, mock_chat_openai, component_class, default_kwargs):
# Setup mock
mock_instance = MagicMock()
mock_chat_openai.return_value = mock_instance
async def test_update_build_config_google(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
build_config = {
"model_name": {"options": [], "value": ""},
"api_key": {"display_name": "API Key"},
}
updated_config = component.update_build_config(build_config, "Google", "provider")
assert updated_config["model_name"]["options"] == GOOGLE_GENERATIVE_AI_MODELS
assert updated_config["model_name"]["value"] == GOOGLE_GENERATIVE_AI_MODELS[0]
assert updated_config["api_key"]["display_name"] == "Google API Key"
# Create and configure the component
async def test_openai_model_creation(self, component_class, default_kwargs):
"""Test that the component returns an instance of ChatOpenAI for OpenAI provider."""
component = component_class(**default_kwargs)
component.provider = "OpenAI"
component.model_name = "gpt-3.5-turbo"
component.api_key = "test-key"
component.api_key = "sk-test-key" # Use a fake but correctly formatted key
component.temperature = 0.5
component.stream = False
# Build the model
# The API key will be invalid, but we should still get a ChatOpenAI instance
model = component.build_model()
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-3.5-turbo"
assert model.temperature == 0.5
assert model.streaming is False
# API key is stored as a SecretStr object, so we can't directly compare values
# Verify the ChatOpenAI was called with the correct parameters
mock_chat_openai.assert_called_once_with(
model_name="gpt-3.5-turbo",
temperature=0.5,
streaming=False,
openai_api_key="test-key",
)
assert model == mock_instance
@patch("langflow.components.models.language_model.ChatAnthropic")
async def test_build_model_anthropic(self, mock_chat_anthropic, component_class, default_kwargs):
# Setup mock
mock_instance = MagicMock()
mock_chat_anthropic.return_value = mock_instance
# Create and configure the component
async def test_anthropic_model_creation(self, component_class, default_kwargs):
"""Test that the component returns an instance of ChatAnthropic for Anthropic provider."""
component = component_class(**default_kwargs)
component.provider = "Anthropic"
component.model_name = ANTHROPIC_MODELS[0] # Use the first model from the constants
component.api_key = "test-key"
component.model_name = ANTHROPIC_MODELS[0]
component.api_key = "sk-ant-test-key" # Use a fake but plausible key
component.temperature = 0.7
component.stream = False
# Build the model
# The API key will be invalid, but we should still get a ChatAnthropic instance
model = component.build_model()
assert isinstance(model, ChatAnthropic)
assert model.model == ANTHROPIC_MODELS[0]
assert model.temperature == 0.7
assert model.streaming is False
# API key is stored as a SecretStr object, so we can't directly compare values
# Verify the ChatAnthropic was called with the correct parameters
mock_chat_anthropic.assert_called_once_with(
model=ANTHROPIC_MODELS[0],
temperature=0.7,
streaming=False,
anthropic_api_key="test-key",
)
assert model == mock_instance
async def test_google_model_creation(self, component_class, default_kwargs):
"""Test that the component returns an instance of ChatGoogleGenerativeAI for Google provider."""
component = component_class(**default_kwargs)
component.provider = "Google"
component.model_name = GOOGLE_GENERATIVE_AI_MODELS[0]
component.api_key = "google-test-key" # Use a fake but plausible key
component.temperature = 0.7
component.stream = False
# The API key will be invalid, but we should still get a ChatGoogleGenerativeAI instance
model = component.build_model()
assert isinstance(model, ChatGoogleGenerativeAI)
# Google model automatically prepends "models/" to the model name
assert model.model == f"models/{GOOGLE_GENERATIVE_AI_MODELS[0]}"
assert model.temperature == 0.7
# Google model uses 'stream' instead of 'streaming'
# Skip this check for Google model since it has a different interface
# API key is stored as a SecretStr object, so we can't directly compare values
async def test_build_model_openai_missing_api_key(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
@ -120,9 +161,59 @@ class TestLanguageModelComponent(ComponentTestBaseWithClient):
with pytest.raises(ValueError, match="Anthropic API key is required when using Anthropic provider"):
component.build_model()
async def test_build_model_google_missing_api_key(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
component.provider = "Google"
component.api_key = None
with pytest.raises(ValueError, match="Google API key is required when using Google provider"):
component.build_model()
async def test_build_model_unknown_provider(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
component.provider = "Unknown"
with pytest.raises(ValueError, match="Unknown provider: Unknown"):
component.build_model()
async def test_openai_live_api(self, component_class, default_kwargs, openai_api_key):
"""Test that the component can create a model with a real API key."""
component = component_class(**default_kwargs)
component.provider = "OpenAI"
component.model_name = "gpt-3.5-turbo"
component.api_key = openai_api_key
component.temperature = 0.1
component.stream = False
model = component.build_model()
assert isinstance(model, ChatOpenAI)
# We could attempt a simple call here, but that would increase test time
# and might fail due to network issues, so we'll just verify the instance
async def test_anthropic_live_api(self, component_class, default_kwargs, anthropic_api_key):
"""Test that the component can create a model with a real API key."""
component = component_class(**default_kwargs)
component.provider = "Anthropic"
component.model_name = ANTHROPIC_MODELS[0]
component.api_key = anthropic_api_key
component.temperature = 0.1
component.stream = False
model = component.build_model()
assert isinstance(model, ChatAnthropic)
# We could attempt a simple call here, but that would increase test time
# and might fail due to network issues, so we'll just verify the instance
async def test_google_live_api(self, component_class, default_kwargs, google_api_key):
"""Test that the component can create a model with a real API key."""
component = component_class(**default_kwargs)
component.provider = "Google"
component.model_name = GOOGLE_GENERATIVE_AI_MODELS[0]
component.api_key = google_api_key
component.temperature = 0.1
component.stream = False
model = component.build_model()
assert isinstance(model, ChatGoogleGenerativeAI)
# We could attempt a simple call here, but that would increase test time
# and might fail due to network issues, so we'll just verify the instance