Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
7ef12dd
docs(tests): update SINGLE_GPU_RESULTS with full investigation findings
claude May 11, 2026
a1899dd
feat(spark-server): inject tool-call parser system prompt (#57)
aceangel3k May 13, 2026
13009b6
fix(mla): use HDIM=128 kernel in cache-skip MLA prefill path
claude May 15, 2026
952ca76
docs(tests): correct Nemotron tool-calling and system_prompt() findings
claude May 15, 2026
f892f58
fix(nemotron-super): skip contradictory template tool injection for b…
claude May 15, 2026
f6161c1
fix(mistral-loader): MLA loader defaults kv_dtype to Bf16, not Fp8
claude May 15, 2026
7ce0a27
fix(mla-prefill): replace HDIM=256 kernel with mla_fused_prefill for …
claude May 15, 2026
3f673d4
fix(mla): replace HDIM=256 flash attn with mla_fused_prefill for Mist…
claude May 15, 2026
1d07183
fix(mla-prefill): add HDIM=128 flash-attn kernels for Mistral Small 4…
claude May 15, 2026
bc25fbc
fix(mla): use HDIM=128 kernel for Mistral Small 4 unabsorbed prefill
claude May 16, 2026
77438b7
fix(mla-prefill): guard against silent HDIM=256 fallback for MLA hd<=128
claude May 16, 2026
eed6190
fix(mla): use absorbed-space scale 1/sqrt(320) in all MLA attention p…
claude May 16, 2026
b274150
fix(mla): cross-chunk paged attention for Mistral MLA multi-chunk pre…
claude May 17, 2026
345c3b2
fix(mla): move smem_dot declaration outside kv_pos loop in mla_fused_…
claude May 17, 2026
427104f
fix: BF16 KV cache not applied to MLA layers + SSM pool over-allocation
claude May 18, 2026
767af31
fix(anthropic): count_tokens inflates token count for skip_template_t…
claude May 18, 2026
24375ac
fix(tests): update kv_dtypes test to match BF16-hardening behavior
claude May 18, 2026
08214f9
docs(tests): 2026-05-19 verification — all P0/P1 bugs confirmed fixed
claude May 19, 2026
0f72e45
docs(tests): 2026-05-20 independent re-verification of all spec_ssm f…
claude May 20, 2026
3145c08
docs(tests): 2026-05-20 final investigation audit — all fixes verifie…
claude May 20, 2026
6b6e755
feat(tool-parser): add suppresses_jinja_tools() parser-level trait me…
claude May 20, 2026
5721593
docs(tests): 2026-05-21 re-audit — all fixes confirmed, suppresses_ji…
claude May 21, 2026
22ae45f
docs(tests): 2026-05-21 independent re-investigation — all fixes veri…
claude May 21, 2026
394cb97
docs(tests): 2026-05-21 final verification — all fixes traced to exac…
claude May 21, 2026
2993894
fix(anthropic): count_tokens missing suppresses_jinja_tools() check +…
claude May 21, 2026
ac64e99
docs(tests): add 2026-05-22 independent verification session to SINGL…
claude May 22, 2026
2bf1da8
docs(tests): add 2026-05-22 second verification — all P0/P1/P2 fixes …
claude May 22, 2026
2f9c5f4
docs(tests): add 2026-05-22 third verification — inferspark_prefill_6…
claude May 22, 2026
465add3
docs(tests): add 2026-05-23 verification — buffer sizing, dimensions,…
claude May 23, 2026
1f07ff8
site: sparkrun quickstart + SSOT-driven model carousel (#68)
tbraun96 May 18, 2026
67285fa
site: 3-level tabbed model nav (vendor/sub-family/recipe) + X links +…
tbraun96 May 18, 2026
edb944e
fix: default --mtp-quantization to bf16 (#61)
tbraun96 May 19, 2026
f37bc10
fix(ssm): use try_kernel + warn for optional FP32 conv1d (#76)
pragmaxim May 20, 2026
dce4d35
fix(ci): fix CLA bot signature recognition and writing failures (#77)
tbraun96 May 20, 2026
0bc4afc
docs: remove dangling SSM_CATASTROPHIC_FORGETTING_TODO.md refs (#75)
pragmaxim May 20, 2026
40e81da
spark-server: add --default-chat-template-kwargs CLI flag (#79)
Marker689 May 21, 2026
3b848cc
refactor(mla): remove unreachable else-if branch in cache_skip.rs
claude May 23, 2026
91ce063
docs(tests): add 2026-05-23 re-investigation — YaRN misdiagnosis conf…
claude May 23, 2026
7265f7f
docs(tests): add 2026-05-24 re-investigation — all fixes confirmed, Y…
claude May 24, 2026
0114743
docs(tests): add 2026-05-24 second-pass verification — all findings c…
claude May 24, 2026
59a55d5
docs(tests): third-pass verification — main vs spec_ssm cross-check (…
claude May 24, 2026
fd2e919
docs(tests): fourth-pass verification — all fixes confirmed at HEAD 5…
claude May 25, 2026
426f7c8
docs(tests): fifth-pass verification — all fixes confirmed at HEAD fd…
claude May 25, 2026
0948d48
docs(tests): sixth-pass verification — all fixes confirmed at HEAD 42…
claude May 25, 2026
5af74d6
docs(tests): seventh-pass verification — all fixes confirmed at HEAD …
claude May 25, 2026
080ef06
docs(tests): eighth-pass verification — all fixes confirmed at HEAD 5…
claude May 25, 2026
a82ba4a
docs(tests): ninth-pass verification — cross-branch main-vs-spec_ssm …
claude May 26, 2026
8a285cb
docs(tests): tenth-pass verification — session context reconciliation…
claude May 27, 2026
74c6d4c
docs(tests): eleventh-pass verification — warp-reduction correctness …
claude May 27, 2026
e7de0f4
fix(mla-prefill): respect kv_write_start in cache_skip MLA KV cache w…
claude May 27, 2026
7fe0788
docs(tests): twelfth-pass verification — kv_write_start MLA cache fix…
claude May 27, 2026
ba5f40f
docs(tests): thirteenth-pass independent audit — all P1/P2/P3 fixes r…
claude May 27, 2026
0b89988
fix(mla_prefill_attn): use half-warp masks to avoid CUDA warp-sync UB
claude May 27, 2026
ebe5b36
fix(mla_prefill_paged_320): use half-warp masks to eliminate CUDA war…
claude May 28, 2026
b2b51f9
docs(tests): sixteenth-pass independent audit — all P1/P2/P3 fixes re…
claude May 28, 2026
1885142
docs(tests): seventeenth-pass independent audit — all P1/P2/P3 fixes …
claude May 28, 2026
2664d14
docs(tests): eighteenth-pass independent audit — all P1/P2/P3 fixes r…
claude May 28, 2026
617bc6e
docs(tests): nineteenth-pass independent audit — all P1/P2/P3 fixes r…
claude May 29, 2026
bda98c5
docs(tests): twentieth-pass independent audit — all P1/P2/P3 fixes re…
claude May 29, 2026
fd1fb9d
docs(tests): twenty-first-pass independent audit — all P1/P2/P3 fixes…
claude May 29, 2026
2d6e810
docs(tests): twenty-second-pass independent audit — all P1/P2/P3 fixe…
claude May 29, 2026
4002085
docs(tests): twenty-third-pass independent audit — all P1/P2/P3 fixes…
claude May 29, 2026
cbb2c08
docs(tests): twenty-fourth-pass audit — P1 root cause re-traced via g…
claude May 29, 2026
3d675ee
docs(tests): twenty-fifth-pass — fresh independent P1/P2/P3 investiga…
claude May 29, 2026
f349662
docs(tests): twenty-sixth-pass — fresh independent P1/P2/P3 investiga…
claude May 29, 2026
9e07ef9
docs(tests): twenty-seventh-pass — fresh session cold-start audit, al…
claude May 29, 2026
d4d222e
docs(tests): twenty-eighth-pass — full audit at 9e07ef9, kv_dtypes no…
claude May 30, 2026
4597624
docs(tests): twenty-ninth-pass — fresh audit at d4d222e, all fixes re…
claude May 30, 2026
eb54c20
docs(tests): thirtieth-pass — full audit at 4597624, all fixes re-ver…
claude May 30, 2026
1f84817
docs(tests): thirty-first-pass — cross-branch audit confirms all P1/P…
claude May 30, 2026
149e2c7
docs(tests): thirty-second-pass — full end-to-end audit, all P1/P2/P3…
claude May 30, 2026
6b742af
docs(tests): correct summary table — Mistral P1 root cause is BF16 KV…
claude May 30, 2026
7dd9233
docs(tests): thirty-fourth-pass — fresh cold-start audit, all P1/P2/P…
claude May 30, 2026
82e8803
docs(tests): thirty-fifth-pass — independent P1/P2/P3 audit, all fixe…
claude May 30, 2026
0e885ef
docs(tests): thirty-sixth-pass — independent P1/P2/P3 cold-start audi…
claude May 30, 2026
782a5a5
docs(tests): thirty-seventh-pass — full cold-start audit per task spe…
claude May 31, 2026
a893a79
docs(tests): thirty-eighth-pass — kernel dead-code audit, warp-mask r…
claude May 31, 2026
9d8f1f4
docs(tests): thirty-ninth-pass — complete independent end-to-end veri…
claude May 31, 2026
5069e3a
docs(tests): fortieth-pass — post-compaction re-verification at HEAD …
claude May 31, 2026
bfc9fe0
docs(tests): forty-first-pass — fresh cold-start re-investigation, al…
claude May 31, 2026
162f9a3
docs(tests): forty-second-pass — independent cold-start audit, all P1…
claude May 31, 2026
5f2c7b0
docs(tests): forty-third-pass — independent P1/P2/P3 cold-start audit…
claude May 31, 2026
525bceb
docs(tests): forty-fourth-pass — full independent P1/P2/P3 cold-start…
claude May 31, 2026
1b63a0c
docs(tests): forty-fifth-pass — cold-start audit traces git diff, all…
claude May 31, 2026
91bae54
docs(tests): forty-sixth-pass — 2026-06-01 cold-start audit, all P1/P…
claude Jun 1, 2026
591fd0e
docs(tests): forty-seventh-pass — fresh cross-branch audit, all P1/P2…
claude Jun 1, 2026
c2e7c09
docs(tests): forty-eighth-pass — fresh cold-start audit, doc inaccura…
claude Jun 1, 2026
bae5408
docs(tests): forty-ninth-pass — independent cold-start investigation,…
claude Jun 1, 2026
322a28b
docs(tests): fiftieth-pass — cold-start audit traces git diff, all P1…
claude Jun 1, 2026
3b8a01a
docs(tests): fifty-first-pass — final session verification, header co…
claude Jun 1, 2026
dd3894d
docs(tests): fifty-second-pass — independent June 2026 audit, all P1/…
claude Jun 1, 2026
bed92bf
docs(tests): fifty-third-pass — 2026-06-01 independent investigation,…
claude Jun 1, 2026
286a9f2
docs(tests): fifty-fourth-pass — P1/P2/P3 audit; correct YaRN misattr…
claude Jun 2, 2026
ec8c377
docs(tests): fifty-fifth-pass — independent cold-start audit, all P1/…
claude Jun 2, 2026
a634442
docs(tests): 2026-06-02 independent audit — all P1/P2/P3 fixes confirmed
claude Jun 2, 2026
266b333
docs(tests): fifty-sixth-pass — fresh cold-start audit, all P1/P2/P3 …
claude Jun 2, 2026
e403918
docs(tests): fifty-seventh-pass — fresh cold-start audit, all P1/P2/P…
claude Jun 2, 2026
47ca5b9
docs(tests): 2026-06-02 audit — add item 13 (warp-mask UB fix) to act…
claude Jun 2, 2026
7a0ed4e
docs(tests): 2026-06-02 final independent audit — all P1/P2/P3 fixes …
claude Jun 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/cla.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
actions: write
steps:
- name: "CLA Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
if: (contains(github.event.comment.body, 'recheck') || contains(github.event.comment.body, 'I have read the CLA Document and I hereby sign the CLA')) || github.event_name == 'pull_request_target'
# Alpha Release
uses: contributor-assistant/github-action@v2.6.1
env:
Expand All @@ -35,8 +35,8 @@ jobs:
with:
path-to-signatures: 'signatures/version1/cla.json'
path-to-document: 'https://github.com/${{ github.repository }}/blob/main/CLA.md'
branch: 'main'
allowlist: dependabot[bot],renovate[bot]
branch: 'cla-signatures'
allowlist: dependabot[bot],renovate[bot],google-labs-jules[bot],claude[bot],google-labs-jules,claude
lock-pullrequest-aftermerge: false
create-file-commit-message: 'chore: setup CLA signatures file'
signed-commit-message: 'chore: $username CLA signature added'
Expand Down
1 change: 0 additions & 1 deletion book/src/deep-dives/ssm.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,4 @@ Each step calls into `spark-runtime::GpuBackend` via the layer's cached `KernelH
- `kernels/gb10/<model>/<quant>/ssm_preprocess.cu`, `gdr.cu`, `causal_conv1d.cu`
- `crates/spark-model/src/layers/qwen3_ssm.rs`, `nemotron_mamba2.rs`
- `crates/spark-runtime/src/prefix_cache.rs` (Marconi SSM snapshot)
- `docs/history/SSM_CATASTROPHIC_FORGETTING_TODO.md`
- README "Atlas Spark" section — the SSM/GDN story in narrative form
3 changes: 3 additions & 0 deletions crates/atlas-kernels/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct Target {
behavior_disable_tool_steering: bool,
behavior_tool_call_parser: String,
behavior_enable_loop_watchdog: bool,
behavior_skip_template_tools: bool,
/// Which `(model_type, hidden_size)` pairs this kernel target supports.
/// Parsed from `[[model_types]]` in MODEL.toml.
model_type_matches: Vec<ModelTypeMatch>,
Expand Down Expand Up @@ -362,6 +363,7 @@ fn resolve_targets(workspace_root: &std::path::Path) -> Vec<Target> {
b_disable_tool_steering,
b_tool_call_parser,
b_enable_loop_watchdog,
b_skip_template_tools,
) = parse_behavior(&model_dir);
let model_type_matches = parse_model_types(&model_dir);
let dflash = parse_dflash(&model_dir);
Expand Down Expand Up @@ -392,6 +394,7 @@ fn resolve_targets(workspace_root: &std::path::Path) -> Vec<Target> {
behavior_disable_tool_steering: b_disable_tool_steering,
behavior_tool_call_parser: b_tool_call_parser,
behavior_enable_loop_watchdog: b_enable_loop_watchdog,
behavior_skip_template_tools: b_skip_template_tools,
model_type_matches,
dflash,
});
Expand Down
2 changes: 2 additions & 0 deletions crates/atlas-kernels/build_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ pub(super) fn generate_target_ptx_rs(
\x20 disable_tool_steering: {},\n\
\x20 tool_call_parser: \"{}\",\n\
\x20 enable_loop_watchdog: {},\n\
\x20 skip_template_tools: {},\n\
\x20 }},\n\
\x20 model_type_matches: vec![{}],\n\
\x20 dflash: {},\n\
Expand All @@ -186,6 +187,7 @@ pub(super) fn generate_target_ptx_rs(
target.behavior_disable_tool_steering,
target.behavior_tool_call_parser,
target.behavior_enable_loop_watchdog,
target.behavior_skip_template_tools,
target.model_type_matches.iter().map(|m| {
let hs = match m.hidden_size {
Some(v) => format!("Some({v})"),
Expand Down
11 changes: 9 additions & 2 deletions crates/atlas-kernels/build_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ pub(super) fn parse_sampling_presets(
}

/// Parse [behavior] from MODEL.toml. Returns
/// (thinking_in_tools, max_thinking_budget, thinking_default, fp8_kv_calibration_tokens, default_kv_dtype, default_num_drafts, disable_tool_steering, tool_call_parser, enable_loop_watchdog).
/// (thinking_in_tools, max_thinking_budget, thinking_default, fp8_kv_calibration_tokens, default_kv_dtype, default_num_drafts, disable_tool_steering, tool_call_parser, enable_loop_watchdog, skip_template_tools).
pub(super) fn parse_behavior(
model_dir: &std::path::Path,
) -> (bool, u32, bool, usize, String, u32, bool, String, bool) {
) -> (bool, u32, bool, usize, String, u32, bool, String, bool, bool) {
let model_toml_path = model_dir.join("MODEL.toml");
if !model_toml_path.exists() {
return (
Expand All @@ -136,6 +136,7 @@ pub(super) fn parse_behavior(
false,
String::new(),
false,
false,
);
}
let content = std::fs::read_to_string(&model_toml_path).unwrap_or_default();
Expand All @@ -152,6 +153,7 @@ pub(super) fn parse_behavior(
false,
String::new(),
false,
false,
);
}
};
Expand Down Expand Up @@ -197,6 +199,10 @@ pub(super) fn parse_behavior(
.and_then(|v| v.get("enable_loop_watchdog"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
let skip_template_tools = b
.and_then(|v| v.get("skip_template_tools"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
(
thinking_in_tools,
max_thinking_budget,
Expand All @@ -207,6 +213,7 @@ pub(super) fn parse_behavior(
disable_tool_steering,
tool_call_parser,
enable_loop_watchdog,
skip_template_tools,
)
}

Expand Down
14 changes: 14 additions & 0 deletions crates/atlas-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ pub struct ModelBehavior {
/// JSON arrays of similar objects, multiplication tables). Enable only
/// when the model has been observed to need it.
pub enable_loop_watchdog: bool,
/// When true, do not pass tool definitions to the Jinja chat template
/// (`jinja_tools` stays `None`). Use this for models where the tool-call
/// parser already injects a complete system-prompt with tool schemas and
/// format instructions, and the template's own XML tool rendering would
/// produce contradictory instructions.
///
/// Example: Nemotron-Super-120B uses `bare_json` grammar (parser emits
/// JSON-schema + bare-JSON instructions) while the `nemotron_h.jinja`
/// template would additionally render XML `<function>` blocks and tell
/// the model to output `<tool_call>` XML — the opposite format. Setting
/// `skip_template_tools = true` suppresses the template rendering and
/// leaves the parser's instructions as the sole tool-format signal.
pub skip_template_tools: bool,
}

impl Default for ModelBehavior {
Expand All @@ -177,6 +190,7 @@ impl Default for ModelBehavior {
disable_tool_steering: false,
tool_call_parser: "",
enable_loop_watchdog: false,
skip_template_tools: false,
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions crates/spark-model/src/layers/ops/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,40 @@ pub fn mla_batched_gemv(
.launch(stream)
}

/// Batched V extraction for N-token MLA prefill.
/// For each (token, head): output[token, head, :] = W_UV[head] @ input[token, head, 0..k]
/// where input has input_head_stride elements per head (only first k are used).
///
/// Grid: (ceil(n_out/8), num_heads, n_tokens) Block: (256, 1, 1)
#[allow(clippy::too_many_arguments)]
pub fn mla_v_extract_batched(
gpu: &dyn GpuBackend,
kernel: KernelHandle,
input: DevicePtr,
weight: DevicePtr,
output: DevicePtr,
n_out: u32,
k: u32,
num_heads: u32,
input_head_stride: u32,
output_head_stride: u32,
n_tokens: u32,
stream: u64,
) -> Result<()> {
KernelLaunch::new(gpu, kernel)
.grid([div_ceil(n_out, 8), num_heads, n_tokens])
.block([256, 1, 1])
.arg_ptr(input)
.arg_ptr(weight)
.arg_ptr(output)
.arg_u32(n_out)
.arg_u32(k)
.arg_u32(num_heads)
.arg_u32(input_head_stride)
.arg_u32(output_head_stride)
.launch(stream)
}

/// MLA Q_rope scatter: copy rope portion from q_full to strided q_absorbed_buf. 1 kernel replaces 32 D2D copies.
#[allow(clippy::too_many_arguments)]
pub fn mla_q_rope_scatter(
Expand Down
45 changes: 45 additions & 0 deletions crates/spark-model/src/layers/ops/prefill_attn_a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,51 @@ pub fn mla_prefill_attention_320(
.launch(stream)
}

/// Paged MLA prefill attention — absorbed form, HDIM=320, multi-chunk (seq_len_start > 0).
///
/// Q [q_len, nq, 320] attends to KV cache (paged) over kv_len tokens with causal masking.
/// Q at local position i (global position q_offset + i) attends to KV 0..q_offset+i.
///
/// Grid: (num_q_heads, ceil(q_len/16), 1) Block: (256, 1, 1)
#[allow(clippy::too_many_arguments)]
pub fn mla_prefill_paged_320(
gpu: &dyn GpuBackend,
kernel: KernelHandle,
q: DevicePtr,
k_cache: DevicePtr,
v_cache: DevicePtr,
output: DevicePtr,
block_table: DevicePtr,
q_len: u32,
kv_len: u32,
q_offset: u32,
num_q_heads: u32,
num_kv_heads: u32,
head_dim: u32,
cache_block_size: u32,
inv_sqrt_d: f32,
stream: u64,
) -> Result<()> {
let br = 16u32; // MLA_BR in the kernel
KernelLaunch::new(gpu, kernel)
.grid([num_q_heads, div_ceil(q_len, br), 1])
.block([256, 1, 1])
.arg_ptr(q)
.arg_ptr(k_cache)
.arg_ptr(v_cache)
.arg_ptr(output)
.arg_ptr(block_table)
.arg_u32(q_len)
.arg_u32(kv_len)
.arg_u32(q_offset)
.arg_u32(num_q_heads)
.arg_u32(num_kv_heads)
.arg_u32(head_dim)
.arg_u32(cache_block_size)
.arg_f32(inv_sqrt_d)
.launch(stream)
}

pub fn paged_decode_attn_bf16(
gpu: &dyn GpuBackend,
kernel: KernelHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,9 @@ impl Qwen3AttentionLayer {

// Step 8: Paged decode attention
let attn_out = ctx.buffers.attn_output();
let inv_sqrt_d = self.effective_attn_scale(hd);
// Absorbed MLA decode operates in (kv_lora+rope)-dim space; 1/sqrt(hd=128)
// would over-sharpen softmax vs the correct 1/sqrt(kv_lora+rope=320).
let inv_sqrt_d = 1.0f32 / ((kv_lora + mla_rope) as f32).sqrt();
prof!("paged_attn", {
ops::paged_decode_attn_bf16(
ctx.gpu,
Expand Down
20 changes: 20 additions & 0 deletions crates/spark-model/src/layers/qwen3_attention/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,16 @@ impl Qwen3AttentionLayer {
"mla_fused_prefill",
"mla_fused_prefill",
),
mla_prefill_paged_k: super::super::try_kernel(
gpu,
"mla_prefill_paged",
"mla_prefill_paged_320",
),
mla_v_extract_batched_k: super::super::try_kernel(
gpu,
"mla_absorbed",
"mla_v_extract_batched",
),
gemm_splitk_partial_k: super::super::try_kernel(
gpu,
"gemm_splitk",
Expand Down Expand Up @@ -391,6 +401,16 @@ impl Qwen3AttentionLayer {
"inferspark_prefill_paged_512",
),
prefill_attn_64_k: gpu.kernel("inferspark_prefill", "inferspark_prefill_64")?,
prefill_attn_128_k: super::super::try_kernel(
gpu,
"inferspark_prefill_128",
"inferspark_prefill_hd128",
),
prefill_attn_64_128_k: super::super::try_kernel(
gpu,
"inferspark_prefill_128",
"inferspark_prefill_64_hd128",
),
prefill_attn_paged_k: gpu.kernel("prefill_paged", "inferspark_prefill_paged")?,
prefill_attn_paged_fp8_k: gpu
.kernel("prefill_paged_fp8", "inferspark_prefill_paged_fp8")?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,13 @@ impl Qwen3AttentionLayer {
if self.mla.is_some() {
let args = super::cache_skip_mla::CacheSkipMlaArgs {
normed,
num_tokens,
n,
h,
nq,
nkv,
hd,
kv_dim,
eps,
bf16,
stream,
kv_write_start,
};
return self.prefill_attention_cache_skip_mla(kv_cache, ctx, &args);
}
Expand Down Expand Up @@ -143,22 +140,6 @@ impl Qwen3AttentionLayer {
q_proj_dim as u32,
stream,
)?;
} else if self.mla.is_some() {
// DIAGNOSTIC: check V BEFORE Q copy
if self.attn_layer_idx == 0 && ctx.config.model_type == "mistral" {
ctx.gpu.synchronize(stream)?;
let v_chk = k_contiguous.offset(num_tokens * kv_dim * bf16);
crate::layers::qwen3_attention::trait_impl::diag_norm(
ctx.gpu,
v_chk,
(nkv * hd) as usize,
stream,
"L0 V BEFORE Q_copy",
);
}
ctx.gpu
.copy_d2d_async(qg_out, q_contiguous, num_tokens * q_dim * bf16, stream)
.map_err(|e| anyhow::anyhow!("MLA Q copy failed: {e}"))?;
} else {
ctx.gpu
.copy_d2d_async(qg_out, q_contiguous, num_tokens * q_dim * bf16, stream)
Expand Down
Loading
Loading