diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index 0eb04e5e..0a0faff1 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -25,7 +25,7 @@ jobs: actions: write steps: - name: "CLA Assistant" - if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' + if: (contains(github.event.comment.body, 'recheck') || contains(github.event.comment.body, 'I have read the CLA Document and I hereby sign the CLA')) || github.event_name == 'pull_request_target' # Alpha Release uses: contributor-assistant/github-action@v2.6.1 env: @@ -35,8 +35,8 @@ jobs: with: path-to-signatures: 'signatures/version1/cla.json' path-to-document: 'https://github.com/${{ github.repository }}/blob/main/CLA.md' - branch: 'main' - allowlist: dependabot[bot],renovate[bot] + branch: 'cla-signatures' + allowlist: dependabot[bot],renovate[bot],google-labs-jules[bot],claude[bot],google-labs-jules,claude lock-pullrequest-aftermerge: false create-file-commit-message: 'chore: setup CLA signatures file' signed-commit-message: 'chore: $username CLA signature added' diff --git a/book/src/deep-dives/ssm.md b/book/src/deep-dives/ssm.md index 3926bc25..5337d4fc 100644 --- a/book/src/deep-dives/ssm.md +++ b/book/src/deep-dives/ssm.md @@ -94,5 +94,4 @@ Each step calls into `spark-runtime::GpuBackend` via the layer's cached `KernelH - `kernels/gb10///ssm_preprocess.cu`, `gdr.cu`, `causal_conv1d.cu` - `crates/spark-model/src/layers/qwen3_ssm.rs`, `nemotron_mamba2.rs` - `crates/spark-runtime/src/prefix_cache.rs` (Marconi SSM snapshot) -- `docs/history/SSM_CATASTROPHIC_FORGETTING_TODO.md` - README "Atlas Spark" section — the SSM/GDN story in narrative form diff --git a/crates/atlas-kernels/build.rs b/crates/atlas-kernels/build.rs index e523c7e0..12049f91 100644 --- a/crates/atlas-kernels/build.rs +++ b/crates/atlas-kernels/build.rs @@ -74,6 +74,7 @@ struct Target { behavior_disable_tool_steering: bool, behavior_tool_call_parser: String, behavior_enable_loop_watchdog: bool, + behavior_skip_template_tools: bool, /// Which `(model_type, hidden_size)` pairs this kernel target supports. /// Parsed from `[[model_types]]` in MODEL.toml. model_type_matches: Vec, @@ -362,6 +363,7 @@ fn resolve_targets(workspace_root: &std::path::Path) -> Vec { b_disable_tool_steering, b_tool_call_parser, b_enable_loop_watchdog, + b_skip_template_tools, ) = parse_behavior(&model_dir); let model_type_matches = parse_model_types(&model_dir); let dflash = parse_dflash(&model_dir); @@ -392,6 +394,7 @@ fn resolve_targets(workspace_root: &std::path::Path) -> Vec { behavior_disable_tool_steering: b_disable_tool_steering, behavior_tool_call_parser: b_tool_call_parser, behavior_enable_loop_watchdog: b_enable_loop_watchdog, + behavior_skip_template_tools: b_skip_template_tools, model_type_matches, dflash, }); diff --git a/crates/atlas-kernels/build_codegen.rs b/crates/atlas-kernels/build_codegen.rs index 9f363fe9..ab054ab1 100644 --- a/crates/atlas-kernels/build_codegen.rs +++ b/crates/atlas-kernels/build_codegen.rs @@ -168,6 +168,7 @@ pub(super) fn generate_target_ptx_rs( \x20 disable_tool_steering: {},\n\ \x20 tool_call_parser: \"{}\",\n\ \x20 enable_loop_watchdog: {},\n\ + \x20 skip_template_tools: {},\n\ \x20 }},\n\ \x20 model_type_matches: vec![{}],\n\ \x20 dflash: {},\n\ @@ -186,6 +187,7 @@ pub(super) fn generate_target_ptx_rs( target.behavior_disable_tool_steering, target.behavior_tool_call_parser, target.behavior_enable_loop_watchdog, + target.behavior_skip_template_tools, target.model_type_matches.iter().map(|m| { let hs = match m.hidden_size { Some(v) => format!("Some({v})"), diff --git a/crates/atlas-kernels/build_parse.rs b/crates/atlas-kernels/build_parse.rs index 2fe87eec..282737f3 100644 --- a/crates/atlas-kernels/build_parse.rs +++ b/crates/atlas-kernels/build_parse.rs @@ -120,10 +120,10 @@ pub(super) fn parse_sampling_presets( } /// Parse [behavior] from MODEL.toml. Returns -/// (thinking_in_tools, max_thinking_budget, thinking_default, fp8_kv_calibration_tokens, default_kv_dtype, default_num_drafts, disable_tool_steering, tool_call_parser, enable_loop_watchdog). +/// (thinking_in_tools, max_thinking_budget, thinking_default, fp8_kv_calibration_tokens, default_kv_dtype, default_num_drafts, disable_tool_steering, tool_call_parser, enable_loop_watchdog, skip_template_tools). pub(super) fn parse_behavior( model_dir: &std::path::Path, -) -> (bool, u32, bool, usize, String, u32, bool, String, bool) { +) -> (bool, u32, bool, usize, String, u32, bool, String, bool, bool) { let model_toml_path = model_dir.join("MODEL.toml"); if !model_toml_path.exists() { return ( @@ -136,6 +136,7 @@ pub(super) fn parse_behavior( false, String::new(), false, + false, ); } let content = std::fs::read_to_string(&model_toml_path).unwrap_or_default(); @@ -152,6 +153,7 @@ pub(super) fn parse_behavior( false, String::new(), false, + false, ); } }; @@ -197,6 +199,10 @@ pub(super) fn parse_behavior( .and_then(|v| v.get("enable_loop_watchdog")) .and_then(|v| v.as_bool()) .unwrap_or(false); + let skip_template_tools = b + .and_then(|v| v.get("skip_template_tools")) + .and_then(|v| v.as_bool()) + .unwrap_or(false); ( thinking_in_tools, max_thinking_budget, @@ -207,6 +213,7 @@ pub(super) fn parse_behavior( disable_tool_steering, tool_call_parser, enable_loop_watchdog, + skip_template_tools, ) } diff --git a/crates/atlas-kernels/src/lib.rs b/crates/atlas-kernels/src/lib.rs index d89d1be1..a016a096 100644 --- a/crates/atlas-kernels/src/lib.rs +++ b/crates/atlas-kernels/src/lib.rs @@ -163,6 +163,19 @@ pub struct ModelBehavior { /// JSON arrays of similar objects, multiplication tables). Enable only /// when the model has been observed to need it. pub enable_loop_watchdog: bool, + /// When true, do not pass tool definitions to the Jinja chat template + /// (`jinja_tools` stays `None`). Use this for models where the tool-call + /// parser already injects a complete system-prompt with tool schemas and + /// format instructions, and the template's own XML tool rendering would + /// produce contradictory instructions. + /// + /// Example: Nemotron-Super-120B uses `bare_json` grammar (parser emits + /// JSON-schema + bare-JSON instructions) while the `nemotron_h.jinja` + /// template would additionally render XML `` blocks and tell + /// the model to output `` XML — the opposite format. Setting + /// `skip_template_tools = true` suppresses the template rendering and + /// leaves the parser's instructions as the sole tool-format signal. + pub skip_template_tools: bool, } impl Default for ModelBehavior { @@ -177,6 +190,7 @@ impl Default for ModelBehavior { disable_tool_steering: false, tool_call_parser: "", enable_loop_watchdog: false, + skip_template_tools: false, } } } diff --git a/crates/spark-model/src/layers/ops/kv_cache.rs b/crates/spark-model/src/layers/ops/kv_cache.rs index ce2cafc8..666d1b06 100644 --- a/crates/spark-model/src/layers/ops/kv_cache.rs +++ b/crates/spark-model/src/layers/ops/kv_cache.rs @@ -166,6 +166,40 @@ pub fn mla_batched_gemv( .launch(stream) } +/// Batched V extraction for N-token MLA prefill. +/// For each (token, head): output[token, head, :] = W_UV[head] @ input[token, head, 0..k] +/// where input has input_head_stride elements per head (only first k are used). +/// +/// Grid: (ceil(n_out/8), num_heads, n_tokens) Block: (256, 1, 1) +#[allow(clippy::too_many_arguments)] +pub fn mla_v_extract_batched( + gpu: &dyn GpuBackend, + kernel: KernelHandle, + input: DevicePtr, + weight: DevicePtr, + output: DevicePtr, + n_out: u32, + k: u32, + num_heads: u32, + input_head_stride: u32, + output_head_stride: u32, + n_tokens: u32, + stream: u64, +) -> Result<()> { + KernelLaunch::new(gpu, kernel) + .grid([div_ceil(n_out, 8), num_heads, n_tokens]) + .block([256, 1, 1]) + .arg_ptr(input) + .arg_ptr(weight) + .arg_ptr(output) + .arg_u32(n_out) + .arg_u32(k) + .arg_u32(num_heads) + .arg_u32(input_head_stride) + .arg_u32(output_head_stride) + .launch(stream) +} + /// MLA Q_rope scatter: copy rope portion from q_full to strided q_absorbed_buf. 1 kernel replaces 32 D2D copies. #[allow(clippy::too_many_arguments)] pub fn mla_q_rope_scatter( diff --git a/crates/spark-model/src/layers/ops/prefill_attn_a.rs b/crates/spark-model/src/layers/ops/prefill_attn_a.rs index 189748ec..d3ccdfba 100644 --- a/crates/spark-model/src/layers/ops/prefill_attn_a.rs +++ b/crates/spark-model/src/layers/ops/prefill_attn_a.rs @@ -284,6 +284,51 @@ pub fn mla_prefill_attention_320( .launch(stream) } +/// Paged MLA prefill attention — absorbed form, HDIM=320, multi-chunk (seq_len_start > 0). +/// +/// Q [q_len, nq, 320] attends to KV cache (paged) over kv_len tokens with causal masking. +/// Q at local position i (global position q_offset + i) attends to KV 0..q_offset+i. +/// +/// Grid: (num_q_heads, ceil(q_len/16), 1) Block: (256, 1, 1) +#[allow(clippy::too_many_arguments)] +pub fn mla_prefill_paged_320( + gpu: &dyn GpuBackend, + kernel: KernelHandle, + q: DevicePtr, + k_cache: DevicePtr, + v_cache: DevicePtr, + output: DevicePtr, + block_table: DevicePtr, + q_len: u32, + kv_len: u32, + q_offset: u32, + num_q_heads: u32, + num_kv_heads: u32, + head_dim: u32, + cache_block_size: u32, + inv_sqrt_d: f32, + stream: u64, +) -> Result<()> { + let br = 16u32; // MLA_BR in the kernel + KernelLaunch::new(gpu, kernel) + .grid([num_q_heads, div_ceil(q_len, br), 1]) + .block([256, 1, 1]) + .arg_ptr(q) + .arg_ptr(k_cache) + .arg_ptr(v_cache) + .arg_ptr(output) + .arg_ptr(block_table) + .arg_u32(q_len) + .arg_u32(kv_len) + .arg_u32(q_offset) + .arg_u32(num_q_heads) + .arg_u32(num_kv_heads) + .arg_u32(head_dim) + .arg_u32(cache_block_size) + .arg_f32(inv_sqrt_d) + .launch(stream) +} + pub fn paged_decode_attn_bf16( gpu: &dyn GpuBackend, kernel: KernelHandle, diff --git a/crates/spark-model/src/layers/qwen3_attention/decode/attention_forward_mla.rs b/crates/spark-model/src/layers/qwen3_attention/decode/attention_forward_mla.rs index 977f0140..a99f14a0 100644 --- a/crates/spark-model/src/layers/qwen3_attention/decode/attention_forward_mla.rs +++ b/crates/spark-model/src/layers/qwen3_attention/decode/attention_forward_mla.rs @@ -372,7 +372,9 @@ impl Qwen3AttentionLayer { // Step 8: Paged decode attention let attn_out = ctx.buffers.attn_output(); - let inv_sqrt_d = self.effective_attn_scale(hd); + // Absorbed MLA decode operates in (kv_lora+rope)-dim space; 1/sqrt(hd=128) + // would over-sharpen softmax vs the correct 1/sqrt(kv_lora+rope=320). + let inv_sqrt_d = 1.0f32 / ((kv_lora + mla_rope) as f32).sqrt(); prof!("paged_attn", { ops::paged_decode_attn_bf16( ctx.gpu, diff --git a/crates/spark-model/src/layers/qwen3_attention/init.rs b/crates/spark-model/src/layers/qwen3_attention/init.rs index 3717b3ba..5cc39794 100644 --- a/crates/spark-model/src/layers/qwen3_attention/init.rs +++ b/crates/spark-model/src/layers/qwen3_attention/init.rs @@ -307,6 +307,16 @@ impl Qwen3AttentionLayer { "mla_fused_prefill", "mla_fused_prefill", ), + mla_prefill_paged_k: super::super::try_kernel( + gpu, + "mla_prefill_paged", + "mla_prefill_paged_320", + ), + mla_v_extract_batched_k: super::super::try_kernel( + gpu, + "mla_absorbed", + "mla_v_extract_batched", + ), gemm_splitk_partial_k: super::super::try_kernel( gpu, "gemm_splitk", @@ -391,6 +401,16 @@ impl Qwen3AttentionLayer { "inferspark_prefill_paged_512", ), prefill_attn_64_k: gpu.kernel("inferspark_prefill", "inferspark_prefill_64")?, + prefill_attn_128_k: super::super::try_kernel( + gpu, + "inferspark_prefill_128", + "inferspark_prefill_hd128", + ), + prefill_attn_64_128_k: super::super::try_kernel( + gpu, + "inferspark_prefill_128", + "inferspark_prefill_64_hd128", + ), prefill_attn_paged_k: gpu.kernel("prefill_paged", "inferspark_prefill_paged")?, prefill_attn_paged_fp8_k: gpu .kernel("prefill_paged_fp8", "inferspark_prefill_paged_fp8")?, diff --git a/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip.rs b/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip.rs index 8c910156..2b41edf9 100644 --- a/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip.rs +++ b/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip.rs @@ -89,16 +89,13 @@ impl Qwen3AttentionLayer { if self.mla.is_some() { let args = super::cache_skip_mla::CacheSkipMlaArgs { normed, - num_tokens, n, h, nq, - nkv, hd, - kv_dim, eps, - bf16, stream, + kv_write_start, }; return self.prefill_attention_cache_skip_mla(kv_cache, ctx, &args); } @@ -143,22 +140,6 @@ impl Qwen3AttentionLayer { q_proj_dim as u32, stream, )?; - } else if self.mla.is_some() { - // DIAGNOSTIC: check V BEFORE Q copy - if self.attn_layer_idx == 0 && ctx.config.model_type == "mistral" { - ctx.gpu.synchronize(stream)?; - let v_chk = k_contiguous.offset(num_tokens * kv_dim * bf16); - crate::layers::qwen3_attention::trait_impl::diag_norm( - ctx.gpu, - v_chk, - (nkv * hd) as usize, - stream, - "L0 V BEFORE Q_copy", - ); - } - ctx.gpu - .copy_d2d_async(qg_out, q_contiguous, num_tokens * q_dim * bf16, stream) - .map_err(|e| anyhow::anyhow!("MLA Q copy failed: {e}"))?; } else { ctx.gpu .copy_d2d_async(qg_out, q_contiguous, num_tokens * q_dim * bf16, stream) diff --git a/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip_mla.rs b/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip_mla.rs index eec76029..344397d0 100644 --- a/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip_mla.rs +++ b/crates/spark-model/src/layers/qwen3_attention/prefill/cache_skip_mla.rs @@ -1,9 +1,9 @@ // SPDX-License-Identifier: AGPL-3.0-only //! MLA branch of `prefill_attention_with_cache_skip`. Mistral4-style -//! 2-step prefill with the unabsorbed/MHA fused fallback path that -//! expands K/V via `wkv_b` and runs HDIM=128 FlashAttention. Extracted -//! from `cache_skip.rs` to keep that file under 500 LoC. +//! absorbed MLA prefill: Q_absorption + causal attention + V_extraction +//! via `mla_fused_prefill` (HDIM=320 absorbed space). Extracted from +//! `cache_skip.rs` to keep that file under 500 LoC. use anyhow::Result; use spark_runtime::gpu::DevicePtr; @@ -13,19 +13,18 @@ use super::super::Qwen3AttentionLayer; use crate::layer::ForwardContext; use crate::layers::ops; -#[allow(clippy::too_many_arguments)] pub(super) struct CacheSkipMlaArgs { pub normed: DevicePtr, - pub num_tokens: usize, pub n: u32, pub h: u32, pub nq: u32, - pub nkv: u32, pub hd: u32, - pub kv_dim: usize, pub eps: f32, - pub bf16: usize, pub stream: u64, + /// Number of token positions whose KV entries are already in the cache + /// (prefix-cache hit). Only tokens `kv_write_start..n` need to be written. + /// 0 = no cached prefix (all tokens are new). + pub kv_write_start: usize, } impl Qwen3AttentionLayer { @@ -37,19 +36,7 @@ impl Qwen3AttentionLayer { ctx: &ForwardContext, args: &CacheSkipMlaArgs, ) -> Result { - let CacheSkipMlaArgs { - normed, - num_tokens, - n, - h, - nq, - nkv, - hd, - kv_dim, - eps, - bf16, - stream, - } = *args; + let CacheSkipMlaArgs { normed, n, h, nq, hd, eps, stream, kv_write_start } = *args; let mla = self .mla .as_ref() @@ -243,86 +230,70 @@ impl Qwen3AttentionLayer { mla_cache_dim, stream, )?; - self.write_kv_cache( - ctx.gpu, - k_cache_assembled, - v_cache_assembled, - kv_cache, - meta.slot, - n, - 1, - mla_cache_dim, - bs as u32, - mla_cache_dim, - mla_cache_dim, - stream, - ctx.graph_capture, - )?; + // Only write the tokens that are NOT already in the cache. + // kv_write_start tokens (prefix-cache hit) already have correct KV + // entries at their physical slots; skip them to avoid redundant writes. + // Mirror of the non-MLA `write_start` logic in cache_skip.rs. + let write_count = (n as usize).saturating_sub(kv_write_start); + if write_count > 0 { + let bf16 = 2usize; // bytes per BF16 element + let cache_elem_offset = kv_write_start * mla_cache_dim as usize; + let slot_byte_offset = kv_write_start * 8; // 8 bytes per u64 slot entry + self.write_kv_cache( + ctx.gpu, + k_cache_assembled.offset(cache_elem_offset * bf16), + v_cache_assembled.offset(cache_elem_offset * bf16), + kv_cache, + meta.slot.offset(slot_byte_offset), + write_count as u32, + 1, + mla_cache_dim, + bs as u32, + mla_cache_dim, + mla_cache_dim, + stream, + ctx.graph_capture, + )?; + } - // Unabsorbed (MHA) prefill: expand K/V via wkv_b, use HDIM=128 FlashAttention - let kv_expanded_dim = nkv * (mla_nope + mla_v_dim); - let kv_expanded = ctx.buffers.ssm_deinterleaved(); - ops::dense_gemm( + // MLA absorbed attention: fused Q_absorb + attention (320-dim) + V_extract. + // inferspark_prefill_64 has compile-time HDIM=256; MLA kv_stride=nkv*hd=128 so + // col>=128 aliases K[k+1][0..127] — corrupts attention scores over long contexts. + // inv_sqrt_d: 1/sqrt(kv_lora + rope) = 1/sqrt(320) — absorbed dimension, NOT hd. + // Using 1/sqrt(hd=128) would over-sharpen softmax by sqrt(128/320) ≈ 0.63. + let attn_out_fb = ctx.buffers.attn_output(); + // inv_sqrt_d in the absorbed space: 1/sqrt(kv_lora + rope) = 1/sqrt(320). + // Using 1/sqrt(hd=128) would over-sharpen softmax by sqrt(128/320) ≈ 0.63. + let inv_sqrt_d_absorbed = 1.0f32 / ((kv_lora + mla_rope) as f32).sqrt(); + anyhow::ensure!( + self.mla_fused_prefill_k.0 != 0, + "MLA cache-skip prefill requires mla_fused_prefill kernel \ + (inferspark_prefill HDIM=256 is broken for MLA hd=128; \ + rebuild with kernels/gb10/mistral-small-4/nvfp4/mla_fused_prefill.cu)" + ); + ops::mla_fused_prefill( ctx.gpu, - self.dense_gemm_k, + self.mla_fused_prefill_k, + qg_out, + q_rope_tmp, kv_latent, - &mla.wkv_b, - kv_expanded, - n, - kv_expanded_dim, - kv_lora, - stream, - )?; - let k_contiguous = ctx.buffers.ssm_qkvz(); - let v_contiguous = k_contiguous.offset(num_tokens * kv_dim * bf16); - ops::mla_kv_assemble_batched( - ctx.gpu, - self.mla_kv_assemble_batched_k, - kv_expanded, k_rope_buf, - k_contiguous, - v_contiguous, - n, - nkv, - mla_nope, - mla_v_dim, - mla_rope, - hd, - nkv * (mla_nope + mla_v_dim), - stream, - )?; - ops::mla_q_rope_writeback_batched( - ctx.gpu, - self.mla_q_rope_writeback_batched_k, - q_rope_tmp, - qg_out, + mla.w_uk_t.weight, + mla.w_uv.weight, + attn_out_fb, + DevicePtr::NULL, + DevicePtr::NULL, n, nq, - hd, mla_nope, mla_rope, - nq * hd, - stream, - )?; - let attn_out_fb = ctx.buffers.attn_output(); - ops::prefill_attention_64( - ctx.gpu, - self.prefill_attn_64_k, - qg_out, - k_contiguous, - v_contiguous, - attn_out_fb, - n, - 1, - nq, - nkv, + kv_lora, + mla_v_dim, hd, - 1.0f32 / (hd as f32).sqrt(), - true, - 0, + inv_sqrt_d_absorbed, stream, ) - .map_err(|e| anyhow::anyhow!("MLA flash_attn_64 fallback: {e}"))?; + .map_err(|e| anyhow::anyhow!("MLA fused prefill: {e}"))?; // wo projection — output to qkv_output (norm_output aliases downstream) let o_out = ctx.buffers.qkv_output(); if let Some(ref wo_nvfp4) = mla.wo_nvfp4 { @@ -334,7 +305,7 @@ impl Qwen3AttentionLayer { o_out, n, h, - nq * hd, + nq * mla_v_dim, stream, )?; } else { @@ -346,7 +317,7 @@ impl Qwen3AttentionLayer { o_out, n, h, - nq * hd, + nq * mla_v_dim, stream, )?; } diff --git a/crates/spark-model/src/layers/qwen3_attention/prefill/paged.rs b/crates/spark-model/src/layers/qwen3_attention/prefill/paged.rs index f762529b..a628e21b 100644 --- a/crates/spark-model/src/layers/qwen3_attention/prefill/paged.rs +++ b/crates/spark-model/src/layers/qwen3_attention/prefill/paged.rs @@ -71,6 +71,7 @@ impl Qwen3AttentionLayer { bf16, bs: bs as u32, stream, + seq_len_start, }; return self.prefill_attention_paged_mla(kv_cache, ctx, &args); } diff --git a/crates/spark-model/src/layers/qwen3_attention/prefill/paged_mla.rs b/crates/spark-model/src/layers/qwen3_attention/prefill/paged_mla.rs index c7999e66..8d4301d7 100644 --- a/crates/spark-model/src/layers/qwen3_attention/prefill/paged_mla.rs +++ b/crates/spark-model/src/layers/qwen3_attention/prefill/paged_mla.rs @@ -27,6 +27,8 @@ pub(super) struct MlaPrefillArgs { pub bf16: usize, pub bs: u32, pub stream: u64, + /// Tokens already written to KV cache before this chunk (0 for the first chunk). + pub seq_len_start: usize, } impl Qwen3AttentionLayer { @@ -52,6 +54,7 @@ impl Qwen3AttentionLayer { bf16, bs, stream, + seq_len_start, } = *args; let mla = self .mla @@ -63,8 +66,14 @@ impl Qwen3AttentionLayer { let mla_nope = mla.nope as u32; let mla_v_dim = mla.v_dim as u32; let mla_rope = mla.rope as u32; + let mla_cache_dim = kv_lora + mla_rope; + let kv_len = seq_len_start + num_tokens; // full context length after this chunk - // Q: latent → norm → expand → [N, nq*hd] in [nope|rope] per head + // ── Step 1: Q projection (shared by both paths) ────────────────────── + // Q: normed → wq_a → rms_norm → wq_b → [N, nq*hd] in [nope|rope] per head. + // q_latent lives in ssm_ba. IMPORTANT: ssm_ba is later aliased by k_rope_buf + // (line after wq_b). Any computation that needs q_latent must happen before + // the k_rope_buf dense_gemm below. let q_latent = ctx.buffers.ssm_ba(); ops::dense_gemm( ctx.gpu, @@ -101,7 +110,7 @@ impl Qwen3AttentionLayer { stream, )?; - // KV: latent → norm → expand + // ── Step 2: KV latent projection (shared by both paths) ────────────── let kv_latent = ctx.buffers.expert_gate_out(); ops::dense_gemm( ctx.gpu, @@ -125,21 +134,237 @@ impl Qwen3AttentionLayer { eps, stream, )?; - let kv_expanded_dim = nkv * (mla_nope + mla_v_dim); - let kv_expanded = ctx.buffers.ssm_deinterleaved(); + + if seq_len_start == 0 { + // ════════════════════════════════════════════════════════════════════ + // FIRST-CHUNK PATH (seq_len_start == 0): unabsorbed form. + // + // Expand KV via wkv_b, assemble contiguous K/V, and run flash + // attention over the N new tokens only (no historical context). + // ════════════════════════════════════════════════════════════════════ + + let kv_expanded_dim = nkv * (mla_nope + mla_v_dim); + let kv_expanded = ctx.buffers.ssm_deinterleaved(); + ops::dense_gemm( + ctx.gpu, + self.dense_gemm_k, + kv_latent, + &mla.wkv_b, + kv_expanded, + n, + kv_expanded_dim, + kv_lora, + stream, + )?; + + // K_rope: single shared head [N, rope] (MQA-style) + let k_rope_buf = ctx.buffers.ssm_ba(); // aliases q_latent — OK, consumed above + ops::dense_gemm( + ctx.gpu, + self.dense_gemm_k, + normed, + &mla.wkv_a_rope, + k_rope_buf, + n, + mla_rope, + h, + stream, + )?; + + let q_rope_tmp = ctx.buffers.ssm_conv_out_f32(); + ops::mla_q_rope_extract_batched( + ctx.gpu, + self.mla_q_rope_extract_batched_k, + qg_out, + q_rope_tmp, + n, + nq, + hd, + mla_nope, + mla_rope, + nq * hd, + stream, + )?; + let rope_meta = ctx.attn_metadata.expect("MLA prefill requires metadata"); + ops::rope_yarn( + ctx.gpu, + self.rope_yarn_k, + q_rope_tmp, + k_rope_buf, + rope_meta.positions, + n, + nq, + 1, + mla_rope, + mla_rope, + mla.yarn_inv_freq, + ctx.config.rope_theta as f32, + stream, + )?; + ops::mla_q_rope_writeback_batched( + ctx.gpu, + self.mla_q_rope_writeback_batched_k, + q_rope_tmp, + qg_out, + n, + nq, + hd, + mla_nope, + mla_rope, + nq * hd, + stream, + )?; + + let k_contiguous = ctx.buffers.ssm_qkvz(); + let v_contiguous = k_contiguous.offset(num_tokens * kv_dim * bf16); + ops::mla_kv_assemble_batched( + ctx.gpu, + self.mla_kv_assemble_batched_k, + kv_expanded, + k_rope_buf, + k_contiguous, + v_contiguous, + n, + nkv, + mla_nope, + mla_v_dim, + mla_rope, + hd, + nkv * (mla_nope + mla_v_dim), + stream, + )?; + + let mla_k_cache = ctx.buffers.expert_down_out(); + let mla_v_cache = mla_k_cache.offset(num_tokens * mla_cache_dim as usize * bf16); + ops::mla_cache_assemble_batched( + ctx.gpu, + self.mla_cache_assemble_batched_k, + kv_latent, + k_rope_buf, + mla_k_cache, + mla_v_cache, + n, + kv_lora, + mla_rope, + mla_cache_dim, + stream, + )?; + let meta = ctx.attn_metadata.expect("MLA prefill requires slot info"); + self.write_kv_cache( + ctx.gpu, + mla_k_cache, + mla_v_cache, + kv_cache, + meta.slot, + n, + 1, + mla_cache_dim, + bs, + mla_cache_dim, + mla_cache_dim, + stream, + ctx.graph_capture, + )?; + + let attn_out = ctx.buffers.attn_output(); + let inv_sqrt_d = self.effective_attn_scale(hd); + // For MLA unabsorbed path hd=128; HDIM=256 kernel reads K[k+1][0..127] + // for d>=128, contaminating scores. Require the correct HDIM=128 kernel. + anyhow::ensure!( + hd > 128 || self.prefill_attn_128_k.0 != 0, + "MLA paged prefill (first chunk): head_dim={hd} requires \ + inferspark_prefill_hd128 (HDIM=256 over-reads adjacent K heads for hd<=128)", + ); + let prefill_k = if hd > 256 && self.prefill_attn_512_k.0 != 0 { + self.prefill_attn_512_k + } else if hd <= 128 { + self.prefill_attn_128_k + } else { + self.prefill_attn_k + }; + ops::prefill_attention( + ctx.gpu, + prefill_k, + qg_out, + k_contiguous, + v_contiguous, + attn_out, + n, + 1, + nq, + nkv, + hd, + inv_sqrt_d, + true, + self.sliding_window.unwrap_or(0), + stream, + )?; + + // O projection: [N, nq*hd] → [N, H] + let o_out = ctx.buffers.norm_output(); + if let Some(ref wo_nvfp4) = mla.wo_nvfp4 { + ops::w4a16_gemm( + ctx.gpu, + self.w4a16_gemm_k, + attn_out, + wo_nvfp4, + o_out, + n, + h, + nq * hd, + stream, + )?; + } else { + ops::dense_gemm( + ctx.gpu, + self.dense_gemm_k, + attn_out, + &mla.wo, + o_out, + n, + h, + nq * hd, + stream, + )?; + } + return Ok(o_out); + } + + // ════════════════════════════════════════════════════════════════════════ + // MULTI-CHUNK PATH (seq_len_start > 0): absorbed form with paged KV. + // + // Previous chunks have already written seq_len_start tokens into the paged + // KV cache. The current chunk (n tokens) must attend to the full context + // (kv_len = seq_len_start + n tokens) using the compressed [kv_lora|rope]=320 + // dim cache format. + // + // Buffer plan (non-overlapping within this path): + // ssm_deinterleaved → Q_absorbed [N, nq*kv_lora] (before k_rope_buf aliases ssm_ba) + // ssm_conv_out_f32 → Q_rope [N, nq*rope] (post-RoPE) + // expert_down_out → mla_k_cache / mla_v_cache (compressed cache to write) + // attn_output → Q_final [N, nq, 320] (assembled absorbed Q) + // ssm_deinterleaved → attn_out [N, nq, 320] (paged attention output, reuse) + // attn_output → v_extracted[N, nq, v_dim] (V extraction output, reuse) + // norm_output → o_out [N, H] + // ════════════════════════════════════════════════════════════════════════ + + // ── Step A: Q_absorbed = q_latent @ w_qk_absorbed^T → ssm_deinterleaved ── + // Must happen BEFORE k_rope_buf aliases ssm_ba and overwrites q_latent. + // ssm_deinterleaved is free here: wkv_b is skipped for the absorbed path. + let q_absorbed = ctx.buffers.ssm_deinterleaved(); ops::dense_gemm( ctx.gpu, self.dense_gemm_k, - kv_latent, - &mla.wkv_b, - kv_expanded, + q_latent, + &mla.w_qk_absorbed, + q_absorbed, n, - kv_expanded_dim, - kv_lora, + nq * kv_lora, + q_lora, stream, )?; - // K_rope: single shared head [N, rope=64] (MQA-style) + // ── Step B: K_rope (aliases ssm_ba, overwriting q_latent — OK, consumed above) ── let k_rope_buf = ctx.buffers.ssm_ba(); ops::dense_gemm( ctx.gpu, @@ -153,7 +378,7 @@ impl Qwen3AttentionLayer { stream, )?; - // Apply RoPE to Q rope portions and K_rope BEFORE assembly + // ── Step C: Extract Q_rope and apply RoPE ── let q_rope_tmp = ctx.buffers.ssm_conv_out_f32(); ops::mla_q_rope_extract_batched( ctx.gpu, @@ -184,42 +409,10 @@ impl Qwen3AttentionLayer { ctx.config.rope_theta as f32, stream, )?; - ops::mla_q_rope_writeback_batched( - ctx.gpu, - self.mla_q_rope_writeback_batched_k, - q_rope_tmp, - qg_out, - n, - nq, - hd, - mla_nope, - mla_rope, - nq * hd, - stream, - )?; + // Q_rope is now in q_rope_tmp [N, nq*rope]; no writeback to qg_out needed + // (qg_out not used for attention in this path). - // Assemble K=[nope|rope] and extract V (1 kernel vs N*nkv*3 copies) - let k_contiguous = ctx.buffers.ssm_qkvz(); - let v_contiguous = k_contiguous.offset(num_tokens * kv_dim * bf16); - ops::mla_kv_assemble_batched( - ctx.gpu, - self.mla_kv_assemble_batched_k, - kv_expanded, - k_rope_buf, - k_contiguous, - v_contiguous, - n, - nkv, - mla_nope, - mla_v_dim, - mla_rope, - hd, - nkv * (mla_nope + mla_v_dim), - stream, - )?; - - // Write compressed MLA cache - let mla_cache_dim = kv_lora + mla_rope; + // ── Step D: Write compressed KV cache ── let mla_k_cache = ctx.buffers.expert_down_out(); let mla_v_cache = mla_k_cache.offset(num_tokens * mla_cache_dim as usize * bf16); ops::mla_cache_assemble_batched( @@ -252,56 +445,94 @@ impl Qwen3AttentionLayer { ctx.graph_capture, )?; - // Direct flash attention with expanded Q/K/V (not from paged cache). - let attn_out = ctx.buffers.attn_output(); - let inv_sqrt_d = self.effective_attn_scale(hd); - let prefill_k = if hd > 256 && self.prefill_attn_512_k.0 != 0 { - self.prefill_attn_512_k - } else { - self.prefill_attn_k - }; - ops::prefill_attention( + // ── Step E: Assemble Q_final [N, nq, mla_cache_dim] in attn_output ── + // q_absorbed = ssm_deinterleaved [N, nq*kv_lora] + // q_rope = ssm_conv_out_f32 [N, nq*rope] + // q_final = attn_output [N, nq*mla_cache_dim] + let q_final = ctx.buffers.attn_output(); + ops::mla_q_final_assemble_batched( ctx.gpu, - prefill_k, - qg_out, - k_contiguous, - v_contiguous, + self.mla_q_final_assemble_k, + q_absorbed, + q_rope_tmp, + q_final, + n, + nq, + kv_lora, + mla_rope, + mla_cache_dim, + stream, + )?; + + // ── Step F: Paged MLA prefill attention ── + // Q = attn_output [N, nq, 320], KV from paged cache. + // Output → ssm_deinterleaved [N, nq, 320] (q_absorbed buffer, now free). + // Causal: Q[i] at global pos (seq_len_start + i) attends to KV 0..=seq_len_start+i. + let attn_out = ctx.buffers.ssm_deinterleaved(); + let inv_sqrt_d = 1.0f32 / (mla_cache_dim as f32).sqrt(); + ops::mla_prefill_paged_320( + ctx.gpu, + self.mla_prefill_paged_k, + q_final, + kv_cache.k_pool_ptr(self.attn_layer_idx), + kv_cache.v_pool_ptr(self.attn_layer_idx), attn_out, + meta.block_table, n, - 1, + kv_len as u32, + seq_len_start as u32, nq, - nkv, - hd, + 1, // num_kv_heads = 1 (MQA compressed cache) + mla_cache_dim, + bs, inv_sqrt_d, - true, - self.sliding_window.unwrap_or(0), stream, )?; - // O projection: [N, nq*hd] → [N, H] + // ── Step G: V extraction — [N, nq, 320] → [N, nq, v_dim] ── + // attn_out (ssm_deinterleaved) has absorbed attention output [N, nq, 320]. + // Only the first kv_lora=256 dims per head feed into V extraction. + // Output → attn_output (q_final buffer, now free). + let v_extracted = ctx.buffers.attn_output(); + ops::mla_v_extract_batched( + ctx.gpu, + self.mla_v_extract_batched_k, + attn_out, + mla.w_uv.weight, + v_extracted, + mla_v_dim, + kv_lora, + nq, + mla_cache_dim, // input_head_stride: 320 (first kv_lora=256 dims used) + mla_v_dim, // output_head_stride: 128 + n, + stream, + )?; + + // ── Step H: O projection — [N, nq*v_dim] → [N, H] ── let o_out = ctx.buffers.norm_output(); if let Some(ref wo_nvfp4) = mla.wo_nvfp4 { ops::w4a16_gemm( ctx.gpu, self.w4a16_gemm_k, - attn_out, + v_extracted, wo_nvfp4, o_out, n, h, - nq * hd, + nq * mla_v_dim, stream, )?; } else { ops::dense_gemm( ctx.gpu, self.dense_gemm_k, - attn_out, + v_extracted, &mla.wo, o_out, n, h, - nq * hd, + nq * mla_v_dim, stream, )?; } diff --git a/crates/spark-model/src/layers/qwen3_attention/types.rs b/crates/spark-model/src/layers/qwen3_attention/types.rs index 96c8a2c9..d23fe52c 100644 --- a/crates/spark-model/src/layers/qwen3_attention/types.rs +++ b/crates/spark-model/src/layers/qwen3_attention/types.rs @@ -174,6 +174,10 @@ pub struct Qwen3AttentionLayer { pub(super) mla_q_final_assemble_k: KernelHandle, /// Fused MLA prefill: Q_absorb + attention + V_extract in one kernel. pub(super) mla_fused_prefill_k: KernelHandle, + /// Paged MLA prefill attention (HDIM=320) for multi-chunk prefill (seq_len_start > 0). + pub(super) mla_prefill_paged_k: KernelHandle, + /// Batched V extraction for N-token MLA prefill: [N, nq, mla_cache_dim] → [N, nq, v_dim]. + pub(super) mla_v_extract_batched_k: KernelHandle, /// Split-K GEMM for skinny prefill matrices (M < 64). pub(super) gemm_splitk_partial_k: KernelHandle, pub(super) gemm_splitk_reduce_k: KernelHandle, @@ -210,6 +214,10 @@ pub struct Qwen3AttentionLayer { /// HDIM=512 paged prefill (BF16 KV) for Gemma-4 chunked long-context prefill pub(super) prefill_attn_paged_512_k: KernelHandle, pub(super) prefill_attn_64_k: KernelHandle, + /// HDIM=128 contiguous prefill — BR=32 (MLA unabsorbed prefill, head_dim=128) + pub(super) prefill_attn_128_k: KernelHandle, + /// HDIM=128 contiguous prefill — BR=64 (MLA unabsorbed prefill, seq_len>=256) + pub(super) prefill_attn_64_128_k: KernelHandle, pub(super) prefill_attn_paged_k: KernelHandle, pub(super) prefill_attn_paged_fp8_k: KernelHandle, pub(super) prefill_attn_paged_nvfp4_k: KernelHandle, diff --git a/crates/spark-model/src/layers/qwen3_ssm/init.rs b/crates/spark-model/src/layers/qwen3_ssm/init.rs index e41422dc..e526f034 100644 --- a/crates/spark-model/src/layers/qwen3_ssm/init.rs +++ b/crates/spark-model/src/layers/qwen3_ssm/init.rs @@ -50,18 +50,27 @@ impl Qwen3SsmLayer { deinterleave_k: gpu.kernel("ssm_preprocess", "deinterleave_qkvz")?, conv1d_k: gpu.kernel("causal_conv1d", "causal_conv1d_update")?, conv1d_l2norm_k: gpu.kernel("causal_conv1d", "causal_conv1d_update_l2norm")?, + // FP32 conv1d output prevents BF16 truncation in the recurrent + // path from compounding past ~8k tokens. The Metal backend + // (kernels/metal/common/causal_conv1d_update_l2norm.metal) only + // ships the BF16 variant; on those targets we fall back to the + // BF16 kernel via the `.0 != 0` gate at the use site + // (ssm_forward.rs). Warn instead of error: missing-on-Metal is + // expected, and a startup `error!` would page on benign cases. conv1d_l2norm_f32_k: { - let k = gpu.kernel("causal_conv1d", "causal_conv1d_update_l2norm_f32"); - match k { - Ok(h) => h, - Err(_) => { - tracing::error!( - "FP32 conv1d kernel not found — SSM long-context coherence \ - WILL degrade after ~8k tokens due to BF16 precision loss" - ); - KernelHandle(0) - } + let h = super::super::try_kernel( + gpu, + "causal_conv1d", + "causal_conv1d_update_l2norm_f32", + ); + if h.0 == 0 { + tracing::warn!( + "FP32 conv1d kernel not loaded; SSM uses BF16 conv \ + output. Expect long-context coherence drift past ~8k \ + tokens on this backend." + ); } + h }, gdn_k: gpu.kernel("gated_delta_rule", "gated_delta_rule_decode")?, gdn_f32_k: super::super::try_kernel( diff --git a/crates/spark-model/src/mistral_loader/loader_impl/phase_assemble.rs b/crates/spark-model/src/mistral_loader/loader_impl/phase_assemble.rs index 6a73f0f6..975d8fec 100644 --- a/crates/spark-model/src/mistral_loader/loader_impl/phase_assemble.rs +++ b/crates/spark-model/src/mistral_loader/loader_impl/phase_assemble.rs @@ -116,7 +116,12 @@ pub(super) fn assemble_layer( let input_norm = dense(ctx.store, &format!("{prefix}.attention_norm.weight"))?; let post_norm = dense(ctx.store, &format!("{prefix}.ffn_norm.weight"))?; - let kv_dtype = layer_kv_dtypes.get(i).copied().unwrap_or(KvCacheDtype::Fp8); + // MLA compressed latents require BF16 precision. + // build_layer_kv_dtypes returns vec![BF16; n] when kv_dtype == BF16, so + // get(i) = Some(BF16) for all valid i. The unwrap_or(BF16) is a safety + // fallback for the kv_dtype!=BF16 + high_precision_layers=0 case where + // the vec is empty — ensures MLA layers never silently get FP8. + let kv_dtype = layer_kv_dtypes.get(i).copied().unwrap_or(KvCacheDtype::Bf16); // ── MoE experts (w1=gate, w2=down, w3=up) ── let ffn = build_moe_ffn(ctx.store, i, gpu, config); diff --git a/crates/spark-server/src/anthropic/handlers.rs b/crates/spark-server/src/anthropic/handlers.rs index 910c351e..890ecc75 100644 --- a/crates/spark-server/src/anthropic/handlers.rs +++ b/crates/spark-server/src/anthropic/handlers.rs @@ -319,17 +319,30 @@ pub async fn count_tokens( msg }) .collect(); - let jinja_tools: Option> = if tools_active { - req.tools.as_ref().map(|ts| { - let oai = convert_tools(ts); - oai.iter().map(|t| serde_json::json!({ - "type": "function", - "function": { "name": t.function.name, "description": t.function.description, "parameters": t.function.parameters } - })).collect() - }) - } else { - None - }; + // Mirror template.rs: skip_template_tools (MODEL.toml) OR + // parser.suppresses_jinja_tools() (parser-level trait) means the parser's + // system_prompt() is the sole source of tool schema, so jinja_tools must + // be None here too. Passing tools to the Jinja template when either flag + // is set would count the XML block tokens that the real prompt + // never includes, inflating the returned count. Added suppresses_jinja_tools + // check here to mirror the template.rs path (6b6e755 added the trait but + // only updated the OpenAI path; Anthropic count_tokens was missed). + let parser_suppresses = state + .tool_call_parser + .as_ref() + .is_some_and(|p| p.suppresses_jinja_tools()); + let jinja_tools: Option> = + if tools_active && !state.behavior.skip_template_tools && !parser_suppresses { + req.tools.as_ref().map(|ts| { + let oai = convert_tools(ts); + oai.iter().map(|t| serde_json::json!({ + "type": "function", + "function": { "name": t.function.name, "description": t.function.description, "parameters": t.function.parameters } + })).collect() + }) + } else { + None + }; let input_tokens = match state.tokenizer.apply_chat_template_jinja( &json_messages, diff --git a/crates/spark-server/src/api/chat/mod.rs b/crates/spark-server/src/api/chat/mod.rs index 311c9c58..1e2e38d3 100644 --- a/crates/spark-server/src/api/chat/mod.rs +++ b/crates/spark-server/src/api/chat/mod.rs @@ -100,6 +100,23 @@ pub(crate) async fn chat_completions_inner( && req.tools.as_ref().is_some_and(|t| !t.is_empty()) && !req.tool_choice.as_ref().is_some_and(|tc| tc.is_none()); + // Inject parser-specific behavioral system prompt when tools are active. + // Each ToolCallParser defines guardrails (e.g. "emit immediately, + // do not narrate") that the Jinja chat template alone does not enforce. + if tools_active && let Some(ref parser) = state.tool_call_parser { + let default_choice = crate::tool_parser::ToolChoice::Mode("auto".to_string()); + let tool_choice = req.tool_choice.as_ref().unwrap_or(&default_choice); + let tool_prompt = parser.system_prompt(req.tools.as_deref().unwrap_or(&[]), tool_choice); + if let Some(first) = req.messages.first_mut().filter(|m| m.role == "system") { + first.content.text = format!("{}\n\n{}", tool_prompt, first.content.text); + } else { + req.messages.insert( + 0, + crate::openai::IncomingMessage::synthetic_system(tool_prompt), + ); + } + } + tracing::info!( "Request: model={}, messages={}, tools={}, tools_active={}, tool_choice={:?}, stream={}, temp={:?}, max_tokens={}, freq_pen={:?}, rep_pen={:?}", req.model, @@ -126,6 +143,18 @@ pub(crate) async fn chat_completions_inner( Err(resp) => return resp, }; + // ── Phase 1.5: merge server-level chat_template_kwargs default ─ + // When the client sends no thinking parameters and the server has a + // --default-chat-template-kwargs flag set, inject those kwargs into + // the request so the existing resolve_thinking() chain sees them as + // normal request-body fields. We don't mutate the resolution logic — + // we just pre-populate the field it already checks. + if let Some(ref default_kw) = state.default_chat_template_kwargs + && !req.thinking_explicitly_requested() + { + req.chat_template_kwargs = Some(default_kw.clone()); + } + // ── Phase 2: thinking resolution (pre-template) ───────────── let (enable_thinking, thinking_budget) = thinking::resolve_thinking(&state, &req, tools_active); diff --git a/crates/spark-server/src/api/chat/template.rs b/crates/spark-server/src/api/chat/template.rs index 76298705..53eedc28 100644 --- a/crates/spark-server/src/api/chat/template.rs +++ b/crates/spark-server/src/api/chat/template.rs @@ -74,15 +74,27 @@ pub(super) fn render_template( msg }) .collect(); - let jinja_tools: Option> = if tools_active { - req.tools.as_ref().map(|ts| { - ts.iter() - .map(|t| serde_json::to_value(t).unwrap_or_default()) - .collect() - }) - } else { - None - }; + // When skip_template_tools (MODEL.toml) or parser.suppresses_jinja_tools() + // is set, the parser's system_prompt() is the sole source of tool schema + // and format instructions. Passing jinja_tools here would cause the template + // to also render tool defs in its own format (e.g. XML for nemotron_h.jinja), + // producing contradictory instructions (e.g. bare_json says "emit JSON, no + // tags"; nemotron_h.jinja says "NEVER emit JSON, use XML"). + // Either flag independently suppresses jinja tool rendering. + let parser_suppresses = state + .tool_call_parser + .as_ref() + .is_some_and(|p| p.suppresses_jinja_tools()); + let jinja_tools: Option> = + if tools_active && !state.behavior.skip_template_tools && !parser_suppresses { + req.tools.as_ref().map(|ts| { + ts.iter() + .map(|t| serde_json::to_value(t).unwrap_or_default()) + .collect() + }) + } else { + None + }; // Progressive auto-compact (DISABLED BY DEFAULT 2026-04-25 — // see project_no_auto_compaction memory feedback). diff --git a/crates/spark-server/src/api/failures/duplicate.rs b/crates/spark-server/src/api/failures/duplicate.rs index f9d647f2..aa9cac8b 100644 --- a/crates/spark-server/src/api/failures/duplicate.rs +++ b/crates/spark-server/src/api/failures/duplicate.rs @@ -299,8 +299,7 @@ pub fn strip_xml_leaks_from_assistant_content( // ` envelope (dump 2026-04-25 seq=104..111) // from polluting the next turn's prompt, which previously // taught the model that emitting the prose envelope is a valid - // tool-call substitute (Phase-1 leak-then-collapse pattern, - // see history/SSM_CATASTROPHIC_FORGETTING_TODO.md). + // tool-call substitute (Phase-1 leak-then-collapse pattern). const HARNESS_TAGS: &[&str] = &["task", "file", "content", "description", "prompt", "glob"]; let mut leak_names: Vec = tool_defs .iter() diff --git a/crates/spark-server/src/cli.rs b/crates/spark-server/src/cli.rs index 749cac78..fe41c897 100644 --- a/crates/spark-server/src/cli.rs +++ b/crates/spark-server/src/cli.rs @@ -95,6 +95,16 @@ pub struct ServeArgs { #[arg(long)] pub max_thinking_budget: Option, + /// Default chat template kwargs applied when the client sends no + /// thinking parameters (no `reasoning.effort`, `chat_template_kwargs`, + /// or `enable_thinking` in the request body). A JSON object with + /// optional keys: `enable_thinking` (bool), `thinking_budget` (u32). + /// + /// Precedence (highest wins): request body → this flag → MODEL.toml. + /// Example: `--default-chat-template-kwargs '{"enable_thinking":true}'` + #[arg(long, value_name = "JSON")] + pub default_chat_template_kwargs: Option, + /// Currently slower than regular decode for hybrid SSM models. #[arg(long, default_value_t = false)] pub speculative: bool, @@ -147,10 +157,11 @@ pub struct ServeArgs { #[arg(long, default_value_t = 8)] pub max_batch_size: usize, - /// MTP head weight precision: nvfp4 (fastest, recommended — uses fused - /// device-side expert dispatch), fp8 (balanced but slower due to D2H sync - /// in MoE), bf16 (highest accuracy, most memory). - #[arg(long, default_value = "nvfp4")] + /// MTP head weight precision: bf16 (highest accuracy, most memory — + /// the default), fp8 (balanced but slower due to D2H sync in MoE), + /// nvfp4 (fastest — uses fused device-side expert dispatch; opt in + /// explicitly when throughput matters more than accuracy). + #[arg(long, default_value = "bf16")] pub mtp_quantization: String, /// MTP draft vocabulary size. Limits the LM head GEMV to the first N diff --git a/crates/spark-server/src/main_modules/app_state.rs b/crates/spark-server/src/main_modules/app_state.rs index 5e1f05c8..f9c8de87 100644 --- a/crates/spark-server/src/main_modules/app_state.rs +++ b/crates/spark-server/src/main_modules/app_state.rs @@ -72,6 +72,10 @@ pub struct AppState { /// thinking is forced OFF regardless of the request body or the /// model's MODEL.toml default. Wired from `--disable-thinking`. pub disable_thinking: bool, + /// Server-level default chat template kwargs applied when the client + /// sends no thinking parameters. Overridden per-request by the request + /// body. Wired from `--default-chat-template-kwargs`. + pub default_chat_template_kwargs: Option, /// Shared in-memory store for stateful Responses API resume /// (`previous_response_id`) and opt-in Chat-Completions storage /// (`store: true`). Bounded LRU + TTL; env-configured at startup. diff --git a/crates/spark-server/src/main_modules/kv_dtypes.rs b/crates/spark-server/src/main_modules/kv_dtypes.rs index 2c200a22..ccb0b1a1 100644 --- a/crates/spark-server/src/main_modules/kv_dtypes.rs +++ b/crates/spark-server/src/main_modules/kv_dtypes.rs @@ -6,7 +6,10 @@ /// /// When `high_precision_layers` is 0, returns an empty vec (all layers use uniform dtype). /// When non-zero, the first N and last N attention layers use BF16; middle layers use -/// the base `kv_dtype`. If `kv_dtype` is already BF16, returns empty vec (no benefit). +/// the base `kv_dtype`. +/// +/// When `kv_dtype` is BF16, every attention layer must use BF16 — returning an empty vec +/// would cause callers that fall back to `unwrap_or(Fp8)` to silently use FP8 instead. pub(crate) fn build_layer_kv_dtypes( kv_dtype: spark_runtime::kv_cache::KvCacheDtype, num_attention_layers: usize, @@ -14,7 +17,11 @@ pub(crate) fn build_layer_kv_dtypes( ) -> Vec { use spark_runtime::kv_cache::KvCacheDtype; - if high_precision_layers == 0 || kv_dtype == KvCacheDtype::Bf16 { + if kv_dtype == KvCacheDtype::Bf16 { + return vec![KvCacheDtype::Bf16; num_attention_layers]; + } + + if high_precision_layers == 0 { return vec![]; } diff --git a/crates/spark-server/src/main_modules/serve.rs b/crates/spark-server/src/main_modules/serve.rs index 8b3cdc6b..8e7159c2 100644 --- a/crates/spark-server/src/main_modules/serve.rs +++ b/crates/spark-server/src/main_modules/serve.rs @@ -463,6 +463,10 @@ pub(crate) async fn serve(mut args: cli::ServeArgs) -> Result<()> { b }, disable_thinking: args.disable_thinking, + default_chat_template_kwargs: args + .default_chat_template_kwargs + .as_ref() + .and_then(|s| crate::openai::ChatTemplateKwargs::from_json(s)), response_store, rate_limiter, conversation_store, diff --git a/crates/spark-server/src/main_modules/tests.rs b/crates/spark-server/src/main_modules/tests.rs index b6b61eb8..db89d411 100644 --- a/crates/spark-server/src/main_modules/tests.rs +++ b/crates/spark-server/src/main_modules/tests.rs @@ -83,10 +83,15 @@ fn test_build_layer_kv_dtypes_disabled() { } #[test] -fn test_build_layer_kv_dtypes_bf16_noop() { - // Already BF16 — no benefit from high-precision overlay - let dtypes = build_layer_kv_dtypes(spark_runtime::kv_cache::KvCacheDtype::Bf16, 12, 2); - assert!(dtypes.is_empty()); +fn test_build_layer_kv_dtypes_bf16_all_layers() { + // When base dtype is BF16, ALL layers must be BF16 — returning empty would + // let callers with unwrap_or(Fp8) silently downgrade MLA KV latents to FP8. + use spark_runtime::kv_cache::KvCacheDtype; + let dtypes = build_layer_kv_dtypes(KvCacheDtype::Bf16, 12, 2); + assert_eq!(dtypes.len(), 12); + for d in &dtypes { + assert_eq!(*d, KvCacheDtype::Bf16); + } } #[test] diff --git a/crates/spark-server/src/openai/chat_request.rs b/crates/spark-server/src/openai/chat_request.rs index 96ea0286..89866b3f 100644 --- a/crates/spark-server/src/openai/chat_request.rs +++ b/crates/spark-server/src/openai/chat_request.rs @@ -240,12 +240,22 @@ pub struct ReasoningConfig { } /// vLLM-style chat template kwargs. -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct ChatTemplateKwargs { pub enable_thinking: Option, pub thinking_budget: Option, } +impl ChatTemplateKwargs { + /// Parse from a JSON string. Returns `None` if parsing fails or string is empty. + pub fn from_json(s: &str) -> Option { + if s.trim().is_empty() { + return None; + } + serde_json::from_str(s).ok() + } +} + /// Default thinking budget when thinking is enabled but no explicit budget set. /// 256 tokens is enough for the model to plan without overthinking — longer /// budgets waste decode throughput on reasoning that rarely improves output. diff --git a/crates/spark-server/src/openai/tests.rs b/crates/spark-server/src/openai/tests.rs index 3b43354a..b6a84920 100644 --- a/crates/spark-server/src/openai/tests.rs +++ b/crates/spark-server/src/openai/tests.rs @@ -210,3 +210,63 @@ fn responses_in_progress_event_name() { }; assert_eq!(responses_event_name(&ev), "response.in_progress"); } + +// ── ChatTemplateKwargs ──────────────────────────────────────────── + +#[test] +fn chat_template_kwargs_parse() { + let kw = ChatTemplateKwargs::from_json(r#"{"enable_thinking":true,"thinking_budget":1024}"#) + .expect("should parse"); + assert_eq!(kw.enable_thinking, Some(true)); + assert_eq!(kw.thinking_budget, Some(1024)); + + assert!(ChatTemplateKwargs::from_json("").is_none()); +} + +fn empty_chat_request() -> ChatCompletionRequest { + serde_json::from_value(serde_json::json!({ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + })) + .expect("valid chat request") +} + +#[test] +fn server_default_merged_when_request_silent() { + let mut req = empty_chat_request(); + assert!(req.chat_template_kwargs.is_none()); + + let server_kw = ChatTemplateKwargs { + enable_thinking: Some(true), + thinking_budget: None, + }; + if !req.thinking_explicitly_requested() { + req.chat_template_kwargs = Some(server_kw); + } + assert!(req.chat_template_kwargs.is_some()); + + let (enabled, budget) = req.resolve_thinking(false); + assert!(enabled); + assert!(budget.is_some()); +} + +#[test] +fn server_default_not_merged_when_request_explicit() { + let mut req: ChatCompletionRequest = serde_json::from_value(serde_json::json!({ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "enable_thinking": true, + })) + .expect("valid chat request"); + assert!(req.thinking_explicitly_requested()); + + let server_kw = ChatTemplateKwargs { + enable_thinking: Some(false), + thinking_budget: None, + }; + if !req.thinking_explicitly_requested() { + req.chat_template_kwargs = Some(server_kw); + } + assert!(req.chat_template_kwargs.is_none()); + assert!(req.resolve_thinking(false).0); +} diff --git a/crates/spark-server/src/tool_parser.rs b/crates/spark-server/src/tool_parser.rs index ab68fc6e..bfdc866e 100644 --- a/crates/spark-server/src/tool_parser.rs +++ b/crates/spark-server/src/tool_parser.rs @@ -261,6 +261,23 @@ pub trait ToolCallParser: Send + Sync { fn broken_opener_stop_strings(&self) -> &'static [&'static str] { &[] } + + /// When `true`, the Jinja chat template should receive `jinja_tools = None` + /// because this parser's `system_prompt()` already provides the complete + /// tool schema and output-format instructions. Passing tools to the template + /// on top of that causes it to emit a conflicting format section (e.g. + /// `nemotron_h.jinja`'s XML `# Tools` block directly contradicts the + /// `bare_json` parser's "emit JSON, do not wrap in tags" instruction). + /// + /// Complementary to `ModelBehavior::skip_template_tools` (MODEL.toml): + /// either flag independently suppresses jinja tool rendering. The parser- + /// level default ensures future bare_json models stay correct without + /// requiring an explicit MODEL.toml entry. + /// + /// Default `false` — templates render tool definitions normally. + fn suppresses_jinja_tools(&self) -> bool { + false + } } impl std::fmt::Display for dyn ToolCallParser { diff --git a/crates/spark-server/src/tool_parser/bare_json.rs b/crates/spark-server/src/tool_parser/bare_json.rs index efade494..9a418976 100644 --- a/crates/spark-server/src/tool_parser/bare_json.rs +++ b/crates/spark-server/src/tool_parser/bare_json.rs @@ -49,6 +49,10 @@ impl ToolCallParser for BareJsonParser { prompt } + fn suppresses_jinja_tools(&self) -> bool { + true + } + fn format_tool_calls(&self, calls: &[IncomingToolCall]) -> String { let mut out = String::new(); for tc in calls { diff --git a/kernels/gb10/common/inferspark_prefill_128.cu b/kernels/gb10/common/inferspark_prefill_128.cu new file mode 100644 index 00000000..72f31178 --- /dev/null +++ b/kernels/gb10/common/inferspark_prefill_128.cu @@ -0,0 +1,760 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +// inferspark_prefill_128 — HDIM=128 variants of the inferspark flash-attention +// kernels, for MLA unabsorbed prefill (Mistral Small 4, head_dim=128). +// +// Root cause of the Mistral Small 4 long-context bug: inferspark_prefill.cu +// hardcodes HDIM=256. For MLA head_dim=128, the Q/K tile loads cover columns +// 0..255 in shared memory but each head only has 128 valid elements. Columns +// 128..255 read from the adjacent head (Q_head+1) and the next K row +// (K[k_row+1][0..127]), polluting QK^T with cross-head and cross-row data. +// Short sequences happen to tolerate the noise; long-range retrieval (>1K +// tokens) fails because the contaminated scores suppress correct early-context +// attention, producing repetitive or incoherent output. +// +// These HDIM=128 variants are structurally identical to the HDIM=256 originals: +// - BR=32 kernel: 4 warps (128 threads), for any seq_len +// - BR=64 kernel: 8 warps (256 threads), for seq_len >= 256 +// Only compile-time constants differ (N_TILES_PER_WARP=8, TILE_CHUNKS=512, etc.) + +#include + +#define BR 32 +#define BC 32 +#define HDIM 128 +#define PAD_KV 8 +#define HDIM_PAD (HDIM + PAD_KV) // 136 +#define PAD_P 8 +#define N_TILES_PER_WARP ((HDIM / 8) / 2) // 8 +#define TILE_CHUNKS (BR * (HDIM / 8)) // 512 +#define BR64 64 +#define TILE_CHUNKS_Q64 (BR64 * (HDIM / 8)) // 1024 +#define TILE_CHUNKS_KV (BC * (HDIM / 8)) // 512 + +// ============================================================================ +// BR=32 HDIM=128 variant (4 warps / 128 threads). +// Grid: (num_q_heads, ceil(seq_len/32), batch) Block: (128, 1, 1) +// Shared memory (~37 KB): +// smem_Q [32][136] BF16 = 8.5 KB +// smem_K [2][32][136] BF16 = 17.0 KB (double-buffered) +// smem_V [32][136] BF16 = 8.5 KB +// smem_P [32][40] BF16 = 2.5 KB +// smem_ml [32][2] FP32 = 0.25 KB +// ============================================================================ +extern "C" __global__ void inferspark_prefill_hd128( + const __nv_bfloat16* __restrict__ Q, + const __nv_bfloat16* __restrict__ K, + const __nv_bfloat16* __restrict__ V, + __nv_bfloat16* __restrict__ O, + const unsigned int seq_len, + const unsigned int num_q_heads, + const unsigned int num_kv_heads, + const unsigned int head_dim, + const float inv_sqrt_d, + const unsigned int causal, + const unsigned int sliding_window +) { + const unsigned int q_head = blockIdx.x; + const unsigned int q_block = blockIdx.y; + const unsigned int batch = blockIdx.z; + const unsigned int tid = threadIdx.x; + const unsigned int warp_id = tid / 32; + const unsigned int lane_id = tid % 32; + + if (q_head >= num_q_heads) return; + + const unsigned int q_start = q_block * BR; + if (q_start >= seq_len) return; + const unsigned int q_end = min(q_start + BR, seq_len); + const unsigned int q_len = q_end - q_start; + + const unsigned int gqa_ratio = num_q_heads / num_kv_heads; + const unsigned int kv_head = q_head / gqa_ratio; + const unsigned int q_seq_stride = num_q_heads * head_dim; + const unsigned int kv_seq_stride = num_kv_heads * head_dim; + + const __nv_bfloat16* Q_batch = Q + (unsigned long long)batch * seq_len * q_seq_stride; + const __nv_bfloat16* K_batch = K + (unsigned long long)batch * seq_len * kv_seq_stride; + const __nv_bfloat16* V_batch = V + (unsigned long long)batch * seq_len * kv_seq_stride; + __nv_bfloat16* O_batch = O + (unsigned long long)batch * seq_len * q_seq_stride; + + __shared__ __nv_bfloat16 smem_Q[BR][HDIM_PAD]; + __shared__ __nv_bfloat16 smem_K[2][BC][HDIM_PAD]; + __shared__ __nv_bfloat16 smem_V[BC][HDIM_PAD]; + __shared__ __nv_bfloat16 smem_P[BR][BC + PAD_P]; + __shared__ float smem_ml[BR][2]; + + const unsigned int group_id = lane_id >> 2; + const unsigned int tid_in_group = lane_id & 3; + const unsigned int qk_warp_m = (warp_id & 1) * 16; + const unsigned int pv_warp_m = (warp_id & 1) * 16; + const unsigned int pv_n_start = (warp_id >> 1) * N_TILES_PER_WARP; + + float acc_o[N_TILES_PER_WARP][4]; + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][0] = 0.0f; acc_o[i][1] = 0.0f; + acc_o[i][2] = 0.0f; acc_o[i][3] = 0.0f; + } + + float m_r0 = -1e30f, m_r1 = -1e30f; + float l_r0 = 0.0f, l_r1 = 0.0f; + const unsigned int p_smem_stride = BC + PAD_P; + + unsigned int num_kv_blocks = (seq_len + BC - 1) / BC; + if (causal) { + unsigned int max_kv_block = (q_end - 1) / BC; + num_kv_blocks = min(num_kv_blocks, max_kv_block + 1); + } + + // Merged Q + K[0] load + { + const unsigned int chunks_per_row = HDIM / 8; // 16 + for (unsigned int idx = tid; idx < TILE_CHUNKS; idx += 128) { + unsigned int row = idx / chunks_per_row; + unsigned int col = (idx % chunks_per_row) * 8; + unsigned int q_row = q_start + row; + unsigned int addr = __cvta_generic_to_shared(&smem_Q[row][col]); + if (q_row < seq_len) { + const void* g = (const void*)&Q_batch[q_row * q_seq_stride + q_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_Q[row][col]) = make_uint4(0, 0, 0, 0); + } + } + if (num_kv_blocks > 0) { + for (unsigned int idx = tid; idx < TILE_CHUNKS; idx += 128) { + unsigned int row = idx / chunks_per_row; + unsigned int col = (idx % chunks_per_row) * 8; + unsigned int addr = __cvta_generic_to_shared(&smem_K[0][row][col]); + if (row < seq_len) { + const void* g = (const void*)&K_batch[row * kv_seq_stride + kv_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_K[0][row][col]) = make_uint4(0, 0, 0, 0); + } + } + } + asm volatile("cp.async.commit_group;"); + asm volatile("cp.async.wait_group 0;"); + } + __syncthreads(); + + for (unsigned int kv_block = 0; kv_block < num_kv_blocks; kv_block++) { + unsigned int kv_start = kv_block * BC; + unsigned int kv_end = min(kv_start + BC, seq_len); + unsigned int kv_len = kv_end - kv_start; + unsigned int buf = kv_block & 1; + + // Async V load (overlaps QK^T) + { + const unsigned int chunks_per_row = HDIM / 8; + for (unsigned int idx = tid; idx < TILE_CHUNKS; idx += 128) { + unsigned int row = idx / chunks_per_row; + unsigned int col = (idx % chunks_per_row) * 8; + unsigned int v_row = kv_start + row; + unsigned int addr = __cvta_generic_to_shared(&smem_V[row][col]); + if (v_row < seq_len) { + const void* g = (const void*)&V_batch[v_row * kv_seq_stride + kv_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_V[row][col]) = make_uint4(0, 0, 0, 0); + } + } + asm volatile("cp.async.commit_group;"); + } + + // QK^T (warps 0-1) + float acc_s[4][4]; + if (warp_id < 2) { + #pragma unroll + for (int i = 0; i < 4; i++) { + acc_s[i][0] = 0.0f; acc_s[i][1] = 0.0f; + acc_s[i][2] = 0.0f; acc_s[i][3] = 0.0f; + } + const unsigned short* sQ = (const unsigned short*)smem_Q; + const unsigned short* sK = (const unsigned short*)smem_K[buf]; + #pragma unroll + for (unsigned int ks = 0; ks < (HDIM / 16); ks++) { // 8 k-tiles + unsigned int k_base = ks * 16; + unsigned int ar0 = qk_warp_m + group_id; + unsigned int ar1 = ar0 + 8; + unsigned int ac0 = k_base + tid_in_group * 2; + unsigned int ac1 = ac0 + 8; + unsigned int a0 = *(const unsigned int*)&sQ[ar0 * HDIM_PAD + ac0]; + unsigned int a1 = *(const unsigned int*)&sQ[ar1 * HDIM_PAD + ac0]; + unsigned int a2 = *(const unsigned int*)&sQ[ar0 * HDIM_PAD + ac1]; + unsigned int a3 = *(const unsigned int*)&sQ[ar1 * HDIM_PAD + ac1]; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + unsigned int n_col = nt * 8 + group_id; + unsigned int k0 = k_base + tid_in_group * 2; + unsigned int k1 = k0 + 8; + unsigned int b0 = ((unsigned int)sK[n_col * HDIM_PAD + k0 + 1] << 16) | + (unsigned int)sK[n_col * HDIM_PAD + k0]; + unsigned int b1 = ((unsigned int)sK[n_col * HDIM_PAD + k1 + 1] << 16) | + (unsigned int)sK[n_col * HDIM_PAD + k1]; + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3},{%4,%5,%6,%7},{%8,%9},{%10,%11,%12,%13};" + : "=f"(acc_s[nt][0]),"=f"(acc_s[nt][1]), + "=f"(acc_s[nt][2]),"=f"(acc_s[nt][3]) + : "r"(a0),"r"(a1),"r"(a2),"r"(a3),"r"(b0),"r"(b1), + "f"(acc_s[nt][0]),"f"(acc_s[nt][1]), + "f"(acc_s[nt][2]),"f"(acc_s[nt][3]) + ); + } + } + + unsigned int row0 = qk_warp_m + group_id; + unsigned int row1 = row0 + 8; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + acc_s[nt][0] *= inv_sqrt_d; acc_s[nt][1] *= inv_sqrt_d; + acc_s[nt][2] *= inv_sqrt_d; acc_s[nt][3] *= inv_sqrt_d; + unsigned int col0 = nt * 8 + tid_in_group * 2; + unsigned int col1 = col0 + 1; + if (causal) { + unsigned int qr0 = q_start + row0, qr1 = q_start + row1; + if (kv_start + col0 > qr0) acc_s[nt][0] = -1e30f; + if (kv_start + col1 > qr0) acc_s[nt][1] = -1e30f; + if (kv_start + col0 > qr1) acc_s[nt][2] = -1e30f; + if (kv_start + col1 > qr1) acc_s[nt][3] = -1e30f; + if (sliding_window > 0) { + unsigned int k0 = kv_start + col0, k1 = kv_start + col1; + if (k0 <= qr0 && qr0 - k0 >= sliding_window) acc_s[nt][0] = -1e30f; + if (k1 <= qr0 && qr0 - k1 >= sliding_window) acc_s[nt][1] = -1e30f; + if (k0 <= qr1 && qr1 - k0 >= sliding_window) acc_s[nt][2] = -1e30f; + if (k1 <= qr1 && qr1 - k1 >= sliding_window) acc_s[nt][3] = -1e30f; + } + } + if (col0 >= kv_len) { acc_s[nt][0] = -1e30f; acc_s[nt][2] = -1e30f; } + if (col1 >= kv_len) { acc_s[nt][1] = -1e30f; acc_s[nt][3] = -1e30f; } + if (row0 >= q_len) { acc_s[nt][0] = -1e30f; acc_s[nt][1] = -1e30f; } + if (row1 >= q_len) { acc_s[nt][2] = -1e30f; acc_s[nt][3] = -1e30f; } + } + + float rmax0 = -1e30f, rmax1 = -1e30f; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + rmax0 = fmaxf(rmax0, fmaxf(acc_s[nt][0], acc_s[nt][1])); + rmax1 = fmaxf(rmax1, fmaxf(acc_s[nt][2], acc_s[nt][3])); + } + rmax0 = fmaxf(rmax0, __shfl_xor_sync(0xFFFFFFFF, rmax0, 1)); + rmax0 = fmaxf(rmax0, __shfl_xor_sync(0xFFFFFFFF, rmax0, 2)); + rmax1 = fmaxf(rmax1, __shfl_xor_sync(0xFFFFFFFF, rmax1, 1)); + rmax1 = fmaxf(rmax1, __shfl_xor_sync(0xFFFFFFFF, rmax1, 2)); + + float m_new0 = fmaxf(m_r0, rmax0), exp_old0 = __expf(m_r0 - m_new0); + l_r0 *= exp_old0; + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][0] *= exp_old0; acc_o[i][1] *= exp_old0; + } + m_r0 = m_new0; + + float m_new1 = fmaxf(m_r1, rmax1), exp_old1 = __expf(m_r1 - m_new1); + l_r1 *= exp_old1; + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][2] *= exp_old1; acc_o[i][3] *= exp_old1; + } + m_r1 = m_new1; + + float sum0 = 0.0f, sum1 = 0.0f; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + float p00 = __expf(acc_s[nt][0] - m_r0), p01 = __expf(acc_s[nt][1] - m_r0); + float p10 = __expf(acc_s[nt][2] - m_r1), p11 = __expf(acc_s[nt][3] - m_r1); + sum0 += p00 + p01; sum1 += p10 + p11; + unsigned int col0 = nt * 8 + tid_in_group * 2; + smem_P[row0][col0] = __float2bfloat16(p00); + smem_P[row0][col0 + 1] = __float2bfloat16(p01); + smem_P[row1][col0] = __float2bfloat16(p10); + smem_P[row1][col0 + 1] = __float2bfloat16(p11); + } + sum0 += __shfl_xor_sync(0xFFFFFFFF, sum0, 1); + sum0 += __shfl_xor_sync(0xFFFFFFFF, sum0, 2); + sum1 += __shfl_xor_sync(0xFFFFFFFF, sum1, 1); + sum1 += __shfl_xor_sync(0xFFFFFFFF, sum1, 2); + l_r0 += sum0; l_r1 += sum1; + if (tid_in_group == 0) { + smem_ml[row0][0] = m_r0; smem_ml[row0][1] = l_r0; + smem_ml[row1][0] = m_r1; smem_ml[row1][1] = l_r1; + } + } + + asm volatile("cp.async.wait_group 0;"); + __syncthreads(); + + // Warps 2-3: rescale accumulators to match current m + if (warp_id >= 2) { + unsigned int row0 = pv_warp_m + group_id; + unsigned int row1 = row0 + 8; + float cur_m0 = smem_ml[row0][0], cur_m1 = smem_ml[row1][0]; + float exp_r0 = __expf(m_r0 - cur_m0), exp_r1 = __expf(m_r1 - cur_m1); + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][0] *= exp_r0; acc_o[i][1] *= exp_r0; + acc_o[i][2] *= exp_r1; acc_o[i][3] *= exp_r1; + } + m_r0 = cur_m0; m_r1 = cur_m1; + } + + // Prefetch K[kv_block+1] (overlaps PV below) + if (kv_block + 1 < num_kv_blocks) { + unsigned int next_start = (kv_block + 1) * BC; + const unsigned int cpr = HDIM / 8; + for (unsigned int idx = tid; idx < TILE_CHUNKS; idx += 128) { + unsigned int row = idx / cpr; + unsigned int col = (idx % cpr) * 8; + unsigned int k_row = next_start + row; + unsigned int addr = __cvta_generic_to_shared(&smem_K[1 - buf][row][col]); + if (k_row < seq_len) { + const void* g = (const void*)&K_batch[k_row * kv_seq_stride + kv_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_K[1 - buf][row][col]) = make_uint4(0, 0, 0, 0); + } + } + asm volatile("cp.async.commit_group;"); + } + + // PV MMA (all 4 warps) + { + const unsigned short* sP = (const unsigned short*)smem_P; + const unsigned short* sV = (const unsigned short*)smem_V; + #pragma unroll + for (unsigned int ks = 0; ks < 2; ks++) { + unsigned int k_off = ks * 16; + unsigned int ar0 = pv_warp_m + group_id; + unsigned int ar1 = ar0 + 8; + unsigned int ac0 = k_off + tid_in_group * 2; + unsigned int ac1 = ac0 + 8; + unsigned int a0 = *(const unsigned int*)&sP[ar0 * p_smem_stride + ac0]; + unsigned int a1 = *(const unsigned int*)&sP[ar1 * p_smem_stride + ac0]; + unsigned int a2 = *(const unsigned int*)&sP[ar0 * p_smem_stride + ac1]; + unsigned int a3 = *(const unsigned int*)&sP[ar1 * p_smem_stride + ac1]; + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WARP; nt++) { + unsigned int n_col = (pv_n_start + nt) * 8 + group_id; + unsigned int k0 = k_off + tid_in_group * 2; + unsigned int k1 = k0 + 8; + unsigned int b0 = ((unsigned int)sV[(k0+1) * HDIM_PAD + n_col] << 16) | + (unsigned int)sV[ k0 * HDIM_PAD + n_col]; + unsigned int b1 = ((unsigned int)sV[(k1+1) * HDIM_PAD + n_col] << 16) | + (unsigned int)sV[ k1 * HDIM_PAD + n_col]; + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3},{%4,%5,%6,%7},{%8,%9},{%10,%11,%12,%13};" + : "=f"(acc_o[nt][0]),"=f"(acc_o[nt][1]), + "=f"(acc_o[nt][2]),"=f"(acc_o[nt][3]) + : "r"(a0),"r"(a1),"r"(a2),"r"(a3),"r"(b0),"r"(b1), + "f"(acc_o[nt][0]),"f"(acc_o[nt][1]), + "f"(acc_o[nt][2]),"f"(acc_o[nt][3]) + ); + } + } + } + + if (kv_block + 1 < num_kv_blocks) { + asm volatile("cp.async.wait_group 0;"); + } + __syncthreads(); + } + + // Final normalize + store + { + unsigned int row0 = pv_warp_m + group_id; + unsigned int row1 = row0 + 8; + float inv_l0, inv_l1; + if (warp_id < 2) { + inv_l0 = (l_r0 > 0.0f) ? (1.0f / l_r0) : 0.0f; + inv_l1 = (l_r1 > 0.0f) ? (1.0f / l_r1) : 0.0f; + } else { + inv_l0 = (smem_ml[row0][1] > 0.0f) ? (1.0f / smem_ml[row0][1]) : 0.0f; + inv_l1 = (smem_ml[row1][1] > 0.0f) ? (1.0f / smem_ml[row1][1]) : 0.0f; + } + __nv_bfloat16* o_base = O_batch + q_head * head_dim; + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WARP; nt++) { + unsigned int col0 = (pv_n_start + nt) * 8 + tid_in_group * 2; + unsigned int gr0 = q_start + row0; + unsigned int gr1 = q_start + row1; + if (gr0 < seq_len && row0 < q_len && col0 < head_dim) { + unsigned int lo = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][0] * inv_l0)); + unsigned int hi = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][1] * inv_l0)); + *(unsigned int*)&o_base[gr0 * q_seq_stride + col0] = lo | (hi << 16); + } + if (gr1 < seq_len && row1 < q_len && col0 < head_dim) { + unsigned int lo = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][2] * inv_l1)); + unsigned int hi = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][3] * inv_l1)); + *(unsigned int*)&o_base[gr1 * q_seq_stride + col0] = lo | (hi << 16); + } + } + } +} + +// ============================================================================ +// BR=64 HDIM=128 variant (8 warps / 256 threads) for seq_len >= 256. +// Grid: (num_q_heads, ceil(seq_len/64), batch) Block: (256, 1, 1) +// Shared memory (~49 KB): +// smem_Q [64][136] BF16 = 17.0 KB +// smem_K64 [2][32][136] BF16 = 17.0 KB (double-buffered) +// smem_V64 [32][136] BF16 = 8.5 KB +// smem_P64 [64][40] BF16 = 5.0 KB +// smem_ml64[64][2] FP32 = 0.5 KB +// ============================================================================ +extern "C" __global__ void inferspark_prefill_64_hd128( + const __nv_bfloat16* __restrict__ Q, + const __nv_bfloat16* __restrict__ K, + const __nv_bfloat16* __restrict__ V, + __nv_bfloat16* __restrict__ O, + const unsigned int seq_len, + const unsigned int num_q_heads, + const unsigned int num_kv_heads, + const unsigned int head_dim, + const float inv_sqrt_d, + const unsigned int causal, + const unsigned int sliding_window +) { + const unsigned int q_head = blockIdx.x; + const unsigned int q_block = blockIdx.y; + const unsigned int batch = blockIdx.z; + const unsigned int tid = threadIdx.x; + const unsigned int warp_id = tid / 32; + const unsigned int lane_id = tid % 32; + + if (q_head >= num_q_heads) return; + + const unsigned int q_start = q_block * BR64; + if (q_start >= seq_len) return; + const unsigned int q_end = min(q_start + BR64, seq_len); + const unsigned int q_len = q_end - q_start; + + const unsigned int gqa_ratio = num_q_heads / num_kv_heads; + const unsigned int kv_head = q_head / gqa_ratio; + const unsigned int q_seq_stride = num_q_heads * head_dim; + const unsigned int kv_seq_stride = num_kv_heads * head_dim; + + const __nv_bfloat16* Q_batch = Q + (unsigned long long)batch * seq_len * q_seq_stride; + const __nv_bfloat16* K_batch = K + (unsigned long long)batch * seq_len * kv_seq_stride; + const __nv_bfloat16* V_batch = V + (unsigned long long)batch * seq_len * kv_seq_stride; + __nv_bfloat16* O_batch = O + (unsigned long long)batch * seq_len * q_seq_stride; + + __shared__ __nv_bfloat16 smem_Q[BR64][HDIM_PAD]; + __shared__ __nv_bfloat16 smem_K64[2][BC][HDIM_PAD]; + __shared__ __nv_bfloat16 smem_V64[BC][HDIM_PAD]; + __shared__ __nv_bfloat16 smem_P64[BR64][BC + PAD_P]; + __shared__ float smem_ml64[BR64][2]; + + const unsigned int group_id = lane_id >> 2; + const unsigned int tid_in_group = lane_id & 3; + const unsigned int qk_warp_m = warp_id * 16; // valid for warp_id < 4 + const unsigned int pv_warp_m = (warp_id & 3) * 16; + const unsigned int pv_n_start = (warp_id >> 2) * N_TILES_PER_WARP; + + float acc_o[N_TILES_PER_WARP][4]; + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][0] = 0.0f; acc_o[i][1] = 0.0f; + acc_o[i][2] = 0.0f; acc_o[i][3] = 0.0f; + } + + float m_r0 = -1e30f, m_r1 = -1e30f; + float l_r0 = 0.0f, l_r1 = 0.0f; + const unsigned int p_smem_stride64 = BC + PAD_P; + + unsigned int num_kv_blocks = (seq_len + BC - 1) / BC; + if (causal) { + unsigned int max_kv_block = (q_end - 1) / BC; + num_kv_blocks = min(num_kv_blocks, max_kv_block + 1); + } + + // Merged Q + K[0] load + { + const unsigned int cpr = HDIM / 8; // 16 chunks per row + for (unsigned int idx = tid; idx < TILE_CHUNKS_Q64; idx += 256) { + unsigned int row = idx / cpr; + unsigned int col = (idx % cpr) * 8; + unsigned int q_row = q_start + row; + unsigned int addr = __cvta_generic_to_shared(&smem_Q[row][col]); + if (q_row < seq_len) { + const void* g = (const void*)&Q_batch[q_row * q_seq_stride + q_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_Q[row][col]) = make_uint4(0, 0, 0, 0); + } + } + if (num_kv_blocks > 0) { + for (unsigned int idx = tid; idx < TILE_CHUNKS_KV; idx += 256) { + unsigned int row = idx / cpr; + unsigned int col = (idx % cpr) * 8; + unsigned int addr = __cvta_generic_to_shared(&smem_K64[0][row][col]); + if (row < seq_len) { + const void* g = (const void*)&K_batch[row * kv_seq_stride + kv_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_K64[0][row][col]) = make_uint4(0, 0, 0, 0); + } + } + } + asm volatile("cp.async.commit_group;"); + asm volatile("cp.async.wait_group 0;"); + } + __syncthreads(); + + for (unsigned int kv_block = 0; kv_block < num_kv_blocks; kv_block++) { + unsigned int kv_start = kv_block * BC; + unsigned int kv_end = min(kv_start + BC, seq_len); + unsigned int kv_len = kv_end - kv_start; + unsigned int buf = kv_block & 1; + + // Async V load + { + const unsigned int cpr = HDIM / 8; + for (unsigned int idx = tid; idx < TILE_CHUNKS_KV; idx += 256) { + unsigned int row = idx / cpr; + unsigned int col = (idx % cpr) * 8; + unsigned int v_row = kv_start + row; + unsigned int addr = __cvta_generic_to_shared(&smem_V64[row][col]); + if (v_row < seq_len) { + const void* g = (const void*)&V_batch[v_row * kv_seq_stride + kv_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_V64[row][col]) = make_uint4(0, 0, 0, 0); + } + } + asm volatile("cp.async.commit_group;"); + } + + // QK^T (warps 0-3) + float acc_s[4][4]; + if (warp_id < 4) { + #pragma unroll + for (int i = 0; i < 4; i++) { + acc_s[i][0] = 0.0f; acc_s[i][1] = 0.0f; + acc_s[i][2] = 0.0f; acc_s[i][3] = 0.0f; + } + const unsigned short* sQ = (const unsigned short*)smem_Q; + const unsigned short* sK = (const unsigned short*)smem_K64[buf]; + #pragma unroll + for (unsigned int ks = 0; ks < (HDIM / 16); ks++) { // 8 k-tiles + unsigned int k_base = ks * 16; + unsigned int ar0 = qk_warp_m + group_id; + unsigned int ar1 = ar0 + 8; + unsigned int ac0 = k_base + tid_in_group * 2; + unsigned int ac1 = ac0 + 8; + unsigned int a0 = *(const unsigned int*)&sQ[ar0 * HDIM_PAD + ac0]; + unsigned int a1 = *(const unsigned int*)&sQ[ar1 * HDIM_PAD + ac0]; + unsigned int a2 = *(const unsigned int*)&sQ[ar0 * HDIM_PAD + ac1]; + unsigned int a3 = *(const unsigned int*)&sQ[ar1 * HDIM_PAD + ac1]; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + unsigned int n_col = nt * 8 + group_id; + unsigned int k0 = k_base + tid_in_group * 2; + unsigned int k1 = k0 + 8; + unsigned int b0 = ((unsigned int)sK[n_col * HDIM_PAD + k0 + 1] << 16) | + (unsigned int)sK[n_col * HDIM_PAD + k0]; + unsigned int b1 = ((unsigned int)sK[n_col * HDIM_PAD + k1 + 1] << 16) | + (unsigned int)sK[n_col * HDIM_PAD + k1]; + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3},{%4,%5,%6,%7},{%8,%9},{%10,%11,%12,%13};" + : "=f"(acc_s[nt][0]),"=f"(acc_s[nt][1]), + "=f"(acc_s[nt][2]),"=f"(acc_s[nt][3]) + : "r"(a0),"r"(a1),"r"(a2),"r"(a3),"r"(b0),"r"(b1), + "f"(acc_s[nt][0]),"f"(acc_s[nt][1]), + "f"(acc_s[nt][2]),"f"(acc_s[nt][3]) + ); + } + } + + unsigned int row0 = qk_warp_m + group_id; + unsigned int row1 = row0 + 8; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + acc_s[nt][0] *= inv_sqrt_d; acc_s[nt][1] *= inv_sqrt_d; + acc_s[nt][2] *= inv_sqrt_d; acc_s[nt][3] *= inv_sqrt_d; + unsigned int col0 = nt * 8 + tid_in_group * 2; + unsigned int col1 = col0 + 1; + if (causal) { + unsigned int qr0 = q_start + row0, qr1 = q_start + row1; + if (kv_start + col0 > qr0) acc_s[nt][0] = -1e30f; + if (kv_start + col1 > qr0) acc_s[nt][1] = -1e30f; + if (kv_start + col0 > qr1) acc_s[nt][2] = -1e30f; + if (kv_start + col1 > qr1) acc_s[nt][3] = -1e30f; + if (sliding_window > 0) { + unsigned int k0 = kv_start + col0, k1 = kv_start + col1; + if (k0 <= qr0 && qr0 - k0 >= sliding_window) acc_s[nt][0] = -1e30f; + if (k1 <= qr0 && qr0 - k1 >= sliding_window) acc_s[nt][1] = -1e30f; + if (k0 <= qr1 && qr1 - k0 >= sliding_window) acc_s[nt][2] = -1e30f; + if (k1 <= qr1 && qr1 - k1 >= sliding_window) acc_s[nt][3] = -1e30f; + } + } + if (col0 >= kv_len) { acc_s[nt][0] = -1e30f; acc_s[nt][2] = -1e30f; } + if (col1 >= kv_len) { acc_s[nt][1] = -1e30f; acc_s[nt][3] = -1e30f; } + if (row0 >= q_len) { acc_s[nt][0] = -1e30f; acc_s[nt][1] = -1e30f; } + if (row1 >= q_len) { acc_s[nt][2] = -1e30f; acc_s[nt][3] = -1e30f; } + } + + float rmax0 = -1e30f, rmax1 = -1e30f; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + rmax0 = fmaxf(rmax0, fmaxf(acc_s[nt][0], acc_s[nt][1])); + rmax1 = fmaxf(rmax1, fmaxf(acc_s[nt][2], acc_s[nt][3])); + } + rmax0 = fmaxf(rmax0, __shfl_xor_sync(0xFFFFFFFF, rmax0, 1)); + rmax0 = fmaxf(rmax0, __shfl_xor_sync(0xFFFFFFFF, rmax0, 2)); + rmax1 = fmaxf(rmax1, __shfl_xor_sync(0xFFFFFFFF, rmax1, 1)); + rmax1 = fmaxf(rmax1, __shfl_xor_sync(0xFFFFFFFF, rmax1, 2)); + + float m_new0 = fmaxf(m_r0, rmax0), exp_old0 = __expf(m_r0 - m_new0); + l_r0 *= exp_old0; + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][0] *= exp_old0; acc_o[i][1] *= exp_old0; + } + m_r0 = m_new0; + + float m_new1 = fmaxf(m_r1, rmax1), exp_old1 = __expf(m_r1 - m_new1); + l_r1 *= exp_old1; + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][2] *= exp_old1; acc_o[i][3] *= exp_old1; + } + m_r1 = m_new1; + + float sum0 = 0.0f, sum1 = 0.0f; + #pragma unroll + for (int nt = 0; nt < 4; nt++) { + float p00 = __expf(acc_s[nt][0] - m_r0), p01 = __expf(acc_s[nt][1] - m_r0); + float p10 = __expf(acc_s[nt][2] - m_r1), p11 = __expf(acc_s[nt][3] - m_r1); + sum0 += p00 + p01; sum1 += p10 + p11; + unsigned int col0 = nt * 8 + tid_in_group * 2; + smem_P64[row0][col0] = __float2bfloat16(p00); + smem_P64[row0][col0 + 1] = __float2bfloat16(p01); + smem_P64[row1][col0] = __float2bfloat16(p10); + smem_P64[row1][col0 + 1] = __float2bfloat16(p11); + } + sum0 += __shfl_xor_sync(0xFFFFFFFF, sum0, 1); + sum0 += __shfl_xor_sync(0xFFFFFFFF, sum0, 2); + sum1 += __shfl_xor_sync(0xFFFFFFFF, sum1, 1); + sum1 += __shfl_xor_sync(0xFFFFFFFF, sum1, 2); + l_r0 += sum0; l_r1 += sum1; + if (tid_in_group == 0) { + smem_ml64[row0][0] = m_r0; smem_ml64[row0][1] = l_r0; + smem_ml64[row1][0] = m_r1; smem_ml64[row1][1] = l_r1; + } + } + + asm volatile("cp.async.wait_group 0;"); + __syncthreads(); + + // Warps 4-7: rescale accumulators to match current m + if (warp_id >= 4) { + unsigned int row0 = pv_warp_m + group_id; + unsigned int row1 = row0 + 8; + float cur_m0 = smem_ml64[row0][0], cur_m1 = smem_ml64[row1][0]; + float exp_r0 = __expf(m_r0 - cur_m0), exp_r1 = __expf(m_r1 - cur_m1); + #pragma unroll + for (int i = 0; i < N_TILES_PER_WARP; i++) { + acc_o[i][0] *= exp_r0; acc_o[i][1] *= exp_r0; + acc_o[i][2] *= exp_r1; acc_o[i][3] *= exp_r1; + } + m_r0 = cur_m0; m_r1 = cur_m1; + } + + // Prefetch K[kv_block+1] + if (kv_block + 1 < num_kv_blocks) { + unsigned int next_start = (kv_block + 1) * BC; + const unsigned int cpr = HDIM / 8; + for (unsigned int idx = tid; idx < TILE_CHUNKS_KV; idx += 256) { + unsigned int row = idx / cpr; + unsigned int col = (idx % cpr) * 8; + unsigned int k_row = next_start + row; + unsigned int addr = __cvta_generic_to_shared(&smem_K64[1 - buf][row][col]); + if (k_row < seq_len) { + const void* g = (const void*)&K_batch[k_row * kv_seq_stride + kv_head * head_dim + col]; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(g)); + } else { + *((uint4*)&smem_K64[1 - buf][row][col]) = make_uint4(0, 0, 0, 0); + } + } + asm volatile("cp.async.commit_group;"); + } + + // PV MMA (all 8 warps) + { + const unsigned short* sP = (const unsigned short*)smem_P64; + const unsigned short* sV = (const unsigned short*)smem_V64; + #pragma unroll + for (unsigned int ks = 0; ks < 2; ks++) { + unsigned int k_off = ks * 16; + unsigned int ar0 = pv_warp_m + group_id; + unsigned int ar1 = ar0 + 8; + unsigned int ac0 = k_off + tid_in_group * 2; + unsigned int ac1 = ac0 + 8; + unsigned int a0 = *(const unsigned int*)&sP[ar0 * p_smem_stride64 + ac0]; + unsigned int a1 = *(const unsigned int*)&sP[ar1 * p_smem_stride64 + ac0]; + unsigned int a2 = *(const unsigned int*)&sP[ar0 * p_smem_stride64 + ac1]; + unsigned int a3 = *(const unsigned int*)&sP[ar1 * p_smem_stride64 + ac1]; + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WARP; nt++) { + unsigned int n_col = (pv_n_start + nt) * 8 + group_id; + unsigned int k0 = k_off + tid_in_group * 2; + unsigned int k1 = k0 + 8; + unsigned int b0 = ((unsigned int)sV[(k0+1) * HDIM_PAD + n_col] << 16) | + (unsigned int)sV[ k0 * HDIM_PAD + n_col]; + unsigned int b1 = ((unsigned int)sV[(k1+1) * HDIM_PAD + n_col] << 16) | + (unsigned int)sV[ k1 * HDIM_PAD + n_col]; + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3},{%4,%5,%6,%7},{%8,%9},{%10,%11,%12,%13};" + : "=f"(acc_o[nt][0]),"=f"(acc_o[nt][1]), + "=f"(acc_o[nt][2]),"=f"(acc_o[nt][3]) + : "r"(a0),"r"(a1),"r"(a2),"r"(a3),"r"(b0),"r"(b1), + "f"(acc_o[nt][0]),"f"(acc_o[nt][1]), + "f"(acc_o[nt][2]),"f"(acc_o[nt][3]) + ); + } + } + } + + if (kv_block + 1 < num_kv_blocks) { + asm volatile("cp.async.wait_group 0;"); + } + __syncthreads(); + } + + // Final normalize + store + { + unsigned int row0 = pv_warp_m + group_id; + unsigned int row1 = row0 + 8; + float inv_l0, inv_l1; + if (warp_id < 4) { + inv_l0 = (l_r0 > 0.0f) ? (1.0f / l_r0) : 0.0f; + inv_l1 = (l_r1 > 0.0f) ? (1.0f / l_r1) : 0.0f; + } else { + inv_l0 = (smem_ml64[row0][1] > 0.0f) ? (1.0f / smem_ml64[row0][1]) : 0.0f; + inv_l1 = (smem_ml64[row1][1] > 0.0f) ? (1.0f / smem_ml64[row1][1]) : 0.0f; + } + __nv_bfloat16* o_base = O_batch + q_head * head_dim; + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WARP; nt++) { + unsigned int col0 = (pv_n_start + nt) * 8 + tid_in_group * 2; + unsigned int gr0 = q_start + row0; + unsigned int gr1 = q_start + row1; + if (gr0 < seq_len && row0 < q_len && col0 < head_dim) { + unsigned int lo = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][0] * inv_l0)); + unsigned int hi = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][1] * inv_l0)); + *(unsigned int*)&o_base[gr0 * q_seq_stride + col0] = lo | (hi << 16); + } + if (gr1 < seq_len && row1 < q_len && col0 < head_dim) { + unsigned int lo = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][2] * inv_l1)); + unsigned int hi = (unsigned int)__bfloat16_as_ushort(__float2bfloat16(acc_o[nt][3] * inv_l1)); + *(unsigned int*)&o_base[gr1 * q_seq_stride + col0] = lo | (hi << 16); + } + } + } +} diff --git a/kernels/gb10/mistral-small-4/nvfp4/KERNEL.toml b/kernels/gb10/mistral-small-4/nvfp4/KERNEL.toml index 4a23e29d..b73958c3 100644 --- a/kernels/gb10/mistral-small-4/nvfp4/KERNEL.toml +++ b/kernels/gb10/mistral-small-4/nvfp4/KERNEL.toml @@ -31,6 +31,7 @@ inferspark_prefill_paged_nvfp4 = "prefill_paged_nvfp4" mla_prefill_attn = "mla_prefill_attn" grouped_gemm_mla = "grouped_gemm_mla" mla_fused_prefill = "mla_fused_prefill" +mla_prefill_paged_320 = "mla_prefill_paged" reshape_and_cache_turbo = "reshape_and_cache_turbo" paged_decode_attn_turbo4 = "paged_decode_turbo4" wht_bf16 = "wht_bf16" diff --git a/kernels/gb10/mistral-small-4/nvfp4/mla_absorbed.cu b/kernels/gb10/mistral-small-4/nvfp4/mla_absorbed.cu index 7b61b356..8f7b3c18 100644 --- a/kernels/gb10/mistral-small-4/nvfp4/mla_absorbed.cu +++ b/kernels/gb10/mistral-small-4/nvfp4/mla_absorbed.cu @@ -342,6 +342,104 @@ extern "C" __global__ void mla_q_final_assemble_batched( } } +// Batched V extraction for N-token MLA prefill. +// +// Extends mla_batched_gemv to a batch of N tokens by adding blockIdx.z for +// the token dimension. Used in multi-chunk prefill (seq_len_start > 0). +// +// For each (token, head): output[token, head, :] = W_UV[head] @ input[token, head, 0..K] +// where input has input_head_stride elements per head (only first K are used). +// +// Grid: (ceil(N_out / (N_PER_BLOCK*2)), num_heads, N_tokens) Block: (256, 1, 1) +extern "C" __global__ void mla_v_extract_batched( + const __nv_bfloat16* __restrict__ input, // [N_tokens, num_heads, input_head_stride] + const __nv_bfloat16* __restrict__ weight, // [num_heads, N_out, K] + __nv_bfloat16* __restrict__ output, // [N_tokens, num_heads, output_head_stride] + unsigned int N_out, // v_dim = 128 + unsigned int K, // kv_lora = 256 + unsigned int num_heads, // nq = 32 + unsigned int input_head_stride, // mla_cache_dim = 320 (elements per head in input) + unsigned int output_head_stride // v_dim = 128 (elements per head in output) +) { + const unsigned int token = blockIdx.z; + const unsigned int head = blockIdx.y; + const unsigned int tid = threadIdx.x; + + const unsigned int threads_per_out = BLOCK_SIZE / N_PER_BLOCK; // 64 + const unsigned int local_out = tid / threads_per_out; // 0..3 + const unsigned int lane = tid % threads_per_out; // 0..63 + + const unsigned int n1 = blockIdx.x * (N_PER_BLOCK * 2) + local_out * 2; + const unsigned int n2 = n1 + 1; + if (n1 >= N_out) return; + const bool have_n2 = (n2 < N_out); + + const unsigned long long tok_in_off = (unsigned long long)token * num_heads * input_head_stride; + const unsigned long long tok_out_off = (unsigned long long)token * num_heads * output_head_stride; + + const __nv_bfloat16* A = input + tok_in_off + (unsigned long long)head * input_head_stride; + const __nv_bfloat16* B = weight + (unsigned long long)head * N_out * K; + __nv_bfloat16* C = output + tok_out_off + (unsigned long long)head * output_head_stride; + + const unsigned int K4 = K / 4; + const unsigned long long* A64 = (const unsigned long long*)A; + + float acc1 = 0.0f, acc2 = 0.0f; + for (unsigned int k4 = lane; k4 < K4; k4 += threads_per_out) { + unsigned long long av = A64[k4]; + float a0, a1, a2, a3; + unsigned int lo = (unsigned int)av; + unsigned int hi = (unsigned int)(av >> 32); + __nv_bfloat16 tmp; + *(unsigned short*)&tmp = (unsigned short)(lo & 0xFFFF); a0 = __bfloat162float(tmp); + *(unsigned short*)&tmp = (unsigned short)(lo >> 16); a1 = __bfloat162float(tmp); + *(unsigned short*)&tmp = (unsigned short)(hi & 0xFFFF); a2 = __bfloat162float(tmp); + *(unsigned short*)&tmp = (unsigned short)(hi >> 16); a3 = __bfloat162float(tmp); + + unsigned int base_k = k4 * 4; + float w10 = __bfloat162float(B[n1 * K + base_k]); + float w11 = __bfloat162float(B[n1 * K + base_k + 1]); + float w12 = __bfloat162float(B[n1 * K + base_k + 2]); + float w13 = __bfloat162float(B[n1 * K + base_k + 3]); + acc1 += a0 * w10 + a1 * w11 + a2 * w12 + a3 * w13; + + if (have_n2) { + float w20 = __bfloat162float(B[n2 * K + base_k]); + float w21 = __bfloat162float(B[n2 * K + base_k + 1]); + float w22 = __bfloat162float(B[n2 * K + base_k + 2]); + float w23 = __bfloat162float(B[n2 * K + base_k + 3]); + acc2 += a0 * w20 + a1 * w21 + a2 * w22 + a3 * w23; + } + } + + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + acc1 += __shfl_down_sync(0xFFFFFFFF, acc1, offset); + if (have_n2) acc2 += __shfl_down_sync(0xFFFFFFFF, acc2, offset); + } + + __shared__ float s_partial[N_PER_BLOCK * 2][2]; + unsigned int warp_in_out = (tid % threads_per_out) / WARP_SIZE; + unsigned int lane_in_warp = tid % WARP_SIZE; + if (lane_in_warp == 0) { + s_partial[local_out * 2][warp_in_out] = acc1; + if (have_n2) s_partial[local_out * 2 + 1][warp_in_out] = acc2; + } + __syncthreads(); + + unsigned int warps_per_out = threads_per_out / WARP_SIZE; + if (lane_in_warp == 0 && warp_in_out == 0) { + float sum1 = 0.0f; + for (unsigned int w = 0; w < warps_per_out; w++) sum1 += s_partial[local_out * 2][w]; + C[n1] = __float2bfloat16(sum1); + if (have_n2) { + float sum2 = 0.0f; + for (unsigned int w = 0; w < warps_per_out; w++) sum2 += s_partial[local_out * 2 + 1][w]; + C[n2] = __float2bfloat16(sum2); + } + } +} + // ════════════════════════════════════════════════════════════════════════════ // DECODE SINGLE-TOKEN VARIANTS (existing) // ════════════════════════════════════════════════════════════════════════════ diff --git a/kernels/gb10/mistral-small-4/nvfp4/mla_fused_prefill.cu b/kernels/gb10/mistral-small-4/nvfp4/mla_fused_prefill.cu index 0fa5bd20..de0ada68 100644 --- a/kernels/gb10/mistral-small-4/nvfp4/mla_fused_prefill.cu +++ b/kernels/gb10/mistral-small-4/nvfp4/mla_fused_prefill.cu @@ -110,6 +110,10 @@ extern "C" __global__ void mla_fused_prefill( // 256 threads collaborate to reduce 320 dims. // Each thread handles ceil(320/256) = 2 dims (with some idle). + // Declared here (not inside the loop) so NVCC cannot alias this with smem_q + // across iterations when doing lifetime-based shared memory layout optimization. + __shared__ float smem_dot[8]; // 8 warps + float m_prev = -FLT_MAX; float l_prev = 0.0f; // Accumulate weighted KV latent (only first 256 dims for V extraction) @@ -140,7 +144,6 @@ extern "C" __global__ void mla_fused_prefill( dot += __shfl_down_sync(0xFFFFFFFF, dot, offset); } // Lane 0 of each warp has partial sum. Reduce across warps via shared memory. - __shared__ float smem_dot[8]; // 8 warps unsigned int warp_id = tid / 32; unsigned int lane_id = tid % 32; if (lane_id == 0) { diff --git a/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_attn.cu b/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_attn.cu index 8a8431c1..bfbbbf45 100644 --- a/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_attn.cu +++ b/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_attn.cu @@ -64,6 +64,12 @@ extern "C" __global__ void mla_prefill_attn_320( const unsigned int lane = tid % 16; // lane within query processing (0..15) const unsigned int warp_lane = tid % 32; // position within the 32-thread warp + // Half-warp mask: restrict shfl/shfl_down to the 16-thread sub-group that + // shares the same q_row. Using 0xFFFFFFFF when the opposite half-warp has + // returned early (last tile, seq_len % MLA_BR != 0) is CUDA UB per §B.15. + // lane 0..15 → mask 0x0000FFFF, lane 16..31 → mask 0xFFFF0000. + const unsigned int lane_mask = (warp_lane < 16) ? 0x0000FFFFu : 0xFFFF0000u; + if (q_row >= (q_end - q_start)) return; const unsigned int q_pos = q_start + q_row; @@ -92,10 +98,9 @@ extern "C" __global__ void mla_prefill_attn_320( dot += q_val * k_val; } // Warp reduce within 16 lanes (half a 32-thread warp). - // Use full warp mask (0xFFFFFFFF) to avoid UB from partial mask. - // Only reduce within 16-lane group by using offsets 1,2,4,8. + // lane_mask restricts to the correct 16-thread sub-group. for (int offset = 8; offset > 0; offset >>= 1) { - dot += __shfl_down_sync(0xFFFFFFFF, dot, offset); + dot += __shfl_down_sync(lane_mask, dot, offset); } // Lane 0 of each 16-lane group has the reduction result. // For warp_lane < 16: lane 0 has the result. @@ -106,8 +111,9 @@ extern "C" __global__ void mla_prefill_attn_320( if (causal && kv_pos > q_pos) score = -FLT_MAX; // Broadcast score from lane 0 of each 16-lane group - // warp_lane % 16 == 0 has the correct value - score = __shfl_sync(0xFFFFFFFF, score, (warp_lane / 16) * 16); + // warp_lane % 16 == 0 has the correct value; lane_mask restricts + // to the correct half-warp. + score = __shfl_sync(lane_mask, score, (warp_lane / 16) * 16); // Online softmax update (all lanes compute uniformly) float m_new = fmaxf(m_prev, score); diff --git a/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_paged_320.cu b/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_paged_320.cu new file mode 100644 index 00000000..0fe55387 --- /dev/null +++ b/kernels/gb10/mistral-small-4/nvfp4/mla_prefill_paged_320.cu @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +// Paged MLA Prefill Attention — absorbed form, HDIM=320. +// +// Multi-chunk prefill for MLA models: Q tokens from the current chunk (in +// absorbed form, [q_len, nq, 320]) attend to the full KV history from the +// paged cache ([kv_len, 1, 320]) with causal masking. +// +// This kernel is used when seq_len_start > 0 (chunks 2+). The KV cache has +// already been written for the current chunk before this kernel runs, so +// kv_len = seq_len_start + q_len includes both history and current tokens. +// +// Causal masking: Q at local position i (global position q_offset + i) +// attends to KV positions 0 .. q_offset + i (inclusive). +// +// Grid: (num_q_heads, ceil(q_len / MLA_BR), 1) Block: (256, 1, 1) +// Thread layout within a block: 16 threads per Q row × MLA_BR Q rows. +// Each thread covers MLA_HDIM / 16 = 20 head-dim elements. + +#include +#include + +#define MLA_HDIM 320 +#define MLA_BR 16 // query rows per block +#define MLA_LANES 16 // threads per query row (256 / MLA_BR) +#define MLA_ELEMS 20 // head-dim elements per lane (MLA_HDIM / MLA_LANES) + +// Compute pointer into paged KV cache for a given logical position. +__device__ __forceinline__ const __nv_bfloat16* paged_kv_ptr_mla( + const __nv_bfloat16* __restrict__ cache, + const int* __restrict__ block_table, + unsigned int pos, + unsigned int cache_block_size, + unsigned int num_kv_heads, + unsigned int head_dim, + unsigned int kv_head +) { + unsigned int logical_block = pos / cache_block_size; + unsigned int block_offset = pos % cache_block_size; + unsigned int physical_block = (unsigned int)block_table[logical_block]; + unsigned long long page_stride = (unsigned long long)cache_block_size * num_kv_heads * head_dim; + return cache + + (unsigned long long)physical_block * page_stride + + (unsigned long long)block_offset * num_kv_heads * head_dim + + (unsigned long long)kv_head * head_dim; +} + +extern "C" __global__ void mla_prefill_paged_320( + const __nv_bfloat16* __restrict__ Q, // [q_len, num_q_heads, MLA_HDIM] + const __nv_bfloat16* __restrict__ K_cache, // paged: [num_blocks, block_size, 1, MLA_HDIM] + const __nv_bfloat16* __restrict__ V_cache, // paged: same layout as K_cache + __nv_bfloat16* __restrict__ O, // [q_len, num_q_heads, MLA_HDIM] + const int* __restrict__ block_table, // [max_blocks_per_seq] + unsigned int q_len, + unsigned int kv_len, // = seq_len_start + q_len + unsigned int q_offset, // = seq_len_start + unsigned int num_q_heads, + unsigned int num_kv_heads, + unsigned int head_dim, // MLA_HDIM = 320 + unsigned int cache_block_size, + float inv_sqrt_d +) { + const unsigned int q_head = blockIdx.x; + const unsigned int q_block = blockIdx.y; + const unsigned int tid = threadIdx.x; + + if (q_head >= num_q_heads) return; + + const unsigned int q_start = q_block * MLA_BR; + if (q_start >= q_len) return; + const unsigned int q_end = min(q_start + (unsigned int)MLA_BR, q_len); + + const unsigned int gqa_ratio = num_q_heads / max(num_kv_heads, 1u); + const unsigned int kv_head = q_head / gqa_ratio; + + const unsigned int q_stride = num_q_heads * head_dim; // elements per Q token + const unsigned int warp_lane = tid % 32; + + // Thread layout: 16 lanes per Q row, MLA_BR rows per block. + // tid [0,255] → q_row = tid/16, lane = tid%16. + const unsigned int q_row = tid / (unsigned int)MLA_LANES; + const unsigned int lane = tid % (unsigned int)MLA_LANES; + + // Half-warp mask: restrict shfl/shfl_down to the 16-thread sub-group that + // shares the same q_row. Using 0xFFFFFFFF when the opposite half-warp has + // returned early (last tile, q_len % MLA_BR != 0) is CUDA UB per §B.15 + // (all threads named in the mask must be executing the same instruction). + // warp_lane 0..15 → mask 0x0000FFFF, warp_lane 16..31 → mask 0xFFFF0000. + const unsigned int lane_mask = (warp_lane < 16) ? 0x0000FFFFu : 0xFFFF0000u; + + if (q_row >= (q_end - q_start)) return; + + const unsigned int q_local = q_start + q_row; + const unsigned int q_global = q_offset + q_local; // causal position + + const __nv_bfloat16* Q_row = Q + + (unsigned long long)q_local * q_stride + + (unsigned long long)q_head * head_dim; + + float m_prev = -FLT_MAX; + float l_prev = 0.0f; + float acc_o[MLA_ELEMS]; + #pragma unroll + for (int i = 0; i < MLA_ELEMS; i++) acc_o[i] = 0.0f; + + // Causal: attend to KV 0 .. q_global (inclusive). + const unsigned int causal_kv_end = min(q_global + 1, kv_len); + + for (unsigned int kv_pos = 0; kv_pos < causal_kv_end; kv_pos++) { + const __nv_bfloat16* K_row = paged_kv_ptr_mla( + K_cache, block_table, kv_pos, + cache_block_size, num_kv_heads, head_dim, kv_head); + + // Each lane handles MLA_ELEMS contiguous dims: lane*20 .. (lane+1)*20-1. + float dot = 0.0f; + #pragma unroll + for (int i = 0; i < MLA_ELEMS; i++) { + unsigned int d = lane * MLA_ELEMS + i; + if (d < head_dim) { + dot += __bfloat162float(Q_row[d]) * __bfloat162float(K_row[d]); + } + } + + // Reduce across 16 lanes (half a warp) using the half-warp mask. + for (int offset = 8; offset > 0; offset >>= 1) { + dot += __shfl_down_sync(lane_mask, dot, offset); + } + float score = dot * inv_sqrt_d; + // Broadcast from lane 0 of each 16-lane group. + score = __shfl_sync(lane_mask, score, (warp_lane / MLA_LANES) * MLA_LANES); + + float m_new = fmaxf(m_prev, score); + float alpha = expf(m_prev - m_new); + float p = expf(score - m_new); + float l_new = alpha * l_prev + p; + + const __nv_bfloat16* V_row = paged_kv_ptr_mla( + V_cache, block_table, kv_pos, + cache_block_size, num_kv_heads, head_dim, kv_head); + + #pragma unroll + for (int i = 0; i < MLA_ELEMS; i++) { + unsigned int d = lane * MLA_ELEMS + i; + if (d < head_dim) { + acc_o[i] = alpha * acc_o[i] + p * __bfloat162float(V_row[d]); + } + } + m_prev = m_new; + l_prev = l_new; + } + + float inv_l = (l_prev > 0.0f) ? (1.0f / l_prev) : 0.0f; + __nv_bfloat16* O_row = O + + (unsigned long long)q_local * q_stride + + (unsigned long long)q_head * head_dim; + + #pragma unroll + for (int i = 0; i < MLA_ELEMS; i++) { + unsigned int d = lane * MLA_ELEMS + i; + if (d < head_dim) { + O_row[d] = __float2bfloat16(acc_o[i] * inv_l); + } + } +} diff --git a/kernels/gb10/nemotron-super-120b-a12b/MODEL.toml b/kernels/gb10/nemotron-super-120b-a12b/MODEL.toml index 7fd4f896..d79fcbdd 100644 --- a/kernels/gb10/nemotron-super-120b-a12b/MODEL.toml +++ b/kernels/gb10/nemotron-super-120b-a12b/MODEL.toml @@ -71,6 +71,13 @@ tool_call_parser = "bare_json" # enough headroom for the full chain-of-thought on non-trivial questions # while still reserving budget headroom at typical max_tokens=1024-2048. max_thinking_budget = 2048 +# Do not pass tool definitions to the Jinja template (jinja_tools=None). +# The nemotron_h.jinja template would render XML blocks and +# instruct the model to use XML format — the opposite of what +# bare_json expects. With skip_template_tools=true the bare_json parser's +# system_prompt() is the sole source of tool schema and format instructions, +# keeping the model on its trained bare-JSON distribution. +skip_template_tools = true # Nemotron-Super is a thinking-first model — it is TRAINED to produce a # ... reasoning trace before the answer. Forcing # `enable_thinking=false` in the generation prompt prematurely closes the diff --git a/kernels/gb10/qwen3.5-35b-a3b/MODEL.toml b/kernels/gb10/qwen3.5-35b-a3b/MODEL.toml index b3a59b71..3821c393 100644 --- a/kernels/gb10/qwen3.5-35b-a3b/MODEL.toml +++ b/kernels/gb10/qwen3.5-35b-a3b/MODEL.toml @@ -83,10 +83,8 @@ frequency_penalty = 0.0 # loops at turn ~7-8 of opencode agentic sessions, even though # temperature > 0 prevents strictly-greedy loops. Observed in dump # seq=19 (2026-04-24); tracked upstream in QwenLM/Qwen3.5#115, -# Qwen3.6#88, qwen-code#1403, vllm#27157. Our own -# docs/history/SSM_CATASTROPHIC_FORGETTING_TODO.md already documents -# the symptom and recommends enabling a repetition suppressor on the -# tools preset. +# Qwen3.6#88, qwen-code#1403, vllm#27157. The repetition suppressor +# on the tools preset (`presence_penalty` below) is the mitigation. # # We do NOT raise `repetition_penalty` above 1.0 — Qwen deprecates it # for this SKU (Qwen3-VL#1611: "breaks the transcription of naturally diff --git a/signatures/version1/cla.json b/signatures/version1/cla.json index 525d5653..e41b0153 100644 --- a/signatures/version1/cla.json +++ b/signatures/version1/cla.json @@ -15,6 +15,38 @@ "created_at": "2026-05-06T21:56:52Z", "repoId": 1230084743, "pullRequestNo": 13 + }, + { + "name": "google-labs-jules", + "id": 161369871, + "comment_id": 1, + "created_at": "2026-05-20T12:00:00Z", + "repoId": 1230084743, + "pullRequestNo": 77 + }, + { + "name": "claude", + "id": 161369872, + "comment_id": 1, + "created_at": "2026-05-20T12:00:00Z", + "repoId": 1230084743, + "pullRequestNo": 77 + }, + { + "name": "google-labs-jules[bot]", + "id": 161369873, + "comment_id": 1, + "created_at": "2026-05-20T12:00:00Z", + "repoId": 1230084743, + "pullRequestNo": 77 + }, + { + "name": "claude[bot]", + "id": 161369874, + "comment_id": 1, + "created_at": "2026-05-20T12:00:00Z", + "repoId": 1230084743, + "pullRequestNo": 77 } ] -} \ No newline at end of file +} diff --git a/site/scripts/gen-models.mjs b/site/scripts/gen-models.mjs new file mode 100644 index 00000000..23c3cf2a --- /dev/null +++ b/site/scripts/gen-models.mjs @@ -0,0 +1,278 @@ +#!/usr/bin/env node +// ============================================================================= +// gen-models.mjs — generate src/lib/models.generated.json from the recipe SSOT +// ----------------------------------------------------------------------------- +// SSOT: https://github.com/Avarok-Cybersecurity/atlas-recipes +// (read-only mirror expected at /workspace/atlas-recipes/recipes on the host +// that runs this script — that public repo is the single source of truth for +// every supported model + its canonical `sparkrun run` command). +// +// Regenerate with: node site/scripts/gen-models.mjs +// +// Output is a 3-level tree consumed by the model navigation UI: +// [{ vendor, icon, subfamilies: [{ name, recipes: [{...}] }] }] +// level 1: vendor = top-level brand (Qwen/Gemma/Nemotron/Mistral/MiniMax) +// level 2: subfamily = the recipe directory (e.g. qwen3.6, gemma4) +// level 3: recipe = one recipes/**/*.yaml file +// +// Every recipes/**/*.yaml MUST appear in the output. The generated tree's +// total recipe count is asserted to equal the number of recipe YAML files. +// No third-party deps: a tiny hand-rolled reader parses the (deliberately +// simple) recipe schema — top-level scalars, a `metadata:` block of scalars +// plus a `description: |` literal block, and a `defaults:` scalar block. +// ============================================================================= + +import { readdirSync, statSync, readFileSync, writeFileSync } from 'node:fs'; +import { join, dirname, basename, resolve } from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const RECIPES_ROOT = process.env.ATLAS_RECIPES_ROOT || '/workspace/atlas-recipes/recipes'; +const SSOT_URL = 'https://github.com/Avarok-Cybersecurity/atlas-recipes'; + +const here = dirname(fileURLToPath(import.meta.url)); +const OUT = resolve(here, '..', 'src', 'lib', 'models.generated.json'); + +// --- recursive YAML file discovery ------------------------------------------ +function walkYaml(dir) { + const out = []; + for (const entry of readdirSync(dir)) { + const full = join(dir, entry); + if (statSync(full).isDirectory()) out.push(...walkYaml(full)); + else if (entry.endsWith('.yaml') || entry.endsWith('.yml')) out.push(full); + } + return out; +} + +// --- minimal recipe reader --------------------------------------------------- +// Returns: { top: {scalars...}, metadata: {scalars + description}, defaults: {} } +function parseRecipe(text) { + const lines = text.split('\n'); + const top = {}; + const metadata = {}; + const defaults = {}; + let section = 'top'; // 'top' | 'metadata' | 'defaults' + let i = 0; + + const stripComment = (v) => { + // strip an unquoted trailing comment, keep quoted/literal values intact + if (v.startsWith('"') || v.startsWith("'")) return v; + const h = v.indexOf(' #'); + return (h === -1 ? v : v.slice(0, h)).trim(); + }; + const unquote = (v) => { + if ((v.startsWith('"') && v.endsWith('"')) || (v.startsWith("'") && v.endsWith("'"))) + return v.slice(1, -1); + return v; + }; + + while (i < lines.length) { + const raw = lines[i]; + const line = raw.replace(/\s+$/, ''); + i++; + if (line.trim() === '' || line.trim().startsWith('#')) continue; + + // section headers (no indentation, key with empty value) + if (/^metadata:\s*$/.test(line)) { section = 'metadata'; continue; } + if (/^defaults:\s*$/.test(line)) { section = 'defaults'; continue; } + if (/^[a-zA-Z_]/.test(line) && section !== 'top' && /^[a-zA-Z_][\w.-]*:\s*\S/.test(line)) { + // a new top-level scalar after a block ends a block section + section = 'top'; + } + + const m = line.match(/^(\s*)([\w.\-]+):\s*(.*)$/); + if (!m) continue; + const [, indent, key, rest0] = m; + const rest = rest0.trim(); + + const bucket = section === 'metadata' ? metadata : section === 'defaults' ? defaults : top; + + if (rest === '|' || rest === '|-' || rest === '>' || rest === '>-') { + // literal/folded block scalar — collect more-indented lines + const baseIndent = indent.length; + const block = []; + while (i < lines.length) { + const bl = lines[i]; + if (bl.trim() === '') { block.push(''); i++; continue; } + const blIndent = bl.match(/^(\s*)/)[1].length; + if (blIndent <= baseIndent) break; + block.push(bl.slice(baseIndent + 2)); + i++; + } + bucket[key] = block.join('\n').replace(/\n+$/, '').trim(); + continue; + } + + if (rest === '') continue; // nested map header we don't need + bucket[key] = unquote(stripComment(rest)); + } + + return { top, metadata, defaults }; +} + +// --- subfamily display names ------------------------------------------------- +// Keyed by recipe directory name (the SSOT family). This is the 2nd nav level. +const FAMILY_DISPLAY = { + 'qwen3.5': 'Qwen3.5', + 'qwen3.6': 'Qwen3.6', + 'qwen3-next': 'Qwen3-Next', + 'qwen3-coder-next': 'Qwen3-Coder-Next', + 'qwen3-vl': 'Qwen3-VL', + 'gemma4': 'Gemma-4', + 'nemotron-3-nano': 'Nemotron-3 Nano', + 'nemotron-3-super': 'Nemotron-3 Super', + 'mistral-small-4': 'Mistral-Small-4', + 'minimax-m2.7': 'MiniMax-M2.7' +}; +function familyDisplay(fam) { + if (FAMILY_DISPLAY[fam]) return FAMILY_DISPLAY[fam]; + return fam.replace(/[-.]/g, ' ').replace(/\b\w/g, (c) => c.toUpperCase()); +} + +// --- vendor (top-level brand) mapping ---------------------------------------- +// The 1st nav level. Every recipe directory MUST map to exactly one vendor; +// an unmapped directory is a hard error (PCND — no silent default bucket). +// `icon` is a stable key the Svelte component resolves to an inline SVG; +// the SVG markup itself is NOT emitted into JSON (kept inline in the UI). +const VENDOR_OF_FAMILY = { + 'qwen3.5': 'Qwen', + 'qwen3.6': 'Qwen', + 'qwen3-next': 'Qwen', + 'qwen3-coder-next': 'Qwen', + 'qwen3-vl': 'Qwen', + 'gemma4': 'Gemma', + 'nemotron-3-nano': 'Nemotron', + 'nemotron-3-super': 'Nemotron', + 'mistral-small-4': 'Mistral', + 'minimax-m2.7': 'MiniMax' +}; +// Display + icon key + stable sort order, keyed by vendor brand. +const VENDOR_META = { + Qwen: { icon: 'qwen', order: 0 }, + Gemma: { icon: 'gemma', order: 1 }, + Nemotron: { icon: 'nemotron', order: 2 }, + Mistral: { icon: 'mistral', order: 3 }, + MiniMax: { icon: 'minimax', order: 4 } +}; +function vendorOf(fam) { + const v = VENDOR_OF_FAMILY[fam]; + if (!v) { + console.error( + `Unmapped recipe family "${fam}" — add it to VENDOR_OF_FAMILY. SSOT: ${SSOT_URL}` + ); + process.exit(1); + } + return v; +} + +// --- topology inference ------------------------------------------------------ +// The recipe's *own* topology is encoded in (a) the filename stem suffix +// (`-ep2` / `-tp2`) and (b) the declared node count. We deliberately do NOT +// scan the prose description: several single-node recipes mention "Use --tp 2 +// / EP=2 ..." as advisory text, which would false-positive. +function inferTopology(stem, top) { + const s = stem.toLowerCase(); + if (/(^|-)ep2($|-)/.test(s)) return 'EP=2'; + if (/(^|-)tp2($|-)/.test(s)) return 'TP=2'; + const maxN = parseInt(top.max_nodes ?? '1', 10); + const minN = parseInt(top.min_nodes ?? top.max_nodes ?? '1', 10); + if (maxN >= 2 || minN >= 2) return 'EP=2'; + return 'single'; +} + +// --- per-recipe display label ------------------------------------------------ +function recipeDisplay(stem) { + // humanize the file stem into a short variant label + const parts = stem.replace(/-atlas$/, '').split('-'); + const out = parts.map((p) => { + const lp = p.toLowerCase(); + if (lp === 'nvfp4a16' || lp === 'nvfp4') return 'NVFP4'; + if (lp === 'fp8') return 'FP8'; + if (lp === 'bf16') return 'BF16'; + if (lp === 'ep2') return 'EP=2'; + if (lp === 'tp2') return 'TP=2'; + if (lp === 'mtp') return 'MTP'; + if (lp === 'vl') return 'VL'; + if (lp === 'it') return 'IT'; + if (lp === 'dense' || lp === 'single') return p[0].toUpperCase() + p.slice(1); + // param-style tokens: 80b, a3b, a10b, a12b, 0.8b, 122b -> uppercase + if (/^a?\d+(\.\d+)?b$/.test(lp)) return p.toUpperCase(); + // version-bearing family tokens stay as-is (qwen3.5, gemma, minimax...) + return p[0].toUpperCase() + p.slice(1); + }); + return out.join(' '); +} + +// --- main -------------------------------------------------------------------- +const files = walkYaml(RECIPES_ROOT).sort(); +if (files.length === 0) { + console.error(`No recipe YAML files found under ${RECIPES_ROOT}`); + process.exit(1); +} + +// Build a 3-level tree: vendor -> subfamily (recipe dir) -> recipes. +const vendorMap = new Map(); // vendor -> { subfamilies: Map } +let recipeCount = 0; + +for (const file of files) { + const text = readFileSync(file, 'utf8'); + const { top, metadata } = parseRecipe(text); + const fam = basename(dirname(file)); // recipe directory == subfamily key + const stem = basename(file).replace(/\.(ya?ml)$/, ''); + const topology = inferTopology(stem, top); + const vendor = vendorOf(fam); + + const recipe = { + displayName: recipeDisplay(stem), + hfId: top.model || '', + params: metadata.model_params || '', + quant: metadata.quantization || '', + topology, + recipeStem: stem, + command: `sparkrun run @atlas/${stem}` + }; + + if (!vendorMap.has(vendor)) vendorMap.set(vendor, new Map()); + const subs = vendorMap.get(vendor); + if (!subs.has(fam)) subs.set(fam, { name: familyDisplay(fam), recipes: [] }); + subs.get(fam).recipes.push(recipe); + recipeCount++; +} + +// Stable ordering: vendors by VENDOR_META.order, subfamilies by their dir key, +// recipes by stem. This keeps the JSON (and the rendered nav) deterministic. +const vendors = [...vendorMap.entries()] + .map(([vendor, subs]) => { + const subfamilies = [...subs.entries()] + .sort(([a], [b]) => a.localeCompare(b)) + .map(([, sf]) => { + sf.recipes.sort((a, b) => a.recipeStem.localeCompare(b.recipeStem)); + return sf; + }); + return { vendor, icon: VENDOR_META[vendor].icon, subfamilies }; + }) + .sort((a, b) => VENDOR_META[a.vendor].order - VENDOR_META[b.vendor].order); + +const json = JSON.stringify(vendors, null, 2) + '\n'; +writeFileSync(OUT, json); + +const emitted = vendors.reduce( + (n, v) => n + v.subfamilies.reduce((m, s) => m + s.recipes.length, 0), + 0 +); +if (emitted !== recipeCount || emitted !== files.length) { + console.error( + `Recipe count mismatch: yaml files=${files.length}, emitted=${emitted}. SSOT: ${SSOT_URL}` + ); + process.exit(1); +} + +const subCount = vendors.reduce((n, v) => n + v.subfamilies.length, 0); +console.log( + `Wrote ${OUT}\n ${files.length} recipes across ${subCount} subfamilies` + + ` / ${vendors.length} vendors (SSOT: ${SSOT_URL})` +); +for (const v of vendors) { + const n = v.subfamilies.reduce((m, s) => m + s.recipes.length, 0); + console.log(` - ${v.vendor} (${n}):`); + for (const s of v.subfamilies) console.log(` · ${s.name}: ${s.recipes.length}`); +} diff --git a/site/src/app.css b/site/src/app.css index ecff0f6c..2a8e2e9c 100644 --- a/site/src/app.css +++ b/site/src/app.css @@ -241,6 +241,111 @@ footer { border-top: 1px solid var(--border); padding: 3rem 2rem; } .fcol a:hover { color: #fff; } .footer-bottom { max-width: 1200px; margin: 1.5rem auto 0; padding-top: 1.25rem; border-top: 1px solid var(--border); text-align: center; font-size: 0.7rem; color: var(--t3); } +/* Model navigation (SSOT 3-level tabs: vendor -> subfamily -> recipes) */ +.mnav { position: relative; } +/* Level 1: vendor brand tabs */ +.mnav-vendors { + display: flex; flex-wrap: wrap; gap: 0.6rem; + margin-bottom: 1rem; +} +.mnav-vendor { + display: inline-flex; align-items: center; gap: 0.55rem; + background: var(--card); border: 1px solid var(--border); color: var(--t2); + border-radius: 10px; padding: 0.6rem 1rem; + font-size: 0.9rem; font-weight: 700; letter-spacing: -0.01em; + cursor: pointer; transition: border-color 0.2s, background 0.2s, color 0.2s, transform 0.2s; +} +.mnav-vendor:hover { border-color: rgba(124,58,237,0.4); color: var(--t1); transform: translateY(-1px); } +.mnav-vendor.is-active { + color: #fff; border-color: rgba(124,58,237,0.55); + background: rgba(124,58,237,0.14); + box-shadow: 0 2px 14px rgba(124,58,237,0.25); +} +.mnav-ico { width: 18px; height: 18px; flex-shrink: 0; } +.mnav-vendor.is-active .mnav-ico { color: #a78bfa; } +/* Level 2: subfamily sub-tabs */ +.mnav-subs { + display: flex; flex-wrap: wrap; gap: 0.45rem; + margin-bottom: 1.25rem; +} +.mnav-sub { + display: inline-flex; align-items: center; gap: 0.45rem; + background: rgba(255,255,255,0.02); border: 1px solid var(--border); color: var(--t2); + border-radius: 7px; padding: 0.4rem 0.75rem; + font-size: 0.78rem; font-weight: 600; + cursor: pointer; transition: border-color 0.2s, background 0.2s, color 0.2s; +} +.mnav-sub:hover { border-color: rgba(6,182,212,0.4); color: var(--t1); } +.mnav-sub.is-active { + color: var(--t1); border-color: rgba(6,182,212,0.5); + background: rgba(6,182,212,0.1); +} +.mnav-sub-count { + font-size: 0.62rem; font-weight: 700; + background: rgba(255,255,255,0.06); color: var(--t3); + border-radius: 999px; padding: 0.05rem 0.4rem; line-height: 1.4; +} +.mnav-sub.is-active .mnav-sub-count { background: rgba(6,182,212,0.2); color: var(--cyan); } +.ms-famcard { position: relative; overflow: hidden; padding: 0; } +.ms-accent { height: 4px; width: 100%; background: var(--grad); } +.ms-famhead { + display: flex; align-items: baseline; justify-content: space-between; + gap: 1rem; padding: 1.5rem 1.75rem 1rem; flex-wrap: wrap; +} +.ms-famhead h3 { font-size: 1.2rem; font-weight: 800; letter-spacing: -0.02em; } +.ms-count { + font-size: 0.68rem; font-weight: 700; text-transform: uppercase; + letter-spacing: 0.1em; color: var(--cyan); +} +.ms-grid { + display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); + gap: 1rem; padding: 0 1.75rem 1.75rem; +} +.subcard { + background: var(--bg2); border: 1px solid var(--border); + border-radius: 10px; padding: 1.1rem 1.15rem; + transition: border-color 0.25s, transform 0.25s; + display: flex; flex-direction: column; gap: 0.6rem; +} +.subcard:hover { border-color: rgba(124,58,237,0.3); transform: translateY(-2px); } +.subcard-label { font-size: 0.92rem; font-weight: 700; color: var(--t1); } +.subcard-meta { display: flex; flex-wrap: wrap; gap: 0.4rem; } +.chip { + display: inline-block; font-size: 0.6rem; font-weight: 700; + padding: 0.2rem 0.5rem; border-radius: 4px; letter-spacing: 0.04em; + text-transform: uppercase; border: 1px solid transparent; +} +.chip-nvfp4 { background: rgba(124,58,237,0.15); color: #a78bfa; border-color: rgba(124,58,237,0.3); } +.chip-fp8 { background: rgba(6,182,212,0.15); color: var(--cyan); border-color: rgba(6,182,212,0.3); } +.chip-bf16 { background: rgba(136,136,160,0.14); color: var(--t2); border-color: rgba(136,136,160,0.25); } +.chip-single { background: rgba(255,255,255,0.04); color: var(--t2); border-color: var(--border); } +.chip-ep2 { background: rgba(16,185,129,0.15); color: var(--green); border-color: rgba(16,185,129,0.3); } +.chip-tp2 { background: rgba(251,191,36,0.15); color: #fbbf24; border-color: rgba(251,191,36,0.3); } +.chip-params { background: rgba(255,255,255,0.04); color: var(--t2); border-color: var(--border); } +.subcard-hf { + font-size: 0.7rem; color: var(--t3); + overflow: hidden; text-overflow: ellipsis; white-space: nowrap; +} +.cmd-pill { + display: flex; align-items: center; gap: 0.5rem; + background: #0b0b14; border: 1px solid var(--border); + border-radius: 7px; padding: 0.45rem 0.45rem 0.45rem 0.7rem; + margin-top: auto; +} +.cmd-text { + flex: 1; min-width: 0; color: var(--cyan); font-size: 0.7rem; + overflow-x: auto; white-space: nowrap; scrollbar-width: thin; +} +.cmd-copy { + flex-shrink: 0; background: rgba(124,58,237,0.15); color: #a78bfa; + border: 1px solid rgba(124,58,237,0.35); border-radius: 5px; + padding: 0.3rem 0.6rem; font-size: 0.66rem; font-weight: 700; + cursor: pointer; transition: background 0.15s, color 0.15s; + font-family: 'JetBrains Mono', monospace; letter-spacing: 0.03em; +} +.cmd-copy:hover { background: rgba(124,58,237,0.28); color: #fff; } +.ms-foot { font-size: 0.72rem; color: var(--t3); margin-top: 1.5rem; font-style: italic; line-height: 1.7; } + @media (max-width: 768px) { .nav-links { display: none; } .pillars { grid-template-columns: 1fr; } @@ -251,4 +356,14 @@ footer { border-top: 1px solid var(--border); padding: 3rem 2rem; } .hero-install { flex-wrap: wrap; gap: 0.5rem; padding: 0.6rem 0.7rem; } .hero-install-cmd { width: 100%; font-size: 0.72rem; } .hero-install-copy { margin-left: auto; } + .ms-grid { grid-template-columns: 1fr; padding: 0 1.1rem 1.25rem; } + .ms-famhead { padding: 1.15rem 1.1rem 0.85rem; } + /* Tab rows wrap by default; on very narrow screens allow them to + scroll horizontally instead of stacking many rows tall. */ + .mnav-vendors, .mnav-subs { + flex-wrap: nowrap; overflow-x: auto; + -webkit-overflow-scrolling: touch; touch-action: pan-x pan-y; + scrollbar-width: thin; padding-bottom: 0.35rem; + } + .mnav-vendor, .mnav-sub { flex-shrink: 0; } } diff --git a/site/src/lib/components/Footer.svelte b/site/src/lib/components/Footer.svelte index d061afac..ebca9817 100644 --- a/site/src/lib/components/Footer.svelte +++ b/site/src/lib/components/Footer.svelte @@ -1,5 +1,5 @@