-
Notifications
You must be signed in to change notification settings - Fork 65
fix(coherence): Debugging to Get Qwen Working In Agentic Coding #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tbraun96
wants to merge
50
commits into
main
Choose a base branch
from
fix/in-think-tool-call-leak
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+145,840
−7,073
Open
Changes from 20 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
7d317a3
fix(streaming): signal finish_reason=length when a tool-loop guard ca…
tbraun96 d1ac545
fix(streaming): cancel scheduler when a loop guard fires, not just su…
tbraun96 7da054d
fix(streaming): detect & cancel in-think `<tool_call>` leak (Qwen3.6 …
tbraun96 1bb82ed
pre-refactor: Phase 2c precision + watchdog baseline
tbraun96 6097910
phase-A: remove all 13 always-on prompt injections
tbraun96 060ffe1
phase-B: vLLM-style stop-string holdback + per-request RepetitionDete…
tbraun96 bfa4666
hotfix(watchdog,budget): tune from phaseAB live opencode test
tbraun96 755db59
hotfix-2: cap orphan-tool-call suppression streak (kill at 256 tokens)
tbraun96 f46d9f4
phase-C-2 (part 1/2): extract pre-sample logits pipeline scaffold
tbraun96 622e8df
hotfix-2b: orphan-suppression check must run on every token, not only…
tbraun96 795d8d1
hotfix-3: content-loop watchdog must run on MTP path too
tbraun96 c6650d3
phase-C-2 (part 2/2): wire LogitsProcessor pipeline into all decode p…
tbraun96 d3adea7
realfix2: speculative xgrammar advance between verify positions (K>=2)
tbraun96 29a0a7c
phase-2c day-1: KV cache sweep — 18 configs across dgx1+dgx2, KV dime…
tbraun96 3580978
phase-2c day-2: kernel bisect infrastructure + 3 NEGATIVE bisects
tbraun96 ea44fe6
phase-2c day-3: NVFP4 weight checkpoint test — BREAKTHROUGH
tbraun96 72644aa
hotfix: stuck-in-tool-body watchdog (NVFP4 doom-loop fix)
tbraun96 303cbab
phase-2c day-3 audit: causal-pathway map of FP8 → NVFP4 dispatch points
tbraun96 0467d1e
phase-2c day-3 Bug #1 attempt: REVERT — kernel infrastructure not ready
tbraun96 8d2cc87
fp8-merge73: native FP8 SSM + byte-exact streaming + PR 73 qwen3_xml …
tbraun96 e99159d
grammar: qwen3_coder body uses any_text (matches XML wire format)
tbraun96 2a2500c
validator: reject empty 'command'/'cmd'/'script' for shell tools
tbraun96 eaaa269
tool_handlers: soft-pass empty-required-string validation errors
tbraun96 49bad35
kernel/moe_fp8: two-level FP32 accumulation (DeepGEMM pattern)
tbraun96 4fa47b6
grammar+sampler: Tier-0 EBNF + Tier-1 byte-counter mask for tool params
tbraun96 6f9d595
mission-12h: Tier-2 strict path/cmd validators + final mission report
tbraun96 8c296ea
fp8-drift: o_proj W8A8 N/K fix + GPU dequant kernel + BF16 MoE path
tbraun96 d03197c
opencode-fix: relative-path validator + WS-mask newline exclusion + d…
tbraun96 6608824
diag: complete per-step logit dump (ATLAS_LOGIT_DUMP)
tbraun96 25f8bbe
fp8-prefix-cache: fix exact-hit SSM double-advance (Marconi snapshot)
tbraun96 367846f
mtp+moe: fix MTP 0%-accept (fp32-residual dtype bug) + BF16 router (t…
tbraun96 d7a4da8
tool-parser: server-side write-path drift recovery (ATLAS_WRITE_PATH_…
tbraun96 b0779b9
residual: remove FP32-residual feature — BF16 residual stream always
tbraun96 9922d4c
fix(tool-recovery): recover FP8-drifted file-write tool calls (3 modes)
tbraun96 d0b95f1
fix(moe): route MTP K=2/K=3 verify through BF16 path when experts are…
tbraun96 d2eb167
test(toml_repair): add r105/r110/r4 TOML-shape probe tests
tbraun96 dc6ea50
debug(kernels): ATLAS_DEBUG_SYNC_KERNELS + ATLAS_DEBUG_NO_GRAPH diagn…
tbraun96 2db83dc
fix(attn): multi-seq O-proj BF16 branch for ATLAS_FP8_DEQUANT_ATTN_TO…
tbraun96 bb2b53f
chore(docker): fast-layer build helper Dockerfile.fencesalvage
tbraun96 f7525bd
feat(quant): ATLAS_FP8_DEQUANT_LAYERS — selective per-layer BF16 dequant
tbraun96 a970624
feat(agentic): BW1 bash-wandering / content-completeness watchdog
tbraun96 fd688ab
bench: opencode harness evidence trail (FP8 drift / BF16 lever / BW1 …
tbraun96 68c3c50
fix(tool-salvage): guard EOF-fence slice panic; repair C-style // com…
tbraun96 4521dc7
fix(loop-detect): tool calls are progress — stop spinning-detector ki…
tbraun96 c487bc4
fix(prefix-cache): recompute SSM over [snap_tok,total) when snapshot …
tbraun96 7e8e2d6
prefix-cache: ATLAS_NO_MARCONI_EXACT diagnostic gate + partial-hit re…
tbraun96 3d43e2f
webserver_ok F1-F5: bound runaway via post-think content cap + watchd…
tbraun96 bc9f694
fix(qwen3.6-fp8): rep_penalty 1.1->1.0 on sampler presets + tool-JSON…
tbraun96 52244ab
fix(qwen3.6-fp8): 10/10 webserver_ok MTP-on — delete tool-call band-a…
tbraun96 0ff94b5
perf(qwen3.6): phase-2 decode profiling — host-path stage timing + sp…
tbraun96 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| #!/usr/bin/env python3 | ||
| """HF-transformers GPU oracle for the Nemotron-3-Nano per-layer divergence hunt. | ||
|
|
||
| Loads nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4, MANUALLY dequantizes every | ||
| NVFP4 weight to BF16 (transformers 5.8 has no NVFP4 backend in this image), | ||
| loads the dequantized state_dict into a fresh NemotronH model, feeds the EXACT | ||
| chat-rendered token IDs, and captures per-block hidden states + final norm + | ||
| logits in headerless little-endian f32 .bin -- the format the Atlas | ||
| ATLAS_NEMO_DUMP hook writes -- so the comparator can diff 1:1. | ||
|
|
||
| The dequantized BF16 graph run through `torch_forward` is the canonical | ||
| "intended math" oracle (no fused Triton kernels, no custom CUDA). | ||
|
|
||
| Env: MODEL (local snapshot path), OUT, PROMPT. | ||
| """ | ||
| import glob | ||
| import json | ||
| import os | ||
| import pathlib | ||
| import sys | ||
|
|
||
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | ||
| import mamba_ssm_stub # noqa: F401 installs pure-torch mamba_ssm + forces torch path | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from safetensors import safe_open | ||
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| MODEL = os.environ["MODEL"] | ||
| OUT = pathlib.Path(os.environ.get("OUT", "/out")) | ||
| PROMPT = os.environ.get("PROMPT", "Please count from 1 to 30. Output every number.") | ||
| OUT.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # E2M1 FP4 code -> float value (sign-magnitude, 16 codes). | ||
| _E2M1 = torch.tensor( | ||
| [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, | ||
| -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], | ||
| dtype=torch.float32, | ||
| ) | ||
|
|
||
|
|
||
| def dequant_nvfp4(packed, wscale, wscale2, group_size=16): | ||
| """packed uint8 [O, K/2] -> bf16 [O, K]. wscale fp8e4m3 [O, K/16]. | ||
| value = E2M1[nibble] * fp8(wscale) * f32(wscale2).""" | ||
| O, Khalf = packed.shape | ||
| K = Khalf * 2 | ||
| lo = (packed & 0x0F).to(torch.long) | ||
| hi = ((packed >> 4) & 0x0F).to(torch.long) | ||
| codes = torch.empty(O, K, dtype=torch.long) | ||
| codes[:, 0::2] = lo | ||
| codes[:, 1::2] = hi | ||
| vals = _E2M1.to(codes.device)[codes] # [O, K] f32 | ||
| s = wscale.to(torch.float32) # [O, K/16] | ||
| s = s.repeat_interleave(group_size, dim=1) # [O, K] | ||
| vals = vals * s * float(wscale2) | ||
| return vals.to(torch.bfloat16) | ||
|
|
||
|
|
||
| def load_dequant_state_dict(): | ||
| """Read every safetensors shard; dequant NVFP4 triples to a single | ||
| bf16 `weight`; pass dense tensors through; drop *_scale* sidecars.""" | ||
| files = sorted(glob.glob(os.path.join(MODEL, "model-*.safetensors"))) | ||
| raw = {} | ||
| for f in files: | ||
| with safe_open(f, "pt") as sf: | ||
| for k in sf.keys(): | ||
| raw[k] = sf.get_tensor(k) | ||
| sd = {} | ||
| quant_bases = set() | ||
| for k in raw: | ||
| if k.endswith(".weight_scale"): | ||
| quant_bases.add(k[: -len(".weight_scale")]) | ||
| for base in quant_bases: | ||
| w = raw[base + ".weight"] | ||
| ws = raw[base + ".weight_scale"] | ||
| ws2 = raw[base + ".weight_scale_2"] | ||
| sd[base + ".weight"] = dequant_nvfp4(w, ws, ws2) | ||
| skip_suffix = (".weight_scale", ".weight_scale_2", ".input_scale") | ||
| for k, v in raw.items(): | ||
| if any(k.endswith(s) for s in skip_suffix): | ||
| continue | ||
| if k.endswith(".weight") and k[: -len(".weight")] in quant_bases: | ||
| continue # already dequantized above | ||
| sd[k] = v | ||
| return sd | ||
|
|
||
|
|
||
| def save_f32(name, t): | ||
| arr = t.detach().to(torch.float32).cpu().numpy().astype("<f4").ravel() | ||
| (OUT / name).write_bytes(arr.tobytes()) | ||
| return arr | ||
|
|
||
|
|
||
| def main(): | ||
| tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) | ||
| msgs = [{"role": "user", "content": PROMPT}] | ||
| ids = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True) | ||
| if hasattr(ids, "input_ids"): | ||
| ids = ids["input_ids"] | ||
| if ids and isinstance(ids[0], list): | ||
| ids = ids[0] | ||
| ids = [int(x) for x in ids] | ||
| print("PROMPT_TOKEN_COUNT:", len(ids)) | ||
| print("PROMPT_TOKEN_IDS:", ids) | ||
| (OUT / "token_ids.json").write_text(json.dumps(ids)) | ||
| rendered = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) | ||
| (OUT / "rendered_prompt.txt").write_text(rendered) | ||
| print("RENDERED_PROMPT_REPR:", repr(rendered)) | ||
|
|
||
| print("Dequantizing NVFP4 -> BF16 state_dict ...", flush=True) | ||
| sd = load_dequant_state_dict() | ||
| print(f" state_dict: {len(sd)} tensors") | ||
|
|
||
| cfg = AutoConfig.from_pretrained(MODEL, trust_remote_code=True) | ||
| print("Building empty BF16 model ...", flush=True) | ||
| with torch.device("meta"): | ||
| model = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True) | ||
| model = model.to_empty(device="cpu") | ||
| # `backbone.` prefix in checkpoint matches NemotronH module tree. | ||
| missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) | ||
| miss = [m for m in missing if "rotary" not in m and "inv_freq" not in m] | ||
| print(f" missing={len(miss)} unexpected={len(unexpected)}") | ||
| if miss[:8]: | ||
| print(" missing sample:", miss[:8]) | ||
| if unexpected[:8]: | ||
| print(" unexpected sample:", unexpected[:8]) | ||
| model = model.to(device="cuda", dtype=torch.bfloat16) | ||
| model.eval() | ||
|
|
||
| input_ids = torch.tensor([ids], device="cuda") | ||
| with torch.no_grad(): | ||
| out = model(input_ids=input_ids, output_hidden_states=True, use_cache=False) | ||
|
|
||
| hs = out.hidden_states | ||
| print("NUM_HIDDEN_STATES:", len(hs), "(embed + 52 layers)") | ||
| last = -1 | ||
| emb = save_f32("hf_embed.bin", hs[0][0, last]) | ||
| print(f"hf_embed: norm={np.linalg.norm(emb):.4f}") | ||
| for i in range(1, len(hs)): | ||
| arr = save_f32(f"hf_L{i-1}.bin", hs[i][0, last]) | ||
| if i - 1 < 4 or i - 1 >= len(hs) - 5: | ||
| print(f"hf_L{i-1}: norm={np.linalg.norm(arr):.4f}") | ||
|
|
||
| final_hidden = hs[-1][0, last].to(torch.bfloat16) | ||
| norm_f = model.backbone.norm_f | ||
| fn = norm_f(final_hidden.unsqueeze(0)).squeeze(0) | ||
| fn_arr = save_f32("hf_final_norm.bin", fn) | ||
| print(f"hf_final_norm: norm={np.linalg.norm(fn_arr):.4f}") | ||
|
|
||
| logits = out.logits[0, last] | ||
| save_f32("hf_logits.bin", logits) | ||
| top = torch.topk(logits.float(), 10) | ||
| top_list = [(int(i), float(v)) for i, v in zip(top.indices, top.values)] | ||
| print("HF_TOP10_LOGITS:", top_list) | ||
| (OUT / "top10.json").write_text(json.dumps(top_list)) | ||
| print("DONE ->", OUT) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| # Causal-Pathway Audit: FP8 → NVFP4 Double-Quantization in Atlas | ||
|
|
||
| **Subject:** `Qwen/Qwen3.6-35B-A3B-FP8` on Atlas (branch `fix/in-think-tool-call-leak`) | ||
| **Date:** 2026-05-24 | ||
| **Mode:** read-only forensic audit | ||
|
|
||
| --- | ||
|
|
||
| ## 1. TL;DR | ||
|
|
||
| Atlas DOES dispatch canonical FP8 kernels for routed MoE experts and full-attention | ||
| QKVO of `Qwen3.6-35B-A3B-FP8` via `set_fp8_experts` / `set_fp8_weights`. The | ||
| "nvfp4" tag in `Selected kernel target: (sm_121, qwen3.6-35b-a3b, nvfp4)` is the | ||
| **kernel-bundle name**, not a runtime selector — FP8 ptr-table kernels live | ||
| inside that bundle (`kernels/gb10/qwen3.6-35b-a3b/nvfp4/moe_w4a16_grouped_gemm.cu:1182`). | ||
|
|
||
| The `quantize_to_nvfp4` lines in the boot log come from the subset of weights | ||
| that DO take the bad path. Findings, ranked by severity: | ||
|
|
||
| | # | Severity | Site | Impact | | ||
| |---|----------|------|--------| | ||
| | 1 | **HIGH** | SSM `in_proj_qkv` + `out_proj` decode: FP8 → BF16 → **NVFP4** | every linear-attn layer × every decode token; aligns with `project_qwen36_drift_gdn_clean.md` post-norm cliff and "L31-L39 deep-layer regression" memory entries | | ||
| | 2 | **MED** | MoE shared-expert weights duplicated as NVFP4 even when native-FP8 shared expert is wired | wasted memory + risk of wrong pointer in fallback paths (forward_batched k=1) | | ||
| | 3 | **MED** | MoE `gate` (router) projection: FP8 → BF16 → **NVFP4** | routing-decision noise; matches `project_qwen36_drift_moe_smoking_gun.md` "MoE expert routing diverges 8/8→3/8" | | ||
| | 4 | **LOW** | LM head: passed BF16-shaped pointer through `quantize_to_nvfp4` without FP8 dequant if checkpoint stores `lm_head` as FP8 (latent — most FP8 checkpoints leave lm_head as BF16) | catastrophic if triggered; usually dormant | | ||
| | 5 | **LOW** | MTP head `quantize_to_nvfp4` chain (BF16→NVFP4) for all projections regardless of FP8 native availability | only relevant with `--mtp-quantization nvfp4` | | ||
|
|
||
| Routed MoE experts and full-attention QKVO are NOT in the wrong path. | ||
|
|
||
| **Single highest-leverage fix:** Bypass the SSM NVFP4 round-trip in | ||
| `weight_loader/qwen35/load_layers/linear_attn_arms.rs` for `Fp8Dequanted` | ||
| (NVFP4 is built unconditionally at lines 176-192; the parallel FP8 prefill copy | ||
| at 213-235 is decode-blind). | ||
|
|
||
| --- | ||
|
|
||
| ## 2. Chain of decisions for one Qwen3.6-35B-A3B-FP8 request | ||
|
|
||
| ``` | ||
| [boot] | ||
| spark-server/main_modules/serve.rs:68 ptx_for_config("qwen3_6_moe", 2048) → nvfp4 bundle | ||
| spark-server/main_modules/serve.rs:83 "Selected kernel target ... nvfp4 (90 modules)" ← kernel-bundle name, NOT runtime quant | ||
| factory/build.rs:98 loader.load_layers(...) → qwen35::load_layers | ||
| [weight_loader/qwen35/load_layers.rs] | ||
| :70 detect_nvfp4_variant → returns Fp8Dequanted (nvfp4_detect.rs:51) | ||
| :76 quant_format = QuantFormat::Fp8 | ||
| :81 native_fp8 = true | ||
| :128 ATLAS_FORCE_NVFP4_MOE? → false so skip_nvfp4_experts = true | ||
| :139 load_moe_qwen35(..., skip_routed_experts=true) | ||
| └─ ssm_qwen35.rs:75 load_moe_qwen35: | ||
| :89 gate = dense(...) ← FP8 byte ptr survives untouched here | ||
| :151 gate_nvfp4 = quantize_to_nvfp4(&moe_weights.gate, ...) ← BUG #3: lm_head router BF16→NVFP4, but gate ptr is FP8 bytes | ||
| :184 shared_expert = load_expert(...) → variant==Fp8Dequanted → | ||
| quantized_from_fp8 → BF16 → NVFP4 ← BUG #2 (NVFP4 shared_expert built, but unused) | ||
| :188 for e in routed: experts.push(NULL) ← OK | ||
| :160 MoeLayer::new(...) | ||
| :183 load_moe_qwen35_fp8_experts ← OK, routed experts FP8 | ||
| :197 load_fp8_block_scaled_as_fp8weight(shared_expert/{gate,up,down}_proj) ← OK, FP8 shared expert | ||
| :215 moe_layer.set_fp8_experts(&fp8_experts, shared_fp8, gpu) ← FP8 path enabled | ||
| ─ FullAttention layers ─ | ||
| :226 LayerType::FullAttention if native_fp8 => ← FP8 attention arm taken | ||
| :255 load_qkvo_tp(load_fp8_proj) ← FP8 QKVO loaded zero-copy | ||
| :298 layer.set_fp8_weights(...) ← FP8 path enabled for full attn | ||
| ─ LinearAttention (SSM) layers — 30 of 40 layers ─ | ||
| :344 build_linear_attention_nvfp4 ← BUG #1: name says nvfp4 unconditionally | ||
| └─ linear_attn_arms.rs:147 load_ssm_qwen35 (Fp8Dequanted) → dense_auto → BF16 (good so far) | ||
| :176-190 quantize_to_nvfp4(qkvz_dense | out_proj) ← BUG #1 fires (decode path will be NVFP4) | ||
| :203-235 if Fp8Dequanted { bf16_to_fp8(...) for prefill_only } ← FP8 prefill path built in PARALLEL, | ||
| but only `set_fp8_prefill_only_weights` | ||
| installs it — decode still NVFP4 | ||
| [factory/build.rs] | ||
| :101 lm_head = loader.load_lm_head(store, &config) → qwen35.rs:56 dense("lm_head.weight") | ||
| ← BUG #4: dense, not dense_auto. | ||
| If checkpoint stores FP8 lm_head, raw FP8 bytes are passed downstream as BF16 | ||
| :144 skip_lm_head_quantization() = false for qwen3.6 | ||
| :148 quantize_to_nvfp4(&lm_head, ...) ← treats whatever the pointer is as BF16 | ||
| :102 loader.load_mtp_weights_multi | ||
| :120 effective_mtp_quant = checkpoint ignores mtp.* ? Bf16 : mtp_quant ← OK-ish gate | ||
| [per-decode-step runtime] | ||
| layers/moe/forward_prefill.rs:27 self.fp8_gate_weight_ptrs.is_some() ? FP8 path : NVFP4 path | ||
| ← chooses FP8 for routed experts (correct) | ||
| layers/moe/forward.rs:217 (same) ← decode FP8 path taken (correct) | ||
| Qwen3SsmLayer::forward_decode uses qkvz_nvfp4 / out_proj_nvfp4 ← BUG #1: SSM decode is always NVFP4 | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## 3. Per-bug table | ||
|
|
||
| | Bug | File:line | Currently dispatched | Canonical kernel that exists | Difficulty | Expected impact | | ||
| |-----|-----------|---------------------|------------------------------|-----------|-----------------| | ||
| | **#1 SSM decode NVFP4** | `weight_loader/qwen35/load_layers/linear_attn_arms.rs:176-190` | `quantize_to_nvfp4(qkvz_dense, ...)` and `out_proj` → NVFP4 used for ALL paths except a parallel FP8-prefill-only override | `load_fp8_block_scaled_as_fp8weight` + `w8a16_gemv` / `w8a16_gemm` already used by `build_linear_attention_fp8` (file marks itself "currently unused", `load_layers.rs:343` always takes `_nvfp4` arm) | **Medium**: build_linear_attention_fp8 already exists at line 24 of same file. Need to: (a) flip dispatch to fp8 arm when `variant==Fp8Dequanted`, (b) verify `set_fp8_weights` on `Qwen3SsmLayer` reaches the SSM decode/verify GEMV path. | High. Memory `project_qwen36_phase2b_softmax_expf.md` already attributes deep-layer regression to FP8-KV+NVFP4-weight noise on out_proj. Eliminating the BF16→NVFP4 step on `out_proj` should match the existing `ATLAS_GDN_BF16_WEIGHTS=1` benefit but with FP8 precision (lower memory). | | ||
| | **#2 dead NVFP4 shared-expert** | `weight_map/ssm_qwen35.rs:184` + `loaders_moe.rs:32-60` | Loaded via `quantized_from_fp8` → BF16 → NVFP4 then never consumed (forward_prefill_fp8 + forward.rs use `fp8_shared_expert`) | n/a — fix is to elide the load when `native_fp8 && !force_nvfp4_moe` | **Trivial**: thread a `skip_shared_expert: bool` like `skip_routed_experts` already does. | Saves a few hundred MB and one source of quant noise if any fallback path ever consults the NVFP4 shared expert. | | ||
| | **#3 router gate FP8→NVFP4** | `weight_loader/qwen35/load_layers.rs:151-159` + `ssm_qwen35.rs:89` | `dense(...)` returns raw FP8 bytes; `quantize_to_nvfp4` then treats them as BF16 if the checkpoint stores `gate.weight` as FP8. For Qwen3.6 FP8 the gate is typically **BF16** in checkpoint (small enough to leave alone) — verify with `WeightDtype` of `mlp.gate.weight`. If it IS FP8 in checkpoint, this is silently miscomputing the entire routing distribution. | `dense_auto` (already exists) — would correctly dequant FP8 to BF16 before NVFP4 quantization. | **Trivial**: swap `dense` → `dense_auto`. | If the gate is FP8 in checkpoint, this is the root of the routing divergence reported in `project_qwen36_drift_moe_smoking_gun.md` (8/8→3/8 expert overlap collapse). Even if gate is BF16 in this checkpoint, applying `dense_auto` everywhere is the defensive fix. | | ||
| | **#4 lm_head FP8 latent corruption** | `weight_loader/qwen35.rs:60-68` + `factory/build.rs:148` | `dense(store, "lm_head.weight")` returns raw pointer regardless of dtype; if the checkpoint's `lm_head` is FP8E4M3 + has `weight_scale_inv`, the FP8 bytes are pumped into `quantize_to_nvfp4` as if BF16 — catastrophic top-token corruption. | `dense_auto` (already exists). | **Trivial**: swap `dense` → `dense_auto` in `load_lm_head` for all FP8-capable loaders (qwen35, qwen3, qwen35_dense). | Latent. Qwen3.6 FP8 typically ships lm_head as BF16; if it does, no current impact. But it's a silent footgun for the next FP8 checkpoint that quantizes lm_head. | | ||
| | **#5 MTP head BF16→NVFP4** | `layers/mtp_head/new.rs:44-53` + `:81-118` | Every MTP projection goes through `quantize_to_nvfp4` when `quant==Nvfp4`. For FP8 source weights this is the same FP8→BF16→NVFP4 chain as the main model. | An `MtpQuantization::Fp8` variant exists (mtp_head.rs:29, line 180 in new.rs) but the user must opt in via `--mtp-quantization fp8`. | **Low**: default `--mtp-quantization` based on `native_fp8`. | Only matters for users who turn on `--speculative`. Spec drafts produced from a doubly-quantized MTP head would amplify Bug #1's drift into wholesale rejection. | | ||
| | **Dead code** | `weight_map/ssm_qwen35.rs:262-279` | `_shared_fp8 = ...` is computed and discarded (sigil `_`). The caller at `qwen35/load_layers.rs:197-214` re-loads the same tensors. | n/a — duplicate I/O. | Trivial: return shared_fp8 from `load_moe_qwen35_fp8_experts` or delete the bind. | None on correctness; ~2× the FP8 shared-expert load time. | | ||
|
|
||
| --- | ||
|
|
||
| ## 4. Recommended fix order | ||
|
|
||
| 1. **Bug #1** (SSM decode NVFP4). Highest expected quality lift. The plumbing | ||
| (`build_linear_attention_fp8`, `Qwen3SsmLayer::set_fp8_weights`, | ||
| `qkvz_fp8 + out_fp8` SSM GEMV) is **already implemented**; it's behind a | ||
| dead-coded gate (`load_layers.rs:334-342` comment: "permanently | ||
| short-circuited"). Re-enabling it is a one-line dispatch flip plus removal | ||
| of the parallel `_nvfp4` build in `build_linear_attention_nvfp4`. Verify | ||
| against `ATLAS_GDN_BF16_WEIGHTS=1` numerics — they should match or beat the | ||
| BF16 dense path while saving ~2× weight memory. | ||
| 2. **Bug #3** (`dense` → `dense_auto` on router gate). Even if the gate is | ||
| BF16 in *this* checkpoint, this defensive fix prevents the silent FP8-as-BF16 | ||
| misread that has already been observed bites elsewhere | ||
| (cf. `project_qwen36_numerical_drift_2026_05_23.md`). | ||
| 3. **Bug #4** (LM head `dense_auto`). Same one-line fix, eliminates a class of | ||
| latent corruption for any future FP8 checkpoint that quantizes the head. | ||
| 4. **Bug #2** (dead NVFP4 shared-expert). Memory savings, cleanliness. | ||
| 5. **Bug #5** (default `--mtp-quantization` to `fp8` when `native_fp8`). Only | ||
| after Bug #1 — otherwise the MTP head still feeds an NVFP4-corrupted target. | ||
|
|
||
| --- | ||
|
|
||
| ## 5. Non-quant red flags noticed in passing | ||
|
|
||
| - `weight_map/ssm_qwen35.rs:263` — `_shared_fp8` discarded (Dead code). | ||
| - `linear_attn_arms.rs:267-269` — installs FP8 prefill weights via | ||
| `set_fp8_prefill_only_weights`. Decode is never routed to FP8 SSM weights | ||
| even though the FP8 buffers exist (lines 215-232). | ||
| - `attention_arms.rs:83` collapses `Standard | Fp8Dequanted | Bf16Raw` to one | ||
| arm. Unreachable for Qwen3.5 (peeled off at `load_layers.rs:226`), so it's | ||
| load-bearing only for non-FP8 checkpoints. Brittle: a future caller that | ||
| enters with `Fp8Dequanted && !native_fp8` (e.g. an `ATLAS_FORCE_NVFP4_MOE` | ||
| variant that spilled into attention) would read FP8 bytes as BF16. Add a | ||
| `debug_assert!`. | ||
| - KV cache: `attn_layer_dtypes` is independent of weight quant. The FP8-KV | ||
| cliff at L35–L39 in `project_qwen36_phase2b_softmax_expf.md` interacts | ||
| with Bug #1: SSM-decode NVFP4 noise compounds with FP8-KV rounding at deep | ||
| layers, so fixing Bug #1 should reduce but not eliminate that regression. | ||
| - `factory/build.rs:144-160` always BF16→NVFP4 the LM head when | ||
| `!skip_lm_head_quantization()`. Atlas has `gemv_fp8w`; an FP8 lm_head | ||
| could stay FP8. | ||
|
|
||
| --- | ||
|
|
||
| ## Appendix: Canonical FP8 kernels (cross-reference) | ||
|
|
||
| | Kernel | File | Used by | | ||
| |--------|------|---------| | ||
| | `moe_fp8_grouped_gemm_ptrtable_t/k/v2_k` | `kernels/gb10/qwen3.6-35b-a3b/nvfp4/moe_w4a16_grouped_gemm.cu:1182` | `forward_prefill_fp8.rs`, `forward_batched.rs`, `forward_k2/k3.rs` | | ||
| | `moe_expert_gate_up_shared_fp8` / `moe_expert_silu_down_shared_fp8` | `kernels/gb10/common/` | `forward.rs:217-260` | | ||
| | `w8a16_gemv` / `w8a16_gemm` (FP8w × BF16a) | `kernels/gb10/common/gemv_fp8w` | `Qwen3AttentionLayer::set_fp8_weights`, FP8 shared expert | | ||
| | `bf16_to_fp8` | `kernels/gb10/common/w4a16` | `linear_attn_arms.rs:213` — SSM prefill-only | | ||
| | `fp8_gemm_n128` | `kernels/gb10/common/` | SSM prefill (Bug #1 partial mitigation) | | ||
|
|
||
| Asymmetry: every full-attention layer and every routed MoE expert already runs | ||
| on canonical FP8. Only SSM layers (30/40), the router gate, the dead NVFP4 | ||
| shared expert, and any FP8 lm_head sit wrong. Fix Bug #1 and the runtime | ||
| profile changes from "10 FP8 + 30 NVFP4-from-FP8 layers" to "40 FP8 layers". |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.