Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/dibble/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from dibble.services.outcome_store import SQLiteOutcomeStore
from dibble.services.strand_store import SQLiteStrandStore
from dibble.services.generation_engine import GenerationEngine
from dibble.services.surplus_practice_cache import SurplusPracticeCache
from dibble.services.generation_mode_calibration import GenerationModeCalibrator
from dibble.services.generated_content_store import SQLiteGeneratedContentStore
from dibble.services.knowledge_component_store import SQLiteKnowledgeComponentStore
Expand Down Expand Up @@ -258,12 +259,17 @@ def build_application_services(
strategy_signal_service=learner_strategy_signal_service,
within_session_adaptation_service=within_session_adaptation_service,
)
surplus_practice_cache = SurplusPracticeCache(
generated_content_store=generated_content_store,
cache_ttl_seconds=settings.generation_cache_ttl_seconds,
)
generation_engine = GenerationEngine(
retriever=plugins.retriever,
router=router_plugin,
provider=plugins.provider,
validator=plugins.validator,
generated_content_store=generated_content_store,
surplus_practice_cache=surplus_practice_cache,
cache_ttl_seconds=settings.generation_cache_ttl_seconds,
)
misconception_remediation_outcome_signal_service = (
Expand Down
74 changes: 74 additions & 0 deletions src/dibble/services/generation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dibble.services.generation_modes import build_generation_mode_plan
from dibble.services.protocols import GeneratedContentStore
from dibble.services.runtime_telemetry import log_runtime_event
from dibble.services.surplus_practice_cache import SurplusPracticeCache

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +45,7 @@ def __init__(
validator: ValidatorPlugin,
moderation_service: ContentModerationService | None = None,
generated_content_store: GeneratedContentStore | None = None,
surplus_practice_cache: SurplusPracticeCache | None = None,
cache_ttl_seconds: int = 3600,
time_provider=monotonic,
) -> None:
Expand All @@ -53,6 +55,7 @@ def __init__(
self.validator = validator
self.moderation_service = moderation_service or ContentModerationService()
self.generated_content_store = generated_content_store
self.surplus_practice_cache = surplus_practice_cache
self.cache_ttl_seconds = max(0, cache_ttl_seconds)
self.time_provider = time_provider

Expand Down Expand Up @@ -89,6 +92,10 @@ def generate(
route=route.model_dump(mode="json"),
grounding=[item.model_dump(mode="json") for item in grounding],
)
surplus = self._pop_surplus(profile, request)
if surplus is not None:
return surplus.response

cache_key = self._cache_key(profile, request, route, grounding)
cached = self._get_cached_content(cache_key=cache_key)
if cached is not None:
Expand All @@ -104,6 +111,7 @@ def generate(
return cached.response

started_at = self.time_provider()
surplus_blocks: list[GeneratedBlock] = []
request_moderation = self.moderation_service.moderate_request(request)
if request_moderation.status == "flagged":
blocks = self._moderation_fallback_blocks(
Expand All @@ -125,6 +133,7 @@ def generate(
else:
blocks = self.provider.generate(profile, request, route, grounding)
blocks = normalize_generated_blocks(blocks)
blocks, surplus_blocks = self._split_surplus(blocks)
moderation = self.moderation_service.moderate_blocks(blocks)
if moderation.status == "flagged":
original_blocks = len(blocks)
Expand Down Expand Up @@ -158,6 +167,7 @@ def generate(
),
)
self._store_generated_content(cache_key=cache_key, content=content)
self._cache_surplus(surplus_blocks, blocks, content, profile, request)
log_runtime_event(
logger,
logging.DEBUG,
Expand All @@ -178,6 +188,31 @@ def stream_generate(
) -> Iterator[GenerationStreamEvent]:
grounding = self._safe_retrieve(profile, request)
route = self.router.route(profile, request)

surplus = self._pop_surplus(profile, request)
if surplus is not None:
yield GenerationStreamEvent(
event="start",
student_id=profile.student_id,
route=surplus.response.route,
grounding=surplus.response.grounding,
)
for chunk in self._stream_cached_blocks(surplus.response.blocks):
yield GenerationStreamEvent(
event="delta",
student_id=profile.student_id,
chunk=chunk,
)
yield GenerationStreamEvent(
event="complete",
student_id=profile.student_id,
route=surplus.response.route,
grounding=surplus.response.grounding,
validation_issues=surplus.response.validation_issues,
response=surplus.response,
)
return

cache_key = self._cache_key(profile, request, route, grounding)
cached = self._get_cached_content(cache_key=cache_key)
if cached is not None:
Expand Down Expand Up @@ -213,6 +248,7 @@ def stream_generate(
return

started_at = self.time_provider()
surplus_blocks: list[GeneratedBlock] = []
request_moderation = self.moderation_service.moderate_request(request)
if request_moderation.status == "flagged":
blocks = self._moderation_fallback_blocks(
Expand Down Expand Up @@ -264,6 +300,7 @@ def stream_generate(
blocks = normalize_generated_blocks(
[block_buffers[index] for index in sorted(block_buffers)]
)
blocks, surplus_blocks = self._split_surplus(blocks)
moderation = self.moderation_service.moderate_blocks(blocks)
if moderation.status == "flagged":
original_blocks = len(blocks)
Expand Down Expand Up @@ -308,6 +345,7 @@ def stream_generate(
),
)
self._store_generated_content(cache_key=cache_key, content=content)
self._cache_surplus(surplus_blocks, blocks, content, profile, request)
log_runtime_event(
logger,
logging.DEBUG,
Expand All @@ -330,6 +368,42 @@ def stream_generate(
response=content.response,
)

def _pop_surplus(
self, profile: LearnerProfile, request: GenerationRequest
) -> GeneratedContent | None:
if self.surplus_practice_cache is None:
return None
return self.surplus_practice_cache.pop_surplus(
student_id=profile.student_id,
learning_session_id=request.learning_session_id,
)

def _split_surplus(
self, blocks: list[GeneratedBlock]
) -> tuple[list[GeneratedBlock], list[GeneratedBlock]]:
if self.surplus_practice_cache is None:
return blocks, []
return SurplusPracticeCache.split_practice_blocks(blocks)

def _cache_surplus(
self,
surplus_blocks: list[GeneratedBlock],
delivery_blocks: list[GeneratedBlock],
content: GeneratedContent,
profile: LearnerProfile,
request: GenerationRequest,
) -> None:
if not surplus_blocks or self.surplus_practice_cache is None:
return
non_practice = [b for b in delivery_blocks if b.kind != "practice_problem"]
self.surplus_practice_cache.cache_surplus(
surplus_blocks=surplus_blocks,
non_practice_blocks=non_practice,
parent_content=content,
profile=profile,
request=request,
)

def _build_response(
self,
profile: LearnerProfile,
Expand Down
Loading
Loading