Object-addressed memory: segment messages into semantic objects, embed with sentence-transformers, store in pgvector-backed store, and reassemble context via goal-aware retrieval. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
408 lines
14 KiB
Python
408 lines
14 KiB
Python
"""Tests for the ContextAssembler micro-fault and context assembly module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from mnemosyne.context_assembler import (
|
|
ContextAssembler,
|
|
ContextWindow,
|
|
MicroFaultResult,
|
|
_estimate_tokens,
|
|
)
|
|
from mnemosyne.fidelity import FidelityLevel, FidelityManager, make_object
|
|
from mnemosyne.object_store import (
|
|
DummyEmbedder,
|
|
InMemoryBackend,
|
|
ObjectStore,
|
|
StoredObject,
|
|
)
|
|
|
|
|
|
# ── Fixtures ─────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def object_store():
|
|
"""Create an ObjectStore with InMemoryBackend and DummyEmbedder."""
|
|
backend = InMemoryBackend()
|
|
embedder = DummyEmbedder()
|
|
return ObjectStore(backend=backend, embedder=embedder)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_helper_llm():
|
|
"""Create a mock HelperLLM with answer_micro_fault returning a canned answer."""
|
|
helper = AsyncMock()
|
|
helper.answer_micro_fault = AsyncMock(
|
|
return_value="The auth library used is passport.js with JWT strategy."
|
|
)
|
|
return helper
|
|
|
|
|
|
@pytest.fixture
|
|
def assembler(object_store, mock_helper_llm):
|
|
"""Create a ContextAssembler with real ObjectStore and mock HelperLLM."""
|
|
return ContextAssembler(object_store=object_store, helper_llm=mock_helper_llm)
|
|
|
|
|
|
@pytest.fixture
|
|
def assembler_no_helper(object_store):
|
|
"""Create a ContextAssembler without HelperLLM (fallback mode)."""
|
|
return ContextAssembler(object_store=object_store, helper_llm=None)
|
|
|
|
|
|
async def _seed_objects(store: ObjectStore, session_id: str = "sess_test") -> list[StoredObject]:
|
|
"""Seed the store with a few test objects and return them."""
|
|
obj1 = await store.store_object(
|
|
session_id=session_id,
|
|
content="Authentication uses passport.js with JWT strategy. "
|
|
"The secret key is loaded from AUTH_SECRET env var. "
|
|
"Token expiry is set to 24 hours.",
|
|
object_type="file_context",
|
|
source_tool="Read",
|
|
source_key="src/auth/middleware.ts",
|
|
stub="[file_context | auth middleware | passport.js JWT]",
|
|
key_entities=["passport.js", "JWT", "AUTH_SECRET"],
|
|
)
|
|
obj2 = await store.store_object(
|
|
session_id=session_id,
|
|
content="Database schema has users, posts, and comments tables. "
|
|
"Users table has id, email, password_hash, created_at columns. "
|
|
"Posts reference users via author_id foreign key.",
|
|
object_type="file_context",
|
|
source_tool="Read",
|
|
source_key="schema.sql",
|
|
stub="[file_context | DB schema | users, posts, comments]",
|
|
key_entities=["users", "posts", "comments", "schema.sql"],
|
|
)
|
|
obj3 = await store.store_object(
|
|
session_id=session_id,
|
|
content="Server runs on port 3000 by default. "
|
|
"Configuration loaded from .env file. "
|
|
"CORS enabled for localhost:5173 in development.",
|
|
object_type="file_context",
|
|
source_tool="Read",
|
|
source_key="src/config.ts",
|
|
stub="[file_context | server config | port 3000, CORS]",
|
|
key_entities=["port 3000", "CORS", ".env"],
|
|
)
|
|
return [obj1, obj2, obj3]
|
|
|
|
|
|
# ── MicroFaultResult dataclass ───────────────────────────────
|
|
|
|
|
|
def test_micro_fault_result_fields():
|
|
"""MicroFaultResult should store all expected fields."""
|
|
result = MicroFaultResult(
|
|
answer="The answer is 42.",
|
|
sources=["obj_abc", "obj_def"],
|
|
answer_tokens=5,
|
|
avoided_tokens=500,
|
|
latency_ms=12.5,
|
|
)
|
|
assert result.answer == "The answer is 42."
|
|
assert result.sources == ["obj_abc", "obj_def"]
|
|
assert result.answer_tokens == 5
|
|
assert result.avoided_tokens == 500
|
|
assert result.latency_ms == 12.5
|
|
|
|
|
|
def test_micro_fault_result_token_savings():
|
|
"""avoided_tokens should represent tokens saved vs full restore."""
|
|
result = MicroFaultResult(
|
|
answer="short answer",
|
|
sources=["obj1"],
|
|
answer_tokens=3,
|
|
avoided_tokens=997,
|
|
latency_ms=10.0,
|
|
)
|
|
# The savings ratio: avoided / (avoided + answer_tokens)
|
|
savings_ratio = result.avoided_tokens / (result.avoided_tokens + result.answer_tokens)
|
|
assert savings_ratio > 0.99 # 99%+ savings
|
|
|
|
|
|
# ── ContextWindow dataclass ──────────────────────────────────
|
|
|
|
|
|
def test_context_window_fields():
|
|
"""ContextWindow should store objects, total_tokens, and pressure_zone."""
|
|
window = ContextWindow(
|
|
objects=[("obj1", 0, "full content"), ("obj2", 2, "compact summary")],
|
|
total_tokens=150,
|
|
pressure_zone="NORMAL",
|
|
)
|
|
assert len(window.objects) == 2
|
|
assert window.total_tokens == 150
|
|
assert window.pressure_zone == "NORMAL"
|
|
|
|
|
|
# ── handle_micro_fault with HelperLLM ────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_with_helper(assembler, object_store, mock_helper_llm):
|
|
"""Micro-fault with HelperLLM should search, call helper, return answer."""
|
|
await _seed_objects(object_store, "sess_test")
|
|
|
|
result = await assembler.handle_micro_fault(
|
|
session_id="sess_test",
|
|
question="What auth library is used?",
|
|
)
|
|
|
|
assert isinstance(result, MicroFaultResult)
|
|
assert result.answer == "The auth library used is passport.js with JWT strategy."
|
|
assert len(result.sources) > 0
|
|
assert result.answer_tokens > 0
|
|
assert result.avoided_tokens > 0
|
|
assert result.latency_ms >= 0
|
|
|
|
# Verify HelperLLM was called
|
|
mock_helper_llm.answer_micro_fault.assert_called_once()
|
|
call_kwargs = mock_helper_llm.answer_micro_fault.call_args
|
|
assert call_kwargs.kwargs["question"] == "What auth library is used?"
|
|
assert len(call_kwargs.kwargs["relevant_contents"]) > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_with_scope(assembler, object_store, mock_helper_llm):
|
|
"""Micro-fault with scope hint should incorporate it into the search query."""
|
|
await _seed_objects(object_store, "sess_test")
|
|
|
|
result = await assembler.handle_micro_fault(
|
|
session_id="sess_test",
|
|
question="What port?",
|
|
scope="config files",
|
|
)
|
|
|
|
assert isinstance(result, MicroFaultResult)
|
|
assert len(result.sources) > 0
|
|
mock_helper_llm.answer_micro_fault.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_with_max_tokens(assembler, object_store, mock_helper_llm):
|
|
"""Micro-fault should pass max_tokens to the HelperLLM."""
|
|
await _seed_objects(object_store, "sess_test")
|
|
|
|
await assembler.handle_micro_fault(
|
|
session_id="sess_test",
|
|
question="What is the DB schema?",
|
|
max_tokens=100,
|
|
)
|
|
|
|
call_kwargs = mock_helper_llm.answer_micro_fault.call_args
|
|
assert call_kwargs.kwargs["max_tokens"] == 100
|
|
|
|
|
|
# ── handle_micro_fault fallback (no HelperLLM) ──────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_fallback_no_helper(assembler_no_helper, object_store):
|
|
"""Without HelperLLM, micro-fault should return summaries as fallback."""
|
|
objects = await _seed_objects(object_store, "sess_test")
|
|
|
|
# Set a summary on one object so fallback has something to show
|
|
await object_store.update_fidelity(objects[0].id, 2, summary="Auth uses passport.js JWT")
|
|
|
|
result = await assembler_no_helper.handle_micro_fault(
|
|
session_id="sess_test",
|
|
question="What auth library is used?",
|
|
)
|
|
|
|
assert isinstance(result, MicroFaultResult)
|
|
assert "HelperLLM unavailable" in result.answer
|
|
assert len(result.sources) > 0
|
|
assert result.answer_tokens > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_fallback_uses_stub_when_no_summary(assembler_no_helper, object_store):
|
|
"""Fallback should use stub when no summaries are available."""
|
|
await _seed_objects(object_store, "sess_test")
|
|
|
|
result = await assembler_no_helper.handle_micro_fault(
|
|
session_id="sess_test",
|
|
question="What auth library is used?",
|
|
)
|
|
|
|
assert isinstance(result, MicroFaultResult)
|
|
assert "HelperLLM unavailable" in result.answer
|
|
# Should contain stub content since no summaries exist
|
|
assert len(result.sources) > 0
|
|
|
|
|
|
# ── handle_micro_fault with no search results ────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_no_results(assembler):
|
|
"""Micro-fault with no matching objects should return a 'not found' message."""
|
|
# Don't seed any objects — empty store
|
|
result = await assembler.handle_micro_fault(
|
|
session_id="sess_empty",
|
|
question="What is the meaning of life?",
|
|
)
|
|
|
|
assert isinstance(result, MicroFaultResult)
|
|
assert "No relevant content found" in result.answer
|
|
assert result.sources == []
|
|
assert result.avoided_tokens == 0
|
|
|
|
|
|
# ── Micro-fault records access on consulted objects ──────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_micro_fault_records_access(assembler, object_store):
|
|
"""Micro-fault should call record_fault(is_micro=True) on consulted objects."""
|
|
objects = await _seed_objects(object_store, "sess_test")
|
|
|
|
result = await assembler.handle_micro_fault(
|
|
session_id="sess_test",
|
|
question="What auth library is used?",
|
|
)
|
|
|
|
# At least one source should have been consulted
|
|
assert len(result.sources) > 0
|
|
|
|
# Check that micro_fault_count was incremented on consulted objects
|
|
for source_id in result.sources:
|
|
obj = await object_store.get(source_id)
|
|
assert obj is not None
|
|
assert obj.micro_fault_count > 0
|
|
|
|
|
|
# ── assemble_context ─────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_context_all_l0(assembler, object_store):
|
|
"""assemble_context with all L0 objects should include full content."""
|
|
objects = await _seed_objects(object_store, "sess_test")
|
|
|
|
fm = FidelityManager(window_size=200_000)
|
|
for obj in objects:
|
|
fm_obj = make_object(
|
|
object_type=obj.object_type,
|
|
content_full=obj.content_full,
|
|
stub=obj.stub,
|
|
)
|
|
fm_obj.id = obj.id # match IDs
|
|
fm.register_object(fm_obj)
|
|
|
|
blocks = await assembler.assemble_context(
|
|
session_id="sess_test",
|
|
fidelity_manager=fm,
|
|
current_turn=1,
|
|
)
|
|
|
|
assert len(blocks) == 3
|
|
for block in blocks:
|
|
assert block["fidelity"] == 0 # L0
|
|
assert len(block["content"]) > 50 # full content
|
|
assert block["tokens"] > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_context_mixed_fidelity(assembler, object_store):
|
|
"""assemble_context should respect per-object fidelity levels."""
|
|
objects = await _seed_objects(object_store, "sess_test")
|
|
|
|
# Set different fidelity levels in the store
|
|
await object_store.update_fidelity(
|
|
objects[0].id,
|
|
0, # L0: full
|
|
)
|
|
await object_store.update_fidelity(
|
|
objects[1].id,
|
|
2,
|
|
summary="DB has users, posts, comments tables", # L2: compact
|
|
)
|
|
await object_store.update_fidelity(
|
|
objects[2].id,
|
|
3, # L3: stub
|
|
)
|
|
|
|
fm = FidelityManager(window_size=200_000)
|
|
for obj in objects:
|
|
fm_obj = make_object(
|
|
object_type=obj.object_type,
|
|
content_full=obj.content_full,
|
|
stub=obj.stub,
|
|
)
|
|
fm_obj.id = obj.id
|
|
# Set fidelity to match what we set in the store
|
|
stored = await object_store.get(obj.id)
|
|
fm_obj.current_fidelity = FidelityLevel(stored.current_fidelity)
|
|
if stored.summary_compact:
|
|
fm_obj.summary_compact = stored.summary_compact
|
|
fm.register_object(fm_obj)
|
|
|
|
blocks = await assembler.assemble_context(
|
|
session_id="sess_test",
|
|
fidelity_manager=fm,
|
|
current_turn=5,
|
|
)
|
|
|
|
assert len(blocks) == 3
|
|
fidelities = {b["object_id"]: b["fidelity"] for b in blocks}
|
|
assert fidelities[objects[0].id] == 0 # L0
|
|
assert fidelities[objects[1].id] == 2 # L2
|
|
assert fidelities[objects[2].id] == 3 # L3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_context_excludes_evicted(assembler, object_store):
|
|
"""assemble_context should exclude L4 (evicted) objects."""
|
|
objects = await _seed_objects(object_store, "sess_test")
|
|
|
|
# Evict one object
|
|
await object_store.update_fidelity(objects[1].id, 4)
|
|
|
|
fm = FidelityManager(window_size=200_000)
|
|
for obj in objects:
|
|
fm_obj = make_object(
|
|
object_type=obj.object_type,
|
|
content_full=obj.content_full,
|
|
stub=obj.stub,
|
|
)
|
|
fm_obj.id = obj.id
|
|
stored = await object_store.get(obj.id)
|
|
fm_obj.current_fidelity = FidelityLevel(min(stored.current_fidelity, 4))
|
|
fm.register_object(fm_obj)
|
|
|
|
blocks = await assembler.assemble_context(
|
|
session_id="sess_test",
|
|
fidelity_manager=fm,
|
|
current_turn=5,
|
|
)
|
|
|
|
# Only 2 objects should be in context (one was evicted)
|
|
block_ids = {b["object_id"] for b in blocks}
|
|
assert objects[1].id not in block_ids
|
|
assert len(blocks) == 2
|
|
|
|
|
|
# ── _estimate_tokens helper ──────────────────────────────────
|
|
|
|
|
|
def test_estimate_tokens_none():
|
|
"""_estimate_tokens(None) should return 0."""
|
|
assert _estimate_tokens(None) == 0
|
|
|
|
|
|
def test_estimate_tokens_empty():
|
|
"""_estimate_tokens('') should return 1 (minimum)."""
|
|
assert _estimate_tokens("") == 1
|
|
|
|
|
|
def test_estimate_tokens_normal():
|
|
"""_estimate_tokens should use ~4 chars per token heuristic."""
|
|
text = "a" * 400
|
|
assert _estimate_tokens(text) == 100
|