Skip to content

refactor(experimental): consolidate DTA Archon integration#1391

Open
ezoicoder wants to merge 1 commit into
areal-project:mainfrom
ezoicoder:feat/zero1-dta-archon-dp
Open

refactor(experimental): consolidate DTA Archon integration#1391
ezoicoder wants to merge 1 commit into
areal-project:mainfrom
ezoicoder:feat/zero1-dta-archon-dp

Conversation

@ezoicoder
Copy link
Copy Markdown
Collaborator

@ezoicoder ezoicoder commented Jun 5, 2026

Description

Consolidates Dynamic Token Alignment (DTA) support into the experimental Archon
path. This adds DTA allocation, trie construction, runner/wrapper integration,
rollout preparation, examples, documentation, and focused regression coverage.

This update also fixes DTA microbatch construction so one-sequence-per-microbatch
batches stay per-rank independent instead of being forced through cross-rank
microbatch-count synchronization. The torchrun DTA case uses 17 turns so two
ranks receive uneven local sequence counts.

The latest update adds an end-to-end global loss comparison for the DTA
engine-step test. The torchrun runner now all-reduces each rank's returned local
loss contribution into stats["global_loss"], and the pytest comparison checks
baseline Archon DP against DTA.

Related Issue

N/A

Type of Change

  • Bug fix
  • New feature
  • Breaking change
  • Documentation update
  • Refactoring
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated if applicable
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

Validation run:

source .venv/bin/activate && pre-commit run --all-files
source /data/jiarui/dta/proxy/network_setup.sh && uv run pytest tests/experimental/dta/test_engine_step.py -k fp32_grad_match -q

Targeted torchrun result:

1 passed, 2 deselected in 128.49s (0:02:08)

Skipped full repository test and docs build suites in this pass due scope and
runtime; the targeted DTA engine-step regression was run on the multi-GPU path.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Dynamic Tree Attention (DTA) as a new tree training mode, replacing the boolean enable_tree_training flag with a multi-option tree_training_mode string. It adds the areal/experimental/dta module, integrates DTA into the Archon engine via a DTAWrapper, and updates attention mechanisms and Qwen2/Qwen3 models to support KV-cache attention. The review feedback highlights critical runtime AttributeError risks across the Qwen2/Qwen3 models and the DTA runner, where the code incorrectly assumes DynamicCache has a .layers attribute instead of using its standard key_cache and value_cache structures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +399 to +401
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.

Suggested change
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
past_len = past_key_values.get_seq_length()

Comment on lines +431 to +433
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
if past_key_values is not None and layer_idx < len(past_key_values):
layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx])

Comment on lines +513 to +515
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.

Suggested change
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
past_len = past_key_values.get_seq_length()

Comment on lines +549 to +551
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
if past_key_values is not None and layer_idx < len(past_key_values):
layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx])

Comment on lines +263 to +270
new_cache = out.past_key_values
for layer_idx, layer in enumerate(new_cache.layers):
self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[
:, :, start:end, :
]
self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[
:, :, start:end, :
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes out.past_key_values has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
new_cache = out.past_key_values
for layer_idx, layer in enumerate(new_cache.layers):
self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[
:, :, start:end, :
]
self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[
:, :, start:end, :
]
new_cache = out.past_key_values
for layer_idx in range(len(new_cache)):
self.kv_cache[0][layer_idx][:, :, start:end, :] = new_cache.key_cache[layer_idx][
:, :, start:end, :
]
self.kv_cache[1][layer_idx][:, :, start:end, :] = new_cache.value_cache[layer_idx][
:, :, start:end, :
]

Comment on lines +550 to +559
for layer_idx, layer in enumerate(block_cache.layers):
k = layer.keys[:, :, start:end, :]
v = layer.values[:, :, start:end, :]
roots.extend([k, v])
grads.extend(
[
self.grad_kv[0][layer_idx][:, :, start:end, :],
self.grad_kv[1][layer_idx][:, :, start:end, :],
]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes block_cache has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
for layer_idx, layer in enumerate(block_cache.layers):
k = layer.keys[:, :, start:end, :]
v = layer.values[:, :, start:end, :]
roots.extend([k, v])
grads.extend(
[
self.grad_kv[0][layer_idx][:, :, start:end, :],
self.grad_kv[1][layer_idx][:, :, start:end, :],
]
)
for layer_idx in range(len(block_cache)):
k = block_cache.key_cache[layer_idx][:, :, start:end, :]
v = block_cache.value_cache[layer_idx][:, :, start:end, :]
roots.extend([k, v])
grads.extend(
[
self.grad_kv[0][layer_idx][:, :, start:end, :],
self.grad_kv[1][layer_idx][:, :, start:end, :],
]
)

@ezoicoder ezoicoder force-pushed the feat/zero1-dta-archon-dp branch 2 times, most recently from e0c0e7e to 5fdea9c Compare June 5, 2026 17:00
Integrate the Dynamic Tree Attention training path with Archon DP while keeping unsupported engines explicit.

Key changes:

- Add DTA trie, runner, allocation, rollout, and Zero1 wrapper utilities

- Route Archon tree_training_mode='dta' through DTA-specific batch handling

- Add DTA examples, docs, and distributed engine-step coverage
@ezoicoder ezoicoder force-pushed the feat/zero1-dta-archon-dp branch from 5fdea9c to 72acd4b Compare June 6, 2026 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant