feat(experimental): enable DTA training for Archon DP#1391
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Dynamic Tree Attention (DTA) as a new tree training mode, replacing the boolean enable_tree_training flag with a multi-option tree_training_mode string. It adds the areal/experimental/dta module, integrates DTA into the Archon engine via a DTAWrapper, and updates attention mechanisms and Qwen2/Qwen3 models to support KV-cache attention. The review feedback highlights critical runtime AttributeError risks across the Qwen2/Qwen3 models and the DTA runner, where the code incorrectly assumes DynamicCache has a .layers attribute instead of using its standard key_cache and value_cache structures.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| past_len = 0 | ||
| if len(past_key_values.layers) > 0: | ||
| past_len = int(past_key_values.layers[0].keys.shape[2]) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.
| past_len = 0 | |
| if len(past_key_values.layers) > 0: | |
| past_len = int(past_key_values.layers[0].keys.shape[2]) | |
| past_len = past_key_values.get_seq_length() |
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | ||
| layer_entry = past_key_values.layers[layer_idx] | ||
| layer_past = (layer_entry.keys, layer_entry.values) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | |
| layer_entry = past_key_values.layers[layer_idx] | |
| layer_past = (layer_entry.keys, layer_entry.values) | |
| if past_key_values is not None and layer_idx < len(past_key_values): | |
| layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx]) |
| past_len = 0 | ||
| if len(past_key_values.layers) > 0: | ||
| past_len = int(past_key_values.layers[0].keys.shape[2]) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.
| past_len = 0 | |
| if len(past_key_values.layers) > 0: | |
| past_len = int(past_key_values.layers[0].keys.shape[2]) | |
| past_len = past_key_values.get_seq_length() |
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | ||
| layer_entry = past_key_values.layers[layer_idx] | ||
| layer_past = (layer_entry.keys, layer_entry.values) |
There was a problem hiding this comment.
The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| if past_key_values is not None and layer_idx < len(past_key_values.layers): | |
| layer_entry = past_key_values.layers[layer_idx] | |
| layer_past = (layer_entry.keys, layer_entry.values) | |
| if past_key_values is not None and layer_idx < len(past_key_values): | |
| layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx]) |
| new_cache = out.past_key_values | ||
| for layer_idx, layer in enumerate(new_cache.layers): | ||
| self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ | ||
| :, :, start:end, : | ||
| ] | ||
| self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ | ||
| :, :, start:end, : | ||
| ] |
There was a problem hiding this comment.
The code assumes out.past_key_values has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| new_cache = out.past_key_values | |
| for layer_idx, layer in enumerate(new_cache.layers): | |
| self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ | |
| :, :, start:end, : | |
| ] | |
| self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ | |
| :, :, start:end, : | |
| ] | |
| new_cache = out.past_key_values | |
| for layer_idx in range(len(new_cache)): | |
| self.kv_cache[0][layer_idx][:, :, start:end, :] = new_cache.key_cache[layer_idx][ | |
| :, :, start:end, : | |
| ] | |
| self.kv_cache[1][layer_idx][:, :, start:end, :] = new_cache.value_cache[layer_idx][ | |
| :, :, start:end, : | |
| ] |
| for layer_idx, layer in enumerate(block_cache.layers): | ||
| k = layer.keys[:, :, start:end, :] | ||
| v = layer.values[:, :, start:end, :] | ||
| roots.extend([k, v]) | ||
| grads.extend( | ||
| [ | ||
| self.grad_kv[0][layer_idx][:, :, start:end, :], | ||
| self.grad_kv[1][layer_idx][:, :, start:end, :], | ||
| ] | ||
| ) |
There was a problem hiding this comment.
The code assumes block_cache has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.
| for layer_idx, layer in enumerate(block_cache.layers): | |
| k = layer.keys[:, :, start:end, :] | |
| v = layer.values[:, :, start:end, :] | |
| roots.extend([k, v]) | |
| grads.extend( | |
| [ | |
| self.grad_kv[0][layer_idx][:, :, start:end, :], | |
| self.grad_kv[1][layer_idx][:, :, start:end, :], | |
| ] | |
| ) | |
| for layer_idx in range(len(block_cache)): | |
| k = block_cache.key_cache[layer_idx][:, :, start:end, :] | |
| v = block_cache.value_cache[layer_idx][:, :, start:end, :] | |
| roots.extend([k, v]) | |
| grads.extend( | |
| [ | |
| self.grad_kv[0][layer_idx][:, :, start:end, :], | |
| self.grad_kv[1][layer_idx][:, :, start:end, :], | |
| ] | |
| ) |
72acd4b to
98e179e
Compare
Add a Dynamic Tree Attention path for Archon data-parallel training so shared-prefix rollout trajectories can be trained with block-wise backward while keeping unsupported engines explicit. Key changes: - Add tree_training_mode=dta and rollout-level DTA allocation config - Route Archon train and forward batches through the DTA wrapper - Add trie, allocation, runner, Zero1, and KV-cache model support - Report DTA allocation metrics during distributed rollout - Add examples, docs, and torchrun regression coverage against baseline DP
98e179e to
35508ea
Compare
Description
Add a Dynamic Tree Attention path for Archon data-parallel training so shared-prefix rollout trajectories can be trained with block-wise backward while unsupported engines remain explicit.
Key changes:
tree_training_mode=dtaand rollout-level DTA allocation configRelated Issue
N/A
Type of Change
Checklist
pre-commit run --all-files)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
Validation run:
DTA test result:
DTA test commands: