"""Tests for the PostgreSQL + pgvector backend. Unit tests use mocked asyncpg connections to verify SQL generation and data conversion without requiring a database. Integration tests require a running PostgreSQL instance (Docker on port 5433) and are skipped automatically when unavailable. """ from __future__ import annotations import asyncio import json import os import socket import uuid from datetime import datetime, timezone from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest from mnemosyne.object_store import StoredObject, _estimate_tokens from mnemosyne.pgvector_backend import ( PgVectorBackend, _parse_jsonb, _parse_timestamp, _row_to_stored_object, ) class _MockPool: """A mock asyncpg pool that properly supports async context manager on acquire().""" def __init__(self, conn: AsyncMock): self._conn = conn def acquire(self): return _MockAcquire(self._conn) class _MockAcquire: """Async context manager returned by pool.acquire().""" def __init__(self, conn: AsyncMock): self._conn = conn async def __aenter__(self): return self._conn async def __aexit__(self, *args): return False def _make_mock_pool(conn: AsyncMock) -> _MockPool: """Create a mock pool with proper async context manager support.""" return _MockPool(conn) # ── Helpers ────────────────────────────────────────────────── def _make_stored_object( session_id: str = "test-session", content: str = "Test content for pgvector backend", *, object_type: str = "file_context", source_tool: str | None = "Read", source_key: str | None = None, stub: str | None = None, embedding: list[float] | None = None, object_id: str | None = None, ) -> StoredObject: """Create a StoredObject with sensible defaults for testing.""" oid = object_id or uuid.uuid4().hex return StoredObject( id=oid, session_id=session_id, object_type=object_type, source_tool=source_tool, source_key=source_key, content_full=content, summary_detailed=None, summary_compact=None, stub=stub or f"{object_type}: test object", tokens_l0=_estimate_tokens(content), tokens_l3=_estimate_tokens(stub or f"{object_type}: test object"), embedding=embedding or [0.1] * 384, created_at="2025-01-01T00:00:00+00:00", last_accessed="2025-01-01T00:00:00+00:00", ) def _make_mock_row( obj: StoredObject | None = None, *, session_external_id: str = "test-session", similarity: float | None = None, ) -> MagicMock: """Create a mock asyncpg.Record from a StoredObject.""" if obj is None: obj = _make_stored_object() row = MagicMock() row_data: dict[str, Any] = { "id": uuid.UUID(obj.id) if len(obj.id) == 32 else uuid.uuid4(), "session_id": uuid.uuid4(), "session_external_id": session_external_id, "object_type": obj.object_type, "source_tool": obj.source_tool, "source_key": obj.source_key, "content_full": obj.content_full, "summary_detailed": obj.summary_detailed, "summary_compact": obj.summary_compact, "stub": obj.stub, "losses_l1": obj.losses_l1, "losses_l2": obj.losses_l2, "can_answer_l1": obj.can_answer_l1, "can_answer_l2": obj.can_answer_l2, "fault_when": obj.fault_when, "key_entities": obj.key_entities, "tags": obj.tags, "current_fidelity": obj.current_fidelity, "pinned": obj.pinned, "tokens_l0": obj.tokens_l0, "tokens_l1": obj.tokens_l1, "tokens_l2": obj.tokens_l2, "tokens_l3": obj.tokens_l3, "source_turn_start": obj.source_turn_start, "source_turn_end": obj.source_turn_end, "embedding": np.array(obj.embedding, dtype=np.float32) if obj.embedding else None, "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), "last_accessed": datetime(2025, 1, 1, tzinfo=timezone.utc), "access_count": obj.access_count, "fault_count": obj.fault_count, "micro_fault_count": obj.micro_fault_count, } if similarity is not None: row_data["similarity"] = similarity row.__getitem__ = lambda self, key: row_data[key] return row def _pg_available() -> bool: """Check if PostgreSQL is reachable on port 5433.""" try: with socket.create_connection(("localhost", 5433), timeout=1): return True except (OSError, ConnectionRefusedError): return False # ── Unit Tests: Data Conversion ────────────────────────────── class TestParseJsonb: def test_none_returns_empty_list(self): assert _parse_jsonb(None) == [] def test_list_passthrough(self): assert _parse_jsonb(["a", "b", "c"]) == ["a", "b", "c"] def test_list_converts_to_strings(self): assert _parse_jsonb([1, 2, 3]) == ["1", "2", "3"] def test_json_string(self): assert _parse_jsonb('["x", "y"]') == ["x", "y"] def test_invalid_json_string(self): assert _parse_jsonb("not json") == [] def test_empty_list(self): assert _parse_jsonb([]) == [] class TestParseTimestamp: def test_iso_format(self): dt = _parse_timestamp("2025-01-01T00:00:00+00:00") assert dt.year == 2025 assert dt.tzinfo is not None def test_empty_string_returns_now(self): dt = _parse_timestamp("") assert dt.tzinfo is not None # Should be close to now diff = abs((datetime.now(timezone.utc) - dt).total_seconds()) assert diff < 5 def test_naive_timestamp_gets_utc(self): dt = _parse_timestamp("2025-06-15T12:00:00") assert dt.tzinfo == timezone.utc def test_invalid_returns_now(self): dt = _parse_timestamp("not-a-date") assert dt.tzinfo is not None class TestRowToStoredObject: def test_basic_conversion(self): obj = _make_stored_object() row = _make_mock_row(obj) result = _row_to_stored_object(row) assert result.object_type == "file_context" assert result.content_full == obj.content_full assert result.stub == obj.stub assert result.session_id == "test-session" assert result.current_fidelity == 0 assert result.pinned is False def test_embedding_conversion(self): obj = _make_stored_object(embedding=[0.5] * 384) row = _make_mock_row(obj) result = _row_to_stored_object(row) assert len(result.embedding) == 384 assert abs(result.embedding[0] - 0.5) < 1e-6 def test_jsonb_fields_parsed(self): obj = _make_stored_object() obj.losses_l1 = ["detail_a", "detail_b"] obj.key_entities = ["src/main.py", "Config"] row = _make_mock_row(obj) result = _row_to_stored_object(row) assert result.losses_l1 == ["detail_a", "detail_b"] assert result.key_entities == ["src/main.py", "Config"] def test_tags_conversion(self): obj = _make_stored_object() obj.tags = ["auth", "middleware"] row = _make_mock_row(obj) result = _row_to_stored_object(row) assert result.tags == ["auth", "middleware"] def test_null_embedding(self): obj = _make_stored_object(embedding=[]) row_data: dict[str, Any] = { "id": uuid.uuid4(), "session_id": uuid.uuid4(), "session_external_id": "test-session", "object_type": obj.object_type, "source_tool": obj.source_tool, "source_key": obj.source_key, "content_full": obj.content_full, "summary_detailed": obj.summary_detailed, "summary_compact": obj.summary_compact, "stub": obj.stub, "losses_l1": [], "losses_l2": [], "can_answer_l1": [], "can_answer_l2": [], "fault_when": [], "key_entities": [], "tags": [], "current_fidelity": 0, "pinned": False, "tokens_l0": obj.tokens_l0, "tokens_l1": None, "tokens_l2": None, "tokens_l3": obj.tokens_l3, "source_turn_start": None, "source_turn_end": None, "embedding": None, "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), "last_accessed": datetime(2025, 1, 1, tzinfo=timezone.utc), "access_count": 0, "fault_count": 0, "micro_fault_count": 0, } row = MagicMock() row.__getitem__ = lambda self, key: row_data[key] result = _row_to_stored_object(row) assert result.embedding == [] # ── Unit Tests: Backend Methods (Mocked DB) ───────────────── class TestPgVectorBackendInit: def test_default_config(self): backend = PgVectorBackend() assert backend._host == "localhost" assert backend._port == 5433 assert backend._database == "mnemosyne" assert backend._pool is None def test_custom_config(self): backend = PgVectorBackend( host="db.example.com", port=5432, database="mydb", user="myuser", password="secret", min_connections=5, max_connections=20, ) assert backend._host == "db.example.com" assert backend._port == 5432 assert backend._min_connections == 5 assert backend._max_connections == 20 def test_get_pool_raises_when_not_connected(self): backend = PgVectorBackend() with pytest.raises(RuntimeError, match="not connected"): backend._get_pool() class TestPgVectorBackendStore: """Test store() with mocked pool.""" async def test_store_calls_execute(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value={"id": uuid.uuid4()}) backend._pool = _make_mock_pool(mock_conn) obj = _make_stored_object(object_id=uuid.uuid4().hex) await backend.store(obj) # Should have called execute for the INSERT (store) and fetchrow for session upsert assert mock_conn.execute.called or mock_conn.fetchrow.called async def test_store_creates_session_if_needed(self): backend = PgVectorBackend() session_uuid = uuid.uuid4() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value={"id": session_uuid}) mock_conn.execute = AsyncMock() backend._pool = _make_mock_pool(mock_conn) obj = _make_stored_object(object_id=uuid.uuid4().hex) await backend.store(obj) # Session should be cached after creation assert "test-session" in backend._session_cache class TestPgVectorBackendGet: """Test get() with mocked pool.""" async def test_get_returns_none_when_not_found(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=None) backend._pool = _make_mock_pool(mock_conn) result = await backend.get(uuid.uuid4().hex) assert result is None async def test_get_returns_stored_object(self): backend = PgVectorBackend() obj = _make_stored_object(object_id=uuid.uuid4().hex) mock_row = _make_mock_row(obj) mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=mock_row) backend._pool = _make_mock_pool(mock_conn) result = await backend.get(obj.id) assert result is not None assert result.content_full == obj.content_full class TestPgVectorBackendGetBySession: """Test get_by_session() with mocked pool.""" async def test_returns_empty_for_unknown_session(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=None) backend._pool = _make_mock_pool(mock_conn) result = await backend.get_by_session("nonexistent") assert result == [] class TestPgVectorBackendUpdateFidelity: """Test update_fidelity() with mocked pool.""" async def test_update_fidelity_basic(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.execute = AsyncMock() backend._pool = _make_mock_pool(mock_conn) await backend.update_fidelity(uuid.uuid4().hex, 1) assert mock_conn.execute.called async def test_update_fidelity_with_summary_l1(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.execute = AsyncMock() backend._pool = _make_mock_pool(mock_conn) await backend.update_fidelity( uuid.uuid4().hex, 1, summary="A summary", losses=["lost detail"] ) # Verify the SQL includes summary_detailed and losses_l1 call_args = mock_conn.execute.call_args query = call_args[0][0] assert "summary_detailed" in query assert "losses_l1" in query async def test_update_fidelity_with_summary_l2(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.execute = AsyncMock() backend._pool = _make_mock_pool(mock_conn) await backend.update_fidelity(uuid.uuid4().hex, 2, summary="Compact", losses=["more lost"]) call_args = mock_conn.execute.call_args query = call_args[0][0] assert "summary_compact" in query assert "losses_l2" in query class TestPgVectorBackendSearchByEmbedding: """Test search_by_embedding() with mocked pool.""" async def test_returns_empty_for_unknown_session(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=None) backend._pool = _make_mock_pool(mock_conn) result = await backend.search_by_embedding("nonexistent", [0.1] * 384) assert result == [] class TestPgVectorBackendSearchByText: """Test search_by_text() with mocked pool.""" async def test_returns_empty_for_unknown_session(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=None) backend._pool = _make_mock_pool(mock_conn) result = await backend.search_by_text("nonexistent", "test query") assert result == [] class TestPgVectorBackendDeleteSession: """Test delete_session() with mocked pool.""" async def test_returns_zero_for_unknown_session(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=None) backend._pool = _make_mock_pool(mock_conn) result = await backend.delete_session("nonexistent") assert result == 0 class TestPgVectorBackendGetBySourceKey: """Test get_by_source_key() with mocked pool.""" async def test_returns_none_for_unknown_session(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value=None) backend._pool = _make_mock_pool(mock_conn) result = await backend.get_by_source_key("nonexistent", "src/main.py") assert result is None class TestPgVectorBackendHealthCheck: """Test health_check().""" async def test_returns_false_when_not_connected(self): backend = PgVectorBackend() assert await backend.health_check() is False async def test_returns_true_with_healthy_pool(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchval = AsyncMock(return_value=1) backend._pool = _make_mock_pool(mock_conn) assert await backend.health_check() is True async def test_returns_false_on_error(self): backend = PgVectorBackend() mock_conn = AsyncMock() mock_conn.fetchval = AsyncMock(side_effect=Exception("connection lost")) backend._pool = _make_mock_pool(mock_conn) assert await backend.health_check() is False class TestPgVectorBackendSessionCache: """Test session ID caching behavior.""" async def test_session_cache_populated_on_ensure(self): backend = PgVectorBackend() session_uuid = uuid.uuid4() mock_conn = AsyncMock() mock_conn.fetchrow = AsyncMock(return_value={"id": session_uuid}) backend._pool = _make_mock_pool(mock_conn) result = await backend.ensure_session("my-session", "claude-opus-4") assert result == session_uuid assert backend._session_cache["my-session"] == session_uuid async def test_session_cache_avoids_db_lookup(self): backend = PgVectorBackend() session_uuid = uuid.uuid4() backend._session_cache["cached-session"] = session_uuid # Pool shouldn't be needed since cache is populated backend._pool = _make_mock_pool(AsyncMock()) result = await backend._resolve_session_id("cached-session") assert result == session_uuid async def test_close_clears_cache(self): backend = PgVectorBackend() backend._session_cache["test"] = uuid.uuid4() mock_pool = AsyncMock() mock_pool.close = AsyncMock() backend._pool = mock_pool await backend.close() assert len(backend._session_cache) == 0 assert backend._pool is None # ── Integration Tests (require Docker PostgreSQL) ──────────── _skip_no_pg = pytest.mark.skipif( not _pg_available(), reason="PostgreSQL not available on localhost:5433 (run: docker compose up -d)", ) @_skip_no_pg class TestPgVectorBackendIntegration: """Integration tests against a real PostgreSQL instance. These tests require Docker PostgreSQL running on port 5433 with the schema from sql/init.sql applied. Start with: cd ~/Projects/contextmanager && docker compose up -d """ @pytest.fixture async def backend(self): """Create a connected backend and clean up after test.""" b = PgVectorBackend( host="localhost", port=5433, database="mnemosyne", user="mnemosyne", password="mnemosyne_dev", min_connections=1, max_connections=3, ) await b.connect() # Create a unique session for this test test_session = f"integration-test-{uuid.uuid4().hex[:8]}" yield b, test_session # Cleanup: delete test session data try: session_uuid = await b._resolve_session_id(test_session) if session_uuid is not None: pool = b._get_pool() async with pool.acquire() as conn: await conn.execute( "DELETE FROM semantic_objects WHERE session_id = $1", session_uuid, ) await conn.execute( "DELETE FROM sessions WHERE id = $1", session_uuid, ) except Exception: pass await b.close() async def test_health_check(self, backend): b, _ = backend assert await b.health_check() is True async def test_ensure_session(self, backend): b, test_session = backend session_uuid = await b.ensure_session(test_session, "claude-opus-4") assert isinstance(session_uuid, uuid.UUID) # Second call should return same UUID session_uuid2 = await b.ensure_session(test_session, "claude-opus-4") assert session_uuid == session_uuid2 async def test_store_and_get(self, backend): b, test_session = backend await b.ensure_session(test_session) obj_id = uuid.uuid4().hex obj = _make_stored_object( session_id=test_session, content="Integration test content", object_id=obj_id, ) await b.store(obj) retrieved = await b.get(obj_id) assert retrieved is not None assert retrieved.content_full == "Integration test content" assert retrieved.session_id == test_session assert retrieved.object_type == "file_context" async def test_store_upsert(self, backend): b, test_session = backend await b.ensure_session(test_session) obj_id = uuid.uuid4().hex obj = _make_stored_object( session_id=test_session, content="Original content", object_id=obj_id, ) await b.store(obj) # Update the same object obj.content_full = "Updated content" await b.store(obj) retrieved = await b.get(obj_id) assert retrieved is not None assert retrieved.content_full == "Updated content" async def test_get_by_session(self, backend): b, test_session = backend await b.ensure_session(test_session) # Store 3 objects for i in range(3): obj = _make_stored_object( session_id=test_session, content=f"Content {i}", object_id=uuid.uuid4().hex, ) await b.store(obj) results = await b.get_by_session(test_session) assert len(results) == 3 async def test_get_by_session_fidelity_filter(self, backend): b, test_session = backend await b.ensure_session(test_session) # Store object at fidelity 0 obj0 = _make_stored_object( session_id=test_session, content="Fidelity 0", object_id=uuid.uuid4().hex, ) await b.store(obj0) # Store object at fidelity 4 (evicted) obj4 = _make_stored_object( session_id=test_session, content="Fidelity 4", object_id=uuid.uuid4().hex, ) obj4.current_fidelity = 4 await b.store(obj4) # Default: include all all_results = await b.get_by_session(test_session, fidelity_max=4) assert len(all_results) == 2 # Exclude evicted active_results = await b.get_by_session(test_session, fidelity_max=3) assert len(active_results) == 1 assert active_results[0].content_full == "Fidelity 0" async def test_update_fidelity(self, backend): b, test_session = backend await b.ensure_session(test_session) obj_id = uuid.uuid4().hex obj = _make_stored_object( session_id=test_session, content="Will be degraded", object_id=obj_id, ) await b.store(obj) await b.update_fidelity( obj_id, 1, summary="Detailed summary", losses=["lost some detail"], ) retrieved = await b.get(obj_id) assert retrieved is not None assert retrieved.current_fidelity == 1 assert retrieved.summary_detailed == "Detailed summary" async def test_search_by_embedding(self, backend): b, test_session = backend await b.ensure_session(test_session) # Store objects with different embeddings rng = np.random.default_rng(42) for i in range(5): vec = rng.standard_normal(384).astype(np.float32) vec = vec / np.linalg.norm(vec) obj = _make_stored_object( session_id=test_session, content=f"Embedding test {i}", object_id=uuid.uuid4().hex, embedding=vec.tolist(), ) await b.store(obj) # Search with the first object's embedding query_vec = rng.standard_normal(384).astype(np.float32) query_vec = query_vec / np.linalg.norm(query_vec) results = await b.search_by_embedding(test_session, query_vec.tolist(), limit=3) assert len(results) <= 3 # Results should be (StoredObject, float) tuples for obj, score in results: assert isinstance(obj, StoredObject) assert isinstance(score, float) assert -1.0 <= score <= 1.0 # Scores should be in descending order scores = [s for _, s in results] assert scores == sorted(scores, reverse=True) async def test_search_by_text(self, backend): b, test_session = backend await b.ensure_session(test_session) obj = _make_stored_object( session_id=test_session, content="The authentication middleware validates JWT tokens and checks expiration dates", object_id=uuid.uuid4().hex, ) await b.store(obj) obj2 = _make_stored_object( session_id=test_session, content="Database connection pooling configuration for PostgreSQL", object_id=uuid.uuid4().hex, ) await b.store(obj2) # Search for auth-related content results = await b.search_by_text(test_session, "authentication JWT tokens") assert len(results) >= 1 assert any("authentication" in r.content_full for r in results) async def test_delete_session(self, backend): b, test_session = backend await b.ensure_session(test_session) # Store some objects for i in range(3): obj = _make_stored_object( session_id=test_session, content=f"Delete test {i}", object_id=uuid.uuid4().hex, ) await b.store(obj) count = await b.delete_session(test_session) assert count == 3 # Verify they're gone results = await b.get_by_session(test_session) assert len(results) == 0 async def test_get_by_source_key(self, backend): b, test_session = backend await b.ensure_session(test_session) obj = _make_stored_object( session_id=test_session, content="File content", source_key="src/main.py", object_id=uuid.uuid4().hex, ) await b.store(obj) result = await b.get_by_source_key(test_session, "src/main.py") assert result is not None assert result.source_key == "src/main.py" # Non-existent key result2 = await b.get_by_source_key(test_session, "nonexistent.py") assert result2 is None async def test_get_by_source_key_returns_most_recent(self, backend): b, test_session = backend await b.ensure_session(test_session) # Store two objects with same source_key obj1 = _make_stored_object( session_id=test_session, content="Old version", source_key="src/config.py", object_id=uuid.uuid4().hex, ) obj1.created_at = "2025-01-01T00:00:00+00:00" await b.store(obj1) obj2 = _make_stored_object( session_id=test_session, content="New version", source_key="src/config.py", object_id=uuid.uuid4().hex, ) obj2.created_at = "2025-06-01T00:00:00+00:00" await b.store(obj2) result = await b.get_by_source_key(test_session, "src/config.py") assert result is not None assert result.content_full == "New version" async def test_session_isolation(self, backend): b, test_session = backend other_session = f"other-{uuid.uuid4().hex[:8]}" await b.ensure_session(test_session) await b.ensure_session(other_session) # Store in test_session obj = _make_stored_object( session_id=test_session, content="Session A content", object_id=uuid.uuid4().hex, ) await b.store(obj) # Store in other_session obj2 = _make_stored_object( session_id=other_session, content="Session B content", object_id=uuid.uuid4().hex, ) await b.store(obj2) # Each session should only see its own objects results_a = await b.get_by_session(test_session) results_b = await b.get_by_session(other_session) assert len(results_a) == 1 assert results_a[0].content_full == "Session A content" assert len(results_b) == 1 assert results_b[0].content_full == "Session B content" # Cleanup other session try: session_uuid = await b._resolve_session_id(other_session) if session_uuid: pool = b._get_pool() async with pool.acquire() as conn: await conn.execute( "DELETE FROM semantic_objects WHERE session_id = $1", session_uuid, ) await conn.execute( "DELETE FROM sessions WHERE id = $1", session_uuid, ) except Exception: pass