Skip to content

feat(tree-native): add tree-attention metadata + position_ids surface for tree training#3

Open
racky-scitix wants to merge 8 commits into
sirl-devfrom
sirl-dev-tree-native
Open

feat(tree-native): add tree-attention metadata + position_ids surface for tree training#3
racky-scitix wants to merge 8 commits into
sirl-devfrom
sirl-dev-tree-native

Conversation

@racky-scitix
Copy link
Copy Markdown
Collaborator

Summary

Adds the minimal Megatron-LM surface that the slime-trainer tree-training path
needs to flow tree-attention topology and per-token positional ids from the
data iterator down to TE's TreeFlashAttention backend. All changes are
opt-in: existing callers see no behavior change because every new field
defaults to None and the new AttnBackend.tree enum is only selected by
slime's tree path.

Scope (5 files, +91 / -2)

  • megatron/core/transformer/tree_metadata.py (new)TreeMetadata
    dataclass carrying cu_node_lens, node_parent, tree_position_ids,
    padded_size, num_nodes, and an opaque precomputed dict reused across
    layers. Lightweight, no CUDA assumptions at definition time.
  • megatron/core/packed_seq_params.py — adds
    PackedSeqParams.tree_metadata: Optional[TreeMetadata] = None. Slime's
    TreeDataIterator populates it when --enable-tree-training is on; all
    existing callers keep None.
  • megatron/core/transformer/enums.py — adds AttnBackend.tree = 6 for
    data-driven dispatch (TEDotProductAttention detects tree_metadata and
    routes to TE's TreeFlashAttention; cp_size == 1 required).
  • megatron/core/models/common/embeddings/rotary_pos_embedding.py
    RotaryEmbedding.forward accepts an optional position_ids tensor and,
    when provided, returns emb[position_ids] so each packed token gets the
    RoPE frequency for its logical (tree) position. Includes a defensive
    bounds check against rotary_seq_len. Default behavior unchanged.
  • megatron/core/models/gpt/gpt_model.py — threads the existing
    position_ids argument into self.rotary_pos_emb(...) so the new RoPE
    path receives it.

Design notes

  • Data-driven dispatch. Earlier iterations introduced spec-level
    TreeFlashAttention / TETreeDotProductAttention classes; those have
    been removed. The branch now keeps the spec untouched and dispatches
    purely on whether PackedSeqParams.tree_metadata is populated.
  • Generic position_ids surface. RoPE stays generic — it does not
    know about tree topology. Whatever produces tree_position_ids (slime's
    TreeDataIterator today) is responsible for getting them onto
    PackedSeqParams.position_ids and into GPTModel.forward(position_ids=...).
  • No public API removed. No upstream-Megatron public contract changes;
    every new field is optional with a None default.

Commit history

8 commits on top of sirl-dev:

5b385ae58 feat(tree-native): add Megatron PackedSeqParams.tree_metadata + TreeMetadata dataclass
5ba1a419b feat(tree-native): add Megatron tree-attention extensions
252c31867 refactor(tree-native): slim TETreeDotProductAttention to TE delegation
b2dbf0024 refactor(tree-native): data-driven tree dispatch, eliminate spec-level tree classes
8b600e448 refactor(tree-native): remove tree-specific code from TEDotProductAttention
fd65546a0 refactor(tree-native): generic position_ids on PackedSeqParams for RoPE
7a6d6e2ae refactor(tree-native): position_ids flows through model.forward, not PackedSeqParams
1813b6948 refactor(tree-native): drop stray blank line in transformer_engine.py

The history is intentionally kept in narrative order rather than squashed —
the later refactor commits show how the surface was simplified down to the
small, additive net diff above. Squashing on merge is fine.

Test plan

  • git diff origin/sirl-dev sirl-dev-tree-native --shortstat =
    5 files, +91 / -2; no upstream-Megatron public APIs removed.
  • Net transformer_engine.py diff vs sirl-dev is empty (the lone
    blank-line touch from earlier was reverted by the last commit).
  • Pinned and consumed by slime-trainer's tree-native lane
    (vendors/megatron-lm gitlink at 1813b6948); precision +
    benchmark validation lives in slime-trainer
    docs/engineering/tree-attention-{benchmark,precision}.md.
  • Reviewer to confirm RotaryEmbedding.forward position_ids defaulting
    to None is the right shape for upstream consumers.

🤖 Generated with Claude Code

MaoChouHJM and others added 8 commits April 22, 2026 11:56
…etadata dataclass

Stage 0.2 of the tree-training native migration. Carries trie topology
(cu_node_lens, node_parent), per-token positional ids
(tree_position_ids), and the FA3 precomputed attention buffers from
slime's TreeDataIterator through PackedSeqParams to the upcoming
TETreeDotProductAttention and TreePackedRotaryEmbedding.

PackedSeqParams.tree_metadata defaults to None so all existing callers
are untouched; TYPE_CHECKING-guarded import keeps the new dataclass off
the import path until tree training is actually used.

Refs NVIDIA#220.
Stage 1.2 of the tree-training native migration:

- TreePackedRotaryEmbedding (rotary_pos_embedding.py): subclass of
  RotaryEmbedding that gathers freqs by packed_seq_params.tree_metadata
  .tree_position_ids when present. apply_rotary_pos_emb (TE side) stays
  untouched — it is pure elementwise so once freqs are reordered the
  rest of the rope path is correct without TE modifications.
- AttnBackend.tree (enums.py): new enum value to advertise the tree
  attention path.
- TETreeDotProductAttention + _FA3TreeAttnFunc (extensions/transformer_engine.py):
  subclass of TEDotProductAttention that asserts cp_size == 1, forces
  qkv_format='thd', and routes through the FA3 tree kernel
  (flash_attn_interface.flash_attn_tree_func) directly. The direct kernel
  call is a transitional shape: once scitix/TransformerEngine adds
  TreeFlashAttention as a real TE backend (Stage 1.1), the kernel call
  here can be replaced by a TE backend dispatch.
- get_gpt_layer_tree_spec (gpt_layer_specs.py): GPT layer spec that swaps
  core_attention for TETreeDotProductAttention and advertises
  AttnMaskType.arbitrary so the upstream block does not synthesize an
  unused mask.

Refs NVIDIA#220.
Now that scitix/TransformerEngine ships TreeFlashAttention as a
first-class backend in DotProductAttention's dispatch (scitix/TE commit
deb33ef7), the Megatron-side tree wrapper no longer needs its own FA3
autograd wrapper or kernel call.

