From 4ea144eba997514554bf50a2cb17db1a5bc6a15e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 7 Jan 2025 12:48:16 -0300 Subject: [PATCH] refactor: make StructuredOutput tests to use MockLanguageModel (#5563) * feat: add MockLanguageModel for testing language model interactions - Introduced a new mock implementation of BaseLanguageModel to facilitate unit testing. - Implemented methods to simulate responses and handle message processing. - Added functionality to generate mock responses based on input messages. - Ensured compatibility with existing language model interfaces for seamless integration in tests. * refactor: simplify test_structured_output_component by replacing MagicMock with MockLanguageModel - Removed extensive mock implementations of BaseLanguageModel in tests. - Replaced instances of MagicMock with a new MockLanguageModel for better clarity and maintainability. - Streamlined test cases for structured output generation and error handling. - Ensured compatibility with existing test structure while enhancing readability. * fix: rename utils.py to useful.py to avoid namespace conflict --------- Co-authored-by: italojohnny --- .../test_structured_output_component.py | 72 +++---------------- src/backend/tests/unit/useful.py | 63 ++++++++++++++++ 2 files changed, 71 insertions(+), 64 deletions(-) create mode 100644 src/backend/tests/unit/useful.py diff --git a/src/backend/tests/unit/components/helpers/test_structured_output_component.py b/src/backend/tests/unit/components/helpers/test_structured_output_component.py index 30af49ea7..14d028235 100644 --- a/src/backend/tests/unit/components/helpers/test_structured_output_component.py +++ b/src/backend/tests/unit/components/helpers/test_structured_output_component.py @@ -1,61 +1,21 @@ import re -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -from langchain_core.language_models import BaseLanguageModel from langflow.components.helpers.structured_output import StructuredOutputComponent from langflow.helpers.base_model import build_model_from_schema from langflow.inputs.inputs import TableInput from langflow.schema.data import Data from pydantic import BaseModel -from typing_extensions import override + +from tests.unit.useful import MockLanguageModel class TestStructuredOutputComponent: - # Ensure that the structured output is successfully generated with the correct BaseModel instance returned by - # the mock function def test_successful_structured_output_generation_with_patch_with_config(self): - class MockLanguageModel(BaseLanguageModel): - @override - def with_structured_output(self, *args, **kwargs): - return self - - @override - def with_config(self, *args, **kwargs): - return self - - @override - def invoke(self, *args, **kwargs): - return self - - @override - def generate_prompt(self, *args, **kwargs): - raise NotImplementedError - - @override - async def agenerate_prompt(self, *args, **kwargs): - raise NotImplementedError - - @override - def predict(self, *args, **kwargs): - raise NotImplementedError - - @override - def predict_messages(self, *args, **kwargs): - raise NotImplementedError - - @override - async def apredict(self, *args, **kwargs): - raise NotImplementedError - - @override - async def apredict_messages(self, *args, **kwargs): - raise NotImplementedError - def mock_get_chat_result(runnable, input_value, config): # noqa: ARG001 class MockBaseModel(BaseModel): - @override - def model_dump(self, **kwargs): + def model_dump(self, **kwargs): # noqa: ARG002 return {"field": "value"} return MockBaseModel() @@ -73,7 +33,6 @@ class TestStructuredOutputComponent: assert isinstance(result, Data) assert result.data == {"field": "value"} - # Raises ValueError when the language model does not support structured output def test_raises_value_error_for_unsupported_language_model(self): # Mocking an incompatible language model class MockLanguageModel: @@ -91,7 +50,6 @@ class TestStructuredOutputComponent: with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")): component.build_structured_output() - # Correctly builds the output model from the provided schema def test_correctly_builds_output_model(self): # Setup component = StructuredOutputComponent() @@ -129,7 +87,6 @@ class TestStructuredOutputComponent: output_model = build_model_from_schema(schema) assert isinstance(output_model, type) - # Properly handles multiple outputs when 'multiple' is set to True def test_handles_multiple_outputs(self): # Setup component = StructuredOutputComponent() @@ -170,7 +127,7 @@ class TestStructuredOutputComponent: def test_empty_output_schema(self): component = StructuredOutputComponent( - llm=MagicMock(), + llm=MockLanguageModel(), input_value="Test input", schema_name="EmptySchema", output_schema=[], @@ -182,7 +139,7 @@ class TestStructuredOutputComponent: def test_invalid_output_schema_type(self): component = StructuredOutputComponent( - llm=MagicMock(), + llm=MockLanguageModel(), input_value="Test input", schema_name="InvalidSchema", output_schema=[{"name": "field", "type": "invalid_type", "description": "Invalid field"}], @@ -200,8 +157,7 @@ class TestStructuredOutputComponent: class ParentModel(BaseModel): parent: ChildModel = ChildModel() - mock_llm = MagicMock() - mock_llm.with_structured_output.return_value = mock_llm + mock_llm = MockLanguageModel() mock_get_chat_result.return_value = ParentModel(parent=ChildModel(child="value")) component = StructuredOutputComponent( @@ -233,7 +189,7 @@ class TestStructuredOutputComponent: mock_get_chat_result.return_value = MockBaseModel(field="value") component = StructuredOutputComponent( - llm=MagicMock(), + llm=MockLanguageModel(), input_value=large_input, schema_name="LargeInputSchema", output_schema=[{"name": "field", "type": "str", "description": "A test field"}], @@ -244,15 +200,3 @@ class TestStructuredOutputComponent: assert isinstance(result, Data) assert result.data == {"field": "value"} mock_get_chat_result.assert_called_once() - - def test_invalid_llm_config(self): - component = StructuredOutputComponent( - llm="invalid_llm", # Not a proper LLM instance - input_value="Test input", - schema_name="InvalidLLMSchema", - output_schema=[{"name": "field", "type": "str", "description": "A test field"}], - multiple=False, - ) - - with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")): - component.build_structured_output() diff --git a/src/backend/tests/unit/useful.py b/src/backend/tests/unit/useful.py new file mode 100644 index 000000000..b2a7fa29c --- /dev/null +++ b/src/backend/tests/unit/useful.py @@ -0,0 +1,63 @@ +from unittest.mock import MagicMock + +from langchain_core.language_models import BaseLanguageModel +from typing_extensions import override + + +class MockLanguageModel(BaseLanguageModel): + """A mock language model for testing purposes.""" + + def __init__(self, response_generator=None): + """Initialize the mock model with an optional response generator function.""" + super().__init__() + # Use object's __dict__ to bypass pydantic validation + object.__setattr__(self, "_response_generator", response_generator or (lambda msg: f"Response for {msg}")) + + @override + def with_config(self, *args, **kwargs): + return self + + @override + def with_structured_output(self, *args, **kwargs): + return self + + @override + async def abatch(self, messages, *args, **kwargs): + if not messages: + return [] + # If message is a list of dicts (chat format), get the last user message + responses = [] + for msg_list in messages: + content = msg_list[-1]["content"] if isinstance(msg_list, list) else msg_list + mock_response = MagicMock() + mock_response.content = self._response_generator(content) + responses.append(mock_response) + return responses + + @override + def invoke(self, *args, **kwargs): + return self + + @override + def generate_prompt(self, *args, **kwargs): + raise NotImplementedError + + @override + async def agenerate_prompt(self, *args, **kwargs): + raise NotImplementedError + + @override + def predict(self, *args, **kwargs): + raise NotImplementedError + + @override + def predict_messages(self, *args, **kwargs): + raise NotImplementedError + + @override + async def apredict(self, *args, **kwargs): + raise NotImplementedError + + @override + async def apredict_messages(self, *args, **kwargs): + raise NotImplementedError