"""Tests for the ContextAssembler micro-fault and context assembly module.""" from __future__ import annotations import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from mnemosyne.context_assembler import ( ContextAssembler, ContextWindow, MicroFaultResult, _estimate_tokens, ) from mnemosyne.fidelity import FidelityLevel, FidelityManager, make_object from mnemosyne.object_store import ( DummyEmbedder, InMemoryBackend, ObjectStore, StoredObject, ) # ── Fixtures ───────────────────────────────────────────────── @pytest.fixture def object_store(): """Create an ObjectStore with InMemoryBackend and DummyEmbedder.""" backend = InMemoryBackend() embedder = DummyEmbedder() return ObjectStore(backend=backend, embedder=embedder) @pytest.fixture def mock_helper_llm(): """Create a mock HelperLLM with answer_micro_fault returning a canned answer.""" helper = AsyncMock() helper.answer_micro_fault = AsyncMock( return_value="The auth library used is passport.js with JWT strategy." ) return helper @pytest.fixture def assembler(object_store, mock_helper_llm): """Create a ContextAssembler with real ObjectStore and mock HelperLLM.""" return ContextAssembler(object_store=object_store, helper_llm=mock_helper_llm) @pytest.fixture def assembler_no_helper(object_store): """Create a ContextAssembler without HelperLLM (fallback mode).""" return ContextAssembler(object_store=object_store, helper_llm=None) async def _seed_objects(store: ObjectStore, session_id: str = "sess_test") -> list[StoredObject]: """Seed the store with a few test objects and return them.""" obj1 = await store.store_object( session_id=session_id, content="Authentication uses passport.js with JWT strategy. " "The secret key is loaded from AUTH_SECRET env var. " "Token expiry is set to 24 hours.", object_type="file_context", source_tool="Read", source_key="src/auth/middleware.ts", stub="[file_context | auth middleware | passport.js JWT]", key_entities=["passport.js", "JWT", "AUTH_SECRET"], ) obj2 = await store.store_object( session_id=session_id, content="Database schema has users, posts, and comments tables. " "Users table has id, email, password_hash, created_at columns. " "Posts reference users via author_id foreign key.", object_type="file_context", source_tool="Read", source_key="schema.sql", stub="[file_context | DB schema | users, posts, comments]", key_entities=["users", "posts", "comments", "schema.sql"], ) obj3 = await store.store_object( session_id=session_id, content="Server runs on port 3000 by default. " "Configuration loaded from .env file. " "CORS enabled for localhost:5173 in development.", object_type="file_context", source_tool="Read", source_key="src/config.ts", stub="[file_context | server config | port 3000, CORS]", key_entities=["port 3000", "CORS", ".env"], ) return [obj1, obj2, obj3] # ── MicroFaultResult dataclass ─────────────────────────────── def test_micro_fault_result_fields(): """MicroFaultResult should store all expected fields.""" result = MicroFaultResult( answer="The answer is 42.", sources=["obj_abc", "obj_def"], answer_tokens=5, avoided_tokens=500, latency_ms=12.5, ) assert result.answer == "The answer is 42." assert result.sources == ["obj_abc", "obj_def"] assert result.answer_tokens == 5 assert result.avoided_tokens == 500 assert result.latency_ms == 12.5 def test_micro_fault_result_token_savings(): """avoided_tokens should represent tokens saved vs full restore.""" result = MicroFaultResult( answer="short answer", sources=["obj1"], answer_tokens=3, avoided_tokens=997, latency_ms=10.0, ) # The savings ratio: avoided / (avoided + answer_tokens) savings_ratio = result.avoided_tokens / (result.avoided_tokens + result.answer_tokens) assert savings_ratio > 0.99 # 99%+ savings # ── ContextWindow dataclass ────────────────────────────────── def test_context_window_fields(): """ContextWindow should store objects, total_tokens, and pressure_zone.""" window = ContextWindow( objects=[("obj1", 0, "full content"), ("obj2", 2, "compact summary")], total_tokens=150, pressure_zone="NORMAL", ) assert len(window.objects) == 2 assert window.total_tokens == 150 assert window.pressure_zone == "NORMAL" # ── handle_micro_fault with HelperLLM ──────────────────────── @pytest.mark.asyncio async def test_micro_fault_with_helper(assembler, object_store, mock_helper_llm): """Micro-fault with HelperLLM should search, call helper, return answer.""" await _seed_objects(object_store, "sess_test") result = await assembler.handle_micro_fault( session_id="sess_test", question="What auth library is used?", ) assert isinstance(result, MicroFaultResult) assert result.answer == "The auth library used is passport.js with JWT strategy." assert len(result.sources) > 0 assert result.answer_tokens > 0 assert result.avoided_tokens > 0 assert result.latency_ms >= 0 # Verify HelperLLM was called mock_helper_llm.answer_micro_fault.assert_called_once() call_kwargs = mock_helper_llm.answer_micro_fault.call_args assert call_kwargs.kwargs["question"] == "What auth library is used?" assert len(call_kwargs.kwargs["relevant_contents"]) > 0 @pytest.mark.asyncio async def test_micro_fault_with_scope(assembler, object_store, mock_helper_llm): """Micro-fault with scope hint should incorporate it into the search query.""" await _seed_objects(object_store, "sess_test") result = await assembler.handle_micro_fault( session_id="sess_test", question="What port?", scope="config files", ) assert isinstance(result, MicroFaultResult) assert len(result.sources) > 0 mock_helper_llm.answer_micro_fault.assert_called_once() @pytest.mark.asyncio async def test_micro_fault_with_max_tokens(assembler, object_store, mock_helper_llm): """Micro-fault should pass max_tokens to the HelperLLM.""" await _seed_objects(object_store, "sess_test") await assembler.handle_micro_fault( session_id="sess_test", question="What is the DB schema?", max_tokens=100, ) call_kwargs = mock_helper_llm.answer_micro_fault.call_args assert call_kwargs.kwargs["max_tokens"] == 100 # ── handle_micro_fault fallback (no HelperLLM) ────────────── @pytest.mark.asyncio async def test_micro_fault_fallback_no_helper(assembler_no_helper, object_store): """Without HelperLLM, micro-fault should return summaries as fallback.""" objects = await _seed_objects(object_store, "sess_test") # Set a summary on one object so fallback has something to show await object_store.update_fidelity(objects[0].id, 2, summary="Auth uses passport.js JWT") result = await assembler_no_helper.handle_micro_fault( session_id="sess_test", question="What auth library is used?", ) assert isinstance(result, MicroFaultResult) assert "HelperLLM unavailable" in result.answer assert len(result.sources) > 0 assert result.answer_tokens > 0 @pytest.mark.asyncio async def test_micro_fault_fallback_uses_stub_when_no_summary(assembler_no_helper, object_store): """Fallback should use stub when no summaries are available.""" await _seed_objects(object_store, "sess_test") result = await assembler_no_helper.handle_micro_fault( session_id="sess_test", question="What auth library is used?", ) assert isinstance(result, MicroFaultResult) assert "HelperLLM unavailable" in result.answer # Should contain stub content since no summaries exist assert len(result.sources) > 0 # ── handle_micro_fault with no search results ──────────────── @pytest.mark.asyncio async def test_micro_fault_no_results(assembler): """Micro-fault with no matching objects should return a 'not found' message.""" # Don't seed any objects — empty store result = await assembler.handle_micro_fault( session_id="sess_empty", question="What is the meaning of life?", ) assert isinstance(result, MicroFaultResult) assert "No relevant content found" in result.answer assert result.sources == [] assert result.avoided_tokens == 0 # ── Micro-fault records access on consulted objects ────────── @pytest.mark.asyncio async def test_micro_fault_records_access(assembler, object_store): """Micro-fault should call record_fault(is_micro=True) on consulted objects.""" objects = await _seed_objects(object_store, "sess_test") result = await assembler.handle_micro_fault( session_id="sess_test", question="What auth library is used?", ) # At least one source should have been consulted assert len(result.sources) > 0 # Check that micro_fault_count was incremented on consulted objects for source_id in result.sources: obj = await object_store.get(source_id) assert obj is not None assert obj.micro_fault_count > 0 # ── assemble_context ───────────────────────────────────────── @pytest.mark.asyncio async def test_assemble_context_all_l0(assembler, object_store): """assemble_context with all L0 objects should include full content.""" objects = await _seed_objects(object_store, "sess_test") fm = FidelityManager(window_size=200_000) for obj in objects: fm_obj = make_object( object_type=obj.object_type, content_full=obj.content_full, stub=obj.stub, ) fm_obj.id = obj.id # match IDs fm.register_object(fm_obj) blocks = await assembler.assemble_context( session_id="sess_test", fidelity_manager=fm, current_turn=1, ) assert len(blocks) == 3 for block in blocks: assert block["fidelity"] == 0 # L0 assert len(block["content"]) > 50 # full content assert block["tokens"] > 0 @pytest.mark.asyncio async def test_assemble_context_mixed_fidelity(assembler, object_store): """assemble_context should respect per-object fidelity levels.""" objects = await _seed_objects(object_store, "sess_test") # Set different fidelity levels in the store await object_store.update_fidelity( objects[0].id, 0, # L0: full ) await object_store.update_fidelity( objects[1].id, 2, summary="DB has users, posts, comments tables", # L2: compact ) await object_store.update_fidelity( objects[2].id, 3, # L3: stub ) fm = FidelityManager(window_size=200_000) for obj in objects: fm_obj = make_object( object_type=obj.object_type, content_full=obj.content_full, stub=obj.stub, ) fm_obj.id = obj.id # Set fidelity to match what we set in the store stored = await object_store.get(obj.id) fm_obj.current_fidelity = FidelityLevel(stored.current_fidelity) if stored.summary_compact: fm_obj.summary_compact = stored.summary_compact fm.register_object(fm_obj) blocks = await assembler.assemble_context( session_id="sess_test", fidelity_manager=fm, current_turn=5, ) assert len(blocks) == 3 fidelities = {b["object_id"]: b["fidelity"] for b in blocks} assert fidelities[objects[0].id] == 0 # L0 assert fidelities[objects[1].id] == 2 # L2 assert fidelities[objects[2].id] == 3 # L3 @pytest.mark.asyncio async def test_assemble_context_excludes_evicted(assembler, object_store): """assemble_context should exclude L4 (evicted) objects.""" objects = await _seed_objects(object_store, "sess_test") # Evict one object await object_store.update_fidelity(objects[1].id, 4) fm = FidelityManager(window_size=200_000) for obj in objects: fm_obj = make_object( object_type=obj.object_type, content_full=obj.content_full, stub=obj.stub, ) fm_obj.id = obj.id stored = await object_store.get(obj.id) fm_obj.current_fidelity = FidelityLevel(min(stored.current_fidelity, 4)) fm.register_object(fm_obj) blocks = await assembler.assemble_context( session_id="sess_test", fidelity_manager=fm, current_turn=5, ) # Only 2 objects should be in context (one was evicted) block_ids = {b["object_id"] for b in blocks} assert objects[1].id not in block_ids assert len(blocks) == 2 # ── _estimate_tokens helper ────────────────────────────────── def test_estimate_tokens_none(): """_estimate_tokens(None) should return 0.""" assert _estimate_tokens(None) == 0 def test_estimate_tokens_empty(): """_estimate_tokens('') should return 1 (minimum).""" assert _estimate_tokens("") == 1 def test_estimate_tokens_normal(): """_estimate_tokens should use ~4 chars per token heuristic.""" text = "a" * 400 assert _estimate_tokens(text) == 100