diff --git a/src/dibble/bootstrap.py b/src/dibble/bootstrap.py index d09ed61..05ae688 100644 --- a/src/dibble/bootstrap.py +++ b/src/dibble/bootstrap.py @@ -29,6 +29,7 @@ from dibble.services.outcome_store import SQLiteOutcomeStore from dibble.services.strand_store import SQLiteStrandStore from dibble.services.generation_engine import GenerationEngine +from dibble.services.surplus_practice_cache import SurplusPracticeCache from dibble.services.generation_mode_calibration import GenerationModeCalibrator from dibble.services.generated_content_store import SQLiteGeneratedContentStore from dibble.services.knowledge_component_store import SQLiteKnowledgeComponentStore @@ -258,12 +259,17 @@ def build_application_services( strategy_signal_service=learner_strategy_signal_service, within_session_adaptation_service=within_session_adaptation_service, ) + surplus_practice_cache = SurplusPracticeCache( + generated_content_store=generated_content_store, + cache_ttl_seconds=settings.generation_cache_ttl_seconds, + ) generation_engine = GenerationEngine( retriever=plugins.retriever, router=router_plugin, provider=plugins.provider, validator=plugins.validator, generated_content_store=generated_content_store, + surplus_practice_cache=surplus_practice_cache, cache_ttl_seconds=settings.generation_cache_ttl_seconds, ) misconception_remediation_outcome_signal_service = ( diff --git a/src/dibble/services/generation_engine.py b/src/dibble/services/generation_engine.py index 00bda1b..f34ad82 100644 --- a/src/dibble/services/generation_engine.py +++ b/src/dibble/services/generation_engine.py @@ -31,6 +31,7 @@ from dibble.services.generation_modes import build_generation_mode_plan from dibble.services.protocols import GeneratedContentStore from dibble.services.runtime_telemetry import log_runtime_event +from dibble.services.surplus_practice_cache import SurplusPracticeCache logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ def __init__( validator: ValidatorPlugin, moderation_service: ContentModerationService | None = None, generated_content_store: GeneratedContentStore | None = None, + surplus_practice_cache: SurplusPracticeCache | None = None, cache_ttl_seconds: int = 3600, time_provider=monotonic, ) -> None: @@ -53,6 +55,7 @@ def __init__( self.validator = validator self.moderation_service = moderation_service or ContentModerationService() self.generated_content_store = generated_content_store + self.surplus_practice_cache = surplus_practice_cache self.cache_ttl_seconds = max(0, cache_ttl_seconds) self.time_provider = time_provider @@ -89,6 +92,10 @@ def generate( route=route.model_dump(mode="json"), grounding=[item.model_dump(mode="json") for item in grounding], ) + surplus = self._pop_surplus(profile, request) + if surplus is not None: + return surplus.response + cache_key = self._cache_key(profile, request, route, grounding) cached = self._get_cached_content(cache_key=cache_key) if cached is not None: @@ -104,6 +111,7 @@ def generate( return cached.response started_at = self.time_provider() + surplus_blocks: list[GeneratedBlock] = [] request_moderation = self.moderation_service.moderate_request(request) if request_moderation.status == "flagged": blocks = self._moderation_fallback_blocks( @@ -125,6 +133,7 @@ def generate( else: blocks = self.provider.generate(profile, request, route, grounding) blocks = normalize_generated_blocks(blocks) + blocks, surplus_blocks = self._split_surplus(blocks) moderation = self.moderation_service.moderate_blocks(blocks) if moderation.status == "flagged": original_blocks = len(blocks) @@ -158,6 +167,7 @@ def generate( ), ) self._store_generated_content(cache_key=cache_key, content=content) + self._cache_surplus(surplus_blocks, blocks, content, profile, request) log_runtime_event( logger, logging.DEBUG, @@ -178,6 +188,31 @@ def stream_generate( ) -> Iterator[GenerationStreamEvent]: grounding = self._safe_retrieve(profile, request) route = self.router.route(profile, request) + + surplus = self._pop_surplus(profile, request) + if surplus is not None: + yield GenerationStreamEvent( + event="start", + student_id=profile.student_id, + route=surplus.response.route, + grounding=surplus.response.grounding, + ) + for chunk in self._stream_cached_blocks(surplus.response.blocks): + yield GenerationStreamEvent( + event="delta", + student_id=profile.student_id, + chunk=chunk, + ) + yield GenerationStreamEvent( + event="complete", + student_id=profile.student_id, + route=surplus.response.route, + grounding=surplus.response.grounding, + validation_issues=surplus.response.validation_issues, + response=surplus.response, + ) + return + cache_key = self._cache_key(profile, request, route, grounding) cached = self._get_cached_content(cache_key=cache_key) if cached is not None: @@ -213,6 +248,7 @@ def stream_generate( return started_at = self.time_provider() + surplus_blocks: list[GeneratedBlock] = [] request_moderation = self.moderation_service.moderate_request(request) if request_moderation.status == "flagged": blocks = self._moderation_fallback_blocks( @@ -264,6 +300,7 @@ def stream_generate( blocks = normalize_generated_blocks( [block_buffers[index] for index in sorted(block_buffers)] ) + blocks, surplus_blocks = self._split_surplus(blocks) moderation = self.moderation_service.moderate_blocks(blocks) if moderation.status == "flagged": original_blocks = len(blocks) @@ -308,6 +345,7 @@ def stream_generate( ), ) self._store_generated_content(cache_key=cache_key, content=content) + self._cache_surplus(surplus_blocks, blocks, content, profile, request) log_runtime_event( logger, logging.DEBUG, @@ -330,6 +368,42 @@ def stream_generate( response=content.response, ) + def _pop_surplus( + self, profile: LearnerProfile, request: GenerationRequest + ) -> GeneratedContent | None: + if self.surplus_practice_cache is None: + return None + return self.surplus_practice_cache.pop_surplus( + student_id=profile.student_id, + learning_session_id=request.learning_session_id, + ) + + def _split_surplus( + self, blocks: list[GeneratedBlock] + ) -> tuple[list[GeneratedBlock], list[GeneratedBlock]]: + if self.surplus_practice_cache is None: + return blocks, [] + return SurplusPracticeCache.split_practice_blocks(blocks) + + def _cache_surplus( + self, + surplus_blocks: list[GeneratedBlock], + delivery_blocks: list[GeneratedBlock], + content: GeneratedContent, + profile: LearnerProfile, + request: GenerationRequest, + ) -> None: + if not surplus_blocks or self.surplus_practice_cache is None: + return + non_practice = [b for b in delivery_blocks if b.kind != "practice_problem"] + self.surplus_practice_cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=non_practice, + parent_content=content, + profile=profile, + request=request, + ) + def _build_response( self, profile: LearnerProfile, diff --git a/src/dibble/services/surplus_practice_cache.py b/src/dibble/services/surplus_practice_cache.py new file mode 100644 index 0000000..abc90d8 --- /dev/null +++ b/src/dibble/services/surplus_practice_cache.py @@ -0,0 +1,226 @@ +"""Split multi-block practice responses and cache surplus questions. + +When the LLM generates 2-3 practice_problem blocks in a single response, +only the first is delivered to the learner. The remaining blocks are stored +as individual ``GeneratedContent`` entries so they can be served instantly +on the next continue request, buying time before the next LLM generation. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +from dibble.models.generation import ( + GeneratedBlock, + GeneratedContent, + GenerationRequest, +) +from dibble.models.profile import LearnerProfile +from dibble.services.protocols import GeneratedContentStore + +logger = logging.getLogger(__name__) + + +class SurplusPracticeCache: + """Manages splitting and caching of surplus practice blocks.""" + + def __init__( + self, + generated_content_store: GeneratedContentStore, + cache_ttl_seconds: int = 3600, + ) -> None: + self.store = generated_content_store + self.cache_ttl_seconds = max(0, cache_ttl_seconds) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @staticmethod + def split_practice_blocks( + blocks: list[GeneratedBlock], + ) -> tuple[list[GeneratedBlock], list[GeneratedBlock]]: + """Separate *blocks* into delivery blocks and surplus practice blocks. + + Returns ``(delivery, surplus)`` where *delivery* contains all + non-practice blocks plus the **first** ``practice_problem`` block, + and *surplus* contains any remaining ``practice_problem`` blocks. + """ + non_practice: list[GeneratedBlock] = [] + practice: list[GeneratedBlock] = [] + for block in blocks: + if block.kind == "practice_problem": + practice.append(block) + else: + non_practice.append(block) + + if len(practice) <= 1: + return blocks, [] + + delivery = non_practice + practice[:1] + surplus = practice[1:] + return delivery, surplus + + def cache_surplus( + self, + *, + surplus_blocks: list[GeneratedBlock], + non_practice_blocks: list[GeneratedBlock], + parent_content: GeneratedContent, + profile: LearnerProfile, + request: GenerationRequest, + ) -> int: + """Store each surplus practice block as a separate cache entry. + + Each entry wraps the surplus block alongside the original + non-practice blocks (e.g. the summary) so the learner still sees + context when the surplus is served. + + Returns the number of surplus entries stored. + """ + if not surplus_blocks or self.cache_ttl_seconds <= 0: + return 0 + + stored = 0 + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=self.cache_ttl_seconds) + + for index, practice_block in enumerate(surplus_blocks): + cache_key = self._surplus_cache_key( + student_id=profile.student_id, + learning_session_id=request.learning_session_id, + sequence_index=index, + ) + blocks = list(non_practice_blocks) + [practice_block] + response = parent_content.response.model_copy( + update={ + "blocks": blocks, + "generation_id": str(uuid4()), + } + ) + request_context = dict(parent_content.request_context) + request_context["is_surplus_practice"] = True + request_context["is_predictive_warm"] = True + request_context["source_generation_id"] = parent_content.generation_id + request_context["surplus_sequence_index"] = index + + content = GeneratedContent( + generation_id=response.generation_id or str(uuid4()), + student_id=profile.student_id, + content_type=parent_content.content_type, + request_context=request_context, + response=response, + quality=parent_content.quality.model_copy( + update={"cache_hit": False} + ), + created_at=now, + expires_at=expires_at, + ) + self.store.upsert(cache_key=cache_key, content=content) + stored += 1 + + logger.debug( + "Cached %d surplus practice blocks for student %s (session %s)", + stored, + profile.student_id, + request.learning_session_id, + ) + return stored + + def pop_surplus( + self, + *, + student_id: UUID, + learning_session_id: str | None, + ) -> GeneratedContent | None: + """Retrieve and expire the next surplus practice block, if any.""" + cache_key = self._surplus_cache_key( + student_id=student_id, + learning_session_id=learning_session_id, + sequence_index=0, + ) + content = self.store.get_fresh(cache_key=cache_key) + if content is None: + return None + + # Expire the entry so it is not served again. + expired = content.model_copy( + update={"expires_at": datetime.now(timezone.utc)} + ) + self.store.refresh(content=expired) + + # Promote sequence_index=1 → 0 so the next pop finds it. + self._promote_surplus( + student_id=student_id, + learning_session_id=learning_session_id, + ) + + logger.debug( + "Popped surplus practice block %s for student %s", + content.generation_id, + student_id, + ) + return content + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _promote_surplus( + self, + *, + student_id: UUID, + learning_session_id: str | None, + ) -> None: + """Shift surplus entries down by one so index 1 becomes index 0.""" + index = 1 + while True: + old_key = self._surplus_cache_key( + student_id=student_id, + learning_session_id=learning_session_id, + sequence_index=index, + ) + entry = self.store.get_fresh(cache_key=old_key) + if entry is None: + break + # Expire old slot. + self.store.refresh( + content=entry.model_copy( + update={"expires_at": datetime.now(timezone.utc)} + ) + ) + # Re-store at index - 1 with a fresh generation_id to avoid + # unique constraint conflicts. + new_key = self._surplus_cache_key( + student_id=student_id, + learning_session_id=learning_session_id, + sequence_index=index - 1, + ) + new_gen_id = str(uuid4()) + new_context = dict(entry.request_context) + new_context["surplus_sequence_index"] = index - 1 + new_response = entry.response.model_copy( + update={"generation_id": new_gen_id} + ) + promoted = entry.model_copy( + update={ + "generation_id": new_gen_id, + "request_context": new_context, + "response": new_response, + "expires_at": entry.expires_at, + } + ) + self.store.upsert(cache_key=new_key, content=promoted) + index += 1 + + @staticmethod + def _surplus_cache_key( + *, + student_id: UUID, + learning_session_id: str | None, + sequence_index: int, + ) -> str: + session = learning_session_id or "none" + return f"surplus:{student_id}:{session}:{sequence_index}" diff --git a/tests/test_surplus_practice_cache.py b/tests/test_surplus_practice_cache.py new file mode 100644 index 0000000..61d44a1 --- /dev/null +++ b/tests/test_surplus_practice_cache.py @@ -0,0 +1,317 @@ +"""Tests for the surplus practice block cache.""" + +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +import pytest + +from dibble.models.generation import ( + AdaptiveRouteDecision, + DeliveryMode, + GeneratedBlock, + GeneratedContent, + GenerationMetadata, + GenerationRequest, + GenerationResponse, + InterventionType, + MultipleChoiceInteraction, + MultipleChoiceOption, +) +from dibble.models.profile import LearnerProfile +from dibble.services.generated_content_store import SQLiteGeneratedContentStore +from dibble.services.surplus_practice_cache import SurplusPracticeCache +from dibble.services.sqlite_connection import create_connection +from dibble.storage import ensure_database + + +STUDENT_ID = UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + + +@pytest.fixture() +def store(tmp_path): + db_path = str(tmp_path / "surplus.db") + ensure_database(db_path) + conn = create_connection(db_path) + return SQLiteGeneratedContentStore(conn) + + +@pytest.fixture() +def cache(store): + return SurplusPracticeCache(store, cache_ttl_seconds=3600) + + +def _practice_block(title: str = "Q1") -> GeneratedBlock: + return GeneratedBlock( + kind="practice_problem", + title=title, + body="Solve this.", + interaction=MultipleChoiceInteraction( + prompt="What is 2+2?", + options=[ + MultipleChoiceOption(option_id="A", label="A", body="3"), + MultipleChoiceOption(option_id="B", label="B", body="4"), + ], + correct_option_id="B", + ), + ) + + +def _summary_block() -> GeneratedBlock: + return GeneratedBlock(kind="summary", title="Summary", body="Context here.") + + +def _parent_content( + blocks: list[GeneratedBlock], + student_id: UUID = STUDENT_ID, +) -> GeneratedContent: + route = AdaptiveRouteDecision( + intervention_type=InterventionType.reteach, + delivery_mode=DeliveryMode.generated, + scaffolding_level="medium", + reasons=["test"], + ) + metadata = GenerationMetadata( + quality_score=0.8, validation_passed=True, grounding_count=0 + ) + gen_id = str(uuid4()) + response = GenerationResponse( + student_id=student_id, + route=route, + blocks=blocks, + curriculum_context=["fractions"], + grounding=[], + safety_notes=[], + generation_id=gen_id, + generation_metadata=metadata, + ) + now = datetime.now(timezone.utc) + return GeneratedContent( + generation_id=gen_id, + student_id=student_id, + content_type="practice_problem", + request_context={ + "target_kc_ids": ["KC-1"], + "target_lo_ids": ["LO-1"], + "learning_session_id": "session-1", + }, + response=response, + quality=metadata, + created_at=now, + expires_at=now + timedelta(hours=1), + ) + + +def _request(student_id: UUID = STUDENT_ID) -> GenerationRequest: + return GenerationRequest( + student_id=student_id, + target_kc_ids=["KC-1"], + learning_session_id="session-1", + ) + + +def _profile(student_id: UUID = STUDENT_ID) -> LearnerProfile: + return LearnerProfile(student_id=student_id, grade_level="7") + + +# ------------------------------------------------------------------ +# split_practice_blocks +# ------------------------------------------------------------------ + + +class TestSplitPracticeBlocks: + def test_single_practice_unchanged(self): + blocks = [_summary_block(), _practice_block("Q1")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert delivery == blocks + assert surplus == [] + + def test_two_practice_splits(self): + blocks = [_summary_block(), _practice_block("Q1"), _practice_block("Q2")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert len(delivery) == 2 # summary + Q1 + assert delivery[0].kind == "summary" + assert delivery[1].title == "Q1" + assert len(surplus) == 1 + assert surplus[0].title == "Q2" + + def test_three_practice_splits(self): + blocks = [ + _summary_block(), + _practice_block("Q1"), + _practice_block("Q2"), + _practice_block("Q3"), + ] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert len(delivery) == 2 + assert len(surplus) == 2 + assert surplus[0].title == "Q2" + assert surplus[1].title == "Q3" + + def test_no_practice_unchanged(self): + blocks = [_summary_block(), GeneratedBlock(kind="instruction", title="I", body="text")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert delivery == blocks + assert surplus == [] + + def test_only_practice_no_summary(self): + blocks = [_practice_block("Q1"), _practice_block("Q2")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert len(delivery) == 1 + assert delivery[0].title == "Q1" + assert len(surplus) == 1 + + +# ------------------------------------------------------------------ +# cache_surplus + pop_surplus +# ------------------------------------------------------------------ + + +class TestCacheAndPop: + def test_cache_and_pop_returns_surplus(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + req = _request() + profile = _profile() + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=profile, + request=req, + ) + + popped = cache.pop_surplus( + student_id=STUDENT_ID, + learning_session_id="session-1", + ) + assert popped is not None + assert popped.generation_id != parent.generation_id + practice_blocks = [b for b in popped.response.blocks if b.kind == "practice_problem"] + assert len(practice_blocks) == 1 + assert practice_blocks[0].title == "Q2" + + def test_pop_returns_none_when_empty(self, cache): + popped = cache.pop_surplus( + student_id=STUDENT_ID, + learning_session_id="session-1", + ) + assert popped is None + + def test_pop_consumes_entry(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + first = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert first is not None + second = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert second is None + + def test_multiple_surplus_served_in_order(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2"), _practice_block("Q3")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + first = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert first is not None + q2_blocks = [b for b in first.response.blocks if b.kind == "practice_problem"] + assert q2_blocks[0].title == "Q2" + + second = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert second is not None + q3_blocks = [b for b in second.response.blocks if b.kind == "practice_problem"] + assert q3_blocks[0].title == "Q3" + + third = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert third is None + + +# ------------------------------------------------------------------ +# Invalidation piggyback +# ------------------------------------------------------------------ + + +class TestInvalidation: + def test_surplus_has_predictive_warm_flag(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + popped = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert popped is not None + assert popped.request_context.get("is_predictive_warm") is True + assert popped.request_context.get("is_surplus_practice") is True + + +# ------------------------------------------------------------------ +# Session isolation +# ------------------------------------------------------------------ + + +class TestSessionIsolation: + def test_different_session_not_served(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + popped = cache.pop_surplus( + student_id=STUDENT_ID, + learning_session_id="different-session", + ) + assert popped is None + + +# ------------------------------------------------------------------ +# TTL +# ------------------------------------------------------------------ + + +class TestTTL: + def test_zero_ttl_does_not_cache(self, store): + no_cache = SurplusPracticeCache(store, cache_ttl_seconds=0) + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + stored = no_cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + assert stored == 0