diff --git a/src/mnemosyne/gateway.py b/src/mnemosyne/gateway.py index 92e31aa..af7fed0 100644 --- a/src/mnemosyne/gateway.py +++ b/src/mnemosyne/gateway.py @@ -1469,10 +1469,19 @@ def create_app( min_evict_size=min_evict_size, ) _bytes_saved += ingest.bytes_saved - if ingest.new_count > 0 or ingest.compacted_count > 0: + if ingest.physical_tail_deleted > 0: + session._segmented_objects = [ + obj + for obj in session._segmented_objects + if obj.turn_end < ingest.deleted_physical_start + ] + + if ingest.new_count > 0 or ingest.compacted_count > 0 or ingest.physical_tail_deleted > 0: parts = [] if ingest.new_count: parts.append(f"+{ingest.new_count} msgs") + if ingest.physical_tail_deleted: + parts.append(f"undo removed {ingest.physical_tail_deleted} msgs") if ingest.compacted_count: parts.append(f"{ingest.compacted_count} evicted") if ingest.bytes_saved: @@ -1486,17 +1495,21 @@ def create_app( file=sys.stderr, ) - # 1b. Segment new messages into semantic objects and store in ObjectStore - # Phase 4d: admission control gates each object before storage + # 1b. Segment only newly ingested physical messages into semantic objects + # and store them in ObjectStore. Re-segmenting the full history here + # would balloon object counts across turns. try: - with Timer(session.benchmark.latency["segmentation"]): - segmented = session.segmenter.segment_incremental( - ms.messages, - session._segmented_objects, - start_turn=0, - ) - new_count = len(segmented) - len(session._segmented_objects) - if new_count > 0: + if ingest.new_count > 0: + new_physical_messages = ms.messages[ + ingest.new_physical_start : ingest.new_physical_start + ingest.new_count + ] + with Timer(session.benchmark.latency["segmentation"]): + segmented = session.segmenter.segment_incremental( + new_physical_messages, + session._segmented_objects, + start_turn=ingest.new_physical_start, + ) + new_count = len(segmented) - len(session._segmented_objects) new_objects = segmented[len(session._segmented_objects) :] admitted_count = 0 rejected_count = 0 diff --git a/src/mnemosyne/message_store.py b/src/mnemosyne/message_store.py index 43bc5a3..eea0cb7 100644 --- a/src/mnemosyne/message_store.py +++ b/src/mnemosyne/message_store.py @@ -79,9 +79,13 @@ def _fingerprint(msg: dict) -> str: @dataclass class IngestResult: """Result of ingesting a new turn's messages.""" + new_count: int = 0 + new_physical_start: int = 0 mutations_detected: int = 0 deletions_detected: int = 0 + physical_tail_deleted: int = 0 + deleted_physical_start: int = 0 compacted_count: int = 0 bytes_saved: int = 0 @@ -89,8 +93,7 @@ class IngestResult: class MessageStore: """Pichay's compacted conversation history for a session.""" - def __init__(self, session_id: str, page_store: PageStore, - log_path: Path | None = None): + def __init__(self, session_id: str, page_store: PageStore, log_path: Path | None = None): self.session_id = session_id self.page_store = page_store self.log_path = log_path @@ -124,10 +127,16 @@ class MessageStore: return json.dumps(content, default=str)[:limit] return str(content)[:limit] - def _log_violation(self, kind: str, index: int, msg: dict | None, - expected_fp: str, actual_fp: str, - old_msg: dict | None = None, - deleted_msgs: list[dict] | None = None) -> None: + def _log_violation( + self, + kind: str, + index: int, + msg: dict | None, + expected_fp: str, + actual_fp: str, + old_msg: dict | None = None, + deleted_msgs: list[dict] | None = None, + ) -> None: """Log append-only violations to file for later analysis.""" if self.log_path is None: return @@ -188,20 +197,26 @@ class MessageStore: self._turn += 1 result = IngestResult() client_known = len(self._client_fps) + first_mutation_index: int | None = None # ── Detect mutations in known client messages ──────────── check_limit = min(client_known, len(incoming)) for i in range(check_limit): fp = _fingerprint(incoming[i]) if fp != self._client_fps[i]: + if first_mutation_index is None: + first_mutation_index = i result.mutations_detected += 1 self.total_mutations += 1 # Look up physical message via mapping for comparison phys_idx = self._client_to_physical[i] old_msg = self._messages[phys_idx] if phys_idx < len(self._messages) else None self._log_violation( - "mutation", i, incoming[i], - self._client_fps[i], fp, + "mutation", + i, + incoming[i], + self._client_fps[i], + fp, old_msg=old_msg, ) print( @@ -210,8 +225,28 @@ class MessageStore: f"got {fp[:32]}{_RESET}", file=sys.stderr, ) - # Update CLIENT fingerprint only — physical store unchanged - self._client_fps[i] = fp + + # Tail mutation: rebuild physical/client history from first changed index onward. + if first_mutation_index is not None: + if first_mutation_index < len(self._client_to_physical): + physical_start = self._client_to_physical[first_mutation_index] + else: + physical_start = len(self._messages) + physical_start = max(0, min(physical_start, len(self._messages))) + removed = len(self._messages) - physical_start + self._messages = self._messages[:physical_start] + self._fingerprints = self._fingerprints[:physical_start] + self._client_fps = self._client_fps[:first_mutation_index] + self._client_to_physical = self._client_to_physical[:first_mutation_index] + client_known = len(self._client_fps) + result.physical_tail_deleted = max(result.physical_tail_deleted, removed) + result.deleted_physical_start = physical_start + print( + f" {_DIM}[{self.session_id}] CLIENT TAIL MUTATION APPLIED: " + f"rebuilt history from client index {first_mutation_index} " + f"(removed {removed} physical msgs){_RESET}", + file=sys.stderr, + ) # ── Detect client deletions (compaction) ───────────────── if len(incoming) < client_known: @@ -219,6 +254,7 @@ class MessageStore: result.deletions_detected = deleted self.total_deletions += deleted self.total_client_deletions_absorbed += deleted + undo_applied = False # Log what the client is deleting (from physical store via mapping) deleted_physical = [] for ci in range(len(incoming), client_known): @@ -226,24 +262,49 @@ class MessageStore: if pi < len(self._messages): deleted_physical.append(self._messages[pi]) self._log_violation( - "deletion", client_known, None, - f"expected_{client_known}", f"got_{len(incoming)}", + "deletion", + client_known, + None, + f"expected_{client_known}", + f"got_{len(incoming)}", deleted_msgs=deleted_physical, ) - print( - f" {_DIM}[{self.session_id}] CLIENT DELETION ABSORBED: " - f"{deleted} messages dropped by client, " - f"physical store unchanged ({len(self._messages)} msgs){_RESET}", - file=sys.stderr, - ) + deleted_indices = self._client_to_physical[len(incoming) : client_known] + if deleted_indices: + valid_deleted = sorted( + {pi for pi in deleted_indices if 0 <= pi < len(self._messages)} + ) + if valid_deleted: + tail_start = valid_deleted[0] + removed = len(self._messages) - tail_start + self._messages = self._messages[:tail_start] + self._fingerprints = self._fingerprints[:tail_start] + result.physical_tail_deleted = removed + result.deleted_physical_start = tail_start + undo_applied = True + print( + f" {_DIM}[{self.session_id}] CLIENT UNDO APPLIED: " + f"removed {removed} tail messages from physical store{_RESET}", + file=sys.stderr, + ) + + if not undo_applied: + print( + f" {_DIM}[{self.session_id}] CLIENT DELETION ABSORBED: " + f"{deleted} messages dropped by client, " + f"physical store unchanged ({len(self._messages)} msgs){_RESET}", + file=sys.stderr, + ) + # Truncate CLIENT tracking only — physical store stays intact - self._client_fps = self._client_fps[:len(incoming)] - self._client_to_physical = self._client_to_physical[:len(incoming)] + self._client_fps = self._client_fps[: len(incoming)] + self._client_to_physical = self._client_to_physical[: len(incoming)] # ── Extract and append new messages ────────────────────── new_start = min(client_known, len(incoming)) new_messages = incoming[new_start:] result.new_count = len(new_messages) + result.new_physical_start = len(self._messages) if new_messages: # Deep copy new messages so we own them @@ -256,7 +317,7 @@ class MessageStore: _strip_cache_control(msg) # Track in physical store and client mapping - phys_start = len(self._messages) + phys_start = result.new_physical_start for j, msg in enumerate(new_messages): fp = _fingerprint(msg) self._fingerprints.append(fp)