Add HISA hierarchical indexer for long-context decode#258
Open
TheTom wants to merge 1 commit into
Open
Conversation
c49488a to
7132b83
Compare
The indexer score scan walks every compressed row at decode-token cost O(n_comp). At 256K context that is ~65K rows per layer per token and the scan starts to outweigh the actual sparse attention behind it. HISA (arxiv 2603.28458) replaces the flat scan with a two-stage walk: score block representatives (mean of 128 consecutive rows) coarsely, pick the top-m blocks, then refine inside those blocks. Cost drops to O(n_blocks + m * 128). Output uses the same per-row scores layout (non-candidate rows are -INF) so the downstream top-K kernel runs unchanged. CUDA kernels live alongside the existing indexer kernels; Metal stubs return zero so the backend falls through to the flat indexer. The per-layer block_reps buffers and the shared sel_blocks / block_scores scratches are allocated alongside the comp cache; the cost is small (roughly 256 KB per layer at 256K ctx cap, ~6 MB across all layers). Dispatch gates on n_index_comp >= 49152 (about 196K context at ratio 4). Below that the rebuild and top-m fixed costs exceed the refine savings; rough numbers on a GB10 Spark with Qwen3.6-A3B IQ2XXS: 64K (n_comp ~16K) HISA -4.7% (gate skips it, flat indexer runs) 128K (n_comp ~32K) HISA -1.8% (gate skips it, flat indexer runs) 256K (n_comp ~65K) HISA +2.3% turbo3+comp, +2.7% fp8, +1.7% turbo4 Perplexity at 64 scored tokens is unchanged (107.19 with and without HISA) and the >99% top-K IoU claim from the paper holds across all KV dtypes tested. The implementation is v1 simple: block reps are recomputed in full on every dispatch. An incremental update covering only the last partial block on each compressor emit is a straightforward follow-up.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a HISA-style hierarchical indexer (arxiv 2603.28458) to the
decode-token path of the existing per-layer compressed indexer. The
flat indexer walks every compressed row at decode-token, which becomes
the dominant decode cost at long context. HISA replaces it with a
block-coarse pass over mean-pooled block representatives, a top-m block
selection, and a token-refine pass restricted to the selected blocks.
Output uses the same per-row scores layout (non-candidate rows are
-INF) so the existing top-K kernel runs unchanged. No new flags;the dispatch decides at runtime which indexer to use.
Runtime gate
HISA fires when
n_index_comp >= 49152(roughly 196K context at ratio4). Below that the block-rep rebuild plus the top-m selection cost
exceed the refine savings, so the dispatch routes through the existing
flat path and behavior is identical to main.
Bench
GB10 (ASUS Ascent, sm_121, 128 GB), Qwen3.6-A3B IQ2XXS,
ds4-bench --backend cuda --kv-cache turbo3 --comp-cache turbo3with the inline-dequant comp_kv path on. Raw
--csvfrom this branch is atspeed-bench/hisa/gb10_spark.csv:The 64K row confirms zero regression when the gate keeps HISA off;
the 256K row is the long-context point where HISA replaces the flat
scan. Parent-commit baseline at the same 256K +
turbo3 --comp-cache turbo3config measuredgen_tps = 7.47on the same session, so theon/off delta at 256K is
+1.9%. Companion before/after CSVs acrossfp8,turbo4, and the canonical--gen-tokens 128 --step-incr 16384sweep are queued and will be added to
speed-bench/hisa/.prefill_tpsis unchanged at both ctx points; HISA is a decode-tokenoptimization and the prefill batched-attention path is untouched.
Quality
Teacher-forced PPL on the same model and prompt at 64 scored tokens:
Identical to the parent-commit baseline at the same configuration. The
HISA paper's >99% top-K IoU vs flat held in every configuration tested
on this session.
What's in the diff
ds4_cuda.cu: five kernels (block-rep mean-pool, block scores, top-mselection, refine scores, scores init) plus two launchers
(
ds4_gpu_hisa_block_rep_update_tensor,ds4_gpu_hisa_score_one_tensor).Block-size and head-dim constants are added to the existing
DS4_CUDA_*enum. Gate threshold and top-m count live with thedispatch in
ds4.c.ds4_gpu.h: API declarations.ds4_metal.m: stubs that return zero so the Metal backend falls backto the flat indexer; the Metal port is deferred to a follow-up PR.
ds4.c: graph state (layer_hisa_block_reps[],hisa_sel_blocks,hisa_block_scores), allocation alongside the comp cache, free, andthe decode-token dispatch site that routes through HISA when
n_index_compis over the gate.speed-bench/hisa/: raw--csvoutput from this branch plus aREADME describing the runs and the queued follow-up sweep.
Memory cost
block_repsisceil(layer_comp_cap / 128) * 128floats per layer,plus a shared
sel_blocks[128] uint32and oneblock_scores[n_blocks_max] floatscratch. At 256K ctx cap that is roughly 256 KB per layer andabout 6 MB across all layers, negligible against the comp cache itself.
Implementation notes
block-scores in the coarse stage and 64 selected blocks for refine
(top-m = 64, recommended by the paper).
visible block per the HISA recency rule.
indexer_score_one_direct_kernel(per-head ReLU dot, per-headweight, scale), so quality matches paper expectations.
update covering only the last partial block on each compressor
emit is the natural follow-up; the rebuild cost at 256K is already
a small fraction of the refine savings.
Tests
Mac Metal exercises the flat-indexer fallback (HISA launchers stub to
zero) so the Metal build and kernel checks cover the unchanged path.
GB10 long-context is the gate that actually exercises HISA at runtime
since the fact-recall prompt drives
n_index_comppast 49152 on thedeeper layers; it runs through the new dispatch path rather than the
flat indexer.
Status
Draft for a first review pass. CUDA-only; the Metal stubs intentionally
return zero so the existing flat indexer continues to handle Metal.
The full per-dtype before/after CSV sweep is queued and will land in
speed-bench/hisa/before this leaves draft.Related: #243 (TurboQuant+ 3-bit KV cache). This PR's bench is taken
with
--kv-cache turbo3 --comp-cache turbo3on a build that includes#243; HISA itself is dtype-agnostic since the indexer operates on the
float
index_compcache, unchanged across KV dtypes.