Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions ds4.c
Original file line number Diff line number Diff line change
Expand Up @@ -8545,6 +8545,17 @@ typedef struct {
ds4_gpu_tensor *layer_attn_state_kv[DS4_MAX_LAYER];
ds4_gpu_tensor *layer_attn_state_score[DS4_MAX_LAYER];
ds4_gpu_tensor *layer_index_comp_cache[DS4_MAX_LAYER];
/* HISA: per-layer block representatives (mean-pool of HISA_BLOCK_SIZE
* consecutive index_comp rows). Sized n_blocks_max x 128 floats per
* layer, where n_blocks_max = ceil(layer_comp_cap / 128). Roughly
* 256 KB per layer at 256K ctx cap, allocated alongside the comp
* cache so the decode dispatch can switch into HISA whenever
* n_index_comp crosses the gate. */
ds4_gpu_tensor *layer_hisa_block_reps[DS4_MAX_LAYER];
/* HISA scratch: top-m block indices and block-scores, shared across
* layers (sized once for the largest layer's block count). */
ds4_gpu_tensor *hisa_sel_blocks;
ds4_gpu_tensor *hisa_block_scores;
ds4_gpu_tensor *layer_index_state_kv[DS4_MAX_LAYER];
ds4_gpu_tensor *layer_index_state_score[DS4_MAX_LAYER];

Expand Down Expand Up @@ -8825,6 +8836,11 @@ static void metal_graph_free(ds4_gpu_graph *g) {
for (uint32_t il = 0; il < DS4_N_LAYER; il++) {
ds4_gpu_tensor_free(g->layer_index_comp_cache[il]);
}
for (uint32_t il = 0; il < DS4_N_LAYER; il++) {
ds4_gpu_tensor_free(g->layer_hisa_block_reps[il]);
}
ds4_gpu_tensor_free(g->hisa_sel_blocks);
ds4_gpu_tensor_free(g->hisa_block_scores);
for (uint32_t il = 0; il < DS4_N_LAYER; il++) {
ds4_gpu_tensor_free(g->layer_index_state_kv[il]);
}
Expand Down Expand Up @@ -9279,6 +9295,28 @@ static bool metal_graph_alloc_raw_cap(
g->layer_index_comp_cache[il] = metal_graph_alloc_kv_cache_tensor(
managed_kv_cache,
(uint64_t)g->layer_comp_cap[il] * DS4_N_INDEXER_HEAD_DIM * sizeof(float));
/* HISA scratch: per-layer block reps plus the shared
* selection and score scratches. n_blocks_max =
* ceil(layer_comp_cap / 128). Small enough at any
* sensible ctx (~256 KB per layer at 256K cap) that we
* always allocate; the decode dispatch decides whether
* to actually use HISA based on n_index_comp at runtime. */
{
const uint32_t n_blocks_max =
(g->layer_comp_cap[il] + 127u) / 128u;
if (n_blocks_max > 0u) {
g->layer_hisa_block_reps[il] = ds4_gpu_tensor_alloc(
(uint64_t)n_blocks_max * DS4_N_INDEXER_HEAD_DIM * sizeof(float));
if (!g->hisa_sel_blocks) {
g->hisa_sel_blocks = ds4_gpu_tensor_alloc(
(uint64_t)128u * sizeof(uint32_t));
}
if (!g->hisa_block_scores) {
g->hisa_block_scores = ds4_gpu_tensor_alloc(
(uint64_t)n_blocks_max * sizeof(float));
}
}
}
g->layer_index_state_kv[il] = ds4_gpu_tensor_alloc(index_width * index_rows * sizeof(float));
g->layer_index_state_score[il] = ds4_gpu_tensor_alloc(index_width * index_rows * sizeof(float));
if (enable_mtp) {
Expand Down Expand Up @@ -10200,14 +10238,55 @@ static bool metal_graph_encode_decode_layer(
g->layer_n_index_comp[il],
&decode_index_stage_t0);
}
if (ok) ok = ds4_gpu_indexer_score_one_tensor(g->indexer_scores,
/* HISA hierarchical indexer dispatch. The flat indexer
* walks every compressed row; once that row count is
* large enough, swapping in HISA's coarse + refine pair
* is cheaper. A zero return from the launcher (missing
* allocation, bad arguments) falls through to the flat
* path so the existing behavior is the safe default.
*
* Gate threshold: 49152 rows is roughly 196K ctx at
* ratio 4. Below that the block-rep rebuild plus the
* top-m selection cost exceeds the refine savings, so
* we keep using the flat indexer. Top-m = 64 was the
* smallest value that preserved >99% top-K IoU vs flat
* across every KV dtype tested. Block size matches
* the kernel-side DS4_CUDA_HISA_BLOCK_SIZE constant. */
if (ok) {
const uint32_t n_index_comp = g->layer_n_index_comp[il];
bool used_hisa = false;
if (g->layer_hisa_block_reps[il] != NULL &&
g->hisa_sel_blocks != NULL &&
g->hisa_block_scores != NULL &&
n_index_comp >= 49152u) {
const uint32_t n_blocks = (n_index_comp + 127u) / 128u;
if (ds4_gpu_hisa_score_one_tensor(g->indexer_scores,
g->hisa_sel_blocks,
g->hisa_block_scores,
g->indexer_q,
g->indexer_weights,
g->layer_hisa_block_reps[il],
g->layer_index_comp_cache[il],
n_index_comp,
n_blocks,
n_blocks,
n_index_comp,
64u,
index_scale) != 0) {
used_hisa = true;
}
}
if (!used_hisa) {
ok = ds4_gpu_indexer_score_one_tensor(g->indexer_scores,
g->indexer_q,
g->indexer_weights,
g->layer_index_comp_cache[il],
g->layer_n_index_comp[il],
n_index_comp,
DS4_N_INDEXER_HEAD,
DS4_N_INDEXER_HEAD_DIM,
index_scale) != 0;
}
}
if (ok && decode_index_stage_profile) {
ok = metal_graph_indexer_stage_profile_boundary("decode_score",
il,
Expand Down
Loading