mnemosyne/tests/test_context_assembler.py
Joey Yakimowich-Payne a13719f754 feat: add object store with semantic segmentation
Object-addressed memory: segment messages into semantic objects,
embed with sentence-transformers, store in pgvector-backed store,
and reassemble context via goal-aware retrieval.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-13 11:41:04 -06:00

408 lines
14 KiB
Python

"""Tests for the ContextAssembler micro-fault and context assembly module."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mnemosyne.context_assembler import (
ContextAssembler,
ContextWindow,
MicroFaultResult,
_estimate_tokens,
)
from mnemosyne.fidelity import FidelityLevel, FidelityManager, make_object
from mnemosyne.object_store import (
DummyEmbedder,
InMemoryBackend,
ObjectStore,
StoredObject,
)
# ── Fixtures ─────────────────────────────────────────────────
@pytest.fixture
def object_store():
"""Create an ObjectStore with InMemoryBackend and DummyEmbedder."""
backend = InMemoryBackend()
embedder = DummyEmbedder()
return ObjectStore(backend=backend, embedder=embedder)
@pytest.fixture
def mock_helper_llm():
"""Create a mock HelperLLM with answer_micro_fault returning a canned answer."""
helper = AsyncMock()
helper.answer_micro_fault = AsyncMock(
return_value="The auth library used is passport.js with JWT strategy."
)
return helper
@pytest.fixture
def assembler(object_store, mock_helper_llm):
"""Create a ContextAssembler with real ObjectStore and mock HelperLLM."""
return ContextAssembler(object_store=object_store, helper_llm=mock_helper_llm)
@pytest.fixture
def assembler_no_helper(object_store):
"""Create a ContextAssembler without HelperLLM (fallback mode)."""
return ContextAssembler(object_store=object_store, helper_llm=None)
async def _seed_objects(store: ObjectStore, session_id: str = "sess_test") -> list[StoredObject]:
"""Seed the store with a few test objects and return them."""
obj1 = await store.store_object(
session_id=session_id,
content="Authentication uses passport.js with JWT strategy. "
"The secret key is loaded from AUTH_SECRET env var. "
"Token expiry is set to 24 hours.",
object_type="file_context",
source_tool="Read",
source_key="src/auth/middleware.ts",
stub="[file_context | auth middleware | passport.js JWT]",
key_entities=["passport.js", "JWT", "AUTH_SECRET"],
)
obj2 = await store.store_object(
session_id=session_id,
content="Database schema has users, posts, and comments tables. "
"Users table has id, email, password_hash, created_at columns. "
"Posts reference users via author_id foreign key.",
object_type="file_context",
source_tool="Read",
source_key="schema.sql",
stub="[file_context | DB schema | users, posts, comments]",
key_entities=["users", "posts", "comments", "schema.sql"],
)
obj3 = await store.store_object(
session_id=session_id,
content="Server runs on port 3000 by default. "
"Configuration loaded from .env file. "
"CORS enabled for localhost:5173 in development.",
object_type="file_context",
source_tool="Read",
source_key="src/config.ts",
stub="[file_context | server config | port 3000, CORS]",
key_entities=["port 3000", "CORS", ".env"],
)
return [obj1, obj2, obj3]
# ── MicroFaultResult dataclass ───────────────────────────────
def test_micro_fault_result_fields():
"""MicroFaultResult should store all expected fields."""
result = MicroFaultResult(
answer="The answer is 42.",
sources=["obj_abc", "obj_def"],
answer_tokens=5,
avoided_tokens=500,
latency_ms=12.5,
)
assert result.answer == "The answer is 42."
assert result.sources == ["obj_abc", "obj_def"]
assert result.answer_tokens == 5
assert result.avoided_tokens == 500
assert result.latency_ms == 12.5
def test_micro_fault_result_token_savings():
"""avoided_tokens should represent tokens saved vs full restore."""
result = MicroFaultResult(
answer="short answer",
sources=["obj1"],
answer_tokens=3,
avoided_tokens=997,
latency_ms=10.0,
)
# The savings ratio: avoided / (avoided + answer_tokens)
savings_ratio = result.avoided_tokens / (result.avoided_tokens + result.answer_tokens)
assert savings_ratio > 0.99 # 99%+ savings
# ── ContextWindow dataclass ──────────────────────────────────
def test_context_window_fields():
"""ContextWindow should store objects, total_tokens, and pressure_zone."""
window = ContextWindow(
objects=[("obj1", 0, "full content"), ("obj2", 2, "compact summary")],
total_tokens=150,
pressure_zone="NORMAL",
)
assert len(window.objects) == 2
assert window.total_tokens == 150
assert window.pressure_zone == "NORMAL"
# ── handle_micro_fault with HelperLLM ────────────────────────
@pytest.mark.asyncio
async def test_micro_fault_with_helper(assembler, object_store, mock_helper_llm):
"""Micro-fault with HelperLLM should search, call helper, return answer."""
await _seed_objects(object_store, "sess_test")
result = await assembler.handle_micro_fault(
session_id="sess_test",
question="What auth library is used?",
)
assert isinstance(result, MicroFaultResult)
assert result.answer == "The auth library used is passport.js with JWT strategy."
assert len(result.sources) > 0
assert result.answer_tokens > 0
assert result.avoided_tokens > 0
assert result.latency_ms >= 0
# Verify HelperLLM was called
mock_helper_llm.answer_micro_fault.assert_called_once()
call_kwargs = mock_helper_llm.answer_micro_fault.call_args
assert call_kwargs.kwargs["question"] == "What auth library is used?"
assert len(call_kwargs.kwargs["relevant_contents"]) > 0
@pytest.mark.asyncio
async def test_micro_fault_with_scope(assembler, object_store, mock_helper_llm):
"""Micro-fault with scope hint should incorporate it into the search query."""
await _seed_objects(object_store, "sess_test")
result = await assembler.handle_micro_fault(
session_id="sess_test",
question="What port?",
scope="config files",
)
assert isinstance(result, MicroFaultResult)
assert len(result.sources) > 0
mock_helper_llm.answer_micro_fault.assert_called_once()
@pytest.mark.asyncio
async def test_micro_fault_with_max_tokens(assembler, object_store, mock_helper_llm):
"""Micro-fault should pass max_tokens to the HelperLLM."""
await _seed_objects(object_store, "sess_test")
await assembler.handle_micro_fault(
session_id="sess_test",
question="What is the DB schema?",
max_tokens=100,
)
call_kwargs = mock_helper_llm.answer_micro_fault.call_args
assert call_kwargs.kwargs["max_tokens"] == 100
# ── handle_micro_fault fallback (no HelperLLM) ──────────────
@pytest.mark.asyncio
async def test_micro_fault_fallback_no_helper(assembler_no_helper, object_store):
"""Without HelperLLM, micro-fault should return summaries as fallback."""
objects = await _seed_objects(object_store, "sess_test")
# Set a summary on one object so fallback has something to show
await object_store.update_fidelity(objects[0].id, 2, summary="Auth uses passport.js JWT")
result = await assembler_no_helper.handle_micro_fault(
session_id="sess_test",
question="What auth library is used?",
)
assert isinstance(result, MicroFaultResult)
assert "HelperLLM unavailable" in result.answer
assert len(result.sources) > 0
assert result.answer_tokens > 0
@pytest.mark.asyncio
async def test_micro_fault_fallback_uses_stub_when_no_summary(assembler_no_helper, object_store):
"""Fallback should use stub when no summaries are available."""
await _seed_objects(object_store, "sess_test")
result = await assembler_no_helper.handle_micro_fault(
session_id="sess_test",
question="What auth library is used?",
)
assert isinstance(result, MicroFaultResult)
assert "HelperLLM unavailable" in result.answer
# Should contain stub content since no summaries exist
assert len(result.sources) > 0
# ── handle_micro_fault with no search results ────────────────
@pytest.mark.asyncio
async def test_micro_fault_no_results(assembler):
"""Micro-fault with no matching objects should return a 'not found' message."""
# Don't seed any objects — empty store
result = await assembler.handle_micro_fault(
session_id="sess_empty",
question="What is the meaning of life?",
)
assert isinstance(result, MicroFaultResult)
assert "No relevant content found" in result.answer
assert result.sources == []
assert result.avoided_tokens == 0
# ── Micro-fault records access on consulted objects ──────────
@pytest.mark.asyncio
async def test_micro_fault_records_access(assembler, object_store):
"""Micro-fault should call record_fault(is_micro=True) on consulted objects."""
objects = await _seed_objects(object_store, "sess_test")
result = await assembler.handle_micro_fault(
session_id="sess_test",
question="What auth library is used?",
)
# At least one source should have been consulted
assert len(result.sources) > 0
# Check that micro_fault_count was incremented on consulted objects
for source_id in result.sources:
obj = await object_store.get(source_id)
assert obj is not None
assert obj.micro_fault_count > 0
# ── assemble_context ─────────────────────────────────────────
@pytest.mark.asyncio
async def test_assemble_context_all_l0(assembler, object_store):
"""assemble_context with all L0 objects should include full content."""
objects = await _seed_objects(object_store, "sess_test")
fm = FidelityManager(window_size=200_000)
for obj in objects:
fm_obj = make_object(
object_type=obj.object_type,
content_full=obj.content_full,
stub=obj.stub,
)
fm_obj.id = obj.id # match IDs
fm.register_object(fm_obj)
blocks = await assembler.assemble_context(
session_id="sess_test",
fidelity_manager=fm,
current_turn=1,
)
assert len(blocks) == 3
for block in blocks:
assert block["fidelity"] == 0 # L0
assert len(block["content"]) > 50 # full content
assert block["tokens"] > 0
@pytest.mark.asyncio
async def test_assemble_context_mixed_fidelity(assembler, object_store):
"""assemble_context should respect per-object fidelity levels."""
objects = await _seed_objects(object_store, "sess_test")
# Set different fidelity levels in the store
await object_store.update_fidelity(
objects[0].id,
0, # L0: full
)
await object_store.update_fidelity(
objects[1].id,
2,
summary="DB has users, posts, comments tables", # L2: compact
)
await object_store.update_fidelity(
objects[2].id,
3, # L3: stub
)
fm = FidelityManager(window_size=200_000)
for obj in objects:
fm_obj = make_object(
object_type=obj.object_type,
content_full=obj.content_full,
stub=obj.stub,
)
fm_obj.id = obj.id
# Set fidelity to match what we set in the store
stored = await object_store.get(obj.id)
fm_obj.current_fidelity = FidelityLevel(stored.current_fidelity)
if stored.summary_compact:
fm_obj.summary_compact = stored.summary_compact
fm.register_object(fm_obj)
blocks = await assembler.assemble_context(
session_id="sess_test",
fidelity_manager=fm,
current_turn=5,
)
assert len(blocks) == 3
fidelities = {b["object_id"]: b["fidelity"] for b in blocks}
assert fidelities[objects[0].id] == 0 # L0
assert fidelities[objects[1].id] == 2 # L2
assert fidelities[objects[2].id] == 3 # L3
@pytest.mark.asyncio
async def test_assemble_context_excludes_evicted(assembler, object_store):
"""assemble_context should exclude L4 (evicted) objects."""
objects = await _seed_objects(object_store, "sess_test")
# Evict one object
await object_store.update_fidelity(objects[1].id, 4)
fm = FidelityManager(window_size=200_000)
for obj in objects:
fm_obj = make_object(
object_type=obj.object_type,
content_full=obj.content_full,
stub=obj.stub,
)
fm_obj.id = obj.id
stored = await object_store.get(obj.id)
fm_obj.current_fidelity = FidelityLevel(min(stored.current_fidelity, 4))
fm.register_object(fm_obj)
blocks = await assembler.assemble_context(
session_id="sess_test",
fidelity_manager=fm,
current_turn=5,
)
# Only 2 objects should be in context (one was evicted)
block_ids = {b["object_id"] for b in blocks}
assert objects[1].id not in block_ids
assert len(blocks) == 2
# ── _estimate_tokens helper ──────────────────────────────────
def test_estimate_tokens_none():
"""_estimate_tokens(None) should return 0."""
assert _estimate_tokens(None) == 0
def test_estimate_tokens_empty():
"""_estimate_tokens('') should return 1 (minimum)."""
assert _estimate_tokens("") == 1
def test_estimate_tokens_normal():
"""_estimate_tokens should use ~4 chars per token heuristic."""
text = "a" * 400
assert _estimate_tokens(text) == 100