diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 7c53445722..7b57c5af3e 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1170,10 +1170,42 @@ class TrainEngineConfig: metadata={"help": "peft method type. Only LoRA is supported for now."}, ) - # Tree training - enable_tree_training: bool = field( - default=False, - metadata={"help": "Enable tree training with flex attention module."}, + # Tree training (str, not Literal: OmegaConf.structured rejects Literal here) + tree_training_mode: str = field( + default="disabled", + metadata={ + "help": ( + "Tree training mode. " + "'sparse' enables tree training with Flex Attention module (flex attention), " + "'dta' enables Dynamic Tree Attention (dynamic tree training), " + "'disabled' disables tree training." + ), + "choices": ["disabled", "sparse", "dta"], + }, + ) + dta_block_size: int = field( + default=2048, + metadata={ + "help": ( + "Block size for Dynamic Tree Attention. " + "Set to -1 to disable block-size limit. " + "Only effective when tree_training_mode='dta'." + ) + }, + ) + packing_algorithm: str = field( + default="ffd", + metadata={ + "help": ( + "Trajectory packing across data-parallel ranks during distributed rollout " + "(``redistribute_trajectories``). " + "'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order " + "n_tree_tokens. " + "Not to be confused with ``mb_spec.packing_algorithm``, which only " + "controls micro-batch formation (ffd/kk) during training." + ), + "choices": ["ffd", "kk", "dta"], + }, ) # Scheduling @@ -1246,6 +1278,23 @@ def __post_init__(self): "memory_efficient_load is for loading pretrained weights on CPU, " "but init_from_scratch creates a model without loading any weights." ) + valid_tree_modes = {"disabled", "sparse", "dta"} + if self.tree_training_mode not in valid_tree_modes: + raise ValueError( + f"tree_training_mode must be one of {valid_tree_modes}, got '{self.tree_training_mode}'" + ) + valid_rollout_packing = {"ffd", "kk", "dta"} + if self.packing_algorithm not in valid_rollout_packing: + raise ValueError( + f"packing_algorithm (rollout) must be one of {valid_rollout_packing}, " + f"got '{self.packing_algorithm}'" + ) + if self.tree_training_mode == "dta": + if self.dta_block_size == 0 or self.dta_block_size < -1: + raise ValueError( + f"dta_block_size must be -1 or a positive integer when tree_training_mode='dta', got {self.dta_block_size}." + ) + if self._version not in ("v1", "v2"): raise ValueError( f"_version must be either 'v1' or 'v2', got '{self._version}'" @@ -1635,6 +1684,22 @@ def __post_init__(self): "Please set `actor.use_decoupled_loss=false` in your configuration." ) + if self.packing_algorithm == "dta": + for norm_name in ["adv_norm", "reward_norm"]: + norm_config = getattr(self, norm_name) + if norm_config is not None: + if ( + norm_config.mean_level == "group" + or norm_config.std_level == "group" + ): + raise ValueError( + f"{norm_name} uses 'group' level normalization, which is incompatible " + "with packing_algorithm='dta'. DTA requires sequence-level independence, " + "but 'group' normalization relies on contiguous group slices. Please use " + "'batch' level normalization or set packing_algorithm='ffd'. " + "(Group-level support for DTA will be provided in a future release.)" + ) + super().__post_init__() diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index b17bcf6967..10cbb4678b 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -262,9 +262,14 @@ def __init__(self, config: TrainEngineConfig): self.dp_rank: int self.is_offload: bool = False + self.tree_training_mode: str = self.config.tree_training_mode + if self.tree_training_mode == "dta": + raise ValueError( + "tree_training_mode='dta' is only supported by ArchonEngine. " + "Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'." + ) self._offload_depth: int = 0 self._per_layer_optim_wrapper: PerLayerOptimWrapper | None = None - self.enable_tree_training: bool = self.config.enable_tree_training @classmethod def from_pretrained( @@ -384,7 +389,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # Create device model self._create_device_model() - if self.enable_tree_training and self.parallel_helper.sp_size > 1: + if self.tree_training_mode == "sparse" and self.parallel_helper.sp_size > 1: raise ValueError( "Tree training currently cannot be enabled with sp_size > 1." ) @@ -395,7 +400,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): shard_vision_across_sp=self.config.fsdp.shard_vision_across_sp, ) # Monkey patch: replace attention's forward() with tree attention. - patch_fsdp_for_tree_training(enable=self.enable_tree_training) + patch_fsdp_for_tree_training(enable=self.tree_training_mode == "sparse") if self.config.use_lora: self._apply_peft_wrapper() @@ -733,7 +738,7 @@ def forward_backward_batch( # module_fsdp.py reads these keys from the **kwargs that transformers # forwards through. tree_attn_keys: list[str] = [] - if self.enable_tree_training and ctx.trie_node is not None: + if self.tree_training_mode == "sparse" and ctx.trie_node is not None: padded_size = mb_item.padded_to_length assert padded_size is not None tree_kwargs = build_tree_attn_kwargs( @@ -881,8 +886,8 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None: self.forward_backward_batch(mb_list, process_output, forward_only=True) # Step 4: Aggregate and reorder outputs - if self.enable_tree_training: - result = merge_packed_tree_results(outputs, batch_size) + if self.tree_training_mode == "sparse": + return merge_packed_tree_results(outputs, batch_size) else: result = reorder_and_pad_outputs( outputs, output_seqlens, mb_list, aggregate_fn @@ -1794,7 +1799,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: input_ = input_.copy() # Tree training path - if self.enable_tree_training: + if self.tree_training_mode == "sparse": mb_list = build_packed_tree_batch( input_, mb_spec=self.config.mb_spec, @@ -2063,12 +2068,12 @@ def _compute_logprobs_and_loss( if local_weight == 0: return logits.mean() * 0.0 - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": # Handle dummy trie (empty tree for DP synchronization) # When trie has no sequences, return zero loss with grad connection if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids: @@ -2126,12 +2131,12 @@ def _compute_forward_result( ctx: FSDPTrainContext, ) -> torch.Tensor | dict[int, torch.Tensor]: """Compute forward output (logprobs or values).""" - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": result = _gather_packed_tree_logprobs( logits, ctx.trie_node, diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index f972fd0834..c653883b30 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -197,8 +197,13 @@ def __init__(self, config: TrainEngineConfig): self.seed: int = 0 self.own_global_group: bool = False self.is_offload: bool = False + self.tree_training_mode: str = self.config.tree_training_mode + if self.tree_training_mode == "dta": + raise ValueError( + "tree_training_mode='dta' is only supported by ArchonEngine. " + "Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'." + ) self._offload_depth: int = 0 - self.enable_tree_training: bool = self.config.enable_tree_training # FP8 configuration self.fp8_config = self.mcore_config.fp8_config self.enable_fp8: bool = self.fp8_config is not None @@ -331,7 +336,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.tokenizer = load_hf_tokenizer(self.config.path) with patch_bridge_for_tree_training( - self.enable_tree_training and self.bridge_cls == "mbridge" + self.tree_training_mode == "sparse" and self.bridge_cls == "mbridge" ): self.bridge = self._build_hf_mcore_bridge() @@ -530,7 +535,7 @@ def _build_hf_mcore_bridge(self): ) elif self.bridge_cls == "megatron-bridge": - if self.enable_tree_training: + if self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training is not supported with bridge_type='megatron-bridge'." ) @@ -819,7 +824,7 @@ def forward_step(batch_iter, model): # save_for_backward() which can only save torch.Tensor objects; # BlockMask is recreated inside PytorchFlexAttention.forward(). tree_attn_keys: list[str] = [] - if self.enable_tree_training: + if self.tree_training_mode == "sparse": trie_node = mb_input.padded_mb.get("trie_node", None) # Ensure trie_node is also in orig_mb for _compute_logprobs_and_loss if trie_node is not None and "trie_node" not in mb_input.orig_mb: @@ -1046,7 +1051,7 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: # Step 4: Aggregate, reorder, and broadcast outputs res = None if mpu.is_pipeline_last_stage(): - if self.enable_tree_training: + if self.tree_training_mode == "sparse": res = merge_packed_tree_results(outputs, batch_size) else: res = reorder_and_pad_outputs( @@ -1926,7 +1931,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: pp_size = self.parallel_strategy.pipeline_parallel_size cp_size = self.parallel_strategy.context_parallel_size tp_size = self.parallel_strategy.tensor_parallel_size - if self.enable_tree_training: + if self.tree_training_mode == "sparse": assert cp_size == 1, ( "Context parallelism is not supported in tree training." ) @@ -2036,12 +2041,12 @@ def _compute_logprobs_and_loss( if local_weight == 0: return output.mean() * 0.0 - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": # Handle dummy trie (empty tree for DP synchronization) # When trie has no sequences, return zero loss with grad connection trie_node = inputs.get("trie_node") @@ -2144,12 +2149,12 @@ def _compute_forward_result( output: torch.Tensor, inputs: dict[str, Any], ) -> torch.Tensor | dict[int, torch.Tensor]: - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": logprobs = _gather_packed_tree_logprobs( output, inputs["trie_node"], diff --git a/areal/experimental/dta/allocation.py b/areal/experimental/dta/allocation.py new file mode 100644 index 0000000000..b86cb762cf --- /dev/null +++ b/areal/experimental/dta/allocation.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import torch + +from areal.experimental.dta.dp import DTAPartitionResult, LB_by_DFS_and_TM +from areal.utils.data import extract_single_valid_token_sequence + + +class _TreeTokenOnlyTimeModel: + def pred(self, stats: dict[str, Any]) -> float: + return float(stats["n_tree_tokens"]) + + +@dataclass(slots=True) +class DTAMetrics: + n_tokens: float + n_tree_tokens_before_allocation: float + n_tree_tokens_after_allocation: float + compression_ratio_before_allocation: float + compression_ratio_after_allocation: float + + def to_stats(self) -> dict[str, float]: + return { + "dta/n_tokens": self.n_tokens, + "dta/n_tree_tokens_before_allocation": self.n_tree_tokens_before_allocation, + "dta/n_tree_tokens_after_allocation": self.n_tree_tokens_after_allocation, + "dta/compression_ratio_before_allocation": self.compression_ratio_before_allocation, + "dta/compression_ratio_after_allocation": self.compression_ratio_after_allocation, + } + + +@dataclass(slots=True) +class DTAAllocationResult: + items: list[dict[str, Any]] + group_indices: list[list[int]] + metrics: DTAMetrics + + +def _extract_token_sequences( + trajectories: list[dict[str, Any]], +) -> list[torch.Tensor]: + token_seqs: list[torch.Tensor] = [] + for idx, trajectory in enumerate(trajectories): + try: + seq = extract_single_valid_token_sequence(trajectory) + except (TypeError, ValueError) as err: + raise ValueError( + f"Invalid trajectory format at index {idx} for DTA partitioning." + ) from err + token_seqs.append(seq) + return token_seqs + + +def allocate_dta_trajectories( + trajectories: list[dict[str, Any]], n_groups: int +) -> DTAAllocationResult: + """Prepare sequence-level DTA trajectories and allocate them across DP groups.""" + from areal.utils.data import unpack_groups_to_sequences + + items = unpack_groups_to_sequences(trajectories) + token_seqs = _extract_token_sequences(items) + config = SimpleNamespace(K=n_groups, mode="backward", block_size=None) + partition = LB_by_DFS_and_TM(token_seqs, _TreeTokenOnlyTimeModel(), config) + return DTAAllocationResult( + items=items, + group_indices=partition.bins, + metrics=_compute_dta_metrics_from_partition(partition), + ) + + +def split_dta_allocation( + allocation: DTAAllocationResult, +) -> list[list[dict[str, Any]]]: + return [ + [allocation.items[idx] for idx in group_indices] + for group_indices in allocation.group_indices + ] + + +def _compute_dta_metrics_from_partition(partition: DTAPartitionResult) -> DTAMetrics: + all_stats = partition.token_trie.get_stats(mode="backward") + n_total_tokens = float(all_stats["n_tokens"]) + n_tree_tokens_before = float(all_stats["n_tree_tokens"]) + + n_tree_tokens_after = 0.0 + for leaf_group in partition.leaf_bins: + if not leaf_group: + continue + group_trie = partition.compressed_trie.get_subtrie(set(leaf_group)) + group_stats = group_trie.get_stats(mode="backward") + n_tree_tokens_after += float(group_stats["n_tree_tokens"]) + + return _make_dta_metrics( + n_total_tokens=n_total_tokens, + n_tree_tokens_before=n_tree_tokens_before, + n_tree_tokens_after=n_tree_tokens_after, + ) + + +def _make_dta_metrics( + n_total_tokens: float, n_tree_tokens_before: float, n_tree_tokens_after: float +) -> DTAMetrics: + compression_ratio_before = ( + n_total_tokens / n_tree_tokens_before + if n_tree_tokens_before > 0 + else float("nan") + ) + compression_ratio_after = ( + n_total_tokens / n_tree_tokens_after + if n_tree_tokens_after > 0 + else float("nan") + ) + return DTAMetrics( + n_tokens=n_total_tokens, + n_tree_tokens_before_allocation=n_tree_tokens_before, + n_tree_tokens_after_allocation=n_tree_tokens_after, + compression_ratio_before_allocation=compression_ratio_before, + compression_ratio_after_allocation=compression_ratio_after, + ) diff --git a/areal/experimental/dta/dp.py b/areal/experimental/dta/dp.py new file mode 100644 index 0000000000..2c621de35d --- /dev/null +++ b/areal/experimental/dta/dp.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 + +# This code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/data_parallel.py. +# Special thanks to Yuchen Yang for significant contributions to the load-balanced data parallel partitioning algorithm. +from dataclasses import dataclass +from types import SimpleNamespace + +from areal.experimental.dta.token_trie import TokenTrie +from areal.experimental.dta.trie import CompressedTrie + + +@dataclass(slots=True) +class DTAPartitionResult: + """DTA partition result plus trie state reusable by callers.""" + + bins: list[list[int]] + leaf_bins: list[list[int]] + token_trie: TokenTrie + compressed_trie: CompressedTrie + + +def LB_by_n_tokens(token_seqs, K): + bins = [[] for _ in range(K)] + bin_lens = [0] * K + seq_indices = sorted(range(len(token_seqs)), key=lambda i: -len(token_seqs[i])) + for i in seq_indices: + min_bin = min(range(K), key=lambda j: bin_lens[j]) + bins[min_bin].append(i) + bin_lens[min_bin] += len(token_seqs[i]) + return bins + + +def pred_time( + compressed_trie, time_model, mode: str, block_size: int | None = None +) -> float: + stats = compressed_trie.get_stats(mode, block_size) + return time_model.pred(stats) + + +def get_original_bins( + token_trie: TokenTrie, leaf_bins: list[list[int]] +) -> list[list[int]]: + bins = [[] for _ in range(len(leaf_bins))] + for bucket_idx, leaf_bucket in enumerate(leaf_bins): + for leaf_idx in leaf_bucket: + attach_lists = token_trie.attach_lists[leaf_idx] + for attach, _ in attach_lists: + original_seq_idx = attach["_sequence_batch_id"] + bins[bucket_idx].append(original_seq_idx) + return bins + + +def LB_by_TM(token_seqs, time_model, config: SimpleNamespace): + token_trie = TokenTrie(token_seqs) + n_leaf_seqs = len(token_trie.inputs) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + + K = config.K + leaf_bins = [[] for _ in range(K)] + bin_times = [0.0] * K + + for i in range(n_leaf_seqs): + min_bin = min(range(K), key=lambda j: bin_times[j]) + leaf_bins[min_bin].append(i) + bin_compressed_trie = compressed_trie.get_subtrie(set(leaf_bins[min_bin])) + bin_times[min_bin] = pred_time( + bin_compressed_trie, time_model, config.mode, config.block_size + ) + + bins = get_original_bins(token_trie, leaf_bins) + return bins + + +def try_divide( + compressed_trie, + n_seqs, + config: SimpleNamespace, + divL, + divR, + time_model, + cost_limit: float, +) -> list[list[int]] | None: + K = config.K + divs = [] + + start = 0 + while start < n_seqs: + divs.append(start) + if len(divs) > K: + break + L = max(divL[len(divs)] - 1, start) + R = divR[len(divs)] - 1 + while L < R: + mid = (L + R + 1) // 2 + cur_subtrie = compressed_trie.get_subtrie(set(range(start, mid + 1))) + est_time = pred_time( + cur_subtrie, time_model, config.mode, config.block_size + ) + if est_time <= cost_limit: + L = mid + else: + R = mid - 1 + start = L + 1 + + return divs + + +def LB_by_DFS_and_TM( + token_seqs, time_model, config: SimpleNamespace +) -> DTAPartitionResult: + token_trie = TokenTrie(token_seqs) + n_leaf_seqs = len(token_trie.inputs) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + K = config.K + if n_leaf_seqs == 0: + leaf_bins = [[] for _ in range(K)] + return DTAPartitionResult( + bins=[[] for _ in range(K)], + leaf_bins=leaf_bins, + token_trie=token_trie, + compressed_trie=compressed_trie, + ) + + R = float(pred_time(compressed_trie, time_model, config.mode, config.block_size)) + L = R / K + eps = R * 1e-4 + + divL = [0] * (K + 1) + # Maintain a valid initial partition boundary so K==1 (L==R) does not + # skip the search and accidentally produce empty bins. + divR = [0] + [n_leaf_seqs] * K + + while R - L > eps: + mid = (L + R) / 2.0 + divs = try_divide( + compressed_trie, n_leaf_seqs, config, divL, divR, time_model, mid + ) + if len(divs) <= K: + R = mid + divR[: len(divs)] = divs + else: + L = mid + eps + divL = divs[: K + 1] + + leaf_bins = [list(range(divR[i], divR[i + 1])) for i in range(K)] + bins = get_original_bins(token_trie, leaf_bins) + return DTAPartitionResult( + bins=bins, + leaf_bins=leaf_bins, + token_trie=token_trie, + compressed_trie=compressed_trie, + ) + + +# -------- Test -------- + + +def eval(token_seqs, bins, time_model, config: SimpleNamespace): + total_time = 0.0 + max_time = 0.0 + for bucket in bins: + token_trie = TokenTrie([token_seqs[i] for i in bucket]) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + bucket_pred_time = pred_time( + compressed_trie, time_model, config.mode, config.block_size + ) + total_time += bucket_pred_time + max_time = max(max_time, bucket_pred_time) + return total_time, max_time diff --git a/areal/experimental/dta/rollout.py b/areal/experimental/dta/rollout.py new file mode 100644 index 0000000000..c641eb235a --- /dev/null +++ b/areal/experimental/dta/rollout.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +from areal.infra.rpc.rtensor import RTensor +from areal.utils.data import unpack_groups_to_sequences + + +def prepare_dta_rollout_batch(batch: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Normalize rollout trajectories into sequence-level items for DTA.""" + return unpack_groups_to_sequences(RTensor.localize(batch)) diff --git a/areal/experimental/dta/runner.py b/areal/experimental/dta/runner.py new file mode 100644 index 0000000000..feb4f456d4 --- /dev/null +++ b/areal/experimental/dta/runner.py @@ -0,0 +1,778 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/tree_training_engine.py. +# Special thanks to Yuchen Yang for outstanding contributions to core DTA algorithms +# and optimizations, including chunked backpropagation and cut tail features. + +from bisect import bisect_left, bisect_right +from math import ceil +from typing import NoReturn + +import torch +import torch.nn.functional as F +from transformers.cache_utils import DynamicCache + +from areal.utils.functional import gather_logprobs, gather_logprobs_entropy +from areal.utils.logging import getLogger + +NO_BLOCK_SIZE_LIMIT = int(1e9) + + +def _get_forkpos(lens, lcp_lens, block_size: int | None) -> list: + """ + Compute all fork positions that DTARunner's stack must track. + + Fork positions are token indices where: + 1) Sequences diverge (longest common prefix boundaries) + 2) Block boundaries for long sequences to reduce memory usage + + Returns a sorted list of unique fork positions. + """ + + forkpos_list = [] + + # 1. Fork positions induced by branching (LCP boundaries) + for lcp in lcp_lens: + if lcp > 0: + forkpos_list.append(lcp - 1) + + # 2. Fork positions induced by block segmentation + if block_size is not None: + for i in range(len(lens)): + start = 0 if i == len(lcp_lens) else lcp_lens[i] + end = lens[i] + + pop_len = end - start + n_blocks = ceil(pop_len / block_size) + block_size_actual = ceil(pop_len / n_blocks) + + for b in range(n_blocks): + pop_start = max(end - (b + 1) * block_size_actual, start) + if pop_start > 0: + forkpos_list.append(pop_start - 1) + + forkpos_list = list(set(forkpos_list)) + forkpos_list.sort() + + return forkpos_list + + +class DTARunner: + """ + Engine for backward computation over sequences with shared prefixes. + + DTARunner stores only necessary KV caches, logits at fork + positions, log-probs, and entropy to efficiently compute gradients + for multiple sequences while saving memory. + + Supports block-wise popping to reduce GPU memory peak. + """ + + def __init__( + self, + model_config, + device, + dtype: torch.dtype, + max_seq_len: int, + forward_only: bool = False, + is_critic: bool = False, + ): + """ + Initialize DTARunner with model config, device and buffer sizes. + + Buffers for tokens, logprobs, entropy and KV caches are preallocated + to max_seq_len. + """ + self.model = None + self.device = device + self.dtype = dtype + self.max_seq_len = max_seq_len + self.is_critic = is_critic + + # ------------------------------------------------------------------------ + # Initialize static stack buffers + # ------------------------------------------------------------------------ + self.cur_len = 0 + + # Token buffer + self.tokens = torch.zeros((max_seq_len), device=self.device, dtype=torch.long) + + if self.is_critic: + # Value buffer for critic + self.values = torch.zeros( + (max_seq_len), device=self.device, dtype=torch.float32 + ) + if not forward_only: + self.grad_values = torch.zeros( + (max_seq_len), device=self.device, dtype=dtype + ) + else: + # Entropy buffer + if not forward_only: + self.entropy = torch.zeros( + (max_seq_len), device=self.device, dtype=torch.float32 + ) + self.grad_entropy = torch.zeros( + (max_seq_len), device=self.device, dtype=dtype + ) + + # Logprob buffer + self.logprobs = torch.zeros( + (max_seq_len), device=self.device, dtype=torch.float32 + ) + if not forward_only: + self.grad_logprobs = torch.zeros( + (max_seq_len), device=self.device, dtype=dtype + ) + + # Fork position logits buffer (store logits only at fork positions, others are None) + self.forkpos_list = [] # List of all fork positions + self.forkpos_logits: list[torch.Tensor | None] = [ + None + ] * max_seq_len # Logits at fork positions for computing logprobs + if not forward_only: + self.grad_forkpos_logits: list[torch.Tensor | None] = [ + None + ] * max_seq_len # Gradients of logits at fork positions + + # Attachments buffer + self.attachs = [] # List of sequences retained in the stack, including (attachments, length) + + # KV cache buffers + self.n_layers = model_config.num_hidden_layers + n_kv_heads = model_config.num_key_value_heads + # Compatible with Qwen2.5 and Qwen3 series + head_dim = ( + model_config.head_dim + if hasattr(model_config, "head_dim") + else model_config.hidden_size // model_config.num_attention_heads + ) + + kv_buffer_shape = (1, n_kv_heads, max_seq_len, head_dim) + + self.kv_cache = ( + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + ) + + if not forward_only: + self.grad_kv = ( + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + ) + + self.ret_logprobs = [] + + self._dta_log = getLogger("DTA") + + def _dta_fail(self, message: str) -> NoReturn: + text = f"[DTA] {message}" + self._dta_log.error("%s", text) + raise RuntimeError(text) + + def get_forkpos(self, start: int, end: int) -> list[int]: + """ + Yield fork positions within the interval [start, end). + + Uses binary search on precomputed forkpos_list. + """ + + left = bisect_left(self.forkpos_list, start) + right = bisect_right(self.forkpos_list, end - 1) + yield from self.forkpos_list[left:right] + + @torch.no_grad() + def push_forward_only( + self, + new_tokens: torch.LongTensor, + attach_list: list[tuple[dict, int]], + ): + """ + Push new tokens into the stack with their attachments. + + Builds cache (KV, logprobs) up to cache_len. + Updates logprobs for the previous token. + + Used in inference mode only. + """ + + B = new_tokens.numel() + if self.cur_len + B > self.max_seq_len: + self._dta_fail( + "Exceeds max_seq_len: " + f"cur_len={self.cur_len}, new_tokens={B}, max={self.max_seq_len}" + ) + if B == 0: + for attachment, length in attach_list: + seq_id = attachment["_sequence_batch_id"] + if length == 0: + self.returns[seq_id] = torch.empty( + 0, device=self.device, dtype=torch.float32 + ) + else: + logprobs = self.logprobs[: length - 1] + self.returns[seq_id] = logprobs.clone() + return + + start, end = self.cur_len, self.cur_len + B + + # ------------------------------------------------------------- + # 1. Build prefix cache from existing KV + # ------------------------------------------------------------- + prefix_cache = DynamicCache() + for layer_idx in range(self.n_layers): + prefix_cache.update( + self.kv_cache[0][layer_idx][:, :, :start, :], + self.kv_cache[1][layer_idx][:, :, :start, :], + layer_idx=layer_idx, + ) + + # ------------------------------------------------------------- + # 2. Forward + # ------------------------------------------------------------- + out = self.model( + new_tokens.unsqueeze(0), + past_key_values=prefix_cache, + use_cache=True, + ) + + # Compute logprobs and entropy for new tokens + logits = out.logits # [1, B, vocab] or [1, B, 1] + + # ------------------------------------------------------------- + # 3. Write tokens, computed logprobs/values, and KV cache into stack + # ------------------------------------------------------------- + + # Write tokens into stack + self.tokens[start:end] = new_tokens + + # Write KV cache into stack + new_cache = out.past_key_values + for layer_idx, layer in enumerate(new_cache.layers): + self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ + :, :, start:end, : + ] + self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ + :, :, start:end, : + ] + + if self.is_critic: + values = logits.squeeze(0).squeeze(-1) + self.values[start:end] = values + + # ------------------------------------------------------------- + # 4. Store values for sequences ending in attach_list + # ------------------------------------------------------------- + for attachment, length in attach_list: + seq_id = attachment["_sequence_batch_id"] + if length == 0: + self.returns[seq_id] = torch.empty( + 0, device=self.device, dtype=torch.float32 + ) + continue + self.returns[seq_id] = self.values[:length].clone() + else: + logprobs = gather_logprobs( + logits=logits, + labels=new_tokens[1:].unsqueeze(0), + ) + + # Write logprobs into stack + self.logprobs[start : end - 1] = logprobs.squeeze(0) + # Fill the logprob of the first token using self.forkpos_logits[start] + if start > 0: + pre_logits = self.forkpos_logits[start - 1].float() + first_token = new_tokens[0].item() + pre_logprob = F.log_softmax(pre_logits, dim=-1)[first_token].item() + self.logprobs[start - 1] = pre_logprob + + # Write logits into stack (fork positions only) + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + self.forkpos_logits[i] = logits[0, i - start].detach().clone() + + # ------------------------------------------------------------- + # 4. Store logprobs for sequences ending in attach_list + # ------------------------------------------------------------- + for attachment, length in attach_list: + seq_id = attachment["_sequence_batch_id"] + if length == 0: + self.returns[seq_id] = torch.empty( + 0, device=self.device, dtype=torch.float32 + ) + continue + logprobs = self.logprobs[: length - 1] + self.returns[seq_id] = logprobs.clone() + + self.cur_len += B + + def build_cache(self, start: int, end: int): + """ + Build KV cache, logprobs and entropy for tokens in [start, end). + Uses the existing prefix cache [0, start). + """ + + # Build prefix cache from existing KV + prefix_cache = DynamicCache() + for layer_idx in range(self.n_layers): + prefix_cache.update( + self.kv_cache[0][layer_idx][:, :, :start, :], + self.kv_cache[1][layer_idx][:, :, :start, :], + layer_idx=layer_idx, + ) + + # Forward pass to compute new KV + out = self.model( + self.tokens[start:end].unsqueeze(0), + past_key_values=prefix_cache, + use_cache=True, + ) + + # Compute logprobs & entropy for new tokens + logits = out.logits # [1, B, vocab] or [1, B, 1] + + # Write new KV cache into stack + new_cache = out.past_key_values + for layer_idx, layer in enumerate(new_cache.layers): + self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ + :, :, start:end, : + ] + self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ + :, :, start:end, : + ] + + if self.is_critic: + values = logits.squeeze(0).squeeze(-1) + self.values[start:end] = values + else: + logprobs, entropy = gather_logprobs_entropy( + logits=logits, + labels=self.tokens[start + 1 : end].unsqueeze(0), + ) + self.logprobs[start : end - 1] = logprobs.squeeze(0) + self.entropy[start:end] = entropy.squeeze(0) + + # Write logits into stack (fork positions only) + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + self.forkpos_logits[i] = logits[0, i - start].detach().clone() + + @torch.no_grad() + def push( + self, + new_tokens: torch.LongTensor, + attachs: list[tuple[dict, int]], + cache_len: int, + ): + """ + Push new tokens into the stack with their attachments. + + Builds cache (KV, logprobs, entropy) up to cache_len. + Updates logprobs for the previous token. + """ + + B = new_tokens.numel() + if self.cur_len + B > self.max_seq_len: + self._dta_fail( + "Exceeds max_seq_len: " + f"cur_len={self.cur_len}, new_tokens={B}, max={self.max_seq_len}" + ) + + start, end = self.cur_len, self.cur_len + B + + # Add attachments + for attachment, length in attachs: + self.attachs.append((attachment, length)) + + # Write tokens + self.tokens[start:end] = new_tokens + + # Build prefix cache (KV & logprobs/entropy) if needed + if start < cache_len: + self.build_cache(start, cache_len) + + # Update the previous token's logprob. + if not self.is_critic and start > 0: + pre_logits = self.forkpos_logits[start - 1].float() + first_token = new_tokens[0].item() + pre_logprob = F.log_softmax(pre_logits, dim=-1)[first_token].item() + self.logprobs[start - 1] = pre_logprob + + self.cur_len = end + + def pop(self, start: int, loss_fn) -> float: + """ + Pop tokens from position `start` to the current end. + + Computes gradients for the popped tokens and accumulates them + into the stack's KV, logprobs, entropy, and fork position logits buffers. + + Args: + start: The starting token index to pop from. + loss_fn: Callable that computes the loss for a sequence segment. + + Returns: + The total loss computed over sequences ending within the popped segment. + """ + if not (0 <= start < self.cur_len): + self._dta_fail(f"Invalid pop start: start={start}, cur_len={self.cur_len}") + + end = self.cur_len + _ = end - start + + tokens_to_pop = self.tokens[start:end] + + # --------------------------------------------------------------------------------- + # 1. Gather prefix KV (with requires_grad=True) + # --------------------------------------------------------------------------------- + prefix_cache = DynamicCache() + prefix_kv = [] + + for layer_idx in range(self.n_layers): + k = ( + self.kv_cache[0][layer_idx][:, :, :start, :] + .detach() + .requires_grad_(True) + ) + v = ( + self.kv_cache[1][layer_idx][:, :, :start, :] + .detach() + .requires_grad_(True) + ) + prefix_cache.update(k, v, layer_idx=layer_idx) + prefix_kv.append((k, v)) + + # --------------------------------------------------------------------------------- + # 2. Forward pass on tokens_to_pop (builds computation graph) + # --------------------------------------------------------------------------------- + out = self.model( + tokens_to_pop.unsqueeze(0), past_key_values=prefix_cache, use_cache=True + ) + + logits = out.logits + block_cache = out.past_key_values + + # --------------------------------------------------------------------------------- + # 3. Compute suffix logprobs & entropy or values + # --------------------------------------------------------------------------------- + if self.is_critic: + suf_values = logits.squeeze(0).squeeze(-1) + else: + suf_logprobs, suf_entropy = gather_logprobs_entropy( + logits=logits, labels=tokens_to_pop[1:].unsqueeze(0) + ) + suf_entropy = suf_entropy.squeeze(0) + suf_logprobs = suf_logprobs.squeeze(0) + + # Compute logprob for connection to previous token if exists + if start > 0: + mid_logits = ( + self.forkpos_logits[start - 1].float().detach().requires_grad_(True) + ) + mid_label = self.tokens[start].item() + mid_logprob = F.log_softmax(mid_logits, dim=-1)[mid_label].unsqueeze(0) + + # --------------------------------------------------------------------------------- + # 4. Compute loss for sequences ending in this block + # --------------------------------------------------------------------------------- + + # Gather attachs for sequences ending in this block + attachs_in_block = [ + (att, length) for att, length in self.attachs if start < length <= end + ] + + if attachs_in_block: + if self.is_critic: + if start > 0: + pre_values = self.values[:start].detach().requires_grad_(True) + values = torch.cat([pre_values, suf_values], dim=0) + else: + values = suf_values + + # Compute loss + loss = 0.0 + for attachment, length in attachs_in_block: + if length == 0: + continue + loss += loss_fn(values[:length], attachment) + else: + # Concatenate full logprobs and entropy, with requires_grad=True + if start > 0: + pre_entropy = self.entropy[:start].detach().requires_grad_(True) + entropys = torch.cat([pre_entropy, suf_entropy], dim=0) + if start > 1: + pre_logprobs = ( + self.logprobs[: start - 1].detach().requires_grad_(True) + ) + logprobs = torch.cat( + [pre_logprobs, mid_logprob, suf_logprobs], dim=0 + ) + else: + logprobs = torch.cat([mid_logprob, suf_logprobs], dim=0) + else: + entropys = suf_entropy + logprobs = suf_logprobs + + # Compute loss + loss = 0.0 + for attachment, length in attachs_in_block: + if length == 0: + continue + loss += loss_fn( + logprobs[: length - 1], entropys[:length], attachment + ) + + # --------------------------------------------------------------------------------- + # 5. Backward with gradient injection from popped tokens + # (to KV, logprobs, entropy, forkpos-logits) + # --------------------------------------------------------------------------------- + roots, grads = [], [] + + # Loss gradient + if attachs_in_block: + roots.append(loss) + grads.append(torch.tensor(1.0, device=self.device, dtype=loss.dtype)) + + # KV gradients from popped tokens + for layer_idx, layer in enumerate(block_cache.layers): + k = layer.keys[:, :, start:end, :] + v = layer.values[:, :, start:end, :] + roots.extend([k, v]) + grads.extend( + [ + self.grad_kv[0][layer_idx][:, :, start:end, :], + self.grad_kv[1][layer_idx][:, :, start:end, :], + ] + ) + + if self.is_critic: + roots.append(suf_values) + grads.append(self.grad_values[start:end]) + else: + # Logprobs & entropy gradients from popped tokens + roots.extend([suf_logprobs, suf_entropy]) + grads.extend( + [self.grad_logprobs[start : end - 1], self.grad_entropy[start:end]] + ) + if start > 0: + roots.append(mid_logprob) + grad_mid_logprob = self.grad_logprobs[start - 1].unsqueeze(0) + grads.append(grad_mid_logprob) + + # Fork position logits gradients + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + if self.grad_forkpos_logits[i] is not None: + fork_logits = logits[0, i - start] + roots.append(fork_logits) + grads.append(self.grad_forkpos_logits[i]) + + # roots: loss, (KV, logprobs, entropy, forkpos logits) in tokens_to_pop + torch.autograd.backward(roots, grads) + + # --------------------------------------------------------------------------------- + # 6. Accumulate gradients to prefix cache (KV, logprobs, entropy, forkpos-logits) + # --------------------------------------------------------------------------------- + + # gradients to prefix KV + for layer_idx, (k, v) in enumerate(prefix_kv): + if k.grad is not None: + self.grad_kv[0][layer_idx][:, :, :start, :] += k.grad + if v.grad is not None: + self.grad_kv[1][layer_idx][:, :, :start, :] += v.grad + + if start > 0: + if self.is_critic: + if attachs_in_block and pre_values.grad is not None: + self.grad_values[:start] += pre_values.grad + else: + # gradients to forkpos logits + if mid_logits.grad is not None: + if self.grad_forkpos_logits[start - 1] is None: + self.grad_forkpos_logits[start - 1] = mid_logits.grad.clone() + else: + self.grad_forkpos_logits[start - 1] += mid_logits.grad + if attachs_in_block: + # gradients to prefix logprobs & entropy + if pre_entropy.grad is not None: + self.grad_entropy[:start] += pre_entropy.grad + if start > 1 and pre_logprobs.grad is not None: + self.grad_logprobs[: start - 1] += pre_logprobs.grad + + # --------------------------------------------------------------------------------- + # 7. Cleanup: truncate and clear buffers + # --------------------------------------------------------------------------------- + + self.attachs = [ + (att, length) for att, length in self.attachs if length <= start + ] + + for layer_idx in range(self.n_layers): + self.grad_kv[0][layer_idx][:, :, start:end, :].zero_() + self.grad_kv[1][layer_idx][:, :, start:end, :].zero_() + + if self.is_critic: + self.grad_values[start:end].zero_() + else: + self.grad_logprobs[0 if start == 0 else start - 1 : end - 1].zero_() + self.grad_entropy[start:end].zero_() + + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + self.forkpos_logits[i] = None + self.grad_forkpos_logits[i] = None + + self.cur_len = start + + return loss.item() if attachs_in_block else 0.0 + + def pop_byblock(self, start: int, block_size: int, loss_fn) -> float: + """ + Pop tokens from [start, cur_len) in blocks to reduce peak GPU memory usage. + + Tokens are popped in reverse block order, calling `pop()` on each block. + + Args: + start: The starting token index to pop from. + block_size: Maximum block size for each pop to control memory usage. + loss_fn: Callable to compute loss for a sequence segment. + + Returns: + Total loss over all popped blocks. + """ + end = self.cur_len + length = end - start + n_blocks = ceil(length / block_size) + block_size_actual = ceil(length / n_blocks) + + loss = 0.0 + for b in range(n_blocks): + pop_start = max(end - (b + 1) * block_size_actual, start) + loss += self.pop(pop_start, loss_fn) + + return loss + + @torch.no_grad() + def forward(self, model, token_trie): + """ + Perform backward pass over all sequences in a TokenTrie. + Compute logprobs for each sequence. + The sequence ID is identified by attachment['_sequence_batch_id'], which TokenTrie automatically adds. + + Args: + token_trie: TokenTrie containing input sequences and attachs. + + Returns: + List of logprob tensors for each sequence in the TokenTrie. + """ + + self.model = model + self.returns = [None] * token_trie.n_sequences + + inputs, attach_lists, lcp_lens = ( + token_trie.inputs, + token_trie.attach_lists, + token_trie.lcp_lens, + ) + + if not self.is_critic: + self.forkpos_list = _get_forkpos(None, lcp_lens, None) + + for i in range(len(inputs)): + input_ids = inputs[i].to(self.device) + attach_list = attach_lists[i] + _ = input_ids.size(0) + + # Pop diverged branch from previous sequence + if i > 0: + self.cur_len = lcp_lens[i - 1] + + # Push new tokens + new_tokens = input_ids[self.cur_len :] + + self.push_forward_only(new_tokens, attach_list) + + self.cur_len = 0 + if not self.is_critic: + self.forkpos_logits = [None] * self.max_seq_len # Clear forkpos_logits + + return self.returns + + def backward(self, model, token_trie, loss_fn, block_size: int) -> float: + """ + Perform backward pass over all sequences in a TokenTrie. + + Processes sequences in lexicographic order, popping diverged + branches (block-wise) and pushing new tokens. + + Args: + token_trie: TokenTrie containing input sequences and attachs. + block_size: Maximum block size for popping to control GPU memory. + Use -1 for no block-size limit. + loss_fn: Callable to compute per-sequence loss. + Returns: + Total loss accumulated over all sequences. + """ + + self.model = model + if block_size == -1: + block_size = NO_BLOCK_SIZE_LIMIT + + total_loss = 0.0 + + inputs, attach_lists, lcp_lens = ( + token_trie.inputs, + token_trie.attach_lists, + token_trie.lcp_lens, + ) + + # Precompute fork positions and block boundaries + lens = [ids.size(0) for ids in inputs] + if not self.is_critic: + self.forkpos_list = _get_forkpos(lens, lcp_lens, block_size) + + # Process each sequence + for i in range(len(inputs)): + input_ids = inputs[i].to(self.device) + attach_list = attach_lists[i] + _ = input_ids.size(0) + + # Pop diverged branch from previous sequence + if i > 0: + lcp = lcp_lens[i - 1] + if lcp < self.cur_len: + total_loss += self.pop_byblock(lcp, block_size, loss_fn) + + # Push new tokens + new_tokens = input_ids[self.cur_len :] + + # Determine cache length to build (optimize for next pop) + next_anchor = lcp_lens[i] if i < len(inputs) - 1 else 0 + B = new_tokens.numel() + next_pop_len = self.cur_len + B - next_anchor + + if next_pop_len > block_size: + n_blocks = ceil(next_pop_len / block_size) + block_size_actual = ceil(next_pop_len / n_blocks) + next_anchor = self.cur_len + B - block_size_actual + + self.push(new_tokens, attach_list, next_anchor) + + # Final pop for remaining tokens + if self.cur_len > 0: + total_loss += self.pop_byblock(0, block_size, loss_fn) + + return total_loss diff --git a/areal/experimental/dta/token_trie.py b/areal/experimental/dta/token_trie.py new file mode 100644 index 0000000000..3d143e8c08 --- /dev/null +++ b/areal/experimental/dta/token_trie.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/token_trie.py. +# Special thanks to Yuchen Yang for outstanding contributions to the optimized DFS order. + +import torch + +from areal.experimental.dta.trie import CompressedTrie, _get_stats + + +def _lcp_torch(a: torch.Tensor, b: torch.Tensor) -> int: + """Compute the length of the longest common prefix of two 1D tensors.""" + L = min(a.numel(), b.numel()) + eq = a[:L] == b[:L] + return L if eq.all() else int((~eq).to(torch.int32).argmax().item()) + + +def _leafization(input_ids: list[torch.LongTensor], attachs: list[dict]): + """ + Args: + input_ids: List of token tensors, sorted in lexicographic order. + attachs: List of dicts, each storing loss-related config for one token tensor. + + Merge fully overlapping prefixes and compute the `lcp_lens` list. + """ + + # Compute adjacent LCP lengths and validate lexicographic ordering. + lcp_lens = [] + for i in range(len(input_ids) - 1): + seq_L, seq_R = input_ids[i], input_ids[i + 1] + lcp = _lcp_torch(seq_L, seq_R) + L = min(seq_L.numel(), seq_R.numel()) + if lcp < L and seq_L[lcp] > seq_R[lcp]: + raise ValueError("input_ids not sorted in lexicographic order.") + lcp_lens.append(lcp) + + # Merge fully overlapping prefixes by keeping only the longest sequence. + input_ids_leafed = [] + attach_lists = [] + lcp_lens_leafed = [] + + fork = -1 + for i in range(len(input_ids)): + if i == len(input_ids) - 1 or lcp_lens[i] < min( + input_ids[i].numel(), input_ids[i + 1].numel() + ): + input_ids_leafed.append(input_ids[i]) + if i < len(input_ids) - 1: + lcp_lens_leafed.append(lcp_lens[i]) + attach_list = [] + for k in range(fork + 1, i + 1): + attach_list.append((attachs[k], input_ids[k].numel())) + attach_lists.append(attach_list) + fork = i + + return input_ids_leafed, attach_lists, lcp_lens_leafed + + +class TokenTrie: + def __init__( + self, + inputs: list[torch.LongTensor], + attachs: list[dict] | None = None, + sorted: bool = False, + ): + if attachs is not None: + if len(inputs) != len(attachs): + raise ValueError("Length of inputs and attachs must match.") + else: + attachs = [{} for _ in range(len(inputs))] + + # Attach the original sequence index to each attachment dict. + for seq_id in range(len(inputs)): + attachs[seq_id]["_sequence_batch_id"] = seq_id + + # -------- sort by lexicographical order of input_ids -------- + if not sorted: + pairs = list(zip(inputs, attachs)) + pairs.sort(key=lambda x: x[0].tolist()) + inputs_sorted, attachs_sorted = [p[0] for p in pairs], [p[1] for p in pairs] + else: + inputs_sorted, attachs_sorted = inputs, attachs + + # -------- leafization -------- + self.inputs, self.attach_lists, self.lcp_lens = _leafization( + inputs_sorted, attachs_sorted + ) + self.lens = [len(ids) for ids in self.inputs] + + # -------- stats -------- + self.n_sequences = len(inputs) + self.n_tokens = sum(len(ids) for ids in inputs) + + def get_stats(self, mode: str, block_size: int | None = None): + stats = _get_stats(self.lens, self.lcp_lens, mode, block_size) + stats["n_sequences"] = self.n_sequences + stats["n_tokens"] = self.n_tokens + return stats + + def permute(self, order): + self.inputs = [self.inputs[i] for i in order] + self.attach_lists = [self.attach_lists[i] for i in order] + self.lens = [self.lens[i] for i in order] + self.lcp_lens = [ + _lcp_torch(self.inputs[i], self.inputs[i + 1]) + for i in range(len(self.inputs) - 1) + ] + + def forward_permute(self): + compressed_trie = CompressedTrie(self.lens, self.lcp_lens) + order, _, _ = compressed_trie.get_order_forward() + self.permute(order) + + def backward_permute(self): + compressed_trie = CompressedTrie(self.lens, self.lcp_lens) + order, _, _ = compressed_trie.get_order_backward() + self.permute(order) + + def random_permute(self): + compressed_trie = CompressedTrie(self.lens, self.lcp_lens) + order = compressed_trie.get_order_random() + self.permute(order) diff --git a/areal/experimental/dta/tree_time_model.py b/areal/experimental/dta/tree_time_model.py new file mode 100644 index 0000000000..d39c856a5f --- /dev/null +++ b/areal/experimental/dta/tree_time_model.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/tree_time_model.py. +import numpy as np +from scipy.optimize import nnls + + +class TreeTimeModel: + MIN_N_DATA_POINTS = 16 + MAX_N_DATA_POINTS = 1024 + + def __init__(self): + # T = c_0 * n_leaf_sequences + c_1 * n_tree_tokens + c_2 * n_f1_tokens + c_3 * sum_prefix_len + c_4 * sum_depth + self.coeffs = None + self.data = [] + + def fit(self): + X, Y = [], [] + for stats in self.data: + # X.append([0, stats["n_tree_tokens"], 0, 0, 0]) + X.append( + [ + stats["n_leaf_sequences"], + stats["n_tree_tokens"], + stats.get("n_f1_tokens", 0), + stats["sum_prefix_len"], + stats["sum_depth"], + ] + ) + Y.append(stats["time"]) + + X, Y = np.array(X), np.array(Y) + self.coeffs, _ = nnls(X, Y) + + T_pred = X @ self.coeffs + mse = np.mean((T_pred - Y) ** 2) + return mse + + def add_data(self, data): + self.data.extend(data) + if len(self.data) > self.MAX_N_DATA_POINTS: + self.data = self.data[-self.MAX_N_DATA_POINTS :] + if len(self.data) >= self.MIN_N_DATA_POINTS: + self.fit() + + def pred(self, stats): + if self.coeffs is None: + return stats["n_tree_tokens"] + return ( + self.coeffs[0] * stats["n_leaf_sequences"] + + self.coeffs[1] * stats["n_tree_tokens"] + + self.coeffs[2] * stats.get("n_f1_tokens", 0) + + self.coeffs[3] * stats["sum_prefix_len"] + + self.coeffs[4] * stats["sum_depth"] + ) diff --git a/areal/experimental/dta/trie.py b/areal/experimental/dta/trie.py new file mode 100644 index 0000000000..637f944eef --- /dev/null +++ b/areal/experimental/dta/trie.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/trie.py. +import random +from dataclasses import dataclass, field +from math import ceil + + +def _get_stats( + lens: list[int], lcp_lens: list[int], mode: str, block_size: int | None = None +) -> dict: + n_tree_tokens = sum(lens) - sum(lcp_lens) + sum_depth = 0 + for i in range(len(lens)): + start = lcp_lens[i - 1] if i > 0 else 0 + end = lens[i] + sum_depth += (start + end - 1) * (end - start) // 2 + + if mode == "forward": + sum_prefix_len = sum(lcp_lens) + + return { + "n_leaf_sequences": len(lens), + "n_tree_tokens": n_tree_tokens, + "sum_prefix_len": sum_prefix_len, + "sum_depth": sum_depth, + } + + elif mode == "backward": + sum_prefix_len = 0 + n_f1_tokens = 0 + for i in range(len(lens)): + start = lcp_lens[i] if i < len(lcp_lens) else 0 + end = lens[i] + pop_len = end - start + f1_start = lcp_lens[i - 1] if i > 0 else 0 + + if block_size is None or pop_len <= block_size: + f1_end = start + sum_prefix_len += start + else: + n_blocks = ceil(pop_len / block_size) + block_size_actual = ceil(pop_len / n_blocks) + f1_end = end - block_size_actual + for b in range(n_blocks): + pop_start = max(end - (b + 1) * block_size_actual, start) + sum_prefix_len += pop_start + + n_f1_tokens += max(f1_end - f1_start, 0) + + return { + "n_leaf_sequences": len(lens), + "n_tree_tokens": n_tree_tokens, + "sum_prefix_len": sum_prefix_len, + "sum_depth": sum_depth, + "n_f1_tokens": n_f1_tokens, + } + + else: + raise ValueError(f"Unsupported mode: {mode}") + + +@dataclass(slots=True) +class CTNode: + """Node in the compressed trie.""" + + depth: int = 0 # Depth of this node. + seq_id: int = -1 # Sequence index; -1 indicates an internal node. + chain_tail_depth: int = 0 # Tail depth of the prioritized chain. + child_ids: list[int] = field(default_factory=list) # IDs of child nodes. + + +class CompressedTrie: + """Compressed trie used to plan traversal order.""" + + def __init__(self, lens: list[int], lcp_lens: list[int]): + """ + Initialize the compressed trie. + + Args: + lens: Length of each sequence, sorted in lexicographic order. + lcp_lens: LCP length between adjacent sequences, where + len(lcp_lens) == max(len(lens) - 1, 0). An empty `lens` + produces a degenerate trie that contains only the root node. + """ + expected_lcp = max(len(lens) - 1, 0) + if len(lcp_lens) != expected_lcp: + raise ValueError( + f"len(lcp_lens) must be {expected_lcp}, got {len(lcp_lens)}" + ) + + self.nodes: list[CTNode] = [] # Stores all trie nodes. + self._build(lens, lcp_lens) + + self.lca_depth = None + self.order = None + self.lens = None + self.lcp_lens = None + + def _new_node(self, depth: int, seq_id: int = -1) -> int: + """Create a new node and return its ID.""" + self.nodes.append(CTNode(depth=depth, seq_id=seq_id)) + return len(self.nodes) - 1 + + def _build(self, lens: list[int], lcp_lens: list[int]): + """Build the compressed trie.""" + + n_seqs = len(lens) + # Create the root node. + root_id = self._new_node(depth=0, seq_id=-1) + stack = [(root_id, 0)] # Stack entries are (node_id, depth). + nodes = self.nodes + + for seq_id in range(n_seqs): + len_i = lens[seq_id] + lcp = lcp_lens[seq_id - 1] if seq_id > 0 else 0 + + if len(stack) >= 2: + while stack[-2][1] > lcp: + # Pop a child node and connect it to its parent. + child_id = stack.pop()[0] + parent_id = stack[-1][0] + nodes[parent_id].child_ids.append(child_id) + + child_id = stack.pop()[0] + if stack[-1][1] < lcp: + lcp_node_id = self._new_node(depth=lcp, seq_id=-1) + stack.append((lcp_node_id, lcp)) + parent_id = stack[-1][0] + nodes[parent_id].child_ids.append(child_id) + else: + if stack[-1][1] < lcp: + lcp_node_id = self._new_node(depth=lcp, seq_id=-1) + stack.append((lcp_node_id, lcp)) + + # Create a new leaf node. + parent_id = stack[-1][0] + cur_node_id = self._new_node(depth=len_i, seq_id=seq_id) + stack.append((cur_node_id, len_i)) + + while len(stack) >= 2: + child_id = stack.pop()[0] + parent_id = stack[-1][0] + nodes[parent_id].child_ids.append(child_id) + + def dfs_chain(self, node_id: int, child_order_func) -> int: + """Compute `chain_tail_depth` for each node.""" + node = self.nodes[node_id] + + # Leaf node. + if node.seq_id != -1: + node.chain_tail_depth = node.depth + return + + for child_id in node.child_ids: + self.dfs_chain(child_id, child_order_func) + + child_ids = child_order_func(node_id) + if not child_ids: + # Only reachable for the root of an empty trie. The value never + # propagates anywhere since the subtree carries no leaves. + node.chain_tail_depth = node.depth + return + node.chain_tail_depth = self.nodes[child_ids[0]].chain_tail_depth + + def dfs_get_lens(self, node_id: int, seq_set: set[int]): + node = self.nodes[node_id] + + if node.seq_id != -1: + if node.seq_id in seq_set: + self.lens.append(node.depth) + self.lcp_lens.append(self.lca_depth) + self.lca_depth = node.depth + return + + for child_id in node.child_ids: + self.lca_depth = min(self.lca_depth, node.depth) + self.dfs_get_lens(child_id, seq_set) + + def get_lens(self, seq_set: set[int]): + self.lens = [] + self.lcp_lens = [] + self.lca_depth = 0 + self.dfs_get_lens(0, seq_set) + return self.lens, self.lcp_lens[1:] + + def dfs_get_order(self, node_id: int, child_order_func): + node = self.nodes[node_id] + + # Leaf node: record the sequence index. + if node.seq_id != -1: + self.order.append(node.seq_id) + self.lens.append(node.depth) + self.lcp_lens.append(self.lca_depth) + self.lca_depth = node.depth + return + + # Get child traversal order from the given strategy. + child_ids = child_order_func(node_id) + + # Recursively traverse children. + for child_id in child_ids: + self.lca_depth = min(self.lca_depth, node.depth) + self.dfs_get_order(child_id, child_order_func) + + def _get_child_order_forward(self, node_id: int) -> list[int]: + node = self.nodes[node_id] + return sorted( + node.child_ids, key=lambda child_id: self.nodes[child_id].chain_tail_depth + ) + + def _get_child_order_backward(self, node_id: int) -> list[int]: + node = self.nodes[node_id] + return sorted( + node.child_ids, + key=lambda child_id: ( + 1 if self.nodes[child_id].child_ids else 0, + self.nodes[child_id].chain_tail_depth, + ), + ) + + def _get_child_order_random( + self, node_id: int, seed: int | None = None + ) -> list[int]: + node = self.nodes[node_id] + child_ids = node.child_ids.copy() + + if seed is not None: + local_random = random.Random(seed) + local_random.shuffle(child_ids) + else: + random.shuffle(child_ids) + + return child_ids + + def get_order(self, child_order_func): + """Get sequence order from DFS with a custom child-order strategy.""" + self.dfs_chain(0, child_order_func) + self.order = [] + self.lens = [] + self.lcp_lens = [] + self.lca_depth = 0 + self.dfs_get_order(0, child_order_func) + + def get_order_forward(self): + """Get sequence order from DFS using main-Ld-priority traversal.""" + self.get_order(self._get_child_order_forward) + return self.order, self.lens, self.lcp_lens[1:] + + def get_order_backward(self): + """Get sequence order from DFS for backward-style pop traversal.""" + self.get_order(self._get_child_order_backward) + return self.order[::-1], self.lens[::-1], self.lcp_lens[1:][::-1] + + def get_order_random(self, seed: int | None = None): + """Get sequence order from DFS after randomizing child edges.""" + self.get_order(lambda node_id: self._get_child_order_random(node_id, seed)) + return self.order + + def get_stats(self, mode: str, block_size: int | None = None) -> dict: + """Get traversal stats for this compressed trie.""" + if mode == "forward": + _, lens, lcp_lens = self.get_order_forward() + elif mode == "backward": + _, lens, lcp_lens = self.get_order_backward() + else: + raise ValueError(f"Unsupported mode: {mode}") + return _get_stats(lens, lcp_lens, mode, block_size) + + def get_subtrie(self, seq_set: set[int]) -> "CompressedTrie": + """Build a compressed subtrie containing the selected leaf sequence IDs.""" + lens, lcp_lens = self.get_lens(seq_set) + return CompressedTrie(lens, lcp_lens) diff --git a/areal/experimental/dta/wrapper.py b/areal/experimental/dta/wrapper.py new file mode 100644 index 0000000000..dfa8d44eab --- /dev/null +++ b/areal/experimental/dta/wrapper.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +from collections.abc import Callable +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any, Protocol + +import torch +import torch.distributed as dist +from torch.distributed.optim import ZeroRedundancyOptimizer +from transformers.cache_utils import DynamicCache + +from areal.api.cli_args import MicroBatchSpec +from areal.engine.core.train_engine import ( + compute_total_loss_weight, + reorder_and_pad_outputs, +) +from areal.experimental.dta.runner import DTARunner +from areal.experimental.dta.token_trie import TokenTrie +from areal.utils.data import ( + MicroBatchList, + amend_position_ids, + pack_tensor_dict, + split_batch, + split_padded_tensor_dict_into_mb_list, +) + +if TYPE_CHECKING: + from areal.experimental.engine.archon_engine import ArchonEngine + + +class KVCacheModel(Protocol): + """Structural contract for DTA-compatible models.""" + + def forward( + self, + tokens: torch.LongTensor, + past_key_values: DynamicCache | None = None, + use_cache: bool = True, + ) -> SimpleNamespace: ... + + +class DTAWrapper: + """DTA adapter wrapped around ArchonEngine's batch-level APIs. + + DTA is adapted at the train_batch boundary instead of the generic runner + boundary. The runner contract exposes dense logits to process_output_fn, + while DTA intentionally keeps backward activation bounded by block_size + tokens inside DTARunner. + """ + + def validate_compatibility(self) -> None: + """Validate DTA-only constraints before model creation.""" + config = self.engine.config + parallel_dims = self.engine.parallel_dims + model_type = getattr(self.engine.model_config, "model_type", "") + + if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: + raise ValueError( + "DTA requires model-level KV-cache support via " + "forward(..., past_key_values=..., use_cache=True). " + f"model_type={model_type!r} is not supported because Qwen3.5 " + "hybrid linear_attention layers do not implement DTA-compatible " + "cache state." + ) + + if config.gradient_checkpointing: + raise ValueError( + "ArchonEngine: gradient_checkpointing=True is incompatible with " + "tree_training_mode='dta'. Disable gradient_checkpointing for DTA." + ) + + if ( + parallel_dims.pp_enabled + or parallel_dims.cp_enabled + or parallel_dims.tp_enabled + or parallel_dims.ep_enabled + or parallel_dims.etp_enabled + ): + raise ValueError( + "DTA currently supports only data parallelism. " + "Found unsupported parallel dimensions enabled among " + "{pp, cp, tp, ep, etp}. " + f"Current sizes: pp={parallel_dims.pp}, cp={parallel_dims.cp}, " + f"tp={parallel_dims.tp}, ep={parallel_dims.ep}, etp={parallel_dims.etp}." + ) + + def __init__( + self, + engine: ArchonEngine, + ) -> None: + self.engine = engine + self.device = engine.device + self.block_size = engine.config.dta_block_size + self.is_critic = engine.config.is_critic + self.runner = DTARunner( + model_config=engine.model_config, + device=engine.device, + dtype=getattr(torch, engine.config.dtype), + max_seq_len=engine.config.mb_spec.max_tokens_per_mb, + is_critic=engine.config.is_critic, + ) + + @property + def model(self) -> KVCacheModel: + return self.engine.model + + def apply_zero1(self) -> None: + """Apply DTA's Zero1 full-replica model setup.""" + model_args = getattr(self.engine.model, "model_args", None) + if getattr(model_args, "enable_weight_tying", False): + output = getattr(self.engine.model, "output", None) + tok_embeddings = getattr(self.engine.model, "tok_embeddings", None) + if output is not None and tok_embeddings is not None: + output.weight = tok_embeddings.weight + self.engine.model_parts = [self.engine.model] + + def create_optimizer(self) -> torch.optim.Optimizer: + """Create DTA's Zero1 optimizer.""" + assert self.engine.optimizer_config is not None + optimizer_config = self.engine.optimizer_config + common_kwargs: dict[str, object] = { + "lr": optimizer_config.lr, + "weight_decay": optimizer_config.weight_decay, + } + if optimizer_config.type == "adam": + return ZeroRedundancyOptimizer( + self.engine._get_all_parameters(), + optimizer_class=torch.optim.AdamW, + process_group=self.engine.data_parallel_group, + betas=(optimizer_config.beta1, optimizer_config.beta2), + eps=optimizer_config.eps, + fused=True, + **common_kwargs, + ) + if optimizer_config.type == "sgd": + return ZeroRedundancyOptimizer( + self.engine._get_all_parameters(), + optimizer_class=torch.optim.SGD, + process_group=self.engine.data_parallel_group, + **common_kwargs, + ) + raise ValueError( + f"Unsupported optimizer type for Zero1: {optimizer_config.type}" + ) + + def clip_grad_norm(self) -> float: + """Clip gradients for DTA's Zero1 full-replica training path.""" + assert self.engine.optimizer_config is not None + grads = [ + p.grad for p in self.engine._get_all_parameters() if p.grad is not None + ] + if not grads: + return 0.0 + + device = grads[0].device + total_sq = torch.zeros((), device=device, dtype=torch.float32) + for grad in grads: + total_sq += grad.detach().float().pow(2).sum() + + total_norm = total_sq.sqrt() + total_norm_value = float(total_norm) + if not math.isfinite(total_norm_value): + return total_norm_value + + clip_coef = ( + self.engine.optimizer_config.gradient_clipping / (total_norm + 1e-6) + ).clamp(max=1.0) + for grad in grads: + grad.mul_(clip_coef.to(device=grad.device, dtype=grad.dtype)) + return total_norm_value + + def prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: + """Build one sequence per microbatch for DTARunner.""" + input_ = amend_position_ids(input_) + n_seqs = input_["input_ids"].shape[0] + mb_spec = MicroBatchSpec.new( + self.engine.config.mb_spec, + n_mbs=n_seqs, + granularity=1, + n_mbs_divisor=1, + max_tokens_per_mb=self.engine.config.mb_spec.max_tokens_per_mb, + ) + # Keep DTA per-rank independent: one sequence per microbatch, no + # cross-rank synced microbatch-count alignment. + mb_list = split_padded_tensor_dict_into_mb_list(input_, mb_spec, sync_mbs=False) + assert len(mb_list.mbs) == n_seqs, ( + f"DTA requires one microbatch per sequence, " + f"expected {n_seqs} microbatches but got {len(mb_list.mbs)}." + ) + return mb_list + + @torch.no_grad() + def forward_batch( + self, + input_: list[dict[str, Any]] | dict[str, Any], + output_seqlens: list[int] | None = None, + aggregate_fn: Callable[[list[torch.Tensor]], torch.Tensor] = torch.cat, + ) -> torch.Tensor | list[torch.Tensor]: + """Forward pass through DTA, matching ArchonEngine.forward_batch.""" + engine = self.engine + assert engine._initialized + + input_batched, meta = engine._normalize_batch_input(input_) + + cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"] + inferred_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() + if meta is not None: + assert isinstance(input_, list) + if output_seqlens is not None and output_seqlens != inferred_seqlens: + raise ValueError( + f"output_seqlens mismatch for list input: " + f"given {output_seqlens}, " + f"inferred {inferred_seqlens} from attention_mask valid lengths." + ) + output_seqlens = inferred_seqlens + elif output_seqlens is None: + output_seqlens = inferred_seqlens + assert output_seqlens is not None + + mb_list = engine._prepare_mb_list(input_batched).to(engine.device) + engine.logger.info("tree_training_mode='dta' in forward_batch") + input_ids_list = self._extract_input_ids_list_from_mb_list(mb_list) + input_data = [{} for _ in input_ids_list] + trie = TokenTrie(input_ids_list, input_data, sorted=False) + trie.forward_permute() + + outputs = self.runner.forward(model=self.model, token_trie=trie) + if not self.is_critic: + outputs = [ + torch.cat([x, x.new_zeros((1, *x.shape[1:]))], dim=0) for x in outputs + ] + res = reorder_and_pad_outputs(outputs, output_seqlens, mb_list, aggregate_fn) + if meta is None: + return res + return split_batch(res, meta) + + @staticmethod + def _extract_input_ids(mb_input: dict[str, Any]) -> torch.Tensor: + if "input_ids" not in mb_input: + raise ValueError("DTA expects `input_ids` in micro-batch input.") + input_ids = mb_input["input_ids"] + if not torch.is_tensor(input_ids) or input_ids.ndim != 1: + raise ValueError( + "DTA expects packed 1D `input_ids` in micro-batch input, " + f"got {type(input_ids)} with ndim=" + f"{getattr(input_ids, 'ndim', 'N/A')}." + ) + return input_ids + + def _extract_input_ids_list_from_mb_list(self, mb_list: Any) -> list[torch.Tensor]: + input_ids_list: list[torch.Tensor] = [] + for mb_item in mb_list: + input_ids_list.append(self._extract_input_ids(mb_item.orig_mb)) + return input_ids_list + + def train_batch( + self, + input_: list[dict[str, Any]] | dict[str, Any], + loss_fn: Callable[..., torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + return_loss: bool = False, + ) -> dict[str, float]: + """Train on a batch using DTA's block-wise backward implementation.""" + engine = self.engine + assert engine._initialized + engine.optimizer_zero_grad() + + input_batched, _ = engine._normalize_batch_input(input_) + mb_list = engine._prepare_mb_list(input_batched).to(engine.device) + total_loss_weight = compute_total_loss_weight( + mb_list, loss_weight_fn, engine.data_parallel_group + ) + + engine.logger.info("tree_training_mode='dta' in train_batch") + engine.logger.info(f"total_loss_weight: {total_loss_weight}") + + dta_loss = self._backward_with_scaled_loss( + mb_list=mb_list, + loss_fn=loss_fn, + loss_weight_fn=loss_weight_fn, + total_loss_weight=total_loss_weight, + ) + engine.logger.info(f"DTA backward loss: {dta_loss}") + + for parameter in engine._get_all_parameters(): + if parameter.grad is not None: + dist.all_reduce(parameter.grad, group=engine.data_parallel_group) + result = engine.optimizer_step() + if return_loss: + result["loss"] = dta_loss + return result + + def _backward_with_scaled_loss( + self, + mb_list: MicroBatchList, + loss_fn: Callable[..., torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + total_loss_weight: torch.Tensor, + ) -> float: + input_ids_list = self._extract_input_ids_list_from_mb_list(mb_list) + per_seq_input_data: list[dict[str, Any]] = [] + for idx, mb_item in enumerate(mb_list): + _, ctx = self.engine._prepare_mb_inputs(mb_item) + mb_input = ctx.mb_input + # Keep backward input source aligned with forward input source. + self._extract_input_ids(mb_input) + if mb_input["input_ids"].shape != input_ids_list[idx].shape: + raise ValueError( + "DTA expects `ctx.mb_input['input_ids']` to align with " + "`mb_item.orig_mb['input_ids']` for each micro-batch." + ) + loss_scale = loss_weight_fn(ctx.mb_input) / total_loss_weight + if isinstance(loss_scale, torch.Tensor): + loss_scale = loss_scale.item() + per_seq_input_data.append({"original": mb_input, "scale": loss_scale}) + + if self.is_critic: + + def scaled_loss_fn( + values: torch.Tensor, + seq_input_data: dict[str, Any], + **extra_kwargs: Any, + ) -> torch.Tensor: + loss_val = loss_fn( + values, + seq_input_data["original"], + **extra_kwargs, + ) + return loss_val * seq_input_data["scale"] + else: + + def scaled_loss_fn( + logprobs: torch.Tensor, + entropy: torch.Tensor, + seq_input_data: dict[str, Any], + **extra_kwargs: Any, + ) -> torch.Tensor: + # Keep current behavior: DTA engine expects one extra position. + logprobs = torch.cat([logprobs, logprobs.new_zeros(1)], dim=0) + loss_val = loss_fn( + logprobs, + entropy, + seq_input_data["original"], + **extra_kwargs, + ) + return loss_val * seq_input_data["scale"] + + trie = TokenTrie(input_ids_list, per_seq_input_data, sorted=False) + trie.backward_permute() + + return float( + self.runner.backward( + model=self.model, + token_trie=trie, + block_size=self.block_size, + loss_fn=scaled_loss_fn, + ) + ) diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 98e2645a2e..25560b4dc9 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -45,6 +45,7 @@ reorder_and_pad_outputs, ) from areal.engine.fsdp_utils.grad import fsdp2_clip_grad_norm +from areal.experimental.dta.wrapper import DTAWrapper from areal.experimental.engine.archon_checkpoint import ( load_from_dcp, load_model_from_hf, @@ -151,8 +152,9 @@ def __init__(self, config: TrainEngineConfig): # Configuration (immutable after init) self.config = config self.optimizer_config = config.optimizer - self.enable_tree_training = config.enable_tree_training - + self.tree_training_mode = config.tree_training_mode + if self.tree_training_mode == "dta": + self.dta_wrapper: DTAWrapper # Model Configuration (loaded during __init__) self.model_config: PretrainedConfig = AutoConfig.from_pretrained( pretrained_model_name_or_path=self.config.path, @@ -310,6 +312,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._create_device_model() self.state_dict_adapter = self._create_state_dict_adapter() + if self.tree_training_mode == "dta": + self.dta_wrapper = DTAWrapper(self) + self.dta_wrapper.validate_compatibility() + self.logger.info(f"DTA Wrapper created on device {self.device}") self.param_dtype = getattr(torch, self.config.dtype) @@ -343,7 +349,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): config=self.config, parallel_dims=self.parallel_dims, model_config=self.model_config, - enable_tree_training=self.enable_tree_training, + tree_training_mode=self.tree_training_mode, logger=self.logger, ) @@ -368,6 +374,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): validate_fp8_shard_alignment(parts) self._materialize_and_load_weights() + if self.tree_training_mode == "dta": + # ``to_empty``/state-dict loading can recreate Parameter objects, so + # re-apply DTA's tied-weight setup before the optimizer captures params. + self.dta_wrapper.apply_zero1() self._create_optimizer(ft_spec) self.runner = create_runner( @@ -482,16 +492,19 @@ def optimizer_step(self): assert self.optimizer_config is not None assert self.lr_scheduler is not None - grad_norm = fsdp2_clip_grad_norm( - self._get_all_parameters(), - max_norm=self.optimizer_config.gradient_clipping, - fsdp_group=self.data_parallel_group, - tp_group=self._tp_group, - pp_group=self.parallel_dims.get_group("pp") - if self.parallel_dims.pp_enabled - else None, - offload_params=self.config.archon.offload_params, - ) + if self.tree_training_mode == "dta": + grad_norm = self.dta_wrapper.clip_grad_norm() + else: + grad_norm = fsdp2_clip_grad_norm( + self._get_all_parameters(), + max_norm=self.optimizer_config.gradient_clipping, + fsdp_group=self.data_parallel_group, + tp_group=self._tp_group, + pp_group=self.parallel_dims.get_group("pp") + if self.parallel_dims.pp_enabled + else None, + offload_params=self.config.archon.offload_params, + ) if not math.isfinite(grad_norm): self.optimizer_zero_grad() @@ -527,9 +540,18 @@ def train_batch( input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + return_loss: bool = False, ) -> dict[str, float]: """Train on a batch of data.""" assert self._initialized + if self.tree_training_mode == "dta": + return self.dta_wrapper.train_batch( + input_, + loss_fn, + loss_weight_fn, + return_loss=return_loss, + ) + self.optimizer_zero_grad() input_batched, _ = self._normalize_batch_input(input_) @@ -540,11 +562,20 @@ def train_batch( mb_list, loss_weight_fn, self.data_parallel_group ) + losses: list[torch.Tensor] = [] + def process_output( logits: torch.Tensor, ctx_dict: dict[str, Any] ) -> torch.Tensor: ctx = ArchonTrainContext(**ctx_dict) - return self._compute_logprobs_and_loss( + # Non-DTA uses FSDP2. We multiply by DP size before backward here: + # _compute_logprobs_and_loss() applies `loss_multiplier` in + # `local_weight / total_loss_weight * loss_multiplier`. + # The matching gradient divide is in PyTorch FSDP2's + # _fsdp_collectives.foreach_reduce(): _get_gradient_divide_factors() + # chooses ReduceOp.AVG for fp32/bf16, or _div_if_needed() for the + # SUM fallback. + loss = self._compute_logprobs_and_loss( logits, ctx, loss_fn, @@ -552,12 +583,27 @@ def process_output( total_loss_weight, loss_multiplier=self.data_parallel_world_size, ) + if return_loss: + losses.append(loss.detach()) + return loss self.forward_backward_batch(mb_list, process_output, forward_only=False) - stats = self.optimizer_step() - stats["num_micro_batches"] = len(mb_list.mbs) - return stats + result = self.optimizer_step() + result["num_micro_batches"] = len(mb_list.mbs) + if return_loss: + if losses: + # `losses` contain the training-scaled objective above: + # loss_i * (w_i / W_total) * dp_world_size + # This division is only for the returned metric; FSDP2 already + # applies the gradient divide during reduce-scatter. + local_loss = float(torch.stack(losses).sum().item()) / float( + self.data_parallel_world_size + ) + else: + local_loss = float("nan") + result["loss"] = local_loss + return result @torch.no_grad() def eval_batch( @@ -610,22 +656,28 @@ def forward_batch( ) -> torch.Tensor | list[torch.Tensor]: """Forward pass without gradient computation.""" assert self._initialized + if self.tree_training_mode == "dta": + return self.dta_wrapper.forward_batch( + input_, + output_seqlens=output_seqlens, + aggregate_fn=aggregate_fn, + ) input_batched, meta = self._normalize_batch_input(input_) + cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"] + inferred_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() if meta is not None: assert isinstance(input_, list) - inferred_seqlens = [d["attention_mask"].shape[-1] for d in input_] if output_seqlens is not None and output_seqlens != inferred_seqlens: raise ValueError( f"output_seqlens mismatch for list input: " f"given {output_seqlens}, " - f"inferred {inferred_seqlens} from attention_mask shapes." + f"inferred {inferred_seqlens} from attention_mask valid lengths." ) output_seqlens = inferred_seqlens - cu_seqlens = pack_tensor_dict(input_batched)["cu_seqlens"] - if output_seqlens is None: - output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() + elif output_seqlens is None: + output_seqlens = inferred_seqlens assert output_seqlens is not None batch_size = len(output_seqlens) @@ -643,7 +695,7 @@ def process_output( if self.pp_has_last_stage: assert outputs is not None - if self.enable_tree_training: + if self.tree_training_mode == "sparse": res = merge_packed_tree_results(outputs, batch_size) else: res = reorder_and_pad_outputs( @@ -948,6 +1000,10 @@ def _apply_parallelism( enable_compile: bool, ) -> None: """Apply parallelism using parallelize_fn.""" + if self.tree_training_mode == "dta": + self.dta_wrapper.apply_zero1() + return + self.spec.parallelize_fn( model=self.model, parallel_dims=self.parallel_dims, @@ -971,7 +1027,7 @@ def _prepare_mb_inputs( # Tree training: labels are derived from trie structure, not torch.roll. # (Tree input_ids is 1D packed format, so roll would be wrong anyway.) - if self.enable_tree_training: + if self.tree_training_mode == "sparse": assert trie_node is not None ctx = ArchonTrainContext( mb_input=mb_item.orig_mb, @@ -1095,7 +1151,7 @@ def _create_model_structure(self) -> nn.Module: """Create model structure on meta device without loading weights.""" # Use tree attention type when tree training is enabled attn_type = self.config.archon.attn_type - if self.enable_tree_training: + if self.tree_training_mode == "sparse": if attn_type != "tree": self.logger.warning( f"Tree training enabled, overriding attn_type '{self.config.archon.attn_type}' -> 'tree'" @@ -1156,9 +1212,12 @@ def _create_optimizer(self, ft_spec: FinetuneSpec): tik = time.perf_counter() - self.optimizer = create_optimizer( - self._get_all_parameters(), self.optimizer_config - ) + if self.tree_training_mode == "dta": + self.optimizer = self.dta_wrapper.create_optimizer() + else: + self.optimizer = create_optimizer( + self._get_all_parameters(), self.optimizer_config + ) self.lr_scheduler = create_lr_scheduler( self.optimizer, self.optimizer_config, ft_spec.total_train_steps ) @@ -1179,7 +1238,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: # Tree training path # Note: CP/PP incompatibility is validated in initialize(). - if self.enable_tree_training: + if self.tree_training_mode == "sparse": mb_list = build_packed_tree_batch( input_, mb_spec=self.config.mb_spec, @@ -1193,31 +1252,33 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: ) return mb_list - input_ = amend_position_ids(input_) - - # Pipeline parallelism requires n_microbatches >= num_total_stages - if self.parallel_dims.pp_enabled: - pp_size = self.parallel_dims.pp - stages_per_rank = len(self.pp_stages) - num_total_stages = pp_size * stages_per_rank - n_seqs = input_["attention_mask"].shape[0] - if n_seqs < num_total_stages: - raise RuntimeError( - f"Pipeline parallelism requires at least {num_total_stages} " - f"sequences (pp_size={pp_size} * stages_per_rank=" - f"{stages_per_rank}), but got {n_seqs}. " - f"Increase batch size or reduce PP degree/stages." - ) - min_n_mbs = num_total_stages - mb_spec = MicroBatchSpec.new( - self.config.mb_spec, - n_mbs=max(min_n_mbs, self.config.mb_spec.n_mbs or 1), - n_mbs_divisor=pp_size, - ) + if self.tree_training_mode == "dta": + mb_list = self.dta_wrapper.prepare_mb_list(input_) else: - mb_spec = self.config.mb_spec + input_ = amend_position_ids(input_) + # Pipeline parallelism requires n_microbatches >= num_total_stages. + if self.parallel_dims.pp_enabled: + pp_size = self.parallel_dims.pp + stages_per_rank = len(self.pp_stages) + num_total_stages = pp_size * stages_per_rank + n_seqs = input_["attention_mask"].shape[0] + if n_seqs < num_total_stages: + raise RuntimeError( + f"Pipeline parallelism requires at least {num_total_stages} " + f"sequences (pp_size={pp_size} * stages_per_rank=" + f"{stages_per_rank}), but got {n_seqs}. " + f"Increase batch size or reduce PP degree/stages." + ) + min_n_mbs = num_total_stages + mb_spec = MicroBatchSpec.new( + self.config.mb_spec, + n_mbs=max(min_n_mbs, self.config.mb_spec.n_mbs or 1), + n_mbs_divisor=pp_size, + ) + else: + mb_spec = self.config.mb_spec + mb_list = split_padded_tensor_dict_into_mb_list(input_, mb_spec) - mb_list = split_padded_tensor_dict_into_mb_list(input_, mb_spec) mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs] # LCM ensures page-aligned memory and exact CP slicing without extra padding. @@ -1304,7 +1365,7 @@ def _gather_actor_train_outputs( ctx: ArchonTrainContext, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None: """Compute (logprobs, entropy, vocab_min, vocab_max) for actor training.""" - if self.enable_tree_training: + if self.tree_training_mode == "sparse": # Handle dummy trie (empty tree for DP synchronization) if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids: return None @@ -1347,7 +1408,7 @@ def _gather_actor_forward_output( ctx: ArchonTrainContext, ) -> torch.Tensor | dict[int, torch.Tensor]: """Compute actor logprobs for forward-only path.""" - if self.enable_tree_training: + if self.tree_training_mode == "sparse": assert ctx.trie_node is not None return _gather_packed_tree_logprobs( logits, diff --git a/areal/experimental/engine/archon_utils.py b/areal/experimental/engine/archon_utils.py index cf9734587f..006967108e 100644 --- a/areal/experimental/engine/archon_utils.py +++ b/areal/experimental/engine/archon_utils.py @@ -258,11 +258,11 @@ def force_pad_to_maximum( config: TrainEngineConfig, parallel_dims: ArchonParallelDims, enable_compile: bool, - enable_tree_training: bool, + tree_training_mode: str, logger: logging.Logger, ) -> None: """Force ``config.pad_to_maximum = True`` when compile, PP, or tree training - requires it. Also validates tree training constraints. + requires it. Also validates sparse tree training constraints. """ # Force pad_to_maximum when compile is enabled to avoid dynamic shape issues if enable_compile and not config.pad_to_maximum: @@ -280,8 +280,8 @@ def force_pad_to_maximum( ) config.pad_to_maximum = True - # Tree training constraints - if enable_tree_training: + # Sparse tree training constraints + if tree_training_mode == "sparse": if config.is_critic: raise NotImplementedError( "Tree training with critic model is not supported yet." @@ -309,7 +309,7 @@ def prepare_training_config( config: TrainEngineConfig, parallel_dims: ArchonParallelDims, model_config: PretrainedConfig, - enable_tree_training: bool, + tree_training_mode: str, logger: logging.Logger, ) -> tuple[ActivationCheckpointConfig | None, bool]: """Build and validate all training configs before parallelism setup. @@ -345,7 +345,7 @@ def prepare_training_config( config=config, parallel_dims=parallel_dims, enable_compile=enable_compile, - enable_tree_training=enable_tree_training, + tree_training_mode=tree_training_mode, logger=logger, ) diff --git a/areal/experimental/models/archon/attention/sdpa.py b/areal/experimental/models/archon/attention/sdpa.py index 96588f5bae..6abccba4e4 100644 --- a/areal/experimental/models/archon/attention/sdpa.py +++ b/areal/experimental/models/archon/attention/sdpa.py @@ -16,8 +16,10 @@ def create_block_causal_mask_2d( - cu_seqlens: torch.Tensor, - seq_len: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + q_len: int, + k_len: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: @@ -28,35 +30,71 @@ def create_block_causal_mask_2d( - Across sequences: no attention allowed Args: - cu_seqlens: Cumulative sequence lengths, shape [num_seqs + 1]. - For example, [0, 3, 5, 7] means 3 sequences with lengths 3, 2, 2. - seq_len: Total sequence length (should equal cu_seqlens[-1]). + cu_seqlens_q: Query cumulative sequence lengths, shape [num_seqs + 1]. + cu_seqlens_k: Key cumulative sequence lengths, shape [num_seqs + 1]. + For self-attention, ``cu_seqlens_q == cu_seqlens_k``. + For KV-cache attention, K can be longer than Q for each sequence. + q_len: Total number of query tokens. + k_len: Total number of key/value tokens. device: Target device for the mask tensor. dtype: Target dtype (float mask with 0.0 and -inf). Returns: - Attention mask of shape [seq_len, seq_len]. + Attention mask of shape [q_len, k_len]. 0.0 = attend, -inf = mask out. - Example for cu_seqlens=[0, 3, 5, 7]:: - - [ 0, -inf, -inf, | -inf, -inf, | -inf, -inf] - [ 0, 0 , -inf, | -inf, -inf, | -inf, -inf] - [ 0, 0 , 0 , | -inf, -inf, | -inf, -inf] - [-inf, -inf, -inf, | 0, -inf, | -inf, -inf] - [-inf, -inf, -inf, | 0, 0 , | -inf, -inf] - [-inf, -inf, -inf, | -inf, -inf, | 0, -inf] - [-inf, -inf, -inf, | -inf, -inf, | 0, 0 ] + Examples: + 1) Standard packed self-attention (q_len == k_len):: + + cu_q = cu_k = [0, 3] + # sequence length = 3 + # allowed key positions per query row: + # q0 -> [0] + # q1 -> [0, 1] + # q2 -> [0, 1, 2] + + 2) Right-aligned KV-cache attention (q_len < k_len):: + + cu_q = [0, 4] + cu_k = [0, 6] + # q tokens correspond to the last 4 positions in key timeline. + # Local alignment offset = k_seq_len - q_seq_len = 2 + # allowed key positions per query row: + # q0 -> [0, 1, 2] + # q1 -> [0, 1, 2, 3] + # q2 -> [0, 1, 2, 3, 4] + # q3 -> [0, 1, 2, 3, 4, 5] """ - positions = torch.arange(seq_len, device=device) - seq_ids = torch.searchsorted(cu_seqlens, positions, side="right") - 1 + if cu_seqlens_q.numel() != cu_seqlens_k.numel(): + raise ValueError( + "cu_seqlens_q and cu_seqlens_k must have same number of sequences, " + f"got {cu_seqlens_q.numel()} vs {cu_seqlens_k.numel()}." + ) + + q_positions = torch.arange(q_len, device=device) + k_positions = torch.arange(k_len, device=device) + cu_q = cu_seqlens_q.to(device) + cu_k = cu_seqlens_k.to(device) + q_seq_ids = torch.searchsorted(cu_q, q_positions, side="right") - 1 + k_seq_ids = torch.searchsorted(cu_k, k_positions, side="right") - 1 + + # Query/key must belong to the same packed sequence. + same_seq = q_seq_ids.unsqueeze(1) == k_seq_ids.unsqueeze(0) - # same_seq: query and key must be in the same sequence - # causal: key position <= query position - same_seq = seq_ids.unsqueeze(1) == seq_ids.unsqueeze(0) - causal = positions.unsqueeze(0) <= positions.unsqueeze(1) + # Sequence-local token indices. + q_local = q_positions - cu_q[q_seq_ids] + k_local = k_positions - cu_k[k_seq_ids] - mask = torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype) + # Right-align Q inside K for KV-cache style attention: + # q_abs = (k_seq_len - q_seq_len) + q_local. + q_seq_lens = cu_q[q_seq_ids + 1] - cu_q[q_seq_ids] + k_seq_lens = cu_k[q_seq_ids + 1] - cu_k[q_seq_ids] + right_offset = (k_seq_lens - q_seq_lens).clamp(min=0) + + # Causal condition: key local index <= aligned query absolute index. + causal = k_local.unsqueeze(0) <= (q_local + right_offset).unsqueeze(1) + + mask = torch.full((q_len, k_len), float("-inf"), device=device, dtype=dtype) mask = mask.masked_fill(same_seq & causal, 0.0) return mask @@ -93,6 +131,7 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + cu_seqlens_k: torch.Tensor | None = None, ) -> torch.Tensor: """Compute attention with block-diagonal causal mask. @@ -101,17 +140,29 @@ def forward( k: Key tensor, shape [batch, heads, seq_len, head_dim] v: Value tensor, shape [batch, heads, seq_len, head_dim] scale: Optional scale factor for attention scores. - cu_seqlens: Cumulative sequence lengths, shape [num_seqs + 1]. + cu_seqlens: Query cumulative sequence lengths, shape [num_seqs + 1]. max_seqlen: Maximum sequence length (unused, for API compatibility). tree_attn_meta: Unused. Accepted for interface compatibility with TreeAttentionWrapper. + cu_seqlens_k: Optional key cumulative sequence lengths. If not set, + defaults to ``cu_seqlens``. Returns: Attention output, shape [batch, heads, seq_len, head_dim] """ - seq_len = q.shape[2] + q_len = q.shape[2] + k_len = k.shape[2] + if cu_seqlens_k is None: + cu_seqlens_k = cu_seqlens.clone() # TODO: Mask should be precomputed and passed in, not computed here. - attn_mask = create_block_causal_mask_2d(cu_seqlens, seq_len, q.device, q.dtype) + attn_mask = create_block_causal_mask_2d( + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens_k, + q_len=q_len, + k_len=k_len, + device=q.device, + dtype=q.dtype, + ) with sdpa_kernel(self.sdpa_backends, set_priority=True): return F.scaled_dot_product_attention( diff --git a/areal/experimental/models/archon/attention/varlen.py b/areal/experimental/models/archon/attention/varlen.py index 75fad6ac6b..697f29aa0b 100644 --- a/areal/experimental/models/archon/attention/varlen.py +++ b/areal/experimental/models/archon/attention/varlen.py @@ -276,6 +276,7 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + cu_seqlens_k: torch.Tensor | None = None, ) -> torch.Tensor: """Compute attention with varlen_attn. @@ -308,15 +309,20 @@ def forward( v_3d = v.squeeze(0).transpose(0, 1).contiguous() # Ensure cu_seqlens is int32 (required by flash_attn) - cu_seqlens_i32 = cu_seqlens.to(torch.int32) + cu_seqlens_q_i32 = cu_seqlens.to(torch.int32) + + if cu_seqlens_k is None: + cu_seqlens_k_i32 = cu_seqlens_q_i32.clone() + else: + cu_seqlens_k_i32 = cu_seqlens_k.to(torch.int32) # Call varlen_attn (self-attention: q and k have same cu_seqlens) out = varlen_attn( q_3d, k_3d, v_3d, - cu_seqlens_i32, - cu_seqlens_i32, + cu_seqlens_q_i32, + cu_seqlens_k_i32, max_seqlen, max_seqlen, is_causal=True, diff --git a/areal/experimental/models/archon/qwen2/model/model.py b/areal/experimental/models/archon/qwen2/model/model.py index c8a36f884f..f82b59b1ef 100644 --- a/areal/experimental/models/archon/qwen2/model/model.py +++ b/areal/experimental/models/archon/qwen2/model/model.py @@ -4,11 +4,14 @@ from __future__ import annotations +from types import SimpleNamespace + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.distributed import ProcessGroup +from transformers.cache_utils import DynamicCache from areal.experimental.models.archon.attention import ( SDPAWrapper, @@ -28,6 +31,8 @@ gather_seq_scatter_heads, ) +LayerKVCache = tuple[torch.Tensor, torch.Tensor] + class RMSNorm(nn.Module): """RMSNorm with float32 intermediate computation.""" @@ -131,7 +136,9 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_layer_kv: LayerKVCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, LayerKVCache]: bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -177,6 +184,15 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) + cu_seqlens_k = cu_seqlens.clone() + + if past_layer_kv is not None: + past_k, past_v = past_layer_kv + xk = torch.cat([past_k, xk], dim=2) + xv = torch.cat([past_v, xv], dim=2) + cu_seqlens_k += past_k.shape[2] + cu_seqlens_k[0] = 0 + output = self.packed_attn( xq, xk, @@ -185,6 +201,7 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, tree_attn_meta=tree_attn_meta, + cu_seqlens_k=cu_seqlens_k, ) output = output.transpose(1, 2).contiguous() @@ -197,7 +214,12 @@ def forward( seqlen = output.shape[1] output = output.view(bs, seqlen, -1) - return self.wo(output) + output = self.wo(output) + + if use_cache: + # Return full K/V because Qwen2Model rebuilds DynamicCache from scratch + return output, (xk, xv) + return output class FeedForward(nn.Module): @@ -246,16 +268,27 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: - x = x + self.attention( + past_layer_kv: LayerKVCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, LayerKVCache]: + attn_result = self.attention( self.attention_norm(x), rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_layer_kv=past_layer_kv, + use_cache=use_cache, ) + if use_cache: + attn_out, layer_kv = attn_result + else: + attn_out = attn_result + x = x + attn_out x = x + self.feed_forward(self.ffn_norm(x)) + if use_cache: + return x, layer_kv return x def init_weights(self): @@ -342,11 +375,36 @@ def init_buffers(self, buffer_device: torch.device | str): def forward( self, tokens: torch.Tensor, - positions: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int | torch.Tensor, + positions: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | torch.Tensor | None = None, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | SimpleNamespace: + if past_key_values is not None: + if ( + positions is not None + or cu_seqlens is not None + or max_seqlen is not None + ): + raise ValueError( + "When past_key_values is provided, positions/cu_seqlens/max_seqlen " + "must be None and are inferred internally." + ) + past_len = past_key_values.get_seq_length() + seq_len = tokens.shape[1] + positions = torch.arange( + past_len, + past_len + seq_len, + dtype=torch.long, + device=tokens.device, + ).unsqueeze(0) + cu_seqlens = torch.tensor( + [0, tokens.shape[1]], dtype=torch.int32, device=tokens.device + ) + max_seqlen = int(tokens.shape[1]) + past_len + # When pipeline parallelism enabled, cu_seqlens is [1, B+1] if cu_seqlens.ndim == 2: cu_seqlens = cu_seqlens.squeeze(0) @@ -357,15 +415,30 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - for layer in self.layers.values(): - h = layer( + if use_cache: + next_cache = DynamicCache() + + for layer_idx, layer in enumerate(self.layers.values()): + layer_past = None + if past_key_values is not None and layer_idx < len(past_key_values.layers): + layer_entry = past_key_values.layers[layer_idx] + layer_past = (layer_entry.keys, layer_entry.values) + + layer_out = layer( h, self.rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_layer_kv=layer_past, + use_cache=use_cache, ) + if use_cache: + h, layer_kv = layer_out + next_cache.update(layer_kv[0], layer_kv[1], layer_idx=layer_idx) + else: + h = layer_out h = self.norm(h) if self.norm else h @@ -373,6 +446,8 @@ def forward( output = self.score(h) if self.score else h else: output = self.output(h) if self.output else h + if use_cache: + return SimpleNamespace(logits=output, past_key_values=next_cache) return output diff --git a/areal/experimental/models/archon/qwen3/model/model.py b/areal/experimental/models/archon/qwen3/model/model.py index 062326be17..373067aa94 100644 --- a/areal/experimental/models/archon/qwen3/model/model.py +++ b/areal/experimental/models/archon/qwen3/model/model.py @@ -4,12 +4,15 @@ from __future__ import annotations +from types import SimpleNamespace + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +from transformers.cache_utils import DynamicCache from areal.experimental.models.archon.attention import ( SDPAWrapper, @@ -30,6 +33,8 @@ gather_seq_scatter_heads, ) +LayerKVCache = tuple[torch.Tensor, torch.Tensor] + def maybe_to_local(x: torch.Tensor) -> torch.Tensor: """Convert DTensor to local tensor if needed.""" @@ -195,7 +200,9 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_layer_kv: LayerKVCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, LayerKVCache]: bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -251,6 +258,15 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) + cu_seqlens_k = cu_seqlens.clone() + + if past_layer_kv is not None: + past_k, past_v = past_layer_kv + xk = torch.cat([past_k, xk], dim=2) + xv = torch.cat([past_v, xv], dim=2) + cu_seqlens_k += past_k.shape[2] + cu_seqlens_k[0] = 0 + output = self.packed_attn( xq, xk, @@ -259,6 +275,7 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, tree_attn_meta=tree_attn_meta, + cu_seqlens_k=cu_seqlens_k, ) output = output.transpose(1, 2).contiguous() @@ -271,7 +288,12 @@ def forward( seqlen = output.shape[1] output = output.view(bs, seqlen, -1) - return self.wo(output) + output = self.wo(output) + + if use_cache: + # Return full K/V because Qwen3Model rebuilds DynamicCache from scratch + return output, (xk, xv) + return output class FeedForward(nn.Module): @@ -334,19 +356,30 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: - x = x + self.attention( + past_layer_kv: LayerKVCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, LayerKVCache]: + attn_result = self.attention( self.attention_norm(x), rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_layer_kv=past_layer_kv, + use_cache=use_cache, ) + if use_cache: + attn_out, layer_kv = attn_result + else: + attn_out = attn_result + x = x + attn_out if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: x = x + self.feed_forward(self.ffn_norm(x)) + if use_cache: + return x, layer_kv return x def init_weights(self): @@ -456,11 +489,36 @@ def init_buffers(self, buffer_device: torch.device | str): def forward( self, tokens: torch.Tensor, - positions: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int | torch.Tensor, + positions: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | torch.Tensor | None = None, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | SimpleNamespace: + if past_key_values is not None: + if ( + positions is not None + or cu_seqlens is not None + or max_seqlen is not None + ): + raise ValueError( + "When past_key_values is provided, positions/cu_seqlens/max_seqlen " + "must be None and are inferred internally." + ) + past_len = past_key_values.get_seq_length() + seq_len = tokens.shape[1] + positions = torch.arange( + past_len, + past_len + seq_len, + dtype=torch.long, + device=tokens.device, + ).unsqueeze(0) + cu_seqlens = torch.tensor( + [0, tokens.shape[1]], dtype=torch.int32, device=tokens.device + ) + max_seqlen = int(tokens.shape[1]) + past_len + # When pipeline parallelism enabled, cu_seqlens is [1, B+1] if cu_seqlens.ndim == 2: cu_seqlens = cu_seqlens.squeeze(0) @@ -471,15 +529,30 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - for layer in self.layers.values(): - h = layer( + if use_cache: + next_cache = DynamicCache() + + for layer_idx, layer in enumerate(self.layers.values()): + layer_past = None + if past_key_values is not None and layer_idx < len(past_key_values.layers): + layer_entry = past_key_values.layers[layer_idx] + layer_past = (layer_entry.keys, layer_entry.values) + + layer_out = layer( h, self.rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_layer_kv=layer_past, + use_cache=use_cache, ) + if use_cache: + h, layer_kv = layer_out + next_cache.update(layer_kv[0], layer_kv[1], layer_idx=layer_idx) + else: + h = layer_out h = self.norm(h) if self.norm else h @@ -487,6 +560,8 @@ def forward( output = self.score(h) if self.score else h else: output = self.output(h) if self.output else h + if use_cache: + return SimpleNamespace(logits=output, past_key_values=next_cache) return output diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 3007569be6..1b2ca23c7f 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -20,12 +20,12 @@ ) from areal.api.alloc_mode import ModelAllocation from areal.api.cli_args import PerfTracerConfig, TrainEngineConfig +from areal.infra.dp_allocation import AllocationInput, allocate_trajectories from areal.infra.rpc.rtensor import RTensor, flatten_shard_ids from areal.infra.utils.concurrent import run_async_task from areal.utils import logging, stats_tracker from areal.utils.data import make_dummy_eval_item from areal.utils.network import find_free_ports -from areal.utils.seqpack import balanced_greedy_partition from .rollout_callback import RolloutCallback from .rollout_controller import RolloutController @@ -58,68 +58,57 @@ def _is_tensor_like(obj: Any) -> bool: ) -def _item_weight(d: dict[str, Any]) -> int: - attn_mask = d.get("attention_mask") - if isinstance(attn_mask, torch.Tensor): - return int(attn_mask.sum().item()) - if isinstance(attn_mask, RTensor): - return attn_mask.data.numel() - # Fallback: first tensor's numel - for v in d.values(): - if isinstance(v, RTensor): - return v.data.numel() - if isinstance(v, torch.Tensor) and v.ndim >= 2: - return v.numel() - return 1 +def _controller_allocation_algorithm(packing_algorithm: str) -> str: + if packing_algorithm in {"dta", "ffd_equal"}: + return packing_algorithm + # TODO(agent): This preserves the controller dispatch behavior introduced by + # e97f2a0c (#1017), where tensor-like eval inputs were split with + # balanced_greedy_partition instead of actor.packing_algorithm. It does not + # honor actor.packing_algorithm="ffd"/"kk" literally; fix this with an + # explicit controller dispatch config or a safe ffd/kk controller policy. + return "ffd_equal" def _dispatch_tensors( item_list: list[dict[str, Any]], dp_size: int, group_size: int = 1, + packing_algorithm: str = "ffd_equal", ) -> tuple[list[list[dict[str, Any]]], list[list[int]]]: """Partition trajectories across DP groups by balanced token count. Args: + packing_algorithm: controller dispatch allocation algorithm. ``"ffd_equal"`` + is the default because controller dispatch historically used + ``balanced_greedy_partition`` for tensor-like eval inputs. External + rollout config values ``"ffd"`` and ``"kk"`` currently keep this + equal-cardinality behavior; see TODO in + ``_controller_allocation_algorithm``. group_size: number of consecutive items that form an atomic dispatch unit (e.g. 2 for chosen/rejected RW pairs). Groups are never split across DP ranks. ``group_size=1`` degenerates to per-item partitioning. """ - n = len(item_list) - if n % group_size != 0: - raise ValueError( - f"item count ({n}) must be divisible by group_size ({group_size})" + allocation = allocate_trajectories( + AllocationInput( + items=item_list, + n_groups=dp_size, + algorithm=_controller_allocation_algorithm(packing_algorithm), + group_size=group_size, ) - - token_weights = [_item_weight(d) for d in item_list] - n_groups = n // group_size - - group_weights = [ - sum(token_weights[g * group_size + k] for k in range(group_size)) - for g in range(n_groups) + ) + if allocation.metrics is not None: + stats_tracker.scalar(**allocation.metrics.to_stats()) + splits = [ + [allocation.items[idx] for idx in group_indices] + for group_indices in allocation.group_indices ] - gpart = balanced_greedy_partition(group_weights, K=dp_size) - - group_indices: list[list[int]] = [] - splits: list[list[dict[str, Any]]] = [] - for gidxs in gpart: - item_idxs: list[int] = [] - items: list[dict[str, Any]] = [] - for g in gidxs: - for k in range(group_size): - idx = g * group_size + k - item_idxs.append(idx) - items.append(item_list[idx]) - group_indices.append(item_idxs) - splits.append(items) - assert all(len(s) % group_size == 0 for s in splits), ( f"Post-dispatch invariant violated: shard sizes " f"{[len(s) for s in splits]} not all divisible by group_size={group_size}" ) - return splits, group_indices + return splits, allocation.group_indices def _pad_eval_batch( @@ -535,7 +524,10 @@ def _split(item: Any) -> list[Any]: if _is_tensor_like(item): if group_indices is None: splits, group_indices = _dispatch_tensors( - item, dp_size, group_size=group_size + item, + dp_size, + group_size=group_size, + packing_algorithm=self.config.packing_algorithm, ) return splits return [[item[i] for i in idxs] for idxs in group_indices] diff --git a/areal/infra/dist_rollout.py b/areal/infra/dist_rollout.py index 28e363066f..0a41a2b6ff 100644 --- a/areal/infra/dist_rollout.py +++ b/areal/infra/dist_rollout.py @@ -8,14 +8,15 @@ from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import InferenceEngine, TrainEngine, WorkflowLike +from areal.infra.dp_allocation import AllocationInput, allocate_trajectories from areal.infra.platforms import current_platform +from areal.utils import stats_tracker from areal.utils.data import ( all_gather_tensor_container, broadcast_tensor_container, split_and_unpad_tensor, tensor_container_to, ) -from areal.utils.seqpack import get_allocate_fn @dataclass @@ -24,6 +25,7 @@ class RedistributedData: data: list[dict[str, Any]] rank: int group_indices: list[list[int]] + dta_metrics: Any | None = None def redistribute_trajectories( @@ -45,7 +47,9 @@ def redistribute_trajectories( group : dist.ProcessGroup, optional The process group for communication. If None, uses the default group. packing_algorithm : str, optional - Packing algorithm to use ("ffd" or "kk"). Default is "ffd". + How to pack trajectories across data-parallel ranks: ``"ffd"`` or ``"kk"`` + balance by total sequence length; ``"dta"`` uses DTA DFS-order partitioning + with ``n_tree_tokens`` as cost. Default ``"ffd"``. Returns ------- @@ -64,9 +68,6 @@ def redistribute_trajectories( for traj_list in all_gathered: all_data.extend(traj_list) - # Compute sequence lengths for load balancing - seqlens = [d["attention_mask"].sum().item() for d in all_data] - # Remove pad positions from each trajectory (split_and_unpad_tensor # auto-derives trim lengths from attention_mask when traj_seqlens=None) all_data = [ @@ -76,21 +77,24 @@ def redistribute_trajectories( for d in all_data ] - allocate_fn = get_allocate_fn(packing_algorithm) - # Allocate trajectories to ranks using the configured packing algorithm - # No capacity limit leads to balanced partition across this group - group_indices = allocate_fn( - seqlens, capacity=int(1e12), min_groups=dist.get_world_size(group) + n_groups = dist.get_world_size(group) + allocation = allocate_trajectories( + AllocationInput( + items=all_data, + n_groups=n_groups, + algorithm=packing_algorithm, + ) ) - local_indices = group_indices[dist.get_rank(group=group)] # Select assigned trajectories for this rank (no concatenation — deferred to train side) - data = [all_data[i] for i in local_indices] + local_indices = allocation.group_indices[dist.get_rank(group=group)] + data = [allocation.items[i] for i in local_indices] return RedistributedData( - all_data=all_data, + all_data=allocation.items, data=data, rank=dist.get_rank(group=group), - group_indices=group_indices, + group_indices=allocation.group_indices, + dta_metrics=allocation.metrics, ) @@ -122,22 +126,32 @@ def _broadcast_and_redistribute_trajectories( list[dict[str, Any]] Redistributed and broadcast batch available on all ranks (list of trajs) """ + rollout_packing = self.train_engine.config.packing_algorithm + if trajectories is not None: - config = getattr(self.train_engine, "config", None) - mb_spec = getattr(config, "mb_spec", None) - packing_algorithm = getattr(mb_spec, "packing_algorithm", "ffd") redist = redistribute_trajectories( trajectories, group=self.train_engine.data_parallel_group, - packing_algorithm=packing_algorithm, + packing_algorithm=rollout_packing, ) batch = redist.data + dta_metrics_payload = [redist.dta_metrics] else: batch = None + dta_metrics_payload = [None] current_platform.synchronize() dist.barrier(group=self.train_engine.cpu_group) + dist.broadcast_object_list( + dta_metrics_payload, + src=self.train_engine.current_data_parallel_head(), + group=self.train_engine.context_and_model_parallel_group, + ) + dta_metrics = dta_metrics_payload[0] + if dta_metrics is not None: + stats_tracker.scalar(**dta_metrics.to_stats()) + batch = broadcast_tensor_container( batch, src_rank=self.train_engine.current_data_parallel_head(), diff --git a/areal/infra/dp_allocation.py b/areal/infra/dp_allocation.py new file mode 100644 index 0000000000..ba698eab89 --- /dev/null +++ b/areal/infra/dp_allocation.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Protocol + +import torch + +from areal.infra.rpc.rtensor import RTensor +from areal.utils.seqpack import balanced_greedy_partition, ffd_allocate, kk_allocate + + +@dataclass(slots=True) +class AllocationInput: + """DP allocation request for trajectory-like items. + + ``group_size`` binds adjacent items into an atomic allocation unit before + cost-based algorithms run. ``capacity`` is used by capacity-style algorithms + (``ffd`` and ``kk``); controller dispatch uses ``ffd_equal`` to preserve + equal item counts across DP groups. + """ + + items: list[dict[str, Any]] + n_groups: int + algorithm: str + group_size: int = 1 + capacity: int = int(1e12) + + +@dataclass(slots=True) +class AllocationOutput: + """DP allocation result for trajectory-like items. + + ``group_indices`` always index ``items`` on this output object. DTA may + normalize grouped trajectories into sequence-level items before allocating. + """ + + items: list[dict[str, Any]] + group_indices: list[list[int]] + metrics: Any | None = None + + +@dataclass(slots=True) +class _AtomicUnit: + indices: list[int] + cost: int + + +def _item_weight(item: dict[str, Any]) -> int: + attn_mask = item.get("attention_mask") + if isinstance(attn_mask, torch.Tensor): + return int(attn_mask.sum().item()) + if isinstance(attn_mask, RTensor): + return attn_mask.data.numel() + for value in item.values(): + if isinstance(value, RTensor): + return value.data.numel() + if isinstance(value, torch.Tensor) and value.ndim >= 2: + return value.numel() + return 1 + + +def _contains_rtensor(obj: Any) -> bool: + if isinstance(obj, RTensor): + return True + if isinstance(obj, dict): + return any(_contains_rtensor(value) for value in obj.values()) + if isinstance(obj, (list, tuple)): + return any(_contains_rtensor(item) for item in obj) + return False + + +def _make_atomic_units( + items: list[dict[str, Any]], group_size: int +) -> list[_AtomicUnit]: + if group_size <= 0: + raise ValueError(f"group_size must be positive, got {group_size}.") + if len(items) % group_size != 0: + raise ValueError( + f"item count ({len(items)}) must be divisible by group_size ({group_size})" + ) + + units: list[_AtomicUnit] = [] + for group_start in range(0, len(items), group_size): + indices = list(range(group_start, group_start + group_size)) + cost = sum(_item_weight(items[idx]) for idx in indices) + units.append(_AtomicUnit(indices=indices, cost=cost)) + return units + + +def _ffd_allocate(req: AllocationInput) -> AllocationOutput: + units = _make_atomic_units(req.items, req.group_size) + costs = [unit.cost for unit in units] + unit_groups = ffd_allocate(costs, capacity=req.capacity, min_groups=req.n_groups) + group_indices = [ + [idx for unit_idx in unit_group for idx in units[unit_idx].indices] + for unit_group in unit_groups + ] + return AllocationOutput(items=req.items, group_indices=group_indices) + + +def _kk_allocate(req: AllocationInput) -> AllocationOutput: + units = _make_atomic_units(req.items, req.group_size) + costs = [unit.cost for unit in units] + unit_groups = kk_allocate(costs, capacity=req.capacity, min_groups=req.n_groups) + group_indices = [ + [idx for unit_idx in unit_group for idx in units[unit_idx].indices] + for unit_group in unit_groups + ] + return AllocationOutput(items=req.items, group_indices=group_indices) + + +def _ffd_equal_allocate(req: AllocationInput) -> AllocationOutput: + units = _make_atomic_units(req.items, req.group_size) + costs = [unit.cost for unit in units] + unit_groups = balanced_greedy_partition(costs, K=req.n_groups) + group_indices = [ + [idx for unit_idx in unit_group for idx in units[unit_idx].indices] + for unit_group in unit_groups + ] + return AllocationOutput(items=req.items, group_indices=group_indices) + + +def _dta_allocate(req: AllocationInput) -> AllocationOutput: + if req.group_size != 1: + raise ValueError( + "packing_algorithm='dta' is incompatible with group_size > 1. " + "DTA requires sequence-level independence." + ) + from areal.experimental.dta.allocation import allocate_dta_trajectories + + items = req.items + if _contains_rtensor(items): + # TODO(agent): This controller-side localization can become a bottleneck. + items = RTensor.localize(items) + + dta_allocation = allocate_dta_trajectories(items, n_groups=req.n_groups) + return AllocationOutput( + items=dta_allocation.items, + group_indices=dta_allocation.group_indices, + metrics=dta_allocation.metrics, + ) + + +class _TrajectoryAllocateFn(Protocol): + def __call__(self, req: AllocationInput) -> AllocationOutput: ... + + +_TRAJECTORY_ALLOCATE_FNS: dict[str, _TrajectoryAllocateFn] = { + "ffd": _ffd_allocate, + "kk": _kk_allocate, + "ffd_equal": _ffd_equal_allocate, + "dta": _dta_allocate, +} + + +def get_dp_allocate_fn(algorithm: str) -> _TrajectoryAllocateFn: + """Return the DP allocation adapter for a rollout packing algorithm.""" + try: + return _TRAJECTORY_ALLOCATE_FNS[algorithm] + except KeyError as err: + raise ValueError( + f"Unknown trajectory packing algorithm '{algorithm}'. " + f"Supported algorithms: {sorted(_TRAJECTORY_ALLOCATE_FNS)}" + ) from err + + +def allocate_trajectories(req: AllocationInput) -> AllocationOutput: + """Allocate trajectory-like items across data-parallel groups. + + ``group_indices`` in the returned object always index ``AllocationOutput.items``. + """ + return get_dp_allocate_fn(req.algorithm)(req) diff --git a/areal/models/tree_attn/module_archon.py b/areal/models/tree_attn/module_archon.py index d039228e69..49140bf109 100644 --- a/areal/models/tree_attn/module_archon.py +++ b/areal/models/tree_attn/module_archon.py @@ -93,6 +93,7 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + cu_seqlens_k: torch.Tensor | None = None, ) -> torch.Tensor: """Compute tree attention. @@ -107,6 +108,8 @@ def forward( kept for API compatibility with VarlenAttentionWrapper). tree_attn_meta: Tree attention metadata containing either a BlockMask (flex attention) or TreeAttentionData (Triton). + cu_seqlens_k: Unused. Accepted for interface compatibility with + VarlenAttentionWrapper. Returns: Attention output, shape [batch, heads, seq_len, head_dim] diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 07944a31a5..12efd03fde 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -366,6 +366,18 @@ def _ppo_update(self, data: dict[str, Any]) -> None: class PPOActorController(TrainController): + def _prepare_rollout_batch( + self, batch: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + if self.config.packing_algorithm != "dta": + return batch + from areal.experimental.dta.rollout import prepare_dta_rollout_batch + + return prepare_dta_rollout_batch(batch) + + def prepare_batch(self, *args, **kwargs) -> list[dict[str, Any]]: + return self._prepare_rollout_batch(super().prepare_batch(*args, **kwargs)) + def compute_logp(self, *args, **kwargs): return self._custom_function_call( "compute_logp", *args, rpc_meta={"broadcast": True}, **kwargs diff --git a/areal/utils/data.py b/areal/utils/data.py index 09368e4ef2..0dab956268 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -47,6 +47,62 @@ def get_batch_size(data: dict[str, Any]) -> int: return 0 +def extract_valid_token_sequences( + input_ids_batch: torch.Tensor, + attention_mask: torch.Tensor, +) -> tuple[list[torch.Tensor], int]: + """Extract unpadded token sequences from a [B, S] batch.""" + if not (torch.is_tensor(input_ids_batch) and torch.is_tensor(attention_mask)): + raise TypeError("input_ids_batch and attention_mask must be torch.Tensor.") + if input_ids_batch.ndim != 2 or attention_mask.ndim != 2: + raise ValueError("input_ids_batch and attention_mask must be rank-2 tensors.") + if input_ids_batch.shape != attention_mask.shape: + raise ValueError( + "input_ids_batch and attention_mask must have identical shapes." + ) + + max_seq_len = 0 + input_ids_list: list[torch.Tensor] = [] + for i in range(input_ids_batch.shape[0]): + valid_length = int(attention_mask[i].sum().item()) + max_seq_len = max(max_seq_len, valid_length) + input_ids_list.append(input_ids_batch[i, :valid_length]) + return input_ids_list, max_seq_len + + +def extract_single_valid_token_sequence( + trajectory: dict[str, Any], +) -> torch.Tensor: + """Extract one unpadded token sequence from a trajectory dict. + + Raises + ------ + ValueError + If required fields are missing, malformed, or trajectory batch size is not 1. + """ + if "input_ids" not in trajectory or "attention_mask" not in trajectory: + raise ValueError( + "trajectory must contain both 'input_ids' and 'attention_mask'." + ) + + input_ids = trajectory["input_ids"] + attention_mask = trajectory["attention_mask"] + seqs, _ = extract_valid_token_sequences(input_ids, attention_mask) + if len(seqs) != 1: + raise ValueError( + f"trajectory must contain exactly one sequence, got {len(seqs)}." + ) + return seqs[0] + + +def get_total_valid_tokens(trajectory: dict[str, Any]) -> int: + """Return total valid token count inferred from attention_mask when available.""" + attention_mask = trajectory.get("attention_mask") + if torch.is_tensor(attention_mask): + return int(attention_mask.sum().item()) + return 0 + + def reorder_list(xs: Sequence, indices: list[int]) -> list: assert len(set(indices)) == len(xs) return [xs[i] for i in indices] @@ -357,6 +413,46 @@ def split_and_unpad_tensor( return result +def unpack_groups_to_sequences( + item_list: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Flatten grouped trajectories into fully independent sequence-level dicts. + + For example, if an item in item_list has shape [8, seq_len, ...] for 8 samples + (group_size=8), it will be split into 8 separate dictionaries, each with + shape [1, seq_len, ...]. This is required for algorithms like DTA that operate + on individual sequences rather than groups. + + Args: + item_list: List of trajectory dictionaries. + + Returns: + A new list where every dictionary represents a single sequence. + """ + flat_item_list = [] + for item in item_list: + attn_mask = item.get("attention_mask") + if ( + attn_mask is not None + and isinstance(attn_mask, torch.Tensor) + and attn_mask.ndim >= 2 + ): + n_seqs = attn_mask.shape[0] + if n_seqs > 1: + splits = split_and_unpad_tensor( + item, n_trajs=n_seqs, traj_group_sizes=1 + ) + if isinstance(splits, list): + flat_item_list.extend(splits) + else: + flat_item_list.append(splits) + else: + flat_item_list.append(item) + else: + flat_item_list.append(item) + return flat_item_list + + @dataclass class TrajBatchMeta: """Metadata for reversing concat_batch: traj counts, group sizes, seqlens.""" @@ -446,9 +542,13 @@ def unpack_sequence( def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: list[int]) -> list[list[int]]: - """Allocate sequences into balanced micro-batches using the configured algorithm. + """Allocate sequence costs into balanced micro-batch groups. - The packing algorithm is determined by ``mb_spec.packing_algorithm``: + This is the low-level cost-packing path used by micro-batch materialization. + It operates on integer costs and is separate from trajectory-level allocation + in ``areal.infra.dp_allocation``. + + The cost-packing algorithm is determined by ``mb_spec.packing_algorithm``: - ``"ffd"`` (default): First Fit Decreasing — fast greedy heuristic. - ``"kk"``: Karmarkar-Karp — produces more balanced partitions at a slight computational cost. @@ -697,6 +797,7 @@ def split_padded_tensor_dict_into_mb_list( data: dict[str, Any], mb_spec: MicroBatchSpec, group: dist.ProcessGroup | None = None, + sync_mbs: bool = True, ) -> MicroBatchList: """Split a padded dict of tensors into micro-batches based on the attention mask. @@ -704,6 +805,7 @@ def split_padded_tensor_dict_into_mb_list( data (Dict): Dictionary containing padded tensors. mb_spec (MicroBatchSpec): Specification for micro-batch splitting. group (Optional[dist.ProcessGroup]): Process group for distributed synchronization. + sync_mbs: Whether to synchronize the number of micro-batches across ranks. Returns: MicroBatchList: A structure containing the split micro-batches and metadata. @@ -748,7 +850,10 @@ def split_padded_tensor_dict_into_mb_list( not_to_split[key] = value # split - group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group) + if sync_mbs: + group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group) + else: + group_indices = allocate_balanced_mbs(mb_spec, input_lens) group_indices = [ seqpack.flat2d( [list(range(i * granularity, (i + 1) * granularity)) for i in group_index] diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 94e5839a6e..11aff6bf15 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -87,6 +87,7 @@ "TreeAttentionCore": "light_cyan", "TreeAttentionConstants": "light_cyan", "TreeAttentionViz": "light_cyan", + "DTA": "light_cyan", # Checkpoint - blue (infrastructure) "Saver": "blue", "AsyncCheckpoint": "blue", diff --git a/areal/utils/seqpack.py b/areal/utils/seqpack.py index 175461771f..80a87b0f2b 100644 --- a/areal/utils/seqpack.py +++ b/areal/utils/seqpack.py @@ -158,20 +158,26 @@ def reorder_to_balanced_batches( # Packing Algorithm Registry # ============================================================================= -# Supported packing algorithm names (used in MicroBatchSpec.packing_algorithm) +# Supported cost-packing algorithm names used by MicroBatchSpec.packing_algorithm. +# This registry operates on integer costs and intentionally does not include DTA; +# trajectory-level allocation lives in areal.infra.dp_allocation. PACKING_ALGORITHM_FFD = "ffd" PACKING_ALGORITHM_KK = "kk" -PACKING_ALGORITHMS = {PACKING_ALGORITHM_FFD, PACKING_ALGORITHM_KK} +PACKING_ALGORITHMS = { + PACKING_ALGORITHM_FFD, + PACKING_ALGORITHM_KK, +} def get_allocate_fn(algorithm: str = PACKING_ALGORITHM_FFD): - """Return the allocation function for the given algorithm name. + """Return a cost allocator for micro-batch sequence packing. Args: algorithm: One of ``"ffd"`` or ``"kk"``. Returns: - The corresponding allocation function (``ffd_allocate`` or ``kk_allocate``). + A function with signature ``(values, capacity, min_groups, + n_groups_divisor=1) -> group_indices``. Raises: ValueError: If the algorithm name is not recognized. diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 63fad21047..6eddb7a9bb 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -370,7 +370,9 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -443,7 +445,9 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -490,7 +494,9 @@ Core configuration for model training, including optimization and backend settin | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -996,7 +1002,9 @@ fields. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -1261,7 +1269,9 @@ Configuration class: TeacherConfig | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/docs/en/reference/tree_training.md b/docs/en/reference/tree_training.md index 4ce5587e18..c97d3427df 100644 --- a/docs/en/reference/tree_training.md +++ b/docs/en/reference/tree_training.md @@ -32,11 +32,11 @@ FLOPs by up to **10x** and achieves up to **7x** acceleration. ### Enabling Tree Training -Enable tree training via the `enable_tree_training` option in `TrainEngineConfig`: +Enable tree training via the `tree_training_mode` option in `TrainEngineConfig`: ```yaml actor: - enable_tree_training: true + tree_training_mode: sparse pad_to_maximum: true # Must be set to true for tree training mb_spec: max_tokens_per_mb: 8192 # Must be set for tree training @@ -44,15 +44,41 @@ actor: ### Required Configuration -| Parameter | Type | Required | Description | -| --------------------------- | ---- | -------- | ---------------------------------- | -| `enable_tree_training` | bool | Yes | Enable tree-based sequence packing | -| `pad_to_maximum` | bool | Yes | Must be `true` for tree training | -| `mb_spec.max_tokens_per_mb` | int | Yes | Max tokens per tree (must be set) | +| Parameter | Type | Required | Description | +| --------------------------- | ---- | -------- | ------------------------------------------------ | +| `tree_training_mode` | str | Yes | `sparse` for sparse tree training, `dta` for DTA | +| `pad_to_maximum` | bool | Yes | Must be `true` for sparse tree training | +| `mb_spec.max_tokens_per_mb` | int | Yes | Max tokens per tree (must be set) | NOTE: When tree training is enabled `max_tokens_per_mb` must be a multiple of `BLOCK_SIZE` (128). +### Dynamic Tree Attention + +Dynamic Tree Attention (DTA) is enabled with `tree_training_mode: dta`. For rollout +redistribution across data-parallel ranks, set the train engine `packing_algorithm` to +`dta`: + +```yaml +actor: + tree_training_mode: dta + packing_algorithm: dta +``` + +This option is separate from `mb_spec.packing_algorithm`. The rollout-level +`packing_algorithm` controls trajectory allocation across data-parallel ranks, while +`mb_spec.packing_algorithm` controls micro-batch formation inside a training step. DTA +allocation is implemented under `areal/experimental/dta` and is exposed through the +trajectory-level allocation wrapper in `areal.infra.dp_allocation`, so core rollout code +uses the same allocation interface as `ffd` and `kk`. + +DTA allocation flattens grouped rollout trajectories to sequence-level items before +partitioning. It reports `dta/*` metrics such as tree tokens before and after allocation +and compression ratios. The sequence-level preparation, allocation validation, and +metric computation live in `areal.experimental.dta.allocation`; shared rollout and +controller code only call the experimental helper and keep algorithm-specific details +out of the core data path. + ## Implementation ### Tree Building Process diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index 79cd7e754c..ad1133097b 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -368,7 +368,9 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -441,7 +443,9 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -488,7 +492,9 @@ Core configuration for model training, including optimization and backend settin | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -994,7 +1000,9 @@ fields. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -1259,7 +1267,9 @@ Configuration class: TeacherConfig | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/docs/zh/reference/tree_training.md b/docs/zh/reference/tree_training.md index 5b052d8322..87b895ac07 100644 --- a/docs/zh/reference/tree_training.md +++ b/docs/zh/reference/tree_training.md @@ -27,11 +27,11 @@ ### 启用树训练 -通过 `TrainEngineConfig` 中的 `enable_tree_training` 选项启用树训练: +通过 `TrainEngineConfig` 中的 `tree_training_mode` 选项启用树训练: ```yaml actor: - enable_tree_training: true + tree_training_mode: sparse pad_to_maximum: true # 树训练必须设为 true mb_spec: max_tokens_per_mb: 8192 # 树训练必须设置 @@ -39,14 +39,35 @@ actor: ### 必需配置 -| 参数 | 类型 | 必需 | 描述 | -| --------------------------- | ---- | ---- | --------------------------------- | -| `enable_tree_training` | bool | 是 | 启用基于树的序列打包 | -| `pad_to_maximum` | bool | 是 | 树训练必须设为 `true` | -| `mb_spec.max_tokens_per_mb` | int | 是 | 每棵树的最大 token 数(必须设置) | +| 参数 | 类型 | 必需 | 描述 | +| --------------------------- | ---- | ---- | --------------------------------------- | +| `tree_training_mode` | str | 是 | `sparse` 启用稀疏树训练,`dta` 启用 DTA | +| `pad_to_maximum` | bool | 是 | 稀疏树训练必须设为 `true` | +| `mb_spec.max_tokens_per_mb` | int | 是 | 每棵树的最大 token 数(必须设置) | 注意:启用树训练时,`max_tokens_per_mb` 必须是 `BLOCK_SIZE`(128)的倍数。 +### Dynamic Tree Attention + +Dynamic Tree Attention(DTA)通过 `tree_training_mode: dta` 启用。若需要在 rollout redistribution +阶段按 DTA 逻辑在数据并行 rank 间重新分配轨迹,请将训练引擎的 `packing_algorithm` 设为 `dta`: + +```yaml +actor: + tree_training_mode: dta + packing_algorithm: dta +``` + +这个选项和 `mb_spec.packing_algorithm` 不同。rollout 级别的 `packing_algorithm` 控制数据并行 rank +间的轨迹分配,而 `mb_spec.packing_algorithm` 只控制训练 step 内的 micro-batch 构造。DTA 分配实现位于 +`areal/experimental/dta`,并通过 `areal.infra.dp_allocation` 中的 trajectory-level allocation +wrapper 暴露, 因此核心 rollout 代码和 `ffd`、`kk` 使用同一个 allocation 接口。 + +DTA 分配会先把 grouped rollout trajectories 展平成 sequence-level items,再进行分区。它会记录 `dta/*` +指标,包括分配前后的 tree token 数和压缩率。sequence-level 准备、allocation validation 和指标计算都放在 +`areal.experimental.dta.allocation` 中;共享的 rollout 和 controller 代码只调用 experimental +helper, 避免把算法专属细节放进核心数据路径。 + ## 实现 ### 树构建过程 diff --git a/examples/math/gsm8k_ppo_dta.yaml b/examples/math/gsm8k_ppo_dta.yaml new file mode 100644 index 0000000000..a644317ac1 --- /dev/null +++ b/examples/math/gsm8k_ppo_dta.yaml @@ -0,0 +1,209 @@ +experiment_name: gsm8k-ppo-dta +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "sglang:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "archon:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + archon: + attn_type: varlen + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + tree_training_mode: dta + packing_algorithm: dta + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +critic: + backend: ${actor.backend} + is_critic: true + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: ${actor.dtype} + archon: + attn_type: varlen + eps_clip: 0.5 + ppo_n_minibatches: ${actor.ppo_n_minibatches} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: ${actor.optimizer} + tree_training_mode: dta + packing_algorithm: ${actor.packing_algorithm} + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: ${actor.dtype} + archon: + attn_type: varlen + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + tree_training_mode: dta + packing_algorithm: ${actor.packing_algorithm} + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/tau2/README.md b/examples/tau2/README.md index 2ba8ce596c..26d8086c83 100644 --- a/examples/tau2/README.md +++ b/examples/tau2/README.md @@ -151,12 +151,23 @@ For reward curves of experiments on a larger scale, please refer to the `generated/` directory under `cluster.fileroot`. You can analyze these for debugging and evaluation. -1. **Tree training**: The configs enable `enable_tree_training=true` by default, which +1. **Tree training**: The configs use `tree_training_mode=sparse` by default, which optimizes training by sharing prefix computations across rollouts with the same prompt. This option can largely accelerate training but will possibly increase GPU memory usage if `actor.mb_spec.max_tokens_per_mb` is large. And this setting may cause instability during the training of the MoE model. + To try Dynamic Tree Attention (DTA) with the Archon Tau2 configs, keep the existing + config file and pass overrides instead of using a separate DTA-specific config: + + ```bash + python3 examples/tau2/train.py \ + --config examples/tau2/config_1.7b_airline.yaml \ + actor.tree_training_mode=dta \ + actor.packing_algorithm=dta \ + actor.gradient_checkpointing=false + ``` + ## Customization We have released the training data and a trained model from this pipeline. You can use diff --git a/examples/tau2/config_1.7b_airline.yaml b/examples/tau2/config_1.7b_airline.yaml index 95a8f8d8e6..f29f81d5c5 100644 --- a/examples/tau2/config_1.7b_airline.yaml +++ b/examples/tau2/config_1.7b_airline.yaml @@ -86,7 +86,7 @@ actor: std_level: batch max_new_tokens: ${gconfig.max_new_tokens} pad_to_maximum: true - enable_tree_training: true + tree_training_mode: sparse scheduling_spec: - task_type: worker port_count: 2 diff --git a/examples/tau2/config_235b_moe_airline.yaml b/examples/tau2/config_235b_moe_airline.yaml index b134e92294..5ea68dd7ce 100644 --- a/examples/tau2/config_235b_moe_airline.yaml +++ b/examples/tau2/config_235b_moe_airline.yaml @@ -124,7 +124,7 @@ actor: std_unbiased: true eps: 1.0e-05 max_new_tokens: ${gconfig.max_new_tokens} - enable_tree_training: false + tree_training_mode: disabled scheduling_spec: - task_type: worker port_count: 2 @@ -158,7 +158,7 @@ ref: max_tokens_per_mb: 32768 n_mbs_divisor: 1 optimizer: null - enable_tree_training: false + tree_training_mode: disabled # Tau2 environment configuration econfig: diff --git a/examples/tau2/config_30b_moe_airline.yaml b/examples/tau2/config_30b_moe_airline.yaml index 07b201fc36..a8ee9d5e1c 100644 --- a/examples/tau2/config_30b_moe_airline.yaml +++ b/examples/tau2/config_30b_moe_airline.yaml @@ -124,7 +124,7 @@ actor: std_unbiased: true eps: 1.0e-05 max_new_tokens: ${gconfig.max_new_tokens} - enable_tree_training: false + tree_training_mode: disabled scheduling_spec: - task_type: worker port_count: 2 @@ -158,7 +158,7 @@ ref: max_tokens_per_mb: 32768 n_mbs_divisor: 1 optimizer: null - enable_tree_training: false + tree_training_mode: disabled # Tau2 environment configuration econfig: diff --git a/examples/tau2/config_8b_airline.yaml b/examples/tau2/config_8b_airline.yaml index 77a366984a..6b0232d4c8 100644 --- a/examples/tau2/config_8b_airline.yaml +++ b/examples/tau2/config_8b_airline.yaml @@ -86,7 +86,7 @@ actor: std_level: batch max_new_tokens: ${gconfig.max_new_tokens} pad_to_maximum: true - enable_tree_training: true + tree_training_mode: sparse scheduling_spec: - task_type: worker port_count: 2 diff --git a/examples/tau2/reward_dta.png b/examples/tau2/reward_dta.png new file mode 100644 index 0000000000..528f6a8e87 Binary files /dev/null and b/examples/tau2/reward_dta.png differ diff --git a/tests/experimental/archon/README.md b/tests/experimental/archon/README.md new file mode 100644 index 0000000000..a984edbdbd --- /dev/null +++ b/tests/experimental/archon/README.md @@ -0,0 +1,90 @@ +# Archon 测试说明 + +## `test_dta.py` 简介 + +`test_dta.py` 主要验证 Archon 的 DTA 路径,包括: + +- `forward_batch` 冒烟检查 +- `train_batch` 冒烟检查 +- 与 FSDP 的数值一致性对比 + +## 测试函数说明 + +- `test_engine_is_initialized`:检查引擎能否正常初始化,并确认 DTA 开关状态正确。 +- `test_forward_batch_runs`:只验证 Archon 的 `forward_batch` 在 DTA 开启时可正常跑通。 +- `test_train_batch_runs`:只验证 Archon 的 `train_batch` 在 DTA 开启时可正常跑通并返回结果。 +- `test_forward_batch_matches_fsdp`:对比 Archon 与 FSDP 的 `forward_batch` + 输出,检查形状和数值误差是否在可接受范围内。 +- `test_train_batch_matches_fsdp`:对比 Archon 与 FSDP + 一次训练步后的梯度范数和参数更新量,检查训练信号一致性。很难强对齐,建议观察 grad_norm 和 delta_norm 是否对齐。 + +## 输入数据格式 + +通过 `--dta-data` 传入一个 `.pt` 文件,内容要求: + +- 类型是 `list[torch.Tensor]` +- 每个元素是 1-D token 序列(不做 padding) + +示例: + +```python +[ + torch.tensor([101, 2023, 2003, 1037, 3231], dtype=torch.long), + torch.tensor([101, 2064, 2017, 2393, 1029], dtype=torch.long), +] +``` + +## 参数说明 + +- `--dta-data PATH`:DTA 数据文件路径;不传会跳过 DTA 测试 +- `--dta-limit INT`:最多使用前 N 条序列,`-1` 表示全部使用 +- `--max-tokens-per-mb INT`:单条序列token 上限(用于序列/微批控制) +- `--no-dta`:关闭 DTA +- `--use-hf`:model 使用 HuggingFace 模型路径分支,即去掉 archon 包装 +- `--model-path PATH`:模型路径(与 `--use-hf` 搭配) + +## 用法示例(`python -m pytest`) + +只跑 DTA 测试: + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt +``` + +限制样本数量(快速迭代): + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt \ + --dta-limit 16 +``` + +调整 token 上限: + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt \ + --max-tokens-per-mb 4096 +``` + +使用 HF 模型路径: + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt \ + --use-hf \ + --model-path /path/to/model +``` + +按函数精确运行(`::`): + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py::test_forward_batch_runs \ + --dta-data /path/to/dta_samples.pt +``` + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py::test_train_batch_matches_fsdp \ + --dta-data /path/to/dta_samples.pt +``` diff --git a/tests/experimental/archon/conftest.py b/tests/experimental/archon/conftest.py index 3782ff2b28..18c094938f 100644 --- a/tests/experimental/archon/conftest.py +++ b/tests/experimental/archon/conftest.py @@ -14,6 +14,7 @@ import sys import types from pathlib import Path +from types import SimpleNamespace import pytest import torch @@ -42,6 +43,61 @@ collect_ignore_glob.extend(["test_qwen3_5*.py", "test_hf_parity_qwen3_5*.py"]) +def pytest_addoption(parser): + parser.addoption( + "--dta-data", + type=str, + default=None, + help="Path to .pt file with DTA token sequences (list[Tensor]).", + ) + parser.addoption( + "--no-dta", + action="store_true", + default=False, + help="Disable DTA.", + ) + parser.addoption( + "--max-tokens-per-mb", + type=int, + default=5596, + help="Cap sequence length and set mb_spec.max_tokens_per_mb for archon tests.", + ) + parser.addoption( + "--dta-limit", + type=int, + default=-1, + help="Use at most N sequences from --dta-data; -1 keeps all sequences.", + ) + parser.addoption( + "--use-hf", + action="store_true", + default=False, + help="Use HuggingFace model for Archon DTA tests.", + ) + parser.addoption( + "--model-path", + type=str, + default="/storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/", + help="Path to model.", + ) + + +@pytest.fixture(scope="module") +def archon_test_config(request) -> SimpleNamespace: + """Expose archon runtime config to tests/fixtures.""" + Ans = SimpleNamespace( + max_tokens_per_mb=int(request.config.getoption("--max-tokens-per-mb")), + tree_training_mode=( + "disabled" if request.config.getoption("--no-dta") else "dta" + ), + dta_data=request.config.getoption("--dta-data"), + dta_limit=int(request.config.getoption("--dta-limit")), + use_hf=request.config.getoption("--use-hf"), + model_path=request.config.getoption("--model-path"), + ) + return Ans + + def pytest_collection_modifyitems(config, items): """Skip archon tests based on version requirements.""" if _TORCH_VERSION >= _MIN_TORCH_VERSION: diff --git a/tests/experimental/archon/utils.py b/tests/experimental/archon/utils.py index 9c31f49383..f47d2d3a67 100644 --- a/tests/experimental/archon/utils.py +++ b/tests/experimental/archon/utils.py @@ -2,13 +2,16 @@ import os import subprocess +from collections.abc import Callable from dataclasses import dataclass +from types import SimpleNamespace from typing import Any import pytest import torch import torch.distributed as dist from datasets import load_dataset +from torch.distributed.tensor import DTensor from transformers import AutoModelForCausalLM from areal.api import FinetuneSpec, ParallelStrategy @@ -54,6 +57,15 @@ "create_grpo_batch", "DualEngineFixture", "dual_engines", + "create_archon_engine", + "create_fsdp_engine", + "destroy_test_engine", + "create_dta_batch", + "load_pt_batch", + "dta_dummy_loss_fn", + "dta_loss_weight_fn", + "snapshot_module_parameters", + "strip_wrapper_prefixes", ] @@ -72,6 +84,11 @@ def get_model_path_for_type(model_type: str) -> str | None: DATASET_PATH = get_dataset_path("/storage/openpsi/data/gsm8k", "openai/gsm8k") +def strip_wrapper_prefixes(name: str) -> str: + """Drop wrapper-generated path segments from parameter names.""" + return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "") + + @dataclass class ComparisonMetrics: """Metrics for comparing two tensors.""" @@ -468,3 +485,262 @@ def dual_engines(): fixture.setup() yield fixture fixture.teardown() + + +# ============================================================================= +# DTA Engine Testing Utilities +# ============================================================================= + + +def create_dta_batch( + batch_size: int = 4, + seq_len: int = 64, + shared_prefix_len: int = 20, + vocab_size: int = 151936, + device: torch.device | None = None, +) -> dict[str, Any]: + """Build a synthetic batch whose sequences share a common prefix. + + Returns a dict compatible with ``ArchonEngine.train_batch`` (GRPO-style + fields included so the default loss path works). + + Args: + batch_size: Number of sequences. + seq_len: Length of each sequence. + shared_prefix_len: Length of the common prefix across all sequences. + vocab_size: Vocabulary size for random token generation. + device: Target device (defaults to current platform device). + """ + if device is None: + device = torch.device(current_platform.device_type) + + torch.manual_seed(42) + + prefix = torch.randint(100, vocab_size - 100, (shared_prefix_len,)) + rows = [] + for _ in range(batch_size): + suffix = torch.randint(100, vocab_size - 100, (seq_len - shared_prefix_len,)) + rows.append(torch.cat([prefix, suffix])) + input_ids = torch.stack(rows).to(device) + + attention_mask = torch.ones_like(input_ids) + loss_mask = torch.ones(batch_size, seq_len, device=device) + loss_mask[:, :10] = 0.0 + + logprobs = torch.randn(batch_size, seq_len, device=device) * 0.5 - 2.0 + old_logprobs = logprobs.clone() + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "logprobs": logprobs, + "old_logprobs": old_logprobs, + "advantages": torch.randn(batch_size, seq_len, device=device), + "rewards": torch.randint(0, 2, (batch_size,), device=device).float(), + "values": torch.zeros(batch_size, seq_len, device=device), + "prox_logp": old_logprobs.clone(), + } + + +def load_pt_batch( + test_config: Any, + prompt_ratio: float = 0.3, + device: torch.device | None = None, +) -> dict[str, Any]: + """Load all token sequences from a ``.pt`` file at full length. + + Each ``.pt`` file contains ``list[Tensor]`` where every tensor is a 1-D + ``int64`` sequence with no padding. All sequences are kept at their + original length and right-padded to the longest one. + + GRPO fields (``loss_mask``, ``logprobs``, ``advantages``, …) are filled + with synthetic values so the batch works with ``train_batch``. + + Args: + test_config: Test config carrying ``dta_data``, ``max_tokens_per_mb``, and optional ``dta_limit``. + prompt_ratio: Fraction of each sequence treated as prompt (loss_mask=0). + device: Target device (defaults to current platform device). + """ + if device is None: + device = torch.device(current_platform.device_type) + # print(f"loadbatch on device: {device}") + + pt_path = str(test_config.dta_data) + assert pt_path is not None, "dta_data is required but got None" + seqs: list[torch.Tensor] = torch.load( + pt_path, map_location="cpu", weights_only=True + ) + assert isinstance(seqs, list) and len(seqs) > 0, ( + f"Expected list[Tensor], got {type(seqs)}" + ) + dta_limit = int(getattr(test_config, "dta_limit", -1)) + if dta_limit >= 0: + seqs = seqs[:dta_limit] + assert len(seqs) > 0, "No sequences available after applying dta_limit." + + bs = len(seqs) + max_tokens_per_mb = int(test_config.max_tokens_per_mb) + lengths = [min(s.numel(), max_tokens_per_mb) for s in seqs] + padded_len = max(lengths) + + input_ids = torch.zeros(bs, padded_len, dtype=torch.long) + attention_mask = torch.zeros(bs, padded_len, dtype=torch.long) + loss_mask = torch.zeros(bs, padded_len) + + for i, (s, length) in enumerate(zip(seqs, lengths)): + input_ids[i, :length] = s[:length] + attention_mask[i, :length] = 1 + prompt_len = max(1, int(length * prompt_ratio)) + loss_mask[i, prompt_len:length] = 1.0 + + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + loss_mask = loss_mask.to(device) + + logprobs = torch.randn(bs, padded_len, device=device) * 0.5 - 2.0 + old_logprobs = logprobs.clone() + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "logprobs": logprobs, + "old_logprobs": old_logprobs, + "advantages": torch.randn(bs, padded_len, device=device), + "rewards": torch.randint(0, 2, (bs,), device=device).float(), + "values": torch.zeros(bs, padded_len, device=device), + "prox_logp": old_logprobs.clone(), + } + + +def dta_dummy_loss_fn(logprobs, entropy, input_data, **kwargs): + """Minimal loss for DTA smoke tests.""" + loss_mask = input_data.get("loss_mask") + if loss_mask is None: + return -logprobs.sum() + min_len = min(logprobs.shape[-1], loss_mask.shape[-1]) + logprobs = logprobs[..., :min_len] + loss_mask = loss_mask[..., :min_len] + return -(logprobs * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + +def dta_loss_weight_fn(input_data): + """Loss weight function for DTA smoke tests.""" + lm = input_data.get("loss_mask") + if lm is not None: + return lm.sum() + return torch.tensor(1.0) + + +def snapshot_module_parameters( + module: torch.nn.Module, + to_cpu: bool = False, + param_filter: Callable[[str, torch.nn.Parameter], bool] | None = None, +) -> dict[str, torch.Tensor]: + """Snapshot (clone) selected named parameters for later delta comparisons. + + This is intentionally lightweight to reuse the same comparison pattern + across tests (similar to how `test_grpo.py` compares weight deltas). + """ + snapshots: dict[str, torch.Tensor] = {} + for name, param in module.named_parameters(): + if param_filter is not None and not param_filter(name, param): + continue + t = param.full_tensor() if isinstance(param, DTensor) else param + t = t.detach().clone() + if to_cpu: + t = t.cpu() + snapshots[name] = t + return snapshots + + +def create_archon_engine( + test_config: SimpleNamespace, + model_path: str | None = None, +) -> ArchonLMEngine: + """Create and initialize a single Archon engine for tests.""" + setup_distributed_environment() + model_path = model_path or MODEL_PATHS["qwen2"] + world_size = dist.get_world_size() if dist.is_initialized() else 1 + parallel_strategy = ParallelStrategy(data_parallel_size=world_size) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=4, train_batch_size=4) + max_tokens_per_mb = int(test_config.max_tokens_per_mb) + + config = create_engine_config( + model_path, + "archon_dta" if test_config.tree_training_mode == "dta" else "archon", + ) + config.mb_spec = MicroBatchSpec.new( + config.mb_spec, max_tokens_per_mb=max_tokens_per_mb + ) + config.tree_training_mode = test_config.tree_training_mode + if os.environ.get("AREAL_DISABLE_TORCH_COMPILE", "").lower() in ( + "1", + "true", + "yes", + ): + config.archon.enable_compile = False + config.path = test_config.model_path + + engine = ArchonLMEngine(config) + engine.create_process_group(parallel_strategy=parallel_strategy) + engine.initialize(addr=None, ft_spec=ft_spec) + + if test_config.use_hf: + # Clean up original engine.model to avoid memory leaks (显存残留) + if hasattr(engine, "model") and engine.model is not None: + try: + # Call .cpu() + del + torch.cuda.empty_cache for safety + engine.model.cpu() + except Exception: + pass + del engine.model + import gc + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Use the traditional HuggingFace transformer model for DTA smoke tests + from transformers import AutoModelForCausalLM + + engine.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map=torch.device(current_platform.device_type), + ) + + return engine + + +def create_fsdp_engine( + test_config: SimpleNamespace, + model_path: str | None = None, +) -> FSDPLMEngine: + """Create and initialize a single FSDP engine for tests.""" + setup_distributed_environment() + model_path = model_path or MODEL_PATHS["qwen2"] + world_size = dist.get_world_size() if dist.is_initialized() else 1 + parallel_strategy = ParallelStrategy(data_parallel_size=world_size) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=4, train_batch_size=4) + max_tokens_per_mb = int(test_config.max_tokens_per_mb) + + config = create_engine_config(model_path, "fsdp") + config.mb_spec = MicroBatchSpec.new( + config.mb_spec, max_tokens_per_mb=max_tokens_per_mb + ) + config.path = test_config.model_path + + engine = FSDPLMEngine(config) + engine.create_process_group(parallel_strategy=parallel_strategy) + engine.initialize(addr=None, ft_spec=ft_spec) + return engine + + +def destroy_test_engine(engine: FSDPLMEngine | ArchonLMEngine | None) -> None: + """Destroy a test engine and tear down the process group.""" + if engine is not None: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/experimental/dta/engine_step_case.py b/tests/experimental/dta/engine_step_case.py new file mode 100644 index 0000000000..3ef9e85ddb --- /dev/null +++ b/tests/experimental/dta/engine_step_case.py @@ -0,0 +1,94 @@ +"""Shared test-case config for torchrun-backed DTA engine-step tests.""" + +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass +from pathlib import Path + + +@dataclass +class EngineStepCase: + mode: str + dtype: str + payload_path: str + sequence_data_path: str + n_gpus: int = 2 + nnodes: int = 1 + master_addr: str = "localhost" + master_port: int | None = None + local_model_path: str = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" + hf_id: str = "Qwen/Qwen3-0.6B" + max_tokens_per_mb: int = 5120 + dta_block_size: int = 512 + gradient_checkpointing: bool = True + optimizer_type: str = "adam" + lr: float = 1.0e-4 + cot_system_prompt_length: int = 1000 + cot_thinking_token_length: int = 500 + cot_response_token_length: int = 200 + cot_turns: int = 17 + sequence_seed: int = 1234 + grad_rtol: float = 2.0e-3 + grad_atol: float = 2.0e-5 + grad_norm_rtol: float = 1.0e-3 + grad_norm_atol: float = 2.0e-5 + forward_rtol: float = 2.0e-3 + forward_atol: float = 2.0e-5 + + def __post_init__(self) -> None: + if self.mode not in {"baseline", "dta"}: + raise ValueError(f"mode must be 'baseline' or 'dta', got {self.mode}") + if self.dtype not in {"float32", "bfloat16"}: + raise ValueError(f"dtype must be 'float32' or 'bfloat16', got {self.dtype}") + if self.optimizer_type not in {"adam", "sgd"}: + raise ValueError( + f"optimizer_type must be 'adam' or 'sgd', got {self.optimizer_type}" + ) + if self.n_gpus <= 0: + raise ValueError(f"n_gpus must be positive, got {self.n_gpus}") + if self.nnodes <= 0: + raise ValueError(f"nnodes must be positive, got {self.nnodes}") + if self.cot_turns <= 0: + raise ValueError(f"cot_turns must be positive, got {self.cot_turns}") + + @property + def dataset_size(self) -> int: + return self.cot_turns + + def cot_sequence_metadata(self) -> dict[str, int]: + return { + "cot_system_prompt_length": self.cot_system_prompt_length, + "cot_thinking_token_length": self.cot_thinking_token_length, + "cot_response_token_length": self.cot_response_token_length, + "cot_turns": self.cot_turns, + "sequence_seed": self.sequence_seed, + } + + def resolve_model_path(self) -> str: + if os.path.exists(self.local_model_path): + return self.local_model_path + + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id=self.hf_id, + ignore_patterns=["*.gguf", "*.ggml", "consolidated*"], + ) + + def dump(self, path: Path) -> None: + data = asdict(self) + data["dataset_size"] = self.dataset_size + path.write_text(json.dumps(data, indent=2, sort_keys=True)) + + @classmethod + def load(cls, path: Path) -> EngineStepCase: + data = json.loads(path.read_text()) + data.pop("dataset_size", None) + if "tensor_rtol" in data: + data.setdefault("grad_rtol", data.pop("tensor_rtol")) + data.pop("update_rtol", None) + data.pop("update_atol", None) + data.pop("adam_update_grad_floor", None) + return cls(**data) diff --git a/tests/experimental/dta/sequence_data.py b/tests/experimental/dta/sequence_data.py new file mode 100644 index 0000000000..d18936c0bd --- /dev/null +++ b/tests/experimental/dta/sequence_data.py @@ -0,0 +1,54 @@ +"""Synthetic token sequence builders for DTA tests.""" + +from __future__ import annotations + +import torch + + +def _token_span(length: int, vocab_size: int) -> torch.Tensor: + tokens = torch.randint(low=0, high=vocab_size, size=(length,), dtype=torch.long) + return tokens + + +def build_cot_token_sequences( + vocab_size: int, + system_prompt_length: int, + thinking_token_length: int, + response_token_length: int, + turns: int, +) -> list[torch.Tensor]: + """Generate multi-turn synthetic CoT-like token sequences, where each turn + accumulates all previous responses in the context. + + Logic: + - The first turn consists of: system prompt + thinking tokens + response tokens. + - Each subsequent turn is: the previous "history" (system prompt + all prior responses) + + new thinking tokens + new response tokens. + - The thinking_tokens and response_tokens are sampled independently for each turn. + - The history grows with every turn as the new response is appended. + + Each output sequence starts with the shared system prompt but includes progressively longer + "conversation context" as prior responses are accumulated. This structure is designed + for Dynamic Token Alignment (DTA) engine tests to challenge trie construction with both + common and incremental prefixes. + """ + if turns <= 0: + raise ValueError(f"turns must be positive, got {turns}") + + history = _token_span(system_prompt_length, vocab_size) + + sequences: list[torch.Tensor] = [] + for turn_idx in range(turns): + thinking_tokens = _token_span(thinking_token_length, vocab_size) + response_tokens = _token_span(response_token_length, vocab_size) + sequences.append( + torch.cat( + [ + history, + thinking_tokens, + response_tokens, + ] + ) + ) + history = torch.cat([history, response_tokens]) + return sequences diff --git a/tests/experimental/dta/test_allocation.py b/tests/experimental/dta/test_allocation.py new file mode 100644 index 0000000000..b179600bc5 --- /dev/null +++ b/tests/experimental/dta/test_allocation.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +import torch + +from areal.infra.controller.train_controller import TrainController, _dispatch_tensors +from areal.infra.dp_allocation import AllocationInput, allocate_trajectories +from areal.trainer.ppo.actor import PPOActorController + + +def _make_rw_pair( + pair_idx: int, chosen_len: int = 3, rejected_len: int = 2 +) -> tuple[dict[str, object], dict[str, object]]: + chosen: dict[str, object] = { + "input_ids": torch.full((1, chosen_len), pair_idx * 2 + 1, dtype=torch.long), + "attention_mask": torch.ones((1, chosen_len), dtype=torch.bool), + "meta": {"pair": pair_idx, "role": "chosen"}, + } + rejected: dict[str, object] = { + "input_ids": torch.full((1, rejected_len), pair_idx * 2 + 2, dtype=torch.long), + "attention_mask": torch.ones((1, rejected_len), dtype=torch.bool), + "meta": {"pair": pair_idx, "role": "rejected"}, + } + return chosen, rejected + + +def _build_rw_batch(n_pairs: int) -> list[dict[str, object]]: + items: list[dict[str, object]] = [] + for pair_idx in range(n_pairs): + chosen, rejected = _make_rw_pair(pair_idx) + items.extend([chosen, rejected]) + return items + + +def test_dta_rejects_group_size_greater_than_one() -> None: + items = _build_rw_batch(n_pairs=2) + + with pytest.raises(ValueError, match="DTA requires sequence-level independence"): + _dispatch_tensors( + items, + dp_size=2, + group_size=2, + packing_algorithm="dta", + ) + + with pytest.raises(ValueError, match="DTA requires sequence-level independence"): + allocate_trajectories( + AllocationInput( + items=items, + n_groups=2, + algorithm="dta", + group_size=2, + ) + ) + + +def test_allocate_trajectories_dta_flattens_grouped_rollouts() -> None: + items: list[dict[str, torch.Tensor]] = [ + { + "input_ids": torch.tensor([[1, 2, 3], [1, 2, 4]]), + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.bool), + }, + { + "input_ids": torch.tensor([[5, 6, 0]]), + "attention_mask": torch.tensor([[1, 1, 0]], dtype=torch.bool), + }, + ] + + allocation = allocate_trajectories( + AllocationInput(items=items, n_groups=2, algorithm="dta") + ) + + assert len(allocation.items) == 3 + assert len(allocation.group_indices) == 2 + flat_indices = [idx for group in allocation.group_indices for idx in group] + assert len(flat_indices) == len(allocation.items) + assert len(set(flat_indices)) == len(allocation.items) + assert sorted(flat_indices) == list(range(len(allocation.items))) + assert allocation.metrics is not None + stats = allocation.metrics.to_stats() + assert stats["dta/n_tokens"] == 8.0 + assert stats["dta/n_tree_tokens_before_allocation"] == 6.0 + assert stats["dta/n_tree_tokens_after_allocation"] == 6.0 + assert "dta/n_tree_tokens_after_allocation" in stats + + +def test_ppo_actor_prepare_batch_dta_flattens_grouped_rollouts(monkeypatch) -> None: + controller = object.__new__(PPOActorController) + controller.config = type("Config", (), {"packing_algorithm": "dta"})() + batch: list[dict[str, torch.Tensor]] = [ + { + "input_ids": torch.tensor([[1, 2, 3], [1, 2, 4]]), + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.bool), + "loss_mask": torch.ones((2, 3), dtype=torch.float32), + } + ] + monkeypatch.setattr( + TrainController, + "prepare_batch", + lambda self, *args, **kwargs: batch, + ) + + prepared = controller.prepare_batch(object(), workflow=object(), workflow_kwargs={}) + + assert len(prepared) == 2 + assert all(item["input_ids"].shape[0] == 1 for item in prepared) + torch.testing.assert_close( + prepared[0]["input_ids"], + torch.tensor([[1, 2, 3]]), + rtol=0, + atol=0, + ) + torch.testing.assert_close( + prepared[1]["input_ids"], + torch.tensor([[1, 2, 4]]), + rtol=0, + atol=0, + ) diff --git a/tests/experimental/dta/test_engine_step.py b/tests/experimental/dta/test_engine_step.py new file mode 100644 index 0000000000..d095ce8a65 --- /dev/null +++ b/tests/experimental/dta/test_engine_step.py @@ -0,0 +1,250 @@ +"""Torchrun-backed DTA engine step tests.""" + +from __future__ import annotations + +import subprocess +from dataclasses import replace +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +from transformers import AutoConfig + +from tests.experimental.dta.engine_step_case import EngineStepCase +from tests.experimental.dta.sequence_data import build_cot_token_sequences + +from areal.api.cli_args import MicroBatchSpec +from areal.experimental.dta import wrapper as dta_wrapper +from areal.infra.platforms import current_platform +from areal.utils.network import find_free_ports + +RUNNER = "tests/experimental/dta/torchrun/run_engine_step.py" + +_CUDA_AVAILABLE = torch.cuda.is_available() + + +def _run_engine_step(case: EngineStepCase) -> dict[str, Any]: + if case.master_port is None: + case = replace(case, master_port=find_free_ports(1)[0]) + payload_path = Path(case.payload_path) + case_config = payload_path.with_suffix(".case.json") + case.dump(case_config) + cmd = [ + "torchrun", + f"--nproc_per_node={case.n_gpus}", + f"--nnodes={case.nnodes}", + f"--master-addr={case.master_addr}", + f"--master_port={case.master_port}", + RUNNER, + f"--case-config={case_config}", + ] + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as exc: + raise AssertionError( + f"torchrun failed for mode={case.mode}, dtype={case.dtype}\n" + f"STDOUT:\n{exc.stdout}\nSTDERR:\n{exc.stderr}" + ) from exc + + return torch.load(payload_path, map_location="cpu", weights_only=False) + + +def _save_cot_sequence_data(case: EngineStepCase) -> None: + model_path = case.resolve_model_path() + model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + vocab_size = int(getattr(model_config, "vocab_size")) + + rng_state = torch.random.get_rng_state() + torch.manual_seed(case.sequence_seed) + try: + sequences = build_cot_token_sequences( + vocab_size, + system_prompt_length=case.cot_system_prompt_length, + thinking_token_length=case.cot_thinking_token_length, + response_token_length=case.cot_response_token_length, + turns=case.cot_turns, + ) + finally: + torch.random.set_rng_state(rng_state) + + sequence_data_path = Path(case.sequence_data_path) + sequence_data_path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "vocab_size": vocab_size, + "sequence_metadata": case.cot_sequence_metadata(), + "sequences": sequences, + }, + sequence_data_path, + ) + + +def _assert_finite_payload(payload: dict[str, Any]) -> None: + assert float(payload["stats"]["update_successful"]) == 1.0 + grad_norm = torch.tensor(float(payload["stats"]["grad_norm"])) + torch.testing.assert_close(grad_norm, grad_norm) + for group_name in ("grads",): + group = payload[group_name] + assert group, f"{payload['mode']} produced no {group_name}" + for name, tensor in group.items(): + assert torch.isfinite(tensor).all().item(), ( + f"{payload['mode']} {group_name} {name} non-finite" + ) + + +def _assert_tensor_groups_elementwise_close( + baseline: dict[str, torch.Tensor], + dta: dict[str, torch.Tensor], + *, + group_name: str, + rtol: float, + atol: float, +) -> int: + baseline_names = set(baseline) + dta_names = set(dta) + assert baseline_names == dta_names, ( + f"{group_name} parameter-name mismatch: " + f"baseline_only={sorted(baseline_names - dta_names)[:16]}, " + f"dta_only={sorted(dta_names - baseline_names)[:16]}" + ) + + for name in sorted(baseline_names): + b_tensor = baseline[name] + d_tensor = dta[name] + assert b_tensor.shape == d_tensor.shape, ( + f"{group_name} shape mismatch for {name}: " + f"baseline={tuple(b_tensor.shape)}, dta={tuple(d_tensor.shape)}" + ) + torch.testing.assert_close( + d_tensor, + b_tensor, + rtol=rtol, + atol=atol, + msg=lambda msg, n=name: f"{group_name} tensor mismatch for {n}: {msg}", + ) + + return len(baseline_names) + + +class TinyEngineConfig: + mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=32) + + +def test_dta_prepare_mb_list_creates_one_microbatch_per_sequence() -> None: + """DTA keeps sequence-level independence when building micro-batches.""" + wrapper = object.__new__(dta_wrapper.DTAWrapper) + wrapper.engine = SimpleNamespace(config=TinyEngineConfig()) + + batch = { + "input_ids": torch.tensor( + [ + [11, 12, 13, 0], + [21, 22, 0, 0], + [31, 32, 33, 34], + ], + dtype=torch.long, + ), + "attention_mask": torch.tensor( + [ + [1, 1, 1, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + ], + dtype=torch.bool, + ), + "loss_mask": torch.ones((3, 4), dtype=torch.float32), + } + + mb_list = dta_wrapper.DTAWrapper.prepare_mb_list(wrapper, batch) + + assert len(mb_list.mbs) == batch["input_ids"].shape[0] + assert mb_list.group_lens == [3, 2, 4] + for mb in mb_list.mbs: + assert mb["input_ids"].shape[0] == 1 + assert mb["attention_mask"].shape[0] == 1 + + +@pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA not available") +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_dta_engine_fp32_grad_match_baseline_and_adam_step_succeeds( + tmp_path: Path, +): + """Compare fp32 forward logprobs and gradients between dense Archon and DTA.""" + if current_platform.device_count() < 2: + pytest.skip("This test requires 2 GPUs") + + sequence_data_path = tmp_path / "engine_step_sequences.pt" + baseline_case = EngineStepCase( + mode="baseline", + dtype="float32", + payload_path=str(tmp_path / "baseline.pt"), + sequence_data_path=str(sequence_data_path), + ) + _save_cot_sequence_data(baseline_case) + dta_case = replace( + baseline_case, + mode="dta", + payload_path=str(tmp_path / "dta.pt"), + gradient_checkpointing=False, + ) + baseline = _run_engine_step(baseline_case) + dta = _run_engine_step(dta_case) + + _assert_finite_payload(baseline) + _assert_finite_payload(dta) + + assert baseline["forward_logprobs"].shape == dta["forward_logprobs"].shape + forward_mask = baseline["forward_loss_mask"] + torch.testing.assert_close( + dta["forward_logprobs"][forward_mask], + baseline["forward_logprobs"][forward_mask], + rtol=baseline_case.forward_rtol, + atol=baseline_case.forward_atol, + ) + + torch.testing.assert_close( + torch.tensor(float(dta["stats"]["grad_norm"])), + torch.tensor(float(baseline["stats"]["grad_norm"])), + rtol=baseline_case.grad_norm_rtol, + atol=baseline_case.grad_norm_atol, + ) + torch.testing.assert_close( + torch.tensor(float(dta["stats"]["global_loss"])), + torch.tensor(float(baseline["stats"]["global_loss"])), + rtol=baseline_case.forward_rtol, + atol=baseline_case.forward_atol, + ) + assert ( + _assert_tensor_groups_elementwise_close( + baseline["grads"], + dta["grads"], + group_name="grads", + rtol=baseline_case.grad_rtol, + atol=baseline_case.grad_atol, + ) + > 0 + ) + + +@pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA not available") +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_dta_engine_bf16_train_step_smoke(tmp_path: Path): + """Smoke check that the DTA engine runs one bf16 train step.""" + if current_platform.device_count() < 2: + pytest.skip("This test requires 2 GPUs") + + case = EngineStepCase( + mode="dta", + dtype="bfloat16", + payload_path=str(tmp_path / "dta_bf16.pt"), + sequence_data_path=str(tmp_path / "bf16_engine_step_sequences.pt"), + gradient_checkpointing=False, + ) + _save_cot_sequence_data(case) + payload = _run_engine_step(case) + + _assert_finite_payload(payload) diff --git a/tests/experimental/dta/test_zero1.py b/tests/experimental/dta/test_zero1.py new file mode 100644 index 0000000000..6abf854063 --- /dev/null +++ b/tests/experimental/dta/test_zero1.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace + +import torch +from torch import nn + +from areal.api.cli_args import OptimizerConfig +from areal.experimental.dta import wrapper as dta_wrapper + + +class TinyTiedModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model_args = SimpleNamespace(enable_weight_tying=True) + self.tok_embeddings = nn.Embedding(4, 3) + self.output = nn.Linear(3, 4, bias=False) + + +class TinyDTAWrapper: + def __init__(self, model: nn.Module) -> None: + self.engine = SimpleNamespace(model=model, model_parts=[]) + + +def test_apply_zero1_ties_embedding_and_output_weight() -> None: + model = TinyTiedModel() + + dta_wrapper.DTAWrapper.apply_zero1(TinyDTAWrapper(model)) + + assert model.output.weight is model.tok_embeddings.weight + param_names = dict(model.named_parameters()) + assert "tok_embeddings.weight" in param_names + assert "output.weight" not in param_names + loss = model.output(model.tok_embeddings(torch.tensor([0, 1]))).sum() + loss.backward() + assert model.output.weight.grad is model.tok_embeddings.weight.grad + + +def test_create_zero1_optimizer_receives_tied_parameter_once(monkeypatch) -> None: + model = TinyTiedModel() + dta_wrapper.DTAWrapper.apply_zero1(TinyDTAWrapper(model)) + captured: dict[str, object] = {} + + class FakeZeroRedundancyOptimizer: + def __init__(self, params, **kwargs) -> None: + captured["params"] = list(params) + captured["kwargs"] = kwargs + + monkeypatch.setattr( + dta_wrapper, "ZeroRedundancyOptimizer", FakeZeroRedundancyOptimizer + ) + + wrapper = TinyDTAWrapper(model) + wrapper.engine.optimizer_config = OptimizerConfig(type="adam", lr=1e-3) + wrapper.engine.data_parallel_group = object() + wrapper.engine._get_all_parameters = lambda: list(model.parameters()) + + optimizer = dta_wrapper.DTAWrapper.create_optimizer( + wrapper, + ) + + assert isinstance(optimizer, FakeZeroRedundancyOptimizer) + assert captured["params"].count(model.tok_embeddings.weight) == 1 + assert captured["params"].count(model.output.weight) == 1 diff --git a/tests/experimental/dta/torchrun/run_engine_step.py b/tests/experimental/dta/torchrun/run_engine_step.py new file mode 100644 index 0000000000..c06d073647 --- /dev/null +++ b/tests/experimental/dta/torchrun/run_engine_step.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +"""Run one real Archon engine train step for DTA comparison tests. + +The pytest side launches this script twice with torchrun: once for regular +Archon DP/FSDP, once for DTA Zero1. Rank 0 writes full CPU gradient and +parameter-update tensors for pytest to compare element-wise. +""" + +from __future__ import annotations + +import argparse +import functools +import os +from dataclasses import asdict +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor + +from tests.experimental.dta.engine_step_case import EngineStepCase + +from areal.api import FinetuneSpec, ParallelStrategy +from areal.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig +from areal.experimental.engine.archon_engine import ArchonLMEngine +from areal.infra.platforms import current_platform +from areal.utils.data import concat_batch +from areal.utils.functional import ppo_actor_loss_fn + + +def _build_trajectory(seq: torch.Tensor) -> dict[str, torch.Tensor]: + input_ids = seq.unsqueeze(0).contiguous() + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + loss_mask = torch.ones_like(input_ids, dtype=torch.float32) + loss_mask[:, 0] = 0.0 + loss_mask[:, -1] = 0.0 + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + } + + +def build_local_batch(sequences: list[torch.Tensor], dp_rank: int, dp_world_size: int): + local_sequences = [ + seq for idx, seq in enumerate(sequences) if idx % dp_world_size == dp_rank + ] + if not local_sequences: + raise RuntimeError( + f"No local sequences for dp_rank={dp_rank}, dp_world_size={dp_world_size}" + ) + trajectories = [_build_trajectory(seq) for seq in local_sequences] + batch, _ = concat_batch(trajectories) + return batch + + +def _load_cot_sequences( + case: EngineStepCase, + *, + vocab_size: int, +) -> list[torch.Tensor]: + path = Path(case.sequence_data_path) + if not path.exists(): + raise FileNotFoundError( + "DTA engine-step COT sequence data must be generated by pytest before " + f"torchrun starts: {path}" + ) + sequence_metadata = case.cot_sequence_metadata() + payload = torch.load(path, map_location="cpu", weights_only=False) + if int(payload["vocab_size"]) != vocab_size: + raise ValueError( + "Saved DTA engine-step sequence vocab mismatch: " + f"saved={payload['vocab_size']}, current={vocab_size}" + ) + if payload["sequence_metadata"] != sequence_metadata: + raise ValueError( + "Saved DTA engine-step COT sequence metadata mismatch: " + f"saved={payload['sequence_metadata']}, current={sequence_metadata}" + ) + return payload["sequences"] + + +def dta_step_loss_fn( + logprobs: torch.Tensor, + entropy: torch.Tensor, + input_data: dict[str, Any], + **_: Any, +) -> torch.Tensor: + loss_mask = input_data["loss_mask"].to(device=logprobs.device).bool() + if logprobs.ndim != 1 or entropy.ndim != 1 or loss_mask.ndim != 1: + raise ValueError( + "DTA engine-step loss expects 1D packed tensors: " + f"logprobs={tuple(logprobs.shape)}, " + f"entropy={tuple(entropy.shape)}, " + f"loss_mask={tuple(loss_mask.shape)}." + ) + if logprobs.shape != entropy.shape or logprobs.shape != loss_mask.shape: + raise ValueError( + "DTA engine-step loss shape mismatch: " + f"logprobs={tuple(logprobs.shape)}, " + f"entropy={tuple(entropy.shape)}, " + f"loss_mask={tuple(loss_mask.shape)}." + ) + + # Keep the comparison deterministic and local to this engine-step test: + # old/prox logprobs are fixed baselines, while current logprobs and entropy + # both contribute gradients. + old_logprobs = logprobs.detach() + advantages = torch.ones_like(logprobs) + loss_fn = functools.partial( + ppo_actor_loss_fn, + old_logprobs=old_logprobs, + proximal_logprobs=old_logprobs, + advantages=advantages, + eps_clip=0.2, + eps_clip_higher=None, + c_clip=None, + rejection_sampling=None, + importance_sampling_level="token", + cu_seqlens=None, + ) + ppo_loss, _ = loss_fn(logprobs=logprobs, loss_mask=loss_mask) + entropy_bonus = torch.where(loss_mask, entropy, entropy.new_zeros(())).sum() + entropy_bonus = entropy_bonus / loss_mask.count_nonzero().clamp(min=1) + return ppo_loss - 0.01 * entropy_bonus + + +def dta_step_loss_weight_fn(input_data: dict[str, Any]) -> torch.Tensor: + loss_mask = input_data["loss_mask"] + return loss_mask.float().sum().to(loss_mask.device) + + +def _strip_wrapper_prefixes(name: str) -> str: + return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "") + + +def _full_tensor(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + return tensor.detach().float().cpu().clone() + + +def _canonical_name(engine: ArchonLMEngine, name: str, tensor: torch.Tensor) -> str: + name = _strip_wrapper_prefixes(name) + adapter = engine.state_dict_adapter + if adapter is None: + return name + mapped = adapter.convert_single_to_hf(name, tensor) + if mapped: + return mapped[0][0] + return name + + +def _snapshot_grad_tensors(engine: ArchonLMEngine) -> dict[str, torch.Tensor]: + tensors: dict[str, torch.Tensor] = {} + for raw_name, param in engine._get_model_name_parameters(): + if not param.requires_grad or param.grad is None: + continue + grad = _full_tensor(param.grad) + name = _canonical_name(engine, raw_name, grad) + tensors[name] = grad + return tensors + + +def create_config( + *, + model_path: str, + mode: str, + dtype: str, + dp_size: int, + max_tokens_per_mb: int, + dta_block_size: int, + gradient_checkpointing: bool, + optimizer_type: str, + lr: float, +) -> TrainEngineConfig: + config = TrainEngineConfig( + path=model_path, + experiment_name=f"dta_engine_step_{mode}", + trial_name="torchrun", + backend=f"archon:d{dp_size}", + dtype=dtype, + mb_spec=MicroBatchSpec( + n_mbs=1, + granularity=1, + max_tokens_per_mb=max_tokens_per_mb, + ), + pad_to_maximum=False, + gradient_checkpointing=gradient_checkpointing, + init_from_scratch=False, + tree_training_mode="dta" if mode == "dta" else "disabled", + dta_block_size=dta_block_size, + optimizer=OptimizerConfig( + type=optimizer_type, + lr=lr, + weight_decay=0.0, + lr_scheduler_type="constant", + warmup_steps_proportion=0.0, + gradient_clipping=1.0e9, + ), + ) + config.archon.attn_type = "sdpa" + return config + + +def run_engine_step(case: EngineStepCase) -> dict[str, Any]: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + + local_rank = int(os.environ["LOCAL_RANK"]) + current_platform.set_device(local_rank) + current_platform.set_numa_affinity(local_rank) + + world_size = int(os.environ["WORLD_SIZE"]) + model_path = case.resolve_model_path() + config = create_config( + model_path=model_path, + mode=case.mode, + dtype=case.dtype, + dp_size=world_size, + max_tokens_per_mb=case.max_tokens_per_mb, + dta_block_size=case.dta_block_size, + gradient_checkpointing=case.gradient_checkpointing, + optimizer_type=case.optimizer_type, + lr=case.lr, + ) + + engine = ArchonLMEngine(config) + try: + engine.create_process_group( + parallel_strategy=ParallelStrategy(data_parallel_size=world_size) + ) + ft_spec = FinetuneSpec( + total_train_epochs=1, + dataset_size=case.dataset_size, + train_batch_size=case.dataset_size, + ) + engine.initialize(addr=None, ft_spec=ft_spec) + engine.train() + + vocab_size = int(getattr(engine.model_config, "vocab_size")) + sequences = _load_cot_sequences( + case, + vocab_size=vocab_size, + ) + batch = build_local_batch( + sequences, + dp_rank=engine.data_parallel_rank, + dp_world_size=engine.data_parallel_world_size, + ) + + forward_logprobs = engine.forward_batch(batch).detach().float().cpu() + forward_loss_mask = batch["loss_mask"].detach().bool().cpu() + + stats = engine.train_batch( + batch, + dta_step_loss_fn, + dta_step_loss_weight_fn, + return_loss=True, + ) + global_loss = torch.tensor( + float(stats["loss"]), device=engine.device, dtype=torch.float32 + ) + dist.all_reduce(global_loss, group=engine.data_parallel_group) + stats["global_loss"] = float(global_loss.item()) + current_platform.synchronize() + grads = _snapshot_grad_tensors(engine) + + return { + "case": asdict(case), + "mode": case.mode, + "dtype": case.dtype, + "world_size": world_size, + "model_path": model_path, + "local_model_path": case.local_model_path, + "hf_id": case.hf_id, + "sequence_data_path": case.sequence_data_path, + "stats": stats, + "grads": grads, + "forward_logprobs": forward_logprobs, + "forward_loss_mask": forward_loss_mask, + "num_sequences": len(sequences), + "sequence_lengths": [int(seq.numel()) for seq in sequences], + } + finally: + engine.destroy() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--case-config", required=True) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + case = EngineStepCase.load(Path(args.case_config)) + if int(os.environ.get("RANK", "0")) == 0: + torch.save(run_engine_step(case), case.payload_path) + else: + run_engine_step(case) + + +if __name__ == "__main__": + main() diff --git a/tests/test_eval_dispatch.py b/tests/test_eval_dispatch.py index 30c40706fa..34fdf64c52 100644 --- a/tests/test_eval_dispatch.py +++ b/tests/test_eval_dispatch.py @@ -13,6 +13,7 @@ _dispatch_tensors, _pad_eval_batch, ) +from areal.infra.dp_allocation import AllocationInput, allocate_trajectories from areal.trainer.rw.rw_engine import ( RWController, RWEngine, @@ -274,12 +275,39 @@ def test_dispatch_group_size_not_divisible_raises(self): with pytest.raises(ValueError, match="divisible by group_size"): _dispatch_tensors(items, dp_size=2, group_size=2) + def test_allocate_trajectories_ffd_equal_preserves_group_size(self): + items = _build_rw_batch(n_pairs=4) + allocation = allocate_trajectories( + AllocationInput( + items=items, + n_groups=2, + algorithm="ffd_equal", + group_size=2, + ) + ) + + assert allocation.items is items + assert len(allocation.group_indices) == 2 + for indices in allocation.group_indices: + assert len(indices) % 2 == 0 + for idx in range(0, len(indices), 2): + assert indices[idx + 1] == indices[idx] + 1 + def test_dispatch_group_size_1_unchanged(self): items = [_make_item(i, seqlen=i + 1) for i in range(8)] _, indices_default = _dispatch_tensors(items, dp_size=4) _, indices_gs1 = _dispatch_tensors(items, dp_size=4, group_size=1) assert indices_default == indices_gs1 + def test_dispatch_default_is_explicit_ffd_equal(self): + items = [_make_item(i, seqlen=i + 1) for i in range(8)] + _, indices_default = _dispatch_tensors(items, dp_size=4) + _, indices_equal = _dispatch_tensors( + items, dp_size=4, packing_algorithm="ffd_equal" + ) + assert indices_default == indices_equal + assert [len(indices) for indices in indices_equal] == [2, 2, 2, 2] + @pytest.mark.parametrize( "dp_size, group_size, n_items", [ diff --git a/tests/test_examples.py b/tests/test_examples.py index 2acb62f18a..68684280cd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -868,7 +868,7 @@ def test_tau2(tmp_path_factory): "econfig.max_steps=3", # Limit steps for faster testing f"econfig.user_llm_base_url={user_llm_base_url}", "econfig.user_llm=openai/self-hosted-qwen3", - "actor.enable_tree_training=false", # Disable tree training for simpler test + "actor.tree_training_mode=disabled", # Disable tree training for simpler test "scheduler.type=local", "stats_logger.wandb.mode=disabled", timeout=600, diff --git a/tests/test_tree_training.py b/tests/test_tree_training.py index e88c358033..0485586d1a 100644 --- a/tests/test_tree_training.py +++ b/tests/test_tree_training.py @@ -166,7 +166,7 @@ def _check_nan_params(params: dict[str, torch.Tensor], label: str) -> list[str]: def _create_engine( engine_type: str, - enable_tree_training: bool = False, + tree_training_mode: str = "disabled", port: str = "7777", experiment_name: str = "test", max_tokens_per_mb: int = 256, @@ -194,7 +194,7 @@ def _create_engine( path=MODEL_PATH, mb_spec=MicroBatchSpec(**mb_spec_kwargs), optimizer=OptimizerConfig(), - enable_tree_training=enable_tree_training, + tree_training_mode=tree_training_mode, pad_to_maximum=True, ) @@ -245,7 +245,7 @@ def test_tree_training_forward(engine_type, tree_attn_backend): inputs = mock_tree_input() tree_engine = _create_engine( engine_type, - enable_tree_training=True, + tree_training_mode="sparse", port="7778", ) tree_engine.eval() @@ -347,7 +347,7 @@ def loss_weight_fn(input_data): inputs = mock_tree_input() tree_engine = _create_engine( engine_type, - enable_tree_training=True, + tree_training_mode="sparse", port="7778", experiment_name="test_tree", ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0c82b47517..3fe6cae288 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -73,3 +73,28 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb, n_mbs_div assert torch.allclose(x, packed_data[key]) y = pad_and_stack_tensors_along_first_dim(xs) assert torch.allclose(mock_padded_data[key], y) + + +def test_micro_batch_split_n_mbs_equal_batch_size_uses_allocator( + mock_padded_data, monkeypatch +): + calls = 0 + original_allocate = __import__( + "areal.utils.data", fromlist=["allocate_balanced_mbs_synced"] + ).allocate_balanced_mbs_synced + + def track_allocate(*args, **kwargs): + nonlocal calls + calls += 1 + return original_allocate(*args, **kwargs) + + monkeypatch.setattr("areal.utils.data.allocate_balanced_mbs_synced", track_allocate) + + bs = mock_padded_data["attention_mask"].shape[0] + mb_spec = MicroBatchSpec(n_mbs=bs, max_tokens_per_mb=100) + + split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec) + + assert calls == 1 + assert len(split_result.mbs) == bs + assert all(mb["input_ids"].shape[0] == 1 for mb in split_result.mbs)