Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,33 @@ impl Qwen3AttentionLayer {
// Window of HBM-resident blocks: block_table[0..block_table.len()]
// covers logical positions [total - block_table.len(), total).
let window_start = total.saturating_sub(block_table.len());
// Always re-offload the ACTIVE (last) block on every call. Decode
// writes one new slot per step into the existing last block — its
// HBM contents change without disk_block_ids.len() growing, so the
// naive `last..total` range would skip it and the streaming kernel
// would read stale (zero-init) bytes for slots 1..15 → degenerate
// attention → "the the the" loop. Setting `start = last.min(total-1)`
// guarantees the active block is re-pushed every step.
let start = last.min(total - 1);
// Always re-offload the BOUNDARY block (one before `last`) on every
// call, in addition to all blocks in `last..total`. Two cases:
//
// (1) Decode case: `last == total` (no new block this step). Slots in
// the active block keep getting written one-per-step without
// `disk_block_ids.len()` growing. `start = total - 1` ensures the
// active block is re-pushed every step. Without this the streaming
// kernel reads stale (zero-init) bytes for the unwritten slots
// → degenerate attention → "the the the" loop.
//
// (2) Chunked-prefill boundary case (issue #31, follow-up to PR #37):
// `last < total` after a new chunk advanced `disk_block_ids`. The
// PREVIOUS chunk's last block (`last - 1`) typically has unwritten
// tail slots — `reshape_and_cache_flash` writes only the chunk's
// own token slots, so when chunk N ended mid-block it left the
// tail slots zero on disk after the post-chunk-N offload. Chunk
// N+1 fills those tail slots in HBM but the offload's `start =
// last` skipped re-pushing the boundary block, so disk's
// boundary-block tail stays permanently zeroed. Decode reads the
// full history from disk via `attend_layer_on_stream`, so the
// zeroed slots silently corrupt attention for the chunk-boundary
// positions (manifests as needle-in-haystack precision loss in
// long-context recall — see issue #31 differential tests).
//
// `last.saturating_sub(1).min(total - 1)` covers both cases at the
// cost of ~one extra D2H per layer per chunk (negligible).
let start = last.saturating_sub(1).min(total - 1);
for logical_pos in start..total {
if logical_pos < window_start {
// Issue #31: the slide-before-alloc loop in
Expand Down
67 changes: 49 additions & 18 deletions crates/spark-model/src/model/trait_impl/verify_b.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,55 @@ impl TransformerModel {
let layer_type = self.config.layer_type(layer_idx);

if layer_type == LayerType::FullAttention {
// Attention: treat 2 tokens as 2 virtual sequences via
// decode_multi_seq. EmptyLayerState has no actual state.
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
if hss_engaged {
// HSS path: `decode_multi_seq` calls the production
// paged-decode kernel which reads K/V from HBM only
// (`meta.block_table`). Under HSS, HBM is capped to
// `cache_blocks_per_seq` blocks, so older context
// lives only on disk and is unreachable from the
// multi-Q kernel — Q/V attends only over the recent
// ~cap×bs tokens, missing the long-context history.
// The single-token `decode` path routes through the
// HSS orchestrator (`attend_layer_on_stream`) which
// reads the full history from disk. Fall back to
// `decode_batched` (N sequential single-token
// decodes via the orchestrator) at the cost of
// ~k× attention launches per verify step. Mirrors
// the SSM branch below which already uses
// decode_batched for the same correctness reason.
layer.decode_batched(
hidden,
residual,
k,
seq.layer_states[layer_idx].as_mut(),
&mut kv_cache,
seq.seq_len,
&mut seq.block_table,
&mut seq.disk_block_ids,
&mut seq.disk_last_offloaded_per_layer,
&ctx,
stream,
)?;
} else {
// Attention: treat 2 tokens as 2 virtual sequences via
// decode_multi_seq. EmptyLayerState has no actual state.
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
}
} else {
// SSM: process K=2 tokens for one sequence via decode_batched.
layer.decode_batched(
Expand Down
53 changes: 37 additions & 16 deletions crates/spark-model/src/model/trait_impl/verify_c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,43 @@ impl TransformerModel {
let layer_type = self.config.layer_type(layer_idx);

if layer_type == LayerType::FullAttention {
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
if hss_engaged {
// HSS path: decode_multi_seq's paged-decode kernel
// reads K/V from HBM only, missing the long-context
// history on disk. Fall back to decode_batched
// (sequential single-token decodes via the HSS
// orchestrator). See verify_b.rs for full rationale.
layer.decode_batched(
hidden,
residual,
k,
seq.layer_states[layer_idx].as_mut(),
&mut kv_cache,
seq.seq_len,
&mut seq.block_table,
&mut seq.disk_block_ids,
&mut seq.disk_last_offloaded_per_layer,
&ctx,
stream,
)?;
} else {
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
}
} else {
layer.decode_batched(
hidden,
Expand Down
53 changes: 37 additions & 16 deletions crates/spark-model/src/model/trait_impl/verify_c2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,43 @@ impl TransformerModel {
let layer_type = self.config.layer_type(layer_idx);

if layer_type == LayerType::FullAttention {
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
if hss_engaged {
// HSS path: decode_multi_seq's paged-decode kernel
// reads K/V from HBM only, missing the long-context
// history on disk. Fall back to decode_batched
// (sequential single-token decodes via the HSS
// orchestrator). See verify_b.rs for full rationale.
layer.decode_batched(
hidden,
residual,
k,
seq.layer_states[layer_idx].as_mut(),
&mut kv_cache,
seq.seq_len,
&mut seq.block_table,
&mut seq.disk_block_ids,
&mut seq.disk_last_offloaded_per_layer,
&ctx,
stream,
)?;
} else {
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
}
} else {
layer.decode_batched(
hidden,
Expand Down
53 changes: 37 additions & 16 deletions crates/spark-model/src/model/trait_impl/verify_d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,43 @@ impl TransformerModel {
let layer_type = self.config.layer_type(layer_idx);

if layer_type == LayerType::FullAttention {
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
if hss_engaged {
// HSS path: decode_multi_seq's paged-decode kernel
// reads K/V from HBM only, missing the long-context
// history on disk. Fall back to decode_batched
// (sequential single-token decodes via the HSS
// orchestrator). See verify_b.rs for full rationale.
layer.decode_batched(
hidden,
residual,
k,
seq.layer_states[layer_idx].as_mut(),
&mut kv_cache,
seq.seq_len,
&mut seq.block_table,
&mut seq.disk_block_ids,
&mut seq.disk_last_offloaded_per_layer,
&ctx,
stream,
)?;
} else {
let mut dummy_states: Vec<Box<dyn LayerState>> = (0..k)
.map(|_| layer.alloc_state(self.gpu.as_ref()))
.collect::<Result<_>>()?;
let mut refs: Vec<&mut (dyn LayerState + 'static)> =
dummy_states.iter_mut().map(|s| s.as_mut()).collect();
layer.decode_multi_seq(
hidden,
residual,
k,
&mut refs,
&mut kv_cache,
&seq_lens_vec,
&block_tables_vec,
&ctx,
stream,
)?;
}
} else {
layer.decode_batched(
hidden,
Expand Down
Loading