refactor: make StructuredOutput tests to use MockLanguageModel (#5563)
* feat: add MockLanguageModel for testing language model interactions - Introduced a new mock implementation of BaseLanguageModel to facilitate unit testing. - Implemented methods to simulate responses and handle message processing. - Added functionality to generate mock responses based on input messages. - Ensured compatibility with existing language model interfaces for seamless integration in tests. * refactor: simplify test_structured_output_component by replacing MagicMock with MockLanguageModel - Removed extensive mock implementations of BaseLanguageModel in tests. - Replaced instances of MagicMock with a new MockLanguageModel for better clarity and maintainability. - Streamlined test cases for structured output generation and error handling. - Ensured compatibility with existing test structure while enhancing readability. * fix: rename utils.py to useful.py to avoid namespace conflict --------- Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
parent
7084433adc
commit
4ea144eba9
2 changed files with 71 additions and 64 deletions
|
|
@ -1,61 +1,21 @@
|
|||
import re
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langflow.components.helpers.structured_output import StructuredOutputComponent
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from langflow.inputs.inputs import TableInput
|
||||
from langflow.schema.data import Data
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from tests.unit.useful import MockLanguageModel
|
||||
|
||||
|
||||
class TestStructuredOutputComponent:
|
||||
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by
|
||||
# the mock function
|
||||
def test_successful_structured_output_generation_with_patch_with_config(self):
|
||||
class MockLanguageModel(BaseLanguageModel):
|
||||
@override
|
||||
def with_structured_output(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
@override
|
||||
def with_config(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
@override
|
||||
def invoke(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
@override
|
||||
def generate_prompt(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def agenerate_prompt(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def predict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def predict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def apredict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def apredict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def mock_get_chat_result(runnable, input_value, config): # noqa: ARG001
|
||||
class MockBaseModel(BaseModel):
|
||||
@override
|
||||
def model_dump(self, **kwargs):
|
||||
def model_dump(self, **kwargs): # noqa: ARG002
|
||||
return {"field": "value"}
|
||||
|
||||
return MockBaseModel()
|
||||
|
|
@ -73,7 +33,6 @@ class TestStructuredOutputComponent:
|
|||
assert isinstance(result, Data)
|
||||
assert result.data == {"field": "value"}
|
||||
|
||||
# Raises ValueError when the language model does not support structured output
|
||||
def test_raises_value_error_for_unsupported_language_model(self):
|
||||
# Mocking an incompatible language model
|
||||
class MockLanguageModel:
|
||||
|
|
@ -91,7 +50,6 @@ class TestStructuredOutputComponent:
|
|||
with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")):
|
||||
component.build_structured_output()
|
||||
|
||||
# Correctly builds the output model from the provided schema
|
||||
def test_correctly_builds_output_model(self):
|
||||
# Setup
|
||||
component = StructuredOutputComponent()
|
||||
|
|
@ -129,7 +87,6 @@ class TestStructuredOutputComponent:
|
|||
output_model = build_model_from_schema(schema)
|
||||
assert isinstance(output_model, type)
|
||||
|
||||
# Properly handles multiple outputs when 'multiple' is set to True
|
||||
def test_handles_multiple_outputs(self):
|
||||
# Setup
|
||||
component = StructuredOutputComponent()
|
||||
|
|
@ -170,7 +127,7 @@ class TestStructuredOutputComponent:
|
|||
|
||||
def test_empty_output_schema(self):
|
||||
component = StructuredOutputComponent(
|
||||
llm=MagicMock(),
|
||||
llm=MockLanguageModel(),
|
||||
input_value="Test input",
|
||||
schema_name="EmptySchema",
|
||||
output_schema=[],
|
||||
|
|
@ -182,7 +139,7 @@ class TestStructuredOutputComponent:
|
|||
|
||||
def test_invalid_output_schema_type(self):
|
||||
component = StructuredOutputComponent(
|
||||
llm=MagicMock(),
|
||||
llm=MockLanguageModel(),
|
||||
input_value="Test input",
|
||||
schema_name="InvalidSchema",
|
||||
output_schema=[{"name": "field", "type": "invalid_type", "description": "Invalid field"}],
|
||||
|
|
@ -200,8 +157,7 @@ class TestStructuredOutputComponent:
|
|||
class ParentModel(BaseModel):
|
||||
parent: ChildModel = ChildModel()
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_llm
|
||||
mock_llm = MockLanguageModel()
|
||||
mock_get_chat_result.return_value = ParentModel(parent=ChildModel(child="value"))
|
||||
|
||||
component = StructuredOutputComponent(
|
||||
|
|
@ -233,7 +189,7 @@ class TestStructuredOutputComponent:
|
|||
mock_get_chat_result.return_value = MockBaseModel(field="value")
|
||||
|
||||
component = StructuredOutputComponent(
|
||||
llm=MagicMock(),
|
||||
llm=MockLanguageModel(),
|
||||
input_value=large_input,
|
||||
schema_name="LargeInputSchema",
|
||||
output_schema=[{"name": "field", "type": "str", "description": "A test field"}],
|
||||
|
|
@ -244,15 +200,3 @@ class TestStructuredOutputComponent:
|
|||
assert isinstance(result, Data)
|
||||
assert result.data == {"field": "value"}
|
||||
mock_get_chat_result.assert_called_once()
|
||||
|
||||
def test_invalid_llm_config(self):
|
||||
component = StructuredOutputComponent(
|
||||
llm="invalid_llm", # Not a proper LLM instance
|
||||
input_value="Test input",
|
||||
schema_name="InvalidLLMSchema",
|
||||
output_schema=[{"name": "field", "type": "str", "description": "A test field"}],
|
||||
multiple=False,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")):
|
||||
component.build_structured_output()
|
||||
|
|
|
|||
63
src/backend/tests/unit/useful.py
Normal file
63
src/backend/tests/unit/useful.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class MockLanguageModel(BaseLanguageModel):
|
||||
"""A mock language model for testing purposes."""
|
||||
|
||||
def __init__(self, response_generator=None):
|
||||
"""Initialize the mock model with an optional response generator function."""
|
||||
super().__init__()
|
||||
# Use object's __dict__ to bypass pydantic validation
|
||||
object.__setattr__(self, "_response_generator", response_generator or (lambda msg: f"Response for {msg}"))
|
||||
|
||||
@override
|
||||
def with_config(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
@override
|
||||
def with_structured_output(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
@override
|
||||
async def abatch(self, messages, *args, **kwargs):
|
||||
if not messages:
|
||||
return []
|
||||
# If message is a list of dicts (chat format), get the last user message
|
||||
responses = []
|
||||
for msg_list in messages:
|
||||
content = msg_list[-1]["content"] if isinstance(msg_list, list) else msg_list
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = self._response_generator(content)
|
||||
responses.append(mock_response)
|
||||
return responses
|
||||
|
||||
@override
|
||||
def invoke(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
@override
|
||||
def generate_prompt(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def agenerate_prompt(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def predict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def predict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def apredict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def apredict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
Loading…
Add table
Add a link
Reference in a new issue