- Remove _FA3TreeAttnFunc (moved into TE's backends.py).
- TETreeDotProductAttention.forward now unpacks
  packed_seq_params.tree_metadata into tree_cu_node_lens /
  tree_node_parent / tree_precomputed kwargs and calls
  te.pytorch.DotProductAttention.forward directly (skipping the
  Megatron-side TEDotProductAttention.forward wrapper, which would
  drop the tree kwargs). TE's dispatch routes to TreeFlashAttention
  when those kwargs are present.

Prerequisite for Stage 2 CP support: CP-aware tree attention will land
inside TE's TreeFlashAttention.forward (via TE's existing
cp_group/cp_comm_type infrastructure), and this wrapper will not need
to change.

Refs NVIDIA#220.
…l tree classes

Tree routing is now fully data-driven: same model spec, same
TEDotProductAttention instance, same RotaryEmbedding instance handle
both tree and varlen micro-batches. When PackedSeqParams.tree_metadata
is populated, both Megatron and TE route to the tree path automatically.

Deleted:
- TETreeDotProductAttention (Megatron) — tree detection folded into
  TEDotProductAttention.forward as a ~15-line early return
- TreePackedRotaryEmbedding (Megatron) — tree-position gather folded
  into RotaryEmbedding.forward as a ~12-line suffix
- get_gpt_layer_tree_spec (Megatron) — no longer needed
- model_provider.py tree spec selection + MoE block spec post-process
  + RotaryEmbedding __class__ swap — all deleted

This eliminates the 'MoE decoder block spec post-processing hack' gap
from the Known Gaps list entirely.

Refs NVIDIA#220.
…ention

TE now accepts tree_metadata as a direct kwarg on DotProductAttention
.forward (TE commit ebfa49a6). Megatron's packed_seq_kwargs dict
already includes tree_metadata (it is a PackedSeqParams dataclass
field), so it flows through to TE transparently via **packed_seq_kwargs.

Delete the 15-line tree detection + manual TE dispatch block from
TEDotProductAttention.forward — zero tree-specific code remains in
Megatron's attention wrapper.

Refs NVIDIA#220.
Decouple RotaryEmbedding from tree-training: instead of reading
tree_metadata.tree_position_ids directly, RotaryEmbedding.forward now
checks the generic packed_seq_params.position_ids field. This follows
the same pattern as MultimodalRotaryEmbedding accepting position_ids.

Changes:
- PackedSeqParams gains position_ids: Optional[Tensor] = None
- RotaryEmbedding.forward gathers emb by position_ids when present —
  no 'tree' concept in the code, just 'custom per-token positions'
- slime get_tree_batch sets both position_ids and tree_metadata on
  PackedSeqParams
- TreeMetadata docstring updated: tree_position_ids is now consumed
  indirectly via PackedSeqParams.position_ids

The position_ids field is reusable by any future feature needing
non-sequential per-token RoPE positions.

Refs NVIDIA#220.
…PackedSeqParams

Move per-token positional ids to the same level as other model.forward
args instead of hiding them inside PackedSeqParams:

- RotaryEmbedding.forward gains an explicit position_ids parameter
  (same pattern as MultimodalRotaryEmbedding). When provided, gathers
  emb by position_ids for non-sequential per-token RoPE.
- GPTModel._preprocess passes position_ids through to
  self.rotary_pos_emb(..., position_ids=position_ids).
- slime model.py: all three model() call sites changed from
  position_ids=None to position_ids=batch.get('position_ids'). For
  varlen batches, get() returns None (no behavior change). For tree
  batches, get_tree_batch already puts position_ids in the batch dict.
- PackedSeqParams.position_ids field removed — no longer needed.
- get_tree_batch: removed position_ids from PackedSeqParams construction.

Refs NVIDIA#220.
Leftover from the tree-specific code removal in 8b600e4; net diff
vs sirl-dev base for this file is now empty.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants