Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
133 commits
Select commit Hold shift + click to select a range
7488385
create transformer_backend folder with debug run
3outeille Aug 28, 2025
39a3b34
add hf config
3outeille Aug 28, 2025
ea7c594
can now register train spec for hf model
3outeille Aug 28, 2025
5f0adf5
can now switch with different flavors using HF Llama modeling
3outeille Aug 28, 2025
7c3795c
it is now working up to apply_ac
3outeille Aug 28, 2025
3fb2bf8
now working up to init_weights
3outeille Sep 6, 2025
25daeca
fix mapping when convert_to_hf_config + add breaking test to ensure p…
3outeille Sep 6, 2025
3e67f2c
define own apply_ac for transformer backend instead of reusing llama3
3outeille Sep 8, 2025
8c5c0ae
HF model without any parallelism now train (but grad_norm is high)
3outeille Sep 9, 2025
4ae9560
a bit cleaner way to get passed args
3outeille Sep 10, 2025
9be95f9
now same number of params + same attention backend but noticed highe…
3outeille Sep 10, 2025
bf91447
fix seed and deterministic
3outeille Sep 11, 2025
4c2fc0b
fix torch deterministic for HF modeling that was producing Nans
3outeille Sep 11, 2025
9bffa38
HF model now numerically stable compared to TT (given a fixed attent…
3outeille Sep 15, 2025
40d84cc
handling the is_hf_initialized flag in patch
3outeille Sep 15, 2025
bd3f332
refactor HF transformer model args
3outeille Sep 16, 2025
249be92
wrapper model class to avoid transformers to be explicit in train.py
3outeille Sep 16, 2025
e2d4ada
add better testing script with reference log for later sanity check
3outeille Sep 16, 2025
4b498a9
no need to fill passed args
3outeille Sep 16, 2025
eb403d5
can now handle multiple HF modeling
3outeille Sep 16, 2025
a0d67a7
handle pref logits accessing inside HF model wrapper
3outeille Sep 16, 2025
ea05552
isolate HF patch for llama in another file
3outeille Sep 16, 2025
adefa2c
find hacky way to pass HF model.name through CLI
3outeille Sep 16, 2025
a235863
more granularity of logging when doing parameter breakdown
3outeille Sep 17, 2025
fc43dc8
add __repr__ to HFTransformerModelArgs for better debugging logs
3outeille Sep 17, 2025
23ae378
HF deepseek v3 is now training
3outeille Sep 17, 2025
2573be4
refactor to make it clear which args comes from which parts
3outeille Sep 17, 2025
46ae0a3
fix refactor and simplify things
3outeille Sep 18, 2025
b33d575
hacky way to switch flavors for now
3outeille Sep 18, 2025
007f005
hf deepseek train while matching same param counts as tt deepseek
3outeille Sep 18, 2025
dd2b04c
wtf deepseek q_proj weight init differ ???
3outeille Sep 22, 2025
9abdae3
deepseek now has same weight init in HF & TT. Reasons was rng_state w…
3outeille Sep 22, 2025
f9e90bc
adapt mfu to handle moe
3outeille Sep 22, 2025
ba5d6d1
beginning parallelism by setting tests
3outeille Sep 23, 2025
338a250
better compare_distributed_run test
3outeille Sep 24, 2025
36a5673
add seed + deterministic to compare_distributed_run
3outeille Sep 24, 2025
ed892a2
better extract and compare metrics
3outeille Sep 24, 2025
1c1452f
refactor to introduce slurm
3outeille Sep 24, 2025
5e4911f
error handling with subprocess
3outeille Sep 24, 2025
4891a47
FSDP for llama in 1D works
3outeille Sep 24, 2025
9e260a0
better formatting of compare_distributed_run + display min/max grad_n…
3outeille Sep 24, 2025
a604bee
make FSDP work in a cleaner way (mapping instead of renaming)
3outeille Sep 25, 2025
0b38d0d
Improve logging in compare_distributed_run
3outeille Sep 26, 2025
025a86f
PP for llama in 1D works
3outeille Sep 26, 2025
590737f
simplify PP logic by flattening the named_children hierarchy. This wi…
3outeille Sep 28, 2025
1a9af68
TP now works in 1D
3outeille Sep 28, 2025
e6b9ff5
add test filtering in compare distributed run
3outeille Sep 28, 2025
a4cb8c3
dont generate EP config if model is not a MoE
3outeille Sep 28, 2025
12c0c47
disable torch.utils.deterministic.fill_uninitialized_memory for Moe d…
3outeille Sep 28, 2025
13edc66
CP is now supported
3outeille Sep 29, 2025
52250fb
some cleaning
3outeille Sep 29, 2025
c523ede
cleaner way to make create_causal_mask = None
3outeille Sep 29, 2025
f9f5c66
uniformize llama and moe args passing
3outeille Sep 29, 2025
5a875b6
cleaning code
3outeille Sep 29, 2025
e4d963c
fix same global_batch_size across training + fix float32 for test (ev…
3outeille Sep 30, 2025
957cc4a
refactor compare_distributed_run to make it slurm compatible
3outeille Sep 30, 2025
a317c53
breaking test
3outeille Oct 1, 2025
d2f80a2
refactor test
3outeille Oct 4, 2025
6454e40
fix running job to slurm
3outeille Oct 5, 2025
b99a4d2
finally have a better testing xp with slurm
3outeille Oct 5, 2025
218f400
now everything works (1D/2D/3D/4D). need to fix correctness with PP
3outeille Oct 9, 2025
bb080ad
fix and uniformize weight init of llama-like model + various fix
3outeille Oct 14, 2025
3168f9e
support moe init and fix with moe layer (TP for lora layers)
3outeille Oct 15, 2025
a9a65b7
begin TP + EP with MoE model
3outeille Oct 15, 2025
b4a1b88
cleaning
3outeille Oct 15, 2025
5f1075b
add small example scripts
3outeille Oct 15, 2025
81f1855
Merge branch 'main' into 3outeille/transformers_backend
3outeille Oct 17, 2025
c35ccfc
fix all the merge issues
3outeille Oct 20, 2025
d5ce2e9
get rid of hf patches files and put it in hf_transformer_args
3outeille Oct 20, 2025
8d46723
remove eos_id + refactor Optional[int] to comply with torchtitan conv…
3outeille Oct 20, 2025
087f841
move torch.utils.deterministic.fill_uninitialized_memory = False to u…
3outeille Oct 20, 2025
937c68d
remove test_template for base_config instead
3outeille Oct 20, 2025
4f2b357
separate args &model + dont extract loss metrics -1.0 when double PP …
3outeille Oct 20, 2025
154289d
use recent refactoring for flops computation for dense and moe model
3outeille Oct 21, 2025
1b2cfd7
fix tie_embedding
3outeille Oct 21, 2025
0f2c51e
remove pad_token_id=None
3outeille Oct 21, 2025
4c8b4b7
make it clearer about args
3outeille Oct 21, 2025
c61271e
remove local testing scripts
3outeille Oct 21, 2025
a848545
fix linting
3outeille Oct 21, 2025
9488a16
create CI jobs to guard
3outeille Oct 21, 2025
5be438b
Merge branch 'main' into 3outeille/transformers_backend
3outeille Oct 29, 2025
e8a1757
update the way we register_train_spec
3outeille Oct 29, 2025
141c377
relative path for qwen3_fsdp2_tp2_pp2.toml
3outeille Oct 29, 2025
a67e971
dont use os.environ, use debugmodel or debugmodel_moe
3outeille Oct 29, 2025
060befe
refactor args to make it clearer
3outeille Oct 30, 2025
3425b12
add README
3outeille Oct 31, 2025
7b0ee5d
add requirements.txt
3outeille Oct 31, 2025
3e2222c
fix linting
3outeille Oct 31, 2025
70c348d
fix bug related to training with different seq_len than max_seq_len
3outeille Nov 1, 2025
af0a1cb
decouple MoE logic to another PR
3outeille Nov 1, 2025
980a92b
update experiments README
3outeille Nov 3, 2025
06b6f24
update README to confirm torch.compile support
3outeille Nov 3, 2025
a70c4c4
custom job_config
3outeille Nov 4, 2025
42884cd
remove unecessary change in train_spec
3outeille Nov 4, 2025
4fa0874
rename file to comply with torchtitan style
3outeille Nov 4, 2025
8ffa7f4
reuse ac form torchtitan
3outeille Nov 4, 2025
ff21c2b
reuse ddp from torchtitan
3outeille Nov 4, 2025
0a43a8a
reuse compile from torchtitan llama3
3outeille Nov 4, 2025
8026bc7
reuse compile from torchtitan
3outeille Nov 4, 2025
cd4042f
update parallelize with main
3outeille Nov 4, 2025
0700bdb
remove moe ep tp for now
3outeille Nov 4, 2025
767f71d
fix SequenceParallel for q and k norm
3outeille Nov 5, 2025
7f71f88
job_config.training will always have seq_len
3outeille Nov 5, 2025
7e63a82
fix loading weights in PP by using Module Dict
3outeille Nov 7, 2025
04fb8eb
clean reference qwen config
3outeille Nov 13, 2025
0d80f62
error out if no layer_idx
3outeille Nov 13, 2025
09f0c94
reuse pipeline from torchtitan
3outeille Nov 13, 2025
78d26ff
use c4 test for integration_tests
3outeille Nov 13, 2025
5243795
fix ci
3outeille Nov 13, 2025
84af768
Merge branch 'main' of github.com:huggingface/torchtitan into 3outeil…
3outeille Nov 13, 2025
fe691b8
fix linting
3outeille Nov 13, 2025
5d5ce2b
fix head dims in flops counting
3outeille Nov 14, 2025
6ace9f4
propose an alternative to passing name
3outeille Nov 14, 2025
97cd6fe
fix linting
3outeille Nov 14, 2025
5f1695f
bump transformers version from 4.55.4 to 4.57.1
3outeille Nov 14, 2025
2d2b612
change qwen3 config name
3outeille Nov 18, 2025
a2ea2ef
reuse fsdp from llama3. Moe will be handle in another PR
3outeille Nov 18, 2025
47fb2ea
clean logging
3outeille Nov 18, 2025
20308d3
move TitanDenseModelArgs to args
3outeille Nov 18, 2025
019f2cc
clean
3outeille Nov 18, 2025
fc93b4f
fix integration tests
3outeille Nov 18, 2025
f9e8e11
rename integration test file
3outeille Nov 18, 2025
83b0437
update README
3outeille Nov 18, 2025
fb978dd
revert accidental changes linting
3outeille Nov 18, 2025
71ff098
typo in naming
3outeille Nov 18, 2025
663a415
refactor
3outeille Nov 18, 2025
3dbe6fa
revert the way we select HF modeling in config
3outeille Nov 18, 2025
9be95da
Revert "reuse pipeline from torchtitan"
3outeille Nov 19, 2025
c0c273c
pass deterministic.fill_uninitialized_memory to HF model
3outeille Nov 19, 2025
4c50a00
fix linting
3outeille Nov 19, 2025
5b8d38c
fix integration tests
3outeille Nov 19, 2025
57bb8dd
fix minor stuff
3outeille Nov 20, 2025
1bbb3a8
Merge branch 'main' into 3outeille/transformers_backend
3outeille Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.qwen3
import torchtitan.experiments.simple_fsdp # noqa: F401
import torchtitan.experiments.transformers_backend # noqa: F401
63 changes: 63 additions & 0 deletions torchtitan/experiments/transformers_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.components.tokenizer import build_hf_tokenizer

from torchtitan.models.llama3 import pipeline_llama
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize_hf_transformers import parallelize_hf_transformers
from .model.hf_transformers_args import HFTransformerModelArgs

from transformers.models.llama.modeling_llama import LlamaForCausalLM


__all__ = [
"HFTransformerModelArgs",
"LlamaForCausalLM", #TODO(3outeille): later use AutoModelForCausalLM
"hf_transformers_configs",
]


flavors = {
"debug": HFTransformerModelArgs(
dim=1,
n_layers=6,
n_heads=16,
rope_theta=500000,
),
"medium": HFTransformerModelArgs(
dim=40,
n_layers=24,
n_heads=32,
rope_theta=500000,
),
"full": HFTransformerModelArgs(),
}

hf_train_spec = TrainSpec(
name="hf_auto_model",
model_cls=LlamaForCausalLM,
model_args=flavors,
parallelize_fn=parallelize_hf_transformers,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)

# Register multiple train_specs under the same name
register_train_spec(hf_train_spec)
register_train_spec(dataclasses.replace(hf_train_spec, name="meta-llama/Llama-3.2-3B"))
register_train_spec(dataclasses.replace(hf_train_spec, name="meta-llama/Llama-3.2-1B"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[job]
dump_folder = "./outputs"
description = "Llama 3 debug training with FSDP on 2 GPUs"
print_args = false
use_for_integration_test = true

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "llama3"
flavor = "debugmodel"
tokenizer_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/tokenizer"

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2
decay_ratio = 0.8
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
max_norm = 1.0
steps = 10
compile = false
dataset = "c4_test"
dataset_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/c4_test"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
expert_parallel_degree = 1

[checkpoint]
enable_checkpoint = false

[activation_checkpoint]
mode = "selective"
selective_ac_option = '2'

[validation]
enabled = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[job]
dump_folder = "./outputs"
description = "Llama 3 debug training with FSDP on 2 GPUs"
print_args = false
use_for_integration_test = true

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "meta-llama/Llama-3.2-1B"
flavor = "medium"
tokenizer_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/tokenizer"

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2
decay_ratio = 0.8
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
max_norm = 1.0
steps = 10
compile = false
dataset = "c4_test"
dataset_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/c4_test"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
expert_parallel_degree = 1

[checkpoint]
enable_checkpoint = false

[activation_checkpoint]
mode = "selective"
selective_ac_option = '2'

[validation]
enabled = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# FSDP-only configuration for a 2-GPU setup.
# Model is sharded across GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 debug training with FSDP on 2 GPUs"
print_args = false
use_for_integration_test = true

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "llama3"
flavor = "debugmodel"
tokenizer_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/tokenizer"

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2
decay_ratio = 0.8
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
max_norm = 1.0
steps = 10
compile = false
dataset = "c4_test"
dataset_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/c4_test"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 2
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
expert_parallel_degree = 1

[checkpoint]
enable_checkpoint = false

[activation_checkpoint]
mode = "selective"
selective_ac_option = '2'

[validation]
enabled = false
Loading