feat(tree-native): add tree-attention metadata + position_ids surface for tree training#3
Open
racky-scitix wants to merge 8 commits into
Open
feat(tree-native): add tree-attention metadata + position_ids surface for tree training#3racky-scitix wants to merge 8 commits into
racky-scitix wants to merge 8 commits into
Conversation
…etadata dataclass Stage 0.2 of the tree-training native migration. Carries trie topology (cu_node_lens, node_parent), per-token positional ids (tree_position_ids), and the FA3 precomputed attention buffers from slime's TreeDataIterator through PackedSeqParams to the upcoming TETreeDotProductAttention and TreePackedRotaryEmbedding. PackedSeqParams.tree_metadata defaults to None so all existing callers are untouched; TYPE_CHECKING-guarded import keeps the new dataclass off the import path until tree training is actually used. Refs NVIDIA#220.
Stage 1.2 of the tree-training native migration: - TreePackedRotaryEmbedding (rotary_pos_embedding.py): subclass of RotaryEmbedding that gathers freqs by packed_seq_params.tree_metadata .tree_position_ids when present. apply_rotary_pos_emb (TE side) stays untouched — it is pure elementwise so once freqs are reordered the rest of the rope path is correct without TE modifications. - AttnBackend.tree (enums.py): new enum value to advertise the tree attention path. - TETreeDotProductAttention + _FA3TreeAttnFunc (extensions/transformer_engine.py): subclass of TEDotProductAttention that asserts cp_size == 1, forces qkv_format='thd', and routes through the FA3 tree kernel (flash_attn_interface.flash_attn_tree_func) directly. The direct kernel call is a transitional shape: once scitix/TransformerEngine adds TreeFlashAttention as a real TE backend (Stage 1.1), the kernel call here can be replaced by a TE backend dispatch. - get_gpt_layer_tree_spec (gpt_layer_specs.py): GPT layer spec that swaps core_attention for TETreeDotProductAttention and advertises AttnMaskType.arbitrary so the upstream block does not synthesize an unused mask. Refs NVIDIA#220.
Now that scitix/TransformerEngine ships TreeFlashAttention as a first-class backend in DotProductAttention's dispatch (scitix/TE commit deb33ef7), the Megatron-side tree wrapper no longer needs its own FA3 autograd wrapper or kernel call. - Remove _FA3TreeAttnFunc (moved into TE's backends.py). - TETreeDotProductAttention.forward now unpacks packed_seq_params.tree_metadata into tree_cu_node_lens / tree_node_parent / tree_precomputed kwargs and calls te.pytorch.DotProductAttention.forward directly (skipping the Megatron-side TEDotProductAttention.forward wrapper, which would drop the tree kwargs). TE's dispatch routes to TreeFlashAttention when those kwargs are present. Prerequisite for Stage 2 CP support: CP-aware tree attention will land inside TE's TreeFlashAttention.forward (via TE's existing cp_group/cp_comm_type infrastructure), and this wrapper will not need to change. Refs NVIDIA#220.
…l tree classes Tree routing is now fully data-driven: same model spec, same TEDotProductAttention instance, same RotaryEmbedding instance handle both tree and varlen micro-batches. When PackedSeqParams.tree_metadata is populated, both Megatron and TE route to the tree path automatically. Deleted: - TETreeDotProductAttention (Megatron) — tree detection folded into TEDotProductAttention.forward as a ~15-line early return - TreePackedRotaryEmbedding (Megatron) — tree-position gather folded into RotaryEmbedding.forward as a ~12-line suffix - get_gpt_layer_tree_spec (Megatron) — no longer needed - model_provider.py tree spec selection + MoE block spec post-process + RotaryEmbedding __class__ swap — all deleted This eliminates the 'MoE decoder block spec post-processing hack' gap from the Known Gaps list entirely. Refs NVIDIA#220.
…ention TE now accepts tree_metadata as a direct kwarg on DotProductAttention .forward (TE commit ebfa49a6). Megatron's packed_seq_kwargs dict already includes tree_metadata (it is a PackedSeqParams dataclass field), so it flows through to TE transparently via **packed_seq_kwargs. Delete the 15-line tree detection + manual TE dispatch block from TEDotProductAttention.forward — zero tree-specific code remains in Megatron's attention wrapper. Refs NVIDIA#220.
Decouple RotaryEmbedding from tree-training: instead of reading tree_metadata.tree_position_ids directly, RotaryEmbedding.forward now checks the generic packed_seq_params.position_ids field. This follows the same pattern as MultimodalRotaryEmbedding accepting position_ids. Changes: - PackedSeqParams gains position_ids: Optional[Tensor] = None - RotaryEmbedding.forward gathers emb by position_ids when present — no 'tree' concept in the code, just 'custom per-token positions' - slime get_tree_batch sets both position_ids and tree_metadata on PackedSeqParams - TreeMetadata docstring updated: tree_position_ids is now consumed indirectly via PackedSeqParams.position_ids The position_ids field is reusable by any future feature needing non-sequential per-token RoPE positions. Refs NVIDIA#220.
…PackedSeqParams
Move per-token positional ids to the same level as other model.forward
args instead of hiding them inside PackedSeqParams:
- RotaryEmbedding.forward gains an explicit position_ids parameter
(same pattern as MultimodalRotaryEmbedding). When provided, gathers
emb by position_ids for non-sequential per-token RoPE.
- GPTModel._preprocess passes position_ids through to
self.rotary_pos_emb(..., position_ids=position_ids).
- slime model.py: all three model() call sites changed from
position_ids=None to position_ids=batch.get('position_ids'). For
varlen batches, get() returns None (no behavior change). For tree
batches, get_tree_batch already puts position_ids in the batch dict.
- PackedSeqParams.position_ids field removed — no longer needed.
- get_tree_batch: removed position_ids from PackedSeqParams construction.
Refs NVIDIA#220.
Leftover from the tree-specific code removal in 8b600e4; net diff vs sirl-dev base for this file is now empty.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the minimal Megatron-LM surface that the slime-trainer tree-training path
needs to flow tree-attention topology and per-token positional ids from the
data iterator down to TE's
TreeFlashAttentionbackend. All changes areopt-in: existing callers see no behavior change because every new field
defaults to
Noneand the newAttnBackend.treeenum is only selected byslime's tree path.
Scope (5 files, +91 / -2)
megatron/core/transformer/tree_metadata.py(new) —TreeMetadatadataclass carrying
cu_node_lens,node_parent,tree_position_ids,padded_size,num_nodes, and an opaqueprecomputeddict reused acrosslayers. Lightweight, no CUDA assumptions at definition time.
megatron/core/packed_seq_params.py— addsPackedSeqParams.tree_metadata: Optional[TreeMetadata] = None. Slime'sTreeDataIteratorpopulates it when--enable-tree-trainingis on; allexisting callers keep
None.megatron/core/transformer/enums.py— addsAttnBackend.tree = 6fordata-driven dispatch (TEDotProductAttention detects
tree_metadataandroutes to TE's
TreeFlashAttention;cp_size == 1required).megatron/core/models/common/embeddings/rotary_pos_embedding.py—RotaryEmbedding.forwardaccepts an optionalposition_idstensor and,when provided, returns
emb[position_ids]so each packed token gets theRoPE frequency for its logical (tree) position. Includes a defensive
bounds check against
rotary_seq_len. Default behavior unchanged.megatron/core/models/gpt/gpt_model.py— threads the existingposition_idsargument intoself.rotary_pos_emb(...)so the new RoPEpath receives it.
Design notes
TreeFlashAttention/TETreeDotProductAttentionclasses; those havebeen removed. The branch now keeps the spec untouched and dispatches
purely on whether
PackedSeqParams.tree_metadatais populated.position_idssurface. RoPE stays generic — it does notknow about tree topology. Whatever produces
tree_position_ids(slime'sTreeDataIteratortoday) is responsible for getting them ontoPackedSeqParams.position_idsand intoGPTModel.forward(position_ids=...).every new field is optional with a
Nonedefault.Commit history
8 commits on top of
sirl-dev:The history is intentionally kept in narrative order rather than squashed —
the later refactor commits show how the surface was simplified down to the
small, additive net diff above. Squashing on merge is fine.
Test plan
git diff origin/sirl-dev sirl-dev-tree-native --shortstat=5 files, +91 / -2; no upstream-Megatron public APIs removed.
transformer_engine.pydiff vssirl-devis empty (the loneblank-line touch from earlier was reverted by the last commit).
(
vendors/megatron-lmgitlink at1813b6948); precision +benchmark validation lives in slime-trainer
docs/engineering/tree-attention-{benchmark,precision}.md.RotaryEmbedding.forwardposition_idsdefaultingto
Noneis the right shape for upstream consumers.🤖 Generated with Claude Code