diff --git a/crates/spark-model/src/layers/qwen3_attention/decode/high_speed_swap.rs b/crates/spark-model/src/layers/qwen3_attention/decode/high_speed_swap.rs index 301d2e60..dc6a2247 100644 --- a/crates/spark-model/src/layers/qwen3_attention/decode/high_speed_swap.rs +++ b/crates/spark-model/src/layers/qwen3_attention/decode/high_speed_swap.rs @@ -96,14 +96,33 @@ impl Qwen3AttentionLayer { // Window of HBM-resident blocks: block_table[0..block_table.len()] // covers logical positions [total - block_table.len(), total). let window_start = total.saturating_sub(block_table.len()); - // Always re-offload the ACTIVE (last) block on every call. Decode - // writes one new slot per step into the existing last block — its - // HBM contents change without disk_block_ids.len() growing, so the - // naive `last..total` range would skip it and the streaming kernel - // would read stale (zero-init) bytes for slots 1..15 → degenerate - // attention → "the the the" loop. Setting `start = last.min(total-1)` - // guarantees the active block is re-pushed every step. - let start = last.min(total - 1); + // Always re-offload the BOUNDARY block (one before `last`) on every + // call, in addition to all blocks in `last..total`. Two cases: + // + // (1) Decode case: `last == total` (no new block this step). Slots in + // the active block keep getting written one-per-step without + // `disk_block_ids.len()` growing. `start = total - 1` ensures the + // active block is re-pushed every step. Without this the streaming + // kernel reads stale (zero-init) bytes for the unwritten slots + // → degenerate attention → "the the the" loop. + // + // (2) Chunked-prefill boundary case (issue #31, follow-up to PR #37): + // `last < total` after a new chunk advanced `disk_block_ids`. The + // PREVIOUS chunk's last block (`last - 1`) typically has unwritten + // tail slots — `reshape_and_cache_flash` writes only the chunk's + // own token slots, so when chunk N ended mid-block it left the + // tail slots zero on disk after the post-chunk-N offload. Chunk + // N+1 fills those tail slots in HBM but the offload's `start = + // last` skipped re-pushing the boundary block, so disk's + // boundary-block tail stays permanently zeroed. Decode reads the + // full history from disk via `attend_layer_on_stream`, so the + // zeroed slots silently corrupt attention for the chunk-boundary + // positions (manifests as needle-in-haystack precision loss in + // long-context recall — see issue #31 differential tests). + // + // `last.saturating_sub(1).min(total - 1)` covers both cases at the + // cost of ~one extra D2H per layer per chunk (negligible). + let start = last.saturating_sub(1).min(total - 1); for logical_pos in start..total { if logical_pos < window_start { // Issue #31: the slide-before-alloc loop in diff --git a/crates/spark-model/src/model/trait_impl/verify_b.rs b/crates/spark-model/src/model/trait_impl/verify_b.rs index f3fdd613..5ffd2fdd 100644 --- a/crates/spark-model/src/model/trait_impl/verify_b.rs +++ b/crates/spark-model/src/model/trait_impl/verify_b.rs @@ -214,24 +214,55 @@ impl TransformerModel { let layer_type = self.config.layer_type(layer_idx); if layer_type == LayerType::FullAttention { - // Attention: treat 2 tokens as 2 virtual sequences via - // decode_multi_seq. EmptyLayerState has no actual state. - let mut dummy_states: Vec> = (0..k) - .map(|_| layer.alloc_state(self.gpu.as_ref())) - .collect::>()?; - let mut refs: Vec<&mut (dyn LayerState + 'static)> = - dummy_states.iter_mut().map(|s| s.as_mut()).collect(); - layer.decode_multi_seq( - hidden, - residual, - k, - &mut refs, - &mut kv_cache, - &seq_lens_vec, - &block_tables_vec, - &ctx, - stream, - )?; + if hss_engaged { + // HSS path: `decode_multi_seq` calls the production + // paged-decode kernel which reads K/V from HBM only + // (`meta.block_table`). Under HSS, HBM is capped to + // `cache_blocks_per_seq` blocks, so older context + // lives only on disk and is unreachable from the + // multi-Q kernel — Q/V attends only over the recent + // ~cap×bs tokens, missing the long-context history. + // The single-token `decode` path routes through the + // HSS orchestrator (`attend_layer_on_stream`) which + // reads the full history from disk. Fall back to + // `decode_batched` (N sequential single-token + // decodes via the orchestrator) at the cost of + // ~k× attention launches per verify step. Mirrors + // the SSM branch below which already uses + // decode_batched for the same correctness reason. + layer.decode_batched( + hidden, + residual, + k, + seq.layer_states[layer_idx].as_mut(), + &mut kv_cache, + seq.seq_len, + &mut seq.block_table, + &mut seq.disk_block_ids, + &mut seq.disk_last_offloaded_per_layer, + &ctx, + stream, + )?; + } else { + // Attention: treat 2 tokens as 2 virtual sequences via + // decode_multi_seq. EmptyLayerState has no actual state. + let mut dummy_states: Vec> = (0..k) + .map(|_| layer.alloc_state(self.gpu.as_ref())) + .collect::>()?; + let mut refs: Vec<&mut (dyn LayerState + 'static)> = + dummy_states.iter_mut().map(|s| s.as_mut()).collect(); + layer.decode_multi_seq( + hidden, + residual, + k, + &mut refs, + &mut kv_cache, + &seq_lens_vec, + &block_tables_vec, + &ctx, + stream, + )?; + } } else { // SSM: process K=2 tokens for one sequence via decode_batched. layer.decode_batched( diff --git a/crates/spark-model/src/model/trait_impl/verify_c.rs b/crates/spark-model/src/model/trait_impl/verify_c.rs index 61bc66d8..6268a94c 100644 --- a/crates/spark-model/src/model/trait_impl/verify_c.rs +++ b/crates/spark-model/src/model/trait_impl/verify_c.rs @@ -199,22 +199,43 @@ impl TransformerModel { let layer_type = self.config.layer_type(layer_idx); if layer_type == LayerType::FullAttention { - let mut dummy_states: Vec> = (0..k) - .map(|_| layer.alloc_state(self.gpu.as_ref())) - .collect::>()?; - let mut refs: Vec<&mut (dyn LayerState + 'static)> = - dummy_states.iter_mut().map(|s| s.as_mut()).collect(); - layer.decode_multi_seq( - hidden, - residual, - k, - &mut refs, - &mut kv_cache, - &seq_lens_vec, - &block_tables_vec, - &ctx, - stream, - )?; + if hss_engaged { + // HSS path: decode_multi_seq's paged-decode kernel + // reads K/V from HBM only, missing the long-context + // history on disk. Fall back to decode_batched + // (sequential single-token decodes via the HSS + // orchestrator). See verify_b.rs for full rationale. + layer.decode_batched( + hidden, + residual, + k, + seq.layer_states[layer_idx].as_mut(), + &mut kv_cache, + seq.seq_len, + &mut seq.block_table, + &mut seq.disk_block_ids, + &mut seq.disk_last_offloaded_per_layer, + &ctx, + stream, + )?; + } else { + let mut dummy_states: Vec> = (0..k) + .map(|_| layer.alloc_state(self.gpu.as_ref())) + .collect::>()?; + let mut refs: Vec<&mut (dyn LayerState + 'static)> = + dummy_states.iter_mut().map(|s| s.as_mut()).collect(); + layer.decode_multi_seq( + hidden, + residual, + k, + &mut refs, + &mut kv_cache, + &seq_lens_vec, + &block_tables_vec, + &ctx, + stream, + )?; + } } else { layer.decode_batched( hidden, diff --git a/crates/spark-model/src/model/trait_impl/verify_c2.rs b/crates/spark-model/src/model/trait_impl/verify_c2.rs index c6a4a429..4c5dcfd3 100644 --- a/crates/spark-model/src/model/trait_impl/verify_c2.rs +++ b/crates/spark-model/src/model/trait_impl/verify_c2.rs @@ -189,22 +189,43 @@ impl TransformerModel { let layer_type = self.config.layer_type(layer_idx); if layer_type == LayerType::FullAttention { - let mut dummy_states: Vec> = (0..k) - .map(|_| layer.alloc_state(self.gpu.as_ref())) - .collect::>()?; - let mut refs: Vec<&mut (dyn LayerState + 'static)> = - dummy_states.iter_mut().map(|s| s.as_mut()).collect(); - layer.decode_multi_seq( - hidden, - residual, - k, - &mut refs, - &mut kv_cache, - &seq_lens_vec, - &block_tables_vec, - &ctx, - stream, - )?; + if hss_engaged { + // HSS path: decode_multi_seq's paged-decode kernel + // reads K/V from HBM only, missing the long-context + // history on disk. Fall back to decode_batched + // (sequential single-token decodes via the HSS + // orchestrator). See verify_b.rs for full rationale. + layer.decode_batched( + hidden, + residual, + k, + seq.layer_states[layer_idx].as_mut(), + &mut kv_cache, + seq.seq_len, + &mut seq.block_table, + &mut seq.disk_block_ids, + &mut seq.disk_last_offloaded_per_layer, + &ctx, + stream, + )?; + } else { + let mut dummy_states: Vec> = (0..k) + .map(|_| layer.alloc_state(self.gpu.as_ref())) + .collect::>()?; + let mut refs: Vec<&mut (dyn LayerState + 'static)> = + dummy_states.iter_mut().map(|s| s.as_mut()).collect(); + layer.decode_multi_seq( + hidden, + residual, + k, + &mut refs, + &mut kv_cache, + &seq_lens_vec, + &block_tables_vec, + &ctx, + stream, + )?; + } } else { layer.decode_batched( hidden, diff --git a/crates/spark-model/src/model/trait_impl/verify_d.rs b/crates/spark-model/src/model/trait_impl/verify_d.rs index b76358ca..67c22bcc 100644 --- a/crates/spark-model/src/model/trait_impl/verify_d.rs +++ b/crates/spark-model/src/model/trait_impl/verify_d.rs @@ -195,22 +195,43 @@ impl TransformerModel { let layer_type = self.config.layer_type(layer_idx); if layer_type == LayerType::FullAttention { - let mut dummy_states: Vec> = (0..k) - .map(|_| layer.alloc_state(self.gpu.as_ref())) - .collect::>()?; - let mut refs: Vec<&mut (dyn LayerState + 'static)> = - dummy_states.iter_mut().map(|s| s.as_mut()).collect(); - layer.decode_multi_seq( - hidden, - residual, - k, - &mut refs, - &mut kv_cache, - &seq_lens_vec, - &block_tables_vec, - &ctx, - stream, - )?; + if hss_engaged { + // HSS path: decode_multi_seq's paged-decode kernel + // reads K/V from HBM only, missing the long-context + // history on disk. Fall back to decode_batched + // (sequential single-token decodes via the HSS + // orchestrator). See verify_b.rs for full rationale. + layer.decode_batched( + hidden, + residual, + k, + seq.layer_states[layer_idx].as_mut(), + &mut kv_cache, + seq.seq_len, + &mut seq.block_table, + &mut seq.disk_block_ids, + &mut seq.disk_last_offloaded_per_layer, + &ctx, + stream, + )?; + } else { + let mut dummy_states: Vec> = (0..k) + .map(|_| layer.alloc_state(self.gpu.as_ref())) + .collect::>()?; + let mut refs: Vec<&mut (dyn LayerState + 'static)> = + dummy_states.iter_mut().map(|s| s.as_mut()).collect(); + layer.decode_multi_seq( + hidden, + residual, + k, + &mut refs, + &mut kv_cache, + &seq_lens_vec, + &block_tables_vec, + &ctx, + stream, + )?; + } } else { layer.decode_batched( hidden,