Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
7d317a3
fix(streaming): signal finish_reason=length when a tool-loop guard ca…
tbraun96 May 23, 2026
d1ac545
fix(streaming): cancel scheduler when a loop guard fires, not just su…
tbraun96 May 23, 2026
7da054d
fix(streaming): detect & cancel in-think `<tool_call>` leak (Qwen3.6 …
tbraun96 May 23, 2026
1bb82ed
pre-refactor: Phase 2c precision + watchdog baseline
tbraun96 May 24, 2026
6097910
phase-A: remove all 13 always-on prompt injections
tbraun96 May 24, 2026
060ffe1
phase-B: vLLM-style stop-string holdback + per-request RepetitionDete…
tbraun96 May 24, 2026
bfa4666
hotfix(watchdog,budget): tune from phaseAB live opencode test
tbraun96 May 24, 2026
755db59
hotfix-2: cap orphan-tool-call suppression streak (kill at 256 tokens)
tbraun96 May 24, 2026
f46d9f4
phase-C-2 (part 1/2): extract pre-sample logits pipeline scaffold
tbraun96 May 24, 2026
622e8df
hotfix-2b: orphan-suppression check must run on every token, not only…
tbraun96 May 24, 2026
795d8d1
hotfix-3: content-loop watchdog must run on MTP path too
tbraun96 May 24, 2026
c6650d3
phase-C-2 (part 2/2): wire LogitsProcessor pipeline into all decode p…
tbraun96 May 24, 2026
d3adea7
realfix2: speculative xgrammar advance between verify positions (K>=2)
tbraun96 May 24, 2026
29a0a7c
phase-2c day-1: KV cache sweep — 18 configs across dgx1+dgx2, KV dime…
tbraun96 May 24, 2026
3580978
phase-2c day-2: kernel bisect infrastructure + 3 NEGATIVE bisects
tbraun96 May 24, 2026
ea44fe6
phase-2c day-3: NVFP4 weight checkpoint test — BREAKTHROUGH
tbraun96 May 24, 2026
72644aa
hotfix: stuck-in-tool-body watchdog (NVFP4 doom-loop fix)
tbraun96 May 24, 2026
303cbab
phase-2c day-3 audit: causal-pathway map of FP8 → NVFP4 dispatch points
tbraun96 May 24, 2026
0467d1e
phase-2c day-3 Bug #1 attempt: REVERT — kernel infrastructure not ready
tbraun96 May 24, 2026
8d2cc87
fp8-merge73: native FP8 SSM + byte-exact streaming + PR 73 qwen3_xml …
tbraun96 May 25, 2026
e99159d
grammar: qwen3_coder body uses any_text (matches XML wire format)
tbraun96 May 25, 2026
2a2500c
validator: reject empty 'command'/'cmd'/'script' for shell tools
tbraun96 May 25, 2026
eaaa269
tool_handlers: soft-pass empty-required-string validation errors
tbraun96 May 25, 2026
49bad35
kernel/moe_fp8: two-level FP32 accumulation (DeepGEMM pattern)
tbraun96 May 25, 2026
4fa47b6
grammar+sampler: Tier-0 EBNF + Tier-1 byte-counter mask for tool params
tbraun96 May 26, 2026
6f9d595
mission-12h: Tier-2 strict path/cmd validators + final mission report
tbraun96 May 26, 2026
8c296ea
fp8-drift: o_proj W8A8 N/K fix + GPU dequant kernel + BF16 MoE path
tbraun96 May 28, 2026
d03197c
opencode-fix: relative-path validator + WS-mask newline exclusion + d…
tbraun96 May 28, 2026
6608824
diag: complete per-step logit dump (ATLAS_LOGIT_DUMP)
tbraun96 May 28, 2026
25f8bbe
fp8-prefix-cache: fix exact-hit SSM double-advance (Marconi snapshot)
tbraun96 May 29, 2026
367846f
mtp+moe: fix MTP 0%-accept (fp32-residual dtype bug) + BF16 router (t…
tbraun96 May 29, 2026
d7a4da8
tool-parser: server-side write-path drift recovery (ATLAS_WRITE_PATH_…
tbraun96 May 29, 2026
b0779b9
residual: remove FP32-residual feature — BF16 residual stream always
tbraun96 May 29, 2026
9922d4c
fix(tool-recovery): recover FP8-drifted file-write tool calls (3 modes)
tbraun96 May 29, 2026
d0b95f1
fix(moe): route MTP K=2/K=3 verify through BF16 path when experts are…
tbraun96 May 30, 2026
d2eb167
test(toml_repair): add r105/r110/r4 TOML-shape probe tests
tbraun96 May 31, 2026
dc6ea50
debug(kernels): ATLAS_DEBUG_SYNC_KERNELS + ATLAS_DEBUG_NO_GRAPH diagn…
tbraun96 May 31, 2026
2db83dc
fix(attn): multi-seq O-proj BF16 branch for ATLAS_FP8_DEQUANT_ATTN_TO…
tbraun96 May 31, 2026
bb2b53f
chore(docker): fast-layer build helper Dockerfile.fencesalvage
tbraun96 May 31, 2026
f7525bd
feat(quant): ATLAS_FP8_DEQUANT_LAYERS — selective per-layer BF16 dequant
tbraun96 May 31, 2026
a970624
feat(agentic): BW1 bash-wandering / content-completeness watchdog
tbraun96 May 31, 2026
fd688ab
bench: opencode harness evidence trail (FP8 drift / BF16 lever / BW1 …
tbraun96 Jun 1, 2026
68c3c50
fix(tool-salvage): guard EOF-fence slice panic; repair C-style // com…
tbraun96 Jun 1, 2026
4521dc7
fix(loop-detect): tool calls are progress — stop spinning-detector ki…
tbraun96 Jun 1, 2026
c487bc4
fix(prefix-cache): recompute SSM over [snap_tok,total) when snapshot …
tbraun96 Jun 1, 2026
7e8e2d6
prefix-cache: ATLAS_NO_MARCONI_EXACT diagnostic gate + partial-hit re…
tbraun96 Jun 1, 2026
3d43e2f
webserver_ok F1-F5: bound runaway via post-think content cap + watchd…
tbraun96 Jun 2, 2026
bc9f694
fix(qwen3.6-fp8): rep_penalty 1.1->1.0 on sampler presets + tool-JSON…
tbraun96 Jun 3, 2026
52244ab
fix(qwen3.6-fp8): 10/10 webserver_ok MTP-on — delete tool-call band-a…
tbraun96 Jun 3, 2026
0ff94b5
perf(qwen3.6): phase-2 decode profiling — host-path stage timing + sp…
tbraun96 Jun 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions bench/nemotron_hf_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3
"""HF-transformers GPU oracle for the Nemotron-3-Nano per-layer divergence hunt.

Loads nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4, MANUALLY dequantizes every
NVFP4 weight to BF16 (transformers 5.8 has no NVFP4 backend in this image),
loads the dequantized state_dict into a fresh NemotronH model, feeds the EXACT
chat-rendered token IDs, and captures per-block hidden states + final norm +
logits in headerless little-endian f32 .bin -- the format the Atlas
ATLAS_NEMO_DUMP hook writes -- so the comparator can diff 1:1.

The dequantized BF16 graph run through `torch_forward` is the canonical
"intended math" oracle (no fused Triton kernels, no custom CUDA).

Env: MODEL (local snapshot path), OUT, PROMPT.
"""
import glob
import json
import os
import pathlib
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import mamba_ssm_stub # noqa: F401 installs pure-torch mamba_ssm + forces torch path
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

import numpy as np
import torch
from safetensors import safe_open
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

MODEL = os.environ["MODEL"]
OUT = pathlib.Path(os.environ.get("OUT", "/out"))
PROMPT = os.environ.get("PROMPT", "Please count from 1 to 30. Output every number.")
OUT.mkdir(parents=True, exist_ok=True)

# E2M1 FP4 code -> float value (sign-magnitude, 16 codes).
_E2M1 = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
dtype=torch.float32,
)


def dequant_nvfp4(packed, wscale, wscale2, group_size=16):
"""packed uint8 [O, K/2] -> bf16 [O, K]. wscale fp8e4m3 [O, K/16].
value = E2M1[nibble] * fp8(wscale) * f32(wscale2)."""
O, Khalf = packed.shape
K = Khalf * 2
lo = (packed & 0x0F).to(torch.long)
hi = ((packed >> 4) & 0x0F).to(torch.long)
codes = torch.empty(O, K, dtype=torch.long)
codes[:, 0::2] = lo
codes[:, 1::2] = hi
vals = _E2M1.to(codes.device)[codes] # [O, K] f32
s = wscale.to(torch.float32) # [O, K/16]
s = s.repeat_interleave(group_size, dim=1) # [O, K]
vals = vals * s * float(wscale2)
return vals.to(torch.bfloat16)


def load_dequant_state_dict():
"""Read every safetensors shard; dequant NVFP4 triples to a single
bf16 `weight`; pass dense tensors through; drop *_scale* sidecars."""
files = sorted(glob.glob(os.path.join(MODEL, "model-*.safetensors")))
raw = {}
for f in files:
with safe_open(f, "pt") as sf:
for k in sf.keys():
raw[k] = sf.get_tensor(k)
sd = {}
quant_bases = set()
for k in raw:
if k.endswith(".weight_scale"):
quant_bases.add(k[: -len(".weight_scale")])
for base in quant_bases:
w = raw[base + ".weight"]
ws = raw[base + ".weight_scale"]
ws2 = raw[base + ".weight_scale_2"]
sd[base + ".weight"] = dequant_nvfp4(w, ws, ws2)
skip_suffix = (".weight_scale", ".weight_scale_2", ".input_scale")
for k, v in raw.items():
if any(k.endswith(s) for s in skip_suffix):
continue
if k.endswith(".weight") and k[: -len(".weight")] in quant_bases:
continue # already dequantized above
sd[k] = v
return sd


def save_f32(name, t):
arr = t.detach().to(torch.float32).cpu().numpy().astype("<f4").ravel()
(OUT / name).write_bytes(arr.tobytes())
return arr


def main():
tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
msgs = [{"role": "user", "content": PROMPT}]
ids = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True)
if hasattr(ids, "input_ids"):
ids = ids["input_ids"]
if ids and isinstance(ids[0], list):
ids = ids[0]
ids = [int(x) for x in ids]
print("PROMPT_TOKEN_COUNT:", len(ids))
print("PROMPT_TOKEN_IDS:", ids)
(OUT / "token_ids.json").write_text(json.dumps(ids))
rendered = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
(OUT / "rendered_prompt.txt").write_text(rendered)
print("RENDERED_PROMPT_REPR:", repr(rendered))

print("Dequantizing NVFP4 -> BF16 state_dict ...", flush=True)
sd = load_dequant_state_dict()
print(f" state_dict: {len(sd)} tensors")

cfg = AutoConfig.from_pretrained(MODEL, trust_remote_code=True)
print("Building empty BF16 model ...", flush=True)
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True)
model = model.to_empty(device="cpu")
# `backbone.` prefix in checkpoint matches NemotronH module tree.
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
miss = [m for m in missing if "rotary" not in m and "inv_freq" not in m]
print(f" missing={len(miss)} unexpected={len(unexpected)}")
if miss[:8]:
print(" missing sample:", miss[:8])
if unexpected[:8]:
print(" unexpected sample:", unexpected[:8])
model = model.to(device="cuda", dtype=torch.bfloat16)
model.eval()

input_ids = torch.tensor([ids], device="cuda")
with torch.no_grad():
out = model(input_ids=input_ids, output_hidden_states=True, use_cache=False)

hs = out.hidden_states
print("NUM_HIDDEN_STATES:", len(hs), "(embed + 52 layers)")
last = -1
emb = save_f32("hf_embed.bin", hs[0][0, last])
print(f"hf_embed: norm={np.linalg.norm(emb):.4f}")
for i in range(1, len(hs)):
arr = save_f32(f"hf_L{i-1}.bin", hs[i][0, last])
if i - 1 < 4 or i - 1 >= len(hs) - 5:
print(f"hf_L{i-1}: norm={np.linalg.norm(arr):.4f}")

final_hidden = hs[-1][0, last].to(torch.bfloat16)
norm_f = model.backbone.norm_f
fn = norm_f(final_hidden.unsqueeze(0)).squeeze(0)
fn_arr = save_f32("hf_final_norm.bin", fn)
print(f"hf_final_norm: norm={np.linalg.norm(fn_arr):.4f}")

logits = out.logits[0, last]
save_f32("hf_logits.bin", logits)
top = torch.topk(logits.float(), 10)
top_list = [(int(i), float(v)) for i, v in zip(top.indices, top.values)]
print("HF_TOP10_LOGITS:", top_list)
(OUT / "top10.json").write_text(json.dumps(top_list))
print("DONE ->", OUT)


if __name__ == "__main__":
main()
158 changes: 158 additions & 0 deletions bench/phase2c-kv-sweep/CAUSAL-PATHWAY-AUDIT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Causal-Pathway Audit: FP8 → NVFP4 Double-Quantization in Atlas

**Subject:** `Qwen/Qwen3.6-35B-A3B-FP8` on Atlas (branch `fix/in-think-tool-call-leak`)
**Date:** 2026-05-24
**Mode:** read-only forensic audit

---

## 1. TL;DR

Atlas DOES dispatch canonical FP8 kernels for routed MoE experts and full-attention
QKVO of `Qwen3.6-35B-A3B-FP8` via `set_fp8_experts` / `set_fp8_weights`. The
"nvfp4" tag in `Selected kernel target: (sm_121, qwen3.6-35b-a3b, nvfp4)` is the
**kernel-bundle name**, not a runtime selector — FP8 ptr-table kernels live
inside that bundle (`kernels/gb10/qwen3.6-35b-a3b/nvfp4/moe_w4a16_grouped_gemm.cu:1182`).

The `quantize_to_nvfp4` lines in the boot log come from the subset of weights
that DO take the bad path. Findings, ranked by severity:

| # | Severity | Site | Impact |
|---|----------|------|--------|
| 1 | **HIGH** | SSM `in_proj_qkv` + `out_proj` decode: FP8 → BF16 → **NVFP4** | every linear-attn layer × every decode token; aligns with `project_qwen36_drift_gdn_clean.md` post-norm cliff and "L31-L39 deep-layer regression" memory entries |
| 2 | **MED** | MoE shared-expert weights duplicated as NVFP4 even when native-FP8 shared expert is wired | wasted memory + risk of wrong pointer in fallback paths (forward_batched k=1) |
| 3 | **MED** | MoE `gate` (router) projection: FP8 → BF16 → **NVFP4** | routing-decision noise; matches `project_qwen36_drift_moe_smoking_gun.md` "MoE expert routing diverges 8/8→3/8" |
| 4 | **LOW** | LM head: passed BF16-shaped pointer through `quantize_to_nvfp4` without FP8 dequant if checkpoint stores `lm_head` as FP8 (latent — most FP8 checkpoints leave lm_head as BF16) | catastrophic if triggered; usually dormant |
| 5 | **LOW** | MTP head `quantize_to_nvfp4` chain (BF16→NVFP4) for all projections regardless of FP8 native availability | only relevant with `--mtp-quantization nvfp4` |

Routed MoE experts and full-attention QKVO are NOT in the wrong path.

**Single highest-leverage fix:** Bypass the SSM NVFP4 round-trip in
`weight_loader/qwen35/load_layers/linear_attn_arms.rs` for `Fp8Dequanted`
(NVFP4 is built unconditionally at lines 176-192; the parallel FP8 prefill copy
at 213-235 is decode-blind).

---

## 2. Chain of decisions for one Qwen3.6-35B-A3B-FP8 request

```
[boot]
spark-server/main_modules/serve.rs:68 ptx_for_config("qwen3_6_moe", 2048) → nvfp4 bundle
spark-server/main_modules/serve.rs:83 "Selected kernel target ... nvfp4 (90 modules)" ← kernel-bundle name, NOT runtime quant
factory/build.rs:98 loader.load_layers(...) → qwen35::load_layers
[weight_loader/qwen35/load_layers.rs]
:70 detect_nvfp4_variant → returns Fp8Dequanted (nvfp4_detect.rs:51)
:76 quant_format = QuantFormat::Fp8
:81 native_fp8 = true
:128 ATLAS_FORCE_NVFP4_MOE? → false so skip_nvfp4_experts = true
:139 load_moe_qwen35(..., skip_routed_experts=true)
└─ ssm_qwen35.rs:75 load_moe_qwen35:
:89 gate = dense(...) ← FP8 byte ptr survives untouched here
:151 gate_nvfp4 = quantize_to_nvfp4(&moe_weights.gate, ...) ← BUG #3: lm_head router BF16→NVFP4, but gate ptr is FP8 bytes
:184 shared_expert = load_expert(...) → variant==Fp8Dequanted →
quantized_from_fp8 → BF16 → NVFP4 ← BUG #2 (NVFP4 shared_expert built, but unused)
:188 for e in routed: experts.push(NULL) ← OK
:160 MoeLayer::new(...)
:183 load_moe_qwen35_fp8_experts ← OK, routed experts FP8
:197 load_fp8_block_scaled_as_fp8weight(shared_expert/{gate,up,down}_proj) ← OK, FP8 shared expert
:215 moe_layer.set_fp8_experts(&fp8_experts, shared_fp8, gpu) ← FP8 path enabled
─ FullAttention layers ─
:226 LayerType::FullAttention if native_fp8 => ← FP8 attention arm taken
:255 load_qkvo_tp(load_fp8_proj) ← FP8 QKVO loaded zero-copy
:298 layer.set_fp8_weights(...) ← FP8 path enabled for full attn
─ LinearAttention (SSM) layers — 30 of 40 layers ─
:344 build_linear_attention_nvfp4 ← BUG #1: name says nvfp4 unconditionally
└─ linear_attn_arms.rs:147 load_ssm_qwen35 (Fp8Dequanted) → dense_auto → BF16 (good so far)
:176-190 quantize_to_nvfp4(qkvz_dense | out_proj) ← BUG #1 fires (decode path will be NVFP4)
:203-235 if Fp8Dequanted { bf16_to_fp8(...) for prefill_only } ← FP8 prefill path built in PARALLEL,
but only `set_fp8_prefill_only_weights`
installs it — decode still NVFP4
[factory/build.rs]
:101 lm_head = loader.load_lm_head(store, &config) → qwen35.rs:56 dense("lm_head.weight")
← BUG #4: dense, not dense_auto.
If checkpoint stores FP8 lm_head, raw FP8 bytes are passed downstream as BF16
:144 skip_lm_head_quantization() = false for qwen3.6
:148 quantize_to_nvfp4(&lm_head, ...) ← treats whatever the pointer is as BF16
:102 loader.load_mtp_weights_multi
:120 effective_mtp_quant = checkpoint ignores mtp.* ? Bf16 : mtp_quant ← OK-ish gate
[per-decode-step runtime]
layers/moe/forward_prefill.rs:27 self.fp8_gate_weight_ptrs.is_some() ? FP8 path : NVFP4 path
← chooses FP8 for routed experts (correct)
layers/moe/forward.rs:217 (same) ← decode FP8 path taken (correct)
Qwen3SsmLayer::forward_decode uses qkvz_nvfp4 / out_proj_nvfp4 ← BUG #1: SSM decode is always NVFP4
```

---

## 3. Per-bug table

| Bug | File:line | Currently dispatched | Canonical kernel that exists | Difficulty | Expected impact |
|-----|-----------|---------------------|------------------------------|-----------|-----------------|
| **#1 SSM decode NVFP4** | `weight_loader/qwen35/load_layers/linear_attn_arms.rs:176-190` | `quantize_to_nvfp4(qkvz_dense, ...)` and `out_proj` → NVFP4 used for ALL paths except a parallel FP8-prefill-only override | `load_fp8_block_scaled_as_fp8weight` + `w8a16_gemv` / `w8a16_gemm` already used by `build_linear_attention_fp8` (file marks itself "currently unused", `load_layers.rs:343` always takes `_nvfp4` arm) | **Medium**: build_linear_attention_fp8 already exists at line 24 of same file. Need to: (a) flip dispatch to fp8 arm when `variant==Fp8Dequanted`, (b) verify `set_fp8_weights` on `Qwen3SsmLayer` reaches the SSM decode/verify GEMV path. | High. Memory `project_qwen36_phase2b_softmax_expf.md` already attributes deep-layer regression to FP8-KV+NVFP4-weight noise on out_proj. Eliminating the BF16→NVFP4 step on `out_proj` should match the existing `ATLAS_GDN_BF16_WEIGHTS=1` benefit but with FP8 precision (lower memory). |
| **#2 dead NVFP4 shared-expert** | `weight_map/ssm_qwen35.rs:184` + `loaders_moe.rs:32-60` | Loaded via `quantized_from_fp8` → BF16 → NVFP4 then never consumed (forward_prefill_fp8 + forward.rs use `fp8_shared_expert`) | n/a — fix is to elide the load when `native_fp8 && !force_nvfp4_moe` | **Trivial**: thread a `skip_shared_expert: bool` like `skip_routed_experts` already does. | Saves a few hundred MB and one source of quant noise if any fallback path ever consults the NVFP4 shared expert. |
| **#3 router gate FP8→NVFP4** | `weight_loader/qwen35/load_layers.rs:151-159` + `ssm_qwen35.rs:89` | `dense(...)` returns raw FP8 bytes; `quantize_to_nvfp4` then treats them as BF16 if the checkpoint stores `gate.weight` as FP8. For Qwen3.6 FP8 the gate is typically **BF16** in checkpoint (small enough to leave alone) — verify with `WeightDtype` of `mlp.gate.weight`. If it IS FP8 in checkpoint, this is silently miscomputing the entire routing distribution. | `dense_auto` (already exists) — would correctly dequant FP8 to BF16 before NVFP4 quantization. | **Trivial**: swap `dense` → `dense_auto`. | If the gate is FP8 in checkpoint, this is the root of the routing divergence reported in `project_qwen36_drift_moe_smoking_gun.md` (8/8→3/8 expert overlap collapse). Even if gate is BF16 in this checkpoint, applying `dense_auto` everywhere is the defensive fix. |
| **#4 lm_head FP8 latent corruption** | `weight_loader/qwen35.rs:60-68` + `factory/build.rs:148` | `dense(store, "lm_head.weight")` returns raw pointer regardless of dtype; if the checkpoint's `lm_head` is FP8E4M3 + has `weight_scale_inv`, the FP8 bytes are pumped into `quantize_to_nvfp4` as if BF16 — catastrophic top-token corruption. | `dense_auto` (already exists). | **Trivial**: swap `dense` → `dense_auto` in `load_lm_head` for all FP8-capable loaders (qwen35, qwen3, qwen35_dense). | Latent. Qwen3.6 FP8 typically ships lm_head as BF16; if it does, no current impact. But it's a silent footgun for the next FP8 checkpoint that quantizes lm_head. |
| **#5 MTP head BF16→NVFP4** | `layers/mtp_head/new.rs:44-53` + `:81-118` | Every MTP projection goes through `quantize_to_nvfp4` when `quant==Nvfp4`. For FP8 source weights this is the same FP8→BF16→NVFP4 chain as the main model. | An `MtpQuantization::Fp8` variant exists (mtp_head.rs:29, line 180 in new.rs) but the user must opt in via `--mtp-quantization fp8`. | **Low**: default `--mtp-quantization` based on `native_fp8`. | Only matters for users who turn on `--speculative`. Spec drafts produced from a doubly-quantized MTP head would amplify Bug #1's drift into wholesale rejection. |
| **Dead code** | `weight_map/ssm_qwen35.rs:262-279` | `_shared_fp8 = ...` is computed and discarded (sigil `_`). The caller at `qwen35/load_layers.rs:197-214` re-loads the same tensors. | n/a — duplicate I/O. | Trivial: return shared_fp8 from `load_moe_qwen35_fp8_experts` or delete the bind. | None on correctness; ~2× the FP8 shared-expert load time. |

---

## 4. Recommended fix order

1. **Bug #1** (SSM decode NVFP4). Highest expected quality lift. The plumbing
(`build_linear_attention_fp8`, `Qwen3SsmLayer::set_fp8_weights`,
`qkvz_fp8 + out_fp8` SSM GEMV) is **already implemented**; it's behind a
dead-coded gate (`load_layers.rs:334-342` comment: "permanently
short-circuited"). Re-enabling it is a one-line dispatch flip plus removal
of the parallel `_nvfp4` build in `build_linear_attention_nvfp4`. Verify
against `ATLAS_GDN_BF16_WEIGHTS=1` numerics — they should match or beat the
BF16 dense path while saving ~2× weight memory.
2. **Bug #3** (`dense` → `dense_auto` on router gate). Even if the gate is
BF16 in *this* checkpoint, this defensive fix prevents the silent FP8-as-BF16
misread that has already been observed bites elsewhere
(cf. `project_qwen36_numerical_drift_2026_05_23.md`).
3. **Bug #4** (LM head `dense_auto`). Same one-line fix, eliminates a class of
latent corruption for any future FP8 checkpoint that quantizes the head.
4. **Bug #2** (dead NVFP4 shared-expert). Memory savings, cleanliness.
5. **Bug #5** (default `--mtp-quantization` to `fp8` when `native_fp8`). Only
after Bug #1 — otherwise the MTP head still feeds an NVFP4-corrupted target.

---

## 5. Non-quant red flags noticed in passing

- `weight_map/ssm_qwen35.rs:263` — `_shared_fp8` discarded (Dead code).
- `linear_attn_arms.rs:267-269` — installs FP8 prefill weights via
`set_fp8_prefill_only_weights`. Decode is never routed to FP8 SSM weights
even though the FP8 buffers exist (lines 215-232).
- `attention_arms.rs:83` collapses `Standard | Fp8Dequanted | Bf16Raw` to one
arm. Unreachable for Qwen3.5 (peeled off at `load_layers.rs:226`), so it's
load-bearing only for non-FP8 checkpoints. Brittle: a future caller that
enters with `Fp8Dequanted && !native_fp8` (e.g. an `ATLAS_FORCE_NVFP4_MOE`
variant that spilled into attention) would read FP8 bytes as BF16. Add a
`debug_assert!`.
- KV cache: `attn_layer_dtypes` is independent of weight quant. The FP8-KV
cliff at L35–L39 in `project_qwen36_phase2b_softmax_expf.md` interacts
with Bug #1: SSM-decode NVFP4 noise compounds with FP8-KV rounding at deep
layers, so fixing Bug #1 should reduce but not eliminate that regression.
- `factory/build.rs:144-160` always BF16→NVFP4 the LM head when
`!skip_lm_head_quantization()`. Atlas has `gemv_fp8w`; an FP8 lm_head
could stay FP8.

---

## Appendix: Canonical FP8 kernels (cross-reference)

| Kernel | File | Used by |
|--------|------|---------|
| `moe_fp8_grouped_gemm_ptrtable_t/k/v2_k` | `kernels/gb10/qwen3.6-35b-a3b/nvfp4/moe_w4a16_grouped_gemm.cu:1182` | `forward_prefill_fp8.rs`, `forward_batched.rs`, `forward_k2/k3.rs` |
| `moe_expert_gate_up_shared_fp8` / `moe_expert_silu_down_shared_fp8` | `kernels/gb10/common/` | `forward.rs:217-260` |
| `w8a16_gemv` / `w8a16_gemm` (FP8w × BF16a) | `kernels/gb10/common/gemv_fp8w` | `Qwen3AttentionLayer::set_fp8_weights`, FP8 shared expert |
| `bf16_to_fp8` | `kernels/gb10/common/w4a16` | `linear_attn_arms.rs:213` — SSM prefill-only |
| `fp8_gemm_n128` | `kernels/gb10/common/` | SSM prefill (Bug #1 partial mitigation) |

Asymmetry: every full-attention layer and every routed MoE expert already runs
on canonical FP8. Only SSM layers (30/40), the router gate, the dead NVFP4
shared expert, and any FP8 lm_head sit wrong. Fix Bug #1 and the runtime
profile changes from "10 FP8 + 30 NVFP4-from-FP8 layers" to "40 FP8 layers".
Loading
Loading