feat(message): support sequencing of multiple streamable models (#8434)
* feat: update OpenAI model parameters handling for reasoning models * feat: extend input_value type in LCModelComponent to support AsyncIterator and Iterator * refactor: remove assert_streaming_sequence method and related checks from Graph class * feat: add consume_iterator method to Message class for handling iterators * test: add unit tests for OpenAIModelComponent functionality and integration * feat: update OpenAIModelComponent to include temperature and seed parameters in build_model method * feat: rename consume_iterator method to consume_iterator_in_text and update its implementation for handling text * feat: add is_connected_to_chat_output method to Component class for improved message handling * feat: refactor LCModelComponent methods to support asynchronous message handling and improve chat output integration * refactor: remove consume_iterator_in_text method from Message class and clean up LCModelComponent input handling * fix: update import paths for input components in multiple starter project JSON files * fix: enhance error message formatting in ErrorMessage class to handle additional exception attributes * refactor: remove validate_stream calls from generate_flow_events and Graph class to streamline flow processing * fix: handle asyncio.CancelledError in aadd_messagetables to ensure proper session rollback and retry logic * refactor: streamline message handling in LCModelComponent by replacing async invocation with synchronous calls and updating message text handling * refactor: enhance message handling in LCModelComponent by introducing lf_message for improved return value management and updating properties for consistency * feat: add _build_source method to Component class for enhanced source handling and flexibility in source object management * feat: enhance LCModelComponent by adding _handle_stream method for improved streaming response handling and refactoring chat output integration * feat: update MemoryComponent to enhance message retrieval and storage functionality, including new sender type handling and output options for text and dataframe formats * test: refactor LanguageModelComponent tests to use ComponentTestBaseWithoutClient and add tests for Google model creation and error handling * test: add fixtures for API keys and implement live API tests for OpenAI, Anthropic, and Google models * fix: reorder JSON properties for consistency in starter projects * Updated JSON files for various starter projects to ensure consistent ordering of properties, specifically moving "type" to follow "selected_output" for better readability and maintainability. * Affected files: Basic Prompt Chaining.json, Blog Writer.json, Financial Report Parser.json, Hybrid Search RAG.json, SEO Keyword Generator.json. * refactor: simplify input_value type in LCModelComponent * Updated the input_value parameter in LCModelComponent to remove AsyncIterator and Iterator types, streamlining the input options to only str and Message for improved clarity and maintainability. * This change enhances the documentation and understanding of the expected input types for the component. * fix: clarify comment for handling source in Component class * refactor: remove unnecessary mocking in OpenAI model integration tests
This commit is contained in:
parent
38d5885fa3
commit
633b1e582a
10 changed files with 431 additions and 75 deletions
|
|
@ -214,7 +214,6 @@ async def generate_flow_events(
|
|||
async with session_scope() as fresh_session:
|
||||
graph = await create_graph(fresh_session, flow_id_str, flow_name)
|
||||
|
||||
graph.validate_stream()
|
||||
first_layer = sort_vertices(graph)
|
||||
|
||||
for vertex_id in first_layer:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from langflow.field_typing import LanguageModel
|
|||
from langflow.inputs.inputs import BoolInput, InputTypes, MessageInput, MultilineInput
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template.field.base import Output
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI
|
||||
|
||||
# Enabled detailed thinking for NVIDIA reasoning models.
|
||||
#
|
||||
|
|
@ -82,12 +83,12 @@ class LCModelComponent(Component):
|
|||
msg = f"Method '{method_name}' must be defined."
|
||||
raise ValueError(msg)
|
||||
|
||||
def text_response(self) -> Message:
|
||||
async def text_response(self) -> Message:
|
||||
input_value = self.input_value
|
||||
stream = self.stream
|
||||
system_message = self.system_message
|
||||
output = self.build_model()
|
||||
result = self.get_chat_result(
|
||||
result = await self.get_chat_result(
|
||||
runnable=output, stream=stream, input_value=input_value, system_message=system_message
|
||||
)
|
||||
self.status = result
|
||||
|
|
@ -167,7 +168,7 @@ class LCModelComponent(Component):
|
|||
status_message = f"Response: {message.content}" # type: ignore[assignment]
|
||||
return status_message
|
||||
|
||||
def get_chat_result(
|
||||
async def get_chat_result(
|
||||
self,
|
||||
*,
|
||||
runnable: LanguageModel,
|
||||
|
|
@ -178,14 +179,14 @@ class LCModelComponent(Component):
|
|||
if getattr(self, "detailed_thinking", False):
|
||||
system_message = DETAILED_THINKING_PREFIX + (system_message or "")
|
||||
|
||||
return self._get_chat_result(
|
||||
return await self._get_chat_result(
|
||||
runnable=runnable,
|
||||
stream=stream,
|
||||
input_value=input_value,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
def _get_chat_result(
|
||||
async def _get_chat_result(
|
||||
self,
|
||||
*,
|
||||
runnable: LanguageModel,
|
||||
|
|
@ -193,11 +194,29 @@ class LCModelComponent(Component):
|
|||
input_value: str | Message,
|
||||
system_message: str | None = None,
|
||||
) -> Message:
|
||||
"""Get chat result from a language model.
|
||||
|
||||
This method handles the core logic of getting a response from a language model,
|
||||
including handling different input types, streaming, and error handling.
|
||||
|
||||
Args:
|
||||
runnable (LanguageModel): The language model to use for generating responses
|
||||
stream (bool): Whether to stream the response
|
||||
input_value (str | Message): The input to send to the model
|
||||
system_message (str | None, optional): System message to prepend. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The model response, either as a Message object or raw content
|
||||
|
||||
Raises:
|
||||
ValueError: If the input message is empty or if there's an error during model invocation
|
||||
"""
|
||||
messages: list[BaseMessage] = []
|
||||
if not input_value and not system_message:
|
||||
msg = "The message you want to send to the model is empty."
|
||||
raise ValueError(msg)
|
||||
system_message_added = False
|
||||
message = None
|
||||
if input_value:
|
||||
if isinstance(input_value, Message):
|
||||
with warnings.catch_warnings():
|
||||
|
|
@ -219,6 +238,7 @@ class LCModelComponent(Component):
|
|||
if system_message and not system_message_added:
|
||||
messages.insert(0, SystemMessage(content=system_message))
|
||||
inputs: list | dict = messages or {}
|
||||
lf_message = None
|
||||
try:
|
||||
# TODO: Depreciated Feature to be removed in upcoming release
|
||||
if hasattr(self, "output_parser") and self.output_parser is not None:
|
||||
|
|
@ -232,9 +252,10 @@ class LCModelComponent(Component):
|
|||
}
|
||||
)
|
||||
if stream:
|
||||
return runnable.stream(inputs)
|
||||
message = runnable.invoke(inputs)
|
||||
result = message.content if hasattr(message, "content") else message
|
||||
lf_message, result = await self._handle_stream(runnable, inputs)
|
||||
else:
|
||||
message = runnable.invoke(inputs)
|
||||
result = message.content if hasattr(message, "content") else message
|
||||
if isinstance(message, AIMessage):
|
||||
status_message = self.build_status_message(message)
|
||||
self.status = status_message
|
||||
|
|
@ -247,8 +268,41 @@ class LCModelComponent(Component):
|
|||
if message := self._get_exception_message(e):
|
||||
raise ValueError(message) from e
|
||||
raise
|
||||
return lf_message or Message(text=result)
|
||||
|
||||
return Message(text=result)
|
||||
async def _handle_stream(self, runnable, inputs):
|
||||
"""Handle streaming responses from the language model.
|
||||
|
||||
Args:
|
||||
runnable: The language model configured for streaming
|
||||
inputs: The inputs to send to the model
|
||||
|
||||
Returns:
|
||||
tuple: (Message object if connected to chat output, model result)
|
||||
"""
|
||||
lf_message = None
|
||||
if self.is_connected_to_chat_output():
|
||||
# Add a Message
|
||||
if hasattr(self, "graph"):
|
||||
session_id = self.graph.session_id
|
||||
elif hasattr(self, "_session_id"):
|
||||
session_id = self._session_id
|
||||
else:
|
||||
session_id = None
|
||||
model_message = Message(
|
||||
text=runnable.stream(inputs),
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="AI",
|
||||
properties={"icon": self.icon, "state": "partial"},
|
||||
session_id=session_id,
|
||||
)
|
||||
model_message.properties.source = self._build_source(self._id, self.display_name, self)
|
||||
lf_message = await self.send_message(model_message)
|
||||
result = lf_message.text
|
||||
else:
|
||||
message = runnable.invoke(inputs)
|
||||
result = message.content if hasattr(message, "content") else message
|
||||
return lf_message, result
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
|
|
|
|||
|
|
@ -103,17 +103,15 @@ class OpenAIModelComponent(LCModelComponent):
|
|||
"max_tokens": self.max_tokens or None,
|
||||
"model_kwargs": self.model_kwargs or {},
|
||||
"base_url": self.openai_api_base or "https://api.openai.com/v1",
|
||||
"seed": self.seed,
|
||||
"max_retries": self.max_retries,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature if self.temperature is not None else 0.1,
|
||||
}
|
||||
|
||||
logger.info(f"Model name: {self.model_name}")
|
||||
if self.model_name in OPENAI_REASONING_MODEL_NAMES:
|
||||
logger.info("Getting reasoning model parameters")
|
||||
parameters.pop("temperature")
|
||||
parameters.pop("seed")
|
||||
if self.model_name not in OPENAI_REASONING_MODEL_NAMES:
|
||||
parameters["temperature"] = self.temperature if self.temperature is not None else 0.1
|
||||
parameters["seed"] = self.seed
|
||||
|
||||
output = ChatOpenAI(**parameters)
|
||||
if self.json_mode:
|
||||
output = output.bind(response_format={"type": "json_object"})
|
||||
|
|
|
|||
|
|
@ -160,6 +160,22 @@ class Component(CustomComponent):
|
|||
self._set_output_types(list(self._outputs_map.values()))
|
||||
self.set_class_code()
|
||||
|
||||
def _build_source(self, id_: str | None, display_name: str | None, source: str | None) -> Source:
|
||||
source_dict = {}
|
||||
if id_:
|
||||
source_dict["id"] = id_
|
||||
if display_name:
|
||||
source_dict["display_name"] = display_name
|
||||
if source:
|
||||
# Handle case where source is a ChatOpenAI and other models objects
|
||||
if hasattr(source, "model_name"):
|
||||
source_dict["source"] = source.model_name
|
||||
elif hasattr(source, "model"):
|
||||
source_dict["source"] = str(source.model)
|
||||
else:
|
||||
source_dict["source"] = str(source)
|
||||
return Source(**source_dict)
|
||||
|
||||
def get_incoming_edge_by_target_param(self, target_param: str) -> str | None:
|
||||
"""Get the source vertex ID for an incoming edge that targets a specific parameter.
|
||||
|
||||
|
|
@ -1354,12 +1370,15 @@ class Component(CustomComponent):
|
|||
)
|
||||
)
|
||||
|
||||
def is_connected_to_chat_output(self) -> bool:
|
||||
return has_chat_output(self.graph.get_vertex_neighbors(self._vertex))
|
||||
|
||||
def _should_skip_message(self, message: Message) -> bool:
|
||||
"""Check if the message should be skipped based on vertex configuration and message type."""
|
||||
return (
|
||||
self._vertex is not None
|
||||
and not (self._vertex.is_output or self._vertex.is_input)
|
||||
and not has_chat_output(self.graph.get_vertex_neighbors(self._vertex))
|
||||
and not self.is_connected_to_chat_output()
|
||||
and not isinstance(message, ErrorMessage)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1218,8 +1218,6 @@ class Graph:
|
|||
if vertex.id in self.cycle_vertices:
|
||||
self.run_manager.add_to_cycle_vertices(vertex.id)
|
||||
|
||||
self.assert_streaming_sequence()
|
||||
|
||||
def _get_edges_as_list_of_tuples(self) -> list[tuple[str, str]]:
|
||||
"""Returns the edges of the graph as a list of tuples.
|
||||
|
||||
|
|
@ -1940,24 +1938,11 @@ class Graph:
|
|||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
return vertex_instance
|
||||
|
||||
def assert_streaming_sequence(self) -> None:
|
||||
for i in self.edges:
|
||||
source = self.get_vertex(i.source_id)
|
||||
if "stream" in source.params and source.params["stream"] is True:
|
||||
target = self.get_vertex(i.target_id)
|
||||
if target.vertex_type != "ChatOutput":
|
||||
msg = (
|
||||
"Error: A 'streaming' vertex cannot be followed by a non-'chat output' vertex."
|
||||
"Disable streaming to run the flow."
|
||||
)
|
||||
raise Exception(msg) # noqa: TRY002
|
||||
|
||||
def prepare(self, stop_component_id: str | None = None, start_component_id: str | None = None):
|
||||
self.initialize()
|
||||
if stop_component_id and start_component_id:
|
||||
msg = "You can only provide one of stop_component_id or start_component_id"
|
||||
raise ValueError(msg)
|
||||
self.validate_stream()
|
||||
|
||||
if stop_component_id or start_component_id:
|
||||
try:
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -153,15 +153,16 @@ async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
|
|||
|
||||
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
|
||||
try:
|
||||
for message in messages:
|
||||
session.add(message)
|
||||
try:
|
||||
for message in messages:
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
# This is a hack.
|
||||
# We are doing this because build_public_tmp causes the CancelledError to be raised
|
||||
# while build_flow does not.
|
||||
except asyncio.CancelledError:
|
||||
await session.commit()
|
||||
await session.rollback()
|
||||
return await aadd_messagetables(messages, session)
|
||||
for message in messages:
|
||||
await session.refresh(message)
|
||||
except asyncio.CancelledError as e:
|
||||
|
|
|
|||
|
|
@ -403,12 +403,18 @@ class ErrorMessage(Message):
|
|||
"""Format the error reason without markdown."""
|
||||
if hasattr(exception, "body") and isinstance(exception.body, dict) and "message" in exception.body:
|
||||
reason = f"{exception.body.get('message')}\n"
|
||||
elif hasattr(exception, "_message"):
|
||||
reason = f"{exception._message()}\n" if callable(exception._message) else f"{exception._message}\n"
|
||||
elif hasattr(exception, "code"):
|
||||
reason = f"Code: {exception.code}\n"
|
||||
elif hasattr(exception, "args") and exception.args:
|
||||
reason = f"{exception.args[0]}\n"
|
||||
elif isinstance(exception, ValidationError):
|
||||
reason = f"{exception!s}\n"
|
||||
elif hasattr(exception, "detail"):
|
||||
reason = f"{exception.detail}\n"
|
||||
elif hasattr(exception, "message"):
|
||||
reason = f"{exception.message}\n"
|
||||
else:
|
||||
reason = "An unknown error occurred.\n"
|
||||
return reason
|
||||
|
|
|
|||
|
|
@ -0,0 +1,203 @@
|
|||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langflow.components.languagemodels.openai_chat_model import OpenAIModelComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
||||
class TestOpenAIModelComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return OpenAIModelComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
return {
|
||||
"max_tokens": 1000,
|
||||
"model_kwargs": {},
|
||||
"json_mode": False,
|
||||
"model_name": "gpt-4.1-nano",
|
||||
"openai_api_base": "https://api.openai.com/v1",
|
||||
"api_key": "test-api-key",
|
||||
"temperature": 0.1,
|
||||
"seed": 1,
|
||||
"max_retries": 5,
|
||||
"timeout": 700,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
# Provide an empty list or the actual mapping if versioned files exist
|
||||
return []
|
||||
|
||||
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
|
||||
async def test_build_model(self, mock_chat_openai, component_class, default_kwargs):
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_openai.return_value = mock_instance
|
||||
component = component_class(**default_kwargs)
|
||||
model = component.build_model()
|
||||
|
||||
mock_chat_openai.assert_called_once_with(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4.1-nano",
|
||||
max_tokens=1000,
|
||||
model_kwargs={},
|
||||
base_url="https://api.openai.com/v1",
|
||||
seed=1,
|
||||
max_retries=5,
|
||||
timeout=700,
|
||||
temperature=0.1,
|
||||
)
|
||||
assert model == mock_instance
|
||||
|
||||
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
|
||||
async def test_build_model_reasoning_model(self, mock_chat_openai, component_class, default_kwargs):
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_openai.return_value = mock_instance
|
||||
default_kwargs["model_name"] = "o1"
|
||||
component = component_class(**default_kwargs)
|
||||
model = component.build_model()
|
||||
|
||||
# For reasoning models, temperature and seed should be excluded
|
||||
mock_chat_openai.assert_called_once_with(
|
||||
api_key="test-api-key",
|
||||
model_name="o1",
|
||||
max_tokens=1000,
|
||||
model_kwargs={},
|
||||
base_url="https://api.openai.com/v1",
|
||||
max_retries=5,
|
||||
timeout=700,
|
||||
)
|
||||
assert model == mock_instance
|
||||
|
||||
# Verify that temperature and seed are not in the parameters
|
||||
args, kwargs = mock_chat_openai.call_args
|
||||
assert "temperature" not in kwargs
|
||||
assert "seed" not in kwargs
|
||||
|
||||
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
|
||||
async def test_build_model_with_json_mode(self, mock_chat_openai, component_class, default_kwargs):
|
||||
mock_instance = MagicMock()
|
||||
mock_bound_instance = MagicMock()
|
||||
mock_instance.bind.return_value = mock_bound_instance
|
||||
mock_chat_openai.return_value = mock_instance
|
||||
|
||||
default_kwargs["json_mode"] = True
|
||||
component = component_class(**default_kwargs)
|
||||
model = component.build_model()
|
||||
|
||||
mock_chat_openai.assert_called_once()
|
||||
mock_instance.bind.assert_called_once_with(response_format={"type": "json_object"})
|
||||
assert model == mock_bound_instance
|
||||
|
||||
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
|
||||
async def test_build_model_no_api_key(self, mock_chat_openai, component_class, default_kwargs):
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_openai.return_value = mock_instance
|
||||
default_kwargs["api_key"] = None
|
||||
component = component_class(**default_kwargs)
|
||||
component.build_model()
|
||||
|
||||
# When api_key is None, it should be passed as None to ChatOpenAI
|
||||
args, kwargs = mock_chat_openai.call_args
|
||||
assert kwargs["api_key"] is None
|
||||
|
||||
@patch("langflow.components.languagemodels.openai_chat_model.ChatOpenAI")
|
||||
async def test_build_model_max_tokens_zero(self, mock_chat_openai, component_class, default_kwargs):
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_openai.return_value = mock_instance
|
||||
default_kwargs["max_tokens"] = 0
|
||||
component = component_class(**default_kwargs)
|
||||
component.build_model()
|
||||
|
||||
# When max_tokens is 0, it should be passed as None to ChatOpenAI
|
||||
args, kwargs = mock_chat_openai.call_args
|
||||
assert kwargs["max_tokens"] is None
|
||||
|
||||
async def test_get_exception_message_bad_request_error(self, component_class, default_kwargs):
|
||||
component_class(**default_kwargs)
|
||||
|
||||
# Create a mock BadRequestError with a body attribute
|
||||
mock_error = MagicMock()
|
||||
mock_error.body = {"message": "test error message"}
|
||||
|
||||
# Test the method directly by patching the import
|
||||
with patch("openai.BadRequestError", mock_error.__class__):
|
||||
# Manually call isinstance to avoid mocking it
|
||||
if hasattr(mock_error, "body"):
|
||||
message = mock_error.body.get("message")
|
||||
assert message == "test error message"
|
||||
|
||||
async def test_get_exception_message_no_openai_import(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Test when openai module is not available
|
||||
with patch.dict("sys.modules", {"openai": None}), patch("builtins.__import__", side_effect=ImportError):
|
||||
message = component._get_exception_message(Exception("test"))
|
||||
assert message is None
|
||||
|
||||
async def test_get_exception_message_other_exception(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Create a regular exception (not BadRequestError)
|
||||
regular_exception = ValueError("test error")
|
||||
|
||||
# Create a simple mock for BadRequestError that the exception won't match
|
||||
class MockBadRequestError:
|
||||
pass
|
||||
|
||||
with patch("openai.BadRequestError", MockBadRequestError):
|
||||
message = component._get_exception_message(regular_exception)
|
||||
assert message is None
|
||||
|
||||
async def test_update_build_config_reasoning_model(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {
|
||||
"temperature": {"show": True},
|
||||
"seed": {"show": True},
|
||||
}
|
||||
|
||||
# Test with reasoning model
|
||||
updated_config = component.update_build_config(build_config, "o1", "model_name")
|
||||
assert updated_config["temperature"]["show"] is False
|
||||
assert updated_config["seed"]["show"] is False
|
||||
|
||||
# Test with regular model
|
||||
updated_config = component.update_build_config(build_config, "gpt-4", "model_name")
|
||||
assert updated_config["temperature"]["show"] is True
|
||||
assert updated_config["seed"]["show"] is True
|
||||
|
||||
def test_build_model_integration(self):
|
||||
component = OpenAIModelComponent()
|
||||
component.api_key = os.getenv("OPENAI_API_KEY")
|
||||
component.model_name = "gpt-4.1-nano"
|
||||
component.temperature = 0.2
|
||||
component.max_tokens = 1000
|
||||
component.seed = 42
|
||||
component.max_retries = 3
|
||||
component.timeout = 600
|
||||
component.openai_api_base = "https://api.openai.com/v1"
|
||||
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "gpt-4.1-nano"
|
||||
assert model.openai_api_base == "https://api.openai.com/v1"
|
||||
|
||||
def test_build_model_integration_reasoning(self):
|
||||
component = OpenAIModelComponent()
|
||||
component.api_key = os.getenv("OPENAI_API_KEY")
|
||||
component.model_name = "o1"
|
||||
component.temperature = 0.2 # This should be ignored for reasoning models
|
||||
component.max_tokens = 1000
|
||||
component.seed = 42 # This should be ignored for reasoning models
|
||||
component.max_retries = 3
|
||||
component.timeout = 600
|
||||
component.openai_api_base = "https://api.openai.com/v1"
|
||||
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "o1"
|
||||
assert model.openai_api_base == "https://api.openai.com/v1"
|
||||
|
|
@ -1,15 +1,18 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langflow.base.models.anthropic_constants import ANTHROPIC_MODELS
|
||||
from langflow.base.models.google_generative_ai_constants import GOOGLE_GENERATIVE_AI_MODELS
|
||||
from langflow.base.models.openai_constants import OPENAI_MODEL_NAMES
|
||||
from langflow.components.models.language_model import LanguageModelComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithClient
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
class TestLanguageModelComponent(ComponentTestBaseWithClient):
|
||||
class TestLanguageModelComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return LanguageModelComponent
|
||||
|
|
@ -29,6 +32,32 @@ class TestLanguageModelComponent(ComponentTestBaseWithClient):
|
|||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
"""Return the file names mapping for version-specific files."""
|
||||
# No version-specific files for this component
|
||||
return []
|
||||
|
||||
@pytest.fixture
|
||||
def openai_api_key(self):
|
||||
"""Fixture to get OpenAI API key from environment variable."""
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY environment variable not set")
|
||||
return api_key
|
||||
|
||||
@pytest.fixture
|
||||
def anthropic_api_key(self):
|
||||
"""Fixture to get Anthropic API key from environment variable."""
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY environment variable not set")
|
||||
return api_key
|
||||
|
||||
@pytest.fixture
|
||||
def google_api_key(self):
|
||||
"""Fixture to get Google API key from environment variable."""
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("GOOGLE_API_KEY environment variable not set")
|
||||
return api_key
|
||||
|
||||
async def test_update_build_config_openai(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
|
|
@ -52,57 +81,69 @@ class TestLanguageModelComponent(ComponentTestBaseWithClient):
|
|||
assert updated_config["model_name"]["value"] == ANTHROPIC_MODELS[0]
|
||||
assert updated_config["api_key"]["display_name"] == "Anthropic API Key"
|
||||
|
||||
@patch("langflow.components.models.language_model.ChatOpenAI")
|
||||
async def test_build_model_openai(self, mock_chat_openai, component_class, default_kwargs):
|
||||
# Setup mock
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_openai.return_value = mock_instance
|
||||
async def test_update_build_config_google(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
build_config = {
|
||||
"model_name": {"options": [], "value": ""},
|
||||
"api_key": {"display_name": "API Key"},
|
||||
}
|
||||
updated_config = component.update_build_config(build_config, "Google", "provider")
|
||||
assert updated_config["model_name"]["options"] == GOOGLE_GENERATIVE_AI_MODELS
|
||||
assert updated_config["model_name"]["value"] == GOOGLE_GENERATIVE_AI_MODELS[0]
|
||||
assert updated_config["api_key"]["display_name"] == "Google API Key"
|
||||
|
||||
# Create and configure the component
|
||||
async def test_openai_model_creation(self, component_class, default_kwargs):
|
||||
"""Test that the component returns an instance of ChatOpenAI for OpenAI provider."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "OpenAI"
|
||||
component.model_name = "gpt-3.5-turbo"
|
||||
component.api_key = "test-key"
|
||||
component.api_key = "sk-test-key" # Use a fake but correctly formatted key
|
||||
component.temperature = 0.5
|
||||
component.stream = False
|
||||
|
||||
# Build the model
|
||||
# The API key will be invalid, but we should still get a ChatOpenAI instance
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "gpt-3.5-turbo"
|
||||
assert model.temperature == 0.5
|
||||
assert model.streaming is False
|
||||
# API key is stored as a SecretStr object, so we can't directly compare values
|
||||
|
||||
# Verify the ChatOpenAI was called with the correct parameters
|
||||
mock_chat_openai.assert_called_once_with(
|
||||
model_name="gpt-3.5-turbo",
|
||||
temperature=0.5,
|
||||
streaming=False,
|
||||
openai_api_key="test-key",
|
||||
)
|
||||
assert model == mock_instance
|
||||
|
||||
@patch("langflow.components.models.language_model.ChatAnthropic")
|
||||
async def test_build_model_anthropic(self, mock_chat_anthropic, component_class, default_kwargs):
|
||||
# Setup mock
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_anthropic.return_value = mock_instance
|
||||
|
||||
# Create and configure the component
|
||||
async def test_anthropic_model_creation(self, component_class, default_kwargs):
|
||||
"""Test that the component returns an instance of ChatAnthropic for Anthropic provider."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Anthropic"
|
||||
component.model_name = ANTHROPIC_MODELS[0] # Use the first model from the constants
|
||||
component.api_key = "test-key"
|
||||
component.model_name = ANTHROPIC_MODELS[0]
|
||||
component.api_key = "sk-ant-test-key" # Use a fake but plausible key
|
||||
component.temperature = 0.7
|
||||
component.stream = False
|
||||
|
||||
# Build the model
|
||||
# The API key will be invalid, but we should still get a ChatAnthropic instance
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatAnthropic)
|
||||
assert model.model == ANTHROPIC_MODELS[0]
|
||||
assert model.temperature == 0.7
|
||||
assert model.streaming is False
|
||||
# API key is stored as a SecretStr object, so we can't directly compare values
|
||||
|
||||
# Verify the ChatAnthropic was called with the correct parameters
|
||||
mock_chat_anthropic.assert_called_once_with(
|
||||
model=ANTHROPIC_MODELS[0],
|
||||
temperature=0.7,
|
||||
streaming=False,
|
||||
anthropic_api_key="test-key",
|
||||
)
|
||||
assert model == mock_instance
|
||||
async def test_google_model_creation(self, component_class, default_kwargs):
|
||||
"""Test that the component returns an instance of ChatGoogleGenerativeAI for Google provider."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Google"
|
||||
component.model_name = GOOGLE_GENERATIVE_AI_MODELS[0]
|
||||
component.api_key = "google-test-key" # Use a fake but plausible key
|
||||
component.temperature = 0.7
|
||||
component.stream = False
|
||||
|
||||
# The API key will be invalid, but we should still get a ChatGoogleGenerativeAI instance
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatGoogleGenerativeAI)
|
||||
# Google model automatically prepends "models/" to the model name
|
||||
assert model.model == f"models/{GOOGLE_GENERATIVE_AI_MODELS[0]}"
|
||||
assert model.temperature == 0.7
|
||||
# Google model uses 'stream' instead of 'streaming'
|
||||
# Skip this check for Google model since it has a different interface
|
||||
# API key is stored as a SecretStr object, so we can't directly compare values
|
||||
|
||||
async def test_build_model_openai_missing_api_key(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
|
|
@ -120,9 +161,59 @@ class TestLanguageModelComponent(ComponentTestBaseWithClient):
|
|||
with pytest.raises(ValueError, match="Anthropic API key is required when using Anthropic provider"):
|
||||
component.build_model()
|
||||
|
||||
async def test_build_model_google_missing_api_key(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Google"
|
||||
component.api_key = None
|
||||
|
||||
with pytest.raises(ValueError, match="Google API key is required when using Google provider"):
|
||||
component.build_model()
|
||||
|
||||
async def test_build_model_unknown_provider(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Unknown"
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider: Unknown"):
|
||||
component.build_model()
|
||||
|
||||
async def test_openai_live_api(self, component_class, default_kwargs, openai_api_key):
|
||||
"""Test that the component can create a model with a real API key."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "OpenAI"
|
||||
component.model_name = "gpt-3.5-turbo"
|
||||
component.api_key = openai_api_key
|
||||
component.temperature = 0.1
|
||||
component.stream = False
|
||||
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
# We could attempt a simple call here, but that would increase test time
|
||||
# and might fail due to network issues, so we'll just verify the instance
|
||||
|
||||
async def test_anthropic_live_api(self, component_class, default_kwargs, anthropic_api_key):
|
||||
"""Test that the component can create a model with a real API key."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Anthropic"
|
||||
component.model_name = ANTHROPIC_MODELS[0]
|
||||
component.api_key = anthropic_api_key
|
||||
component.temperature = 0.1
|
||||
component.stream = False
|
||||
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatAnthropic)
|
||||
# We could attempt a simple call here, but that would increase test time
|
||||
# and might fail due to network issues, so we'll just verify the instance
|
||||
|
||||
async def test_google_live_api(self, component_class, default_kwargs, google_api_key):
|
||||
"""Test that the component can create a model with a real API key."""
|
||||
component = component_class(**default_kwargs)
|
||||
component.provider = "Google"
|
||||
component.model_name = GOOGLE_GENERATIVE_AI_MODELS[0]
|
||||
component.api_key = google_api_key
|
||||
component.temperature = 0.1
|
||||
component.stream = False
|
||||
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatGoogleGenerativeAI)
|
||||
# We could attempt a simple call here, but that would increase test time
|
||||
# and might fail due to network issues, so we'll just verify the instance
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue