forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 2
Add transformers backend (Dense model only) #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
3outeille
wants to merge
133
commits into
main
Choose a base branch
from
3outeille/transformers_backend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+0
−0
Open
Changes from 7 commits
Commits
Show all changes
133 commits
Select commit
Hold shift + click to select a range
7488385
create transformer_backend folder with debug run
3outeille 39a3b34
add hf config
3outeille ea7c594
can now register train spec for hf model
3outeille 5f0adf5
can now switch with different flavors using HF Llama modeling
3outeille 7c3795c
it is now working up to apply_ac
3outeille 3fb2bf8
now working up to init_weights
3outeille 25daeca
fix mapping when convert_to_hf_config + add breaking test to ensure p…
3outeille 3e67f2c
define own apply_ac for transformer backend instead of reusing llama3
3outeille 8c5c0ae
HF model without any parallelism now train (but grad_norm is high)
3outeille 4ae9560
a bit cleaner way to get passed args
3outeille 9be95f9
now same number of params + same attention backend but noticed highe…
3outeille bf91447
fix seed and deterministic
3outeille 4c2fc0b
fix torch deterministic for HF modeling that was producing Nans
3outeille 9bffa38
HF model now numerically stable compared to TT (given a fixed attent…
3outeille 40d84cc
handling the is_hf_initialized flag in patch
3outeille bd3f332
refactor HF transformer model args
3outeille 249be92
wrapper model class to avoid transformers to be explicit in train.py
3outeille e2d4ada
add better testing script with reference log for later sanity check
3outeille 4b498a9
no need to fill passed args
3outeille eb403d5
can now handle multiple HF modeling
3outeille a0d67a7
handle pref logits accessing inside HF model wrapper
3outeille ea05552
isolate HF patch for llama in another file
3outeille adefa2c
find hacky way to pass HF model.name through CLI
3outeille a235863
more granularity of logging when doing parameter breakdown
3outeille fc43dc8
add __repr__ to HFTransformerModelArgs for better debugging logs
3outeille 23ae378
HF deepseek v3 is now training
3outeille 2573be4
refactor to make it clear which args comes from which parts
3outeille 46ae0a3
fix refactor and simplify things
3outeille b33d575
hacky way to switch flavors for now
3outeille 007f005
hf deepseek train while matching same param counts as tt deepseek
3outeille dd2b04c
wtf deepseek q_proj weight init differ ???
3outeille 9abdae3
deepseek now has same weight init in HF & TT. Reasons was rng_state w…
3outeille f9e90bc
adapt mfu to handle moe
3outeille ba5d6d1
beginning parallelism by setting tests
3outeille 338a250
better compare_distributed_run test
3outeille 36a5673
add seed + deterministic to compare_distributed_run
3outeille ed892a2
better extract and compare metrics
3outeille 1c1452f
refactor to introduce slurm
3outeille 5e4911f
error handling with subprocess
3outeille 4891a47
FSDP for llama in 1D works
3outeille 9e260a0
better formatting of compare_distributed_run + display min/max grad_n…
3outeille a604bee
make FSDP work in a cleaner way (mapping instead of renaming)
3outeille 0b38d0d
Improve logging in compare_distributed_run
3outeille 025a86f
PP for llama in 1D works
3outeille 590737f
simplify PP logic by flattening the named_children hierarchy. This wi…
3outeille 1a9af68
TP now works in 1D
3outeille e6b9ff5
add test filtering in compare distributed run
3outeille a4cb8c3
dont generate EP config if model is not a MoE
3outeille 12c0c47
disable torch.utils.deterministic.fill_uninitialized_memory for Moe d…
3outeille 13edc66
CP is now supported
3outeille 52250fb
some cleaning
3outeille c523ede
cleaner way to make create_causal_mask = None
3outeille f9f5c66
uniformize llama and moe args passing
3outeille 5a875b6
cleaning code
3outeille e4d963c
fix same global_batch_size across training + fix float32 for test (ev…
3outeille 957cc4a
refactor compare_distributed_run to make it slurm compatible
3outeille a317c53
breaking test
3outeille d2f80a2
refactor test
3outeille 6454e40
fix running job to slurm
3outeille b99a4d2
finally have a better testing xp with slurm
3outeille 218f400
now everything works (1D/2D/3D/4D). need to fix correctness with PP
3outeille bb080ad
fix and uniformize weight init of llama-like model + various fix
3outeille 3168f9e
support moe init and fix with moe layer (TP for lora layers)
3outeille a9a65b7
begin TP + EP with MoE model
3outeille b4a1b88
cleaning
3outeille 5f1075b
add small example scripts
3outeille 81f1855
Merge branch 'main' into 3outeille/transformers_backend
3outeille c35ccfc
fix all the merge issues
3outeille d5ce2e9
get rid of hf patches files and put it in hf_transformer_args
3outeille 8d46723
remove eos_id + refactor Optional[int] to comply with torchtitan conv…
3outeille 087f841
move torch.utils.deterministic.fill_uninitialized_memory = False to u…
3outeille 937c68d
remove test_template for base_config instead
3outeille 4f2b357
separate args &model + dont extract loss metrics -1.0 when double PP …
3outeille 154289d
use recent refactoring for flops computation for dense and moe model
3outeille 1b2cfd7
fix tie_embedding
3outeille 0f2c51e
remove pad_token_id=None
3outeille 4c8b4b7
make it clearer about args
3outeille c61271e
remove local testing scripts
3outeille a848545
fix linting
3outeille 9488a16
create CI jobs to guard
3outeille 5be438b
Merge branch 'main' into 3outeille/transformers_backend
3outeille e8a1757
update the way we register_train_spec
3outeille 141c377
relative path for qwen3_fsdp2_tp2_pp2.toml
3outeille a67e971
dont use os.environ, use debugmodel or debugmodel_moe
3outeille 060befe
refactor args to make it clearer
3outeille 3425b12
add README
3outeille 7b0ee5d
add requirements.txt
3outeille 3e2222c
fix linting
3outeille 70c348d
fix bug related to training with different seq_len than max_seq_len
3outeille af0a1cb
decouple MoE logic to another PR
3outeille 980a92b
update experiments README
3outeille 06b6f24
update README to confirm torch.compile support
3outeille a70c4c4
custom job_config
3outeille 42884cd
remove unecessary change in train_spec
3outeille 4fa0874
rename file to comply with torchtitan style
3outeille 8ffa7f4
reuse ac form torchtitan
3outeille ff21c2b
reuse ddp from torchtitan
3outeille 0a43a8a
reuse compile from torchtitan llama3
3outeille 8026bc7
reuse compile from torchtitan
3outeille cd4042f
update parallelize with main
3outeille 0700bdb
remove moe ep tp for now
3outeille 767f71d
fix SequenceParallel for q and k norm
3outeille 7f71f88
job_config.training will always have seq_len
3outeille 7e63a82
fix loading weights in PP by using Module Dict
3outeille 04fb8eb
clean reference qwen config
3outeille 0d80f62
error out if no layer_idx
3outeille 09f0c94
reuse pipeline from torchtitan
3outeille 78d26ff
use c4 test for integration_tests
3outeille 5243795
fix ci
3outeille 84af768
Merge branch 'main' of github.com:huggingface/torchtitan into 3outeil…
3outeille fe691b8
fix linting
3outeille 5d5ce2b
fix head dims in flops counting
3outeille 6ace9f4
propose an alternative to passing name
3outeille 97cd6fe
fix linting
3outeille 5f1695f
bump transformers version from 4.55.4 to 4.57.1
3outeille 2d2b612
change qwen3 config name
3outeille a2ea2ef
reuse fsdp from llama3. Moe will be handle in another PR
3outeille 47fb2ea
clean logging
3outeille 20308d3
move TitanDenseModelArgs to args
3outeille 019f2cc
clean
3outeille fc93b4f
fix integration tests
3outeille f9e8e11
rename integration test file
3outeille 83b0437
update README
3outeille fb978dd
revert accidental changes linting
3outeille 71ff098
typo in naming
3outeille 663a415
refactor
3outeille 3dbe6fa
revert the way we select HF modeling in config
3outeille 9be95da
Revert "reuse pipeline from torchtitan"
3outeille c0c273c
pass deterministic.fill_uninitialized_memory to HF model
3outeille 4c50a00
fix linting
3outeille 5b8d38c
fix integration tests
3outeille 57bb8dd
fix minor stuff
3outeille 1bbb3a8
Merge branch 'main' into 3outeille/transformers_backend
3outeille File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
3outeille marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import dataclasses | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers | ||
| from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
|
|
||
| from torchtitan.models.llama3 import pipeline_llama | ||
| from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
|
|
||
| from .infra.parallelize_hf_transformers import parallelize_hf_transformers | ||
| from .model.hf_transformers_args import HFTransformerModelArgs | ||
|
|
||
| from transformers.models.llama.modeling_llama import LlamaForCausalLM | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "HFTransformerModelArgs", | ||
| "LlamaForCausalLM", #TODO(3outeille): later use AutoModelForCausalLM | ||
| "hf_transformers_configs", | ||
| ] | ||
|
|
||
|
|
||
| flavors = { | ||
| "debug": HFTransformerModelArgs( | ||
| dim=1, | ||
| n_layers=6, | ||
| n_heads=16, | ||
| rope_theta=500000, | ||
| ), | ||
| "medium": HFTransformerModelArgs( | ||
| dim=40, | ||
| n_layers=24, | ||
| n_heads=32, | ||
| rope_theta=500000, | ||
| ), | ||
| "full": HFTransformerModelArgs(), | ||
| } | ||
|
|
||
| hf_train_spec = TrainSpec( | ||
| name="hf_auto_model", | ||
| model_cls=LlamaForCausalLM, | ||
3outeille marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| model_args=flavors, | ||
| parallelize_fn=parallelize_hf_transformers, | ||
| pipelining_fn=pipeline_llama, | ||
| build_optimizers_fn=build_optimizers, | ||
3outeille marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_hf_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| ) | ||
|
|
||
| # Register multiple train_specs under the same name | ||
3outeille marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| register_train_spec(hf_train_spec) | ||
| register_train_spec(dataclasses.replace(hf_train_spec, name="meta-llama/Llama-3.2-3B")) | ||
| register_train_spec(dataclasses.replace(hf_train_spec, name="meta-llama/Llama-3.2-1B")) | ||
62 changes: 62 additions & 0 deletions
62
torchtitan/experiments/transformers_backend/configs/debug_1_gpu.toml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| [job] | ||
| dump_folder = "./outputs" | ||
| description = "Llama 3 debug training with FSDP on 2 GPUs" | ||
| print_args = false | ||
| use_for_integration_test = true | ||
|
|
||
| [profiling] | ||
| enable_profiling = false | ||
| save_traces_folder = "profile_trace" | ||
| profile_freq = 10 | ||
| enable_memory_snapshot = false | ||
| save_memory_snapshot_folder = "memory_snapshot" | ||
|
|
||
| [metrics] | ||
| log_freq = 1 | ||
| disable_color_printing = false | ||
| enable_tensorboard = false | ||
| save_tb_folder = "tb" | ||
| enable_wandb = false | ||
|
|
||
| [model] | ||
| name = "llama3" | ||
| flavor = "debugmodel" | ||
| tokenizer_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/tokenizer" | ||
|
|
||
| [optimizer] | ||
| name = "AdamW" | ||
| lr = 8e-4 | ||
| eps = 1e-8 | ||
|
|
||
| [lr_scheduler] | ||
| warmup_steps = 2 | ||
| decay_ratio = 0.8 | ||
| decay_type = "linear" | ||
| min_lr_factor = 0.0 | ||
|
|
||
| [training] | ||
| local_batch_size = 8 | ||
| seq_len = 2048 | ||
| max_norm = 1.0 | ||
| steps = 10 | ||
| compile = false | ||
| dataset = "c4_test" | ||
| dataset_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/c4_test" | ||
|
|
||
| [parallelism] | ||
| data_parallel_replicate_degree = 1 | ||
| data_parallel_shard_degree = 1 | ||
| tensor_parallel_degree = 1 | ||
| pipeline_parallel_degree = 1 | ||
| context_parallel_degree = 1 | ||
| expert_parallel_degree = 1 | ||
|
|
||
| [checkpoint] | ||
| enable_checkpoint = false | ||
|
|
||
| [activation_checkpoint] | ||
| mode = "selective" | ||
| selective_ac_option = '2' | ||
|
|
||
| [validation] | ||
| enabled = false |
62 changes: 62 additions & 0 deletions
62
torchtitan/experiments/transformers_backend/configs/debug_1_gpu_hf.toml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| [job] | ||
| dump_folder = "./outputs" | ||
| description = "Llama 3 debug training with FSDP on 2 GPUs" | ||
| print_args = false | ||
| use_for_integration_test = true | ||
|
|
||
| [profiling] | ||
| enable_profiling = false | ||
| save_traces_folder = "profile_trace" | ||
| profile_freq = 10 | ||
| enable_memory_snapshot = false | ||
| save_memory_snapshot_folder = "memory_snapshot" | ||
|
|
||
| [metrics] | ||
| log_freq = 1 | ||
| disable_color_printing = false | ||
| enable_tensorboard = false | ||
| save_tb_folder = "tb" | ||
| enable_wandb = false | ||
|
|
||
| [model] | ||
| name = "meta-llama/Llama-3.2-1B" | ||
| flavor = "medium" | ||
| tokenizer_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/tokenizer" | ||
|
|
||
| [optimizer] | ||
| name = "AdamW" | ||
| lr = 8e-4 | ||
| eps = 1e-8 | ||
|
|
||
| [lr_scheduler] | ||
| warmup_steps = 2 | ||
| decay_ratio = 0.8 | ||
| decay_type = "linear" | ||
| min_lr_factor = 0.0 | ||
|
|
||
| [training] | ||
| local_batch_size = 8 | ||
| seq_len = 2048 | ||
| max_norm = 1.0 | ||
| steps = 10 | ||
| compile = false | ||
| dataset = "c4_test" | ||
| dataset_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/c4_test" | ||
|
|
||
| [parallelism] | ||
| data_parallel_replicate_degree = 1 | ||
| data_parallel_shard_degree = 1 | ||
| tensor_parallel_degree = 1 | ||
| pipeline_parallel_degree = 1 | ||
| context_parallel_degree = 1 | ||
| expert_parallel_degree = 1 | ||
|
|
||
| [checkpoint] | ||
| enable_checkpoint = false | ||
|
|
||
| [activation_checkpoint] | ||
| mode = "selective" | ||
| selective_ac_option = '2' | ||
|
|
||
| [validation] | ||
| enabled = false |
65 changes: 65 additions & 0 deletions
65
torchtitan/experiments/transformers_backend/configs/debug_fsdp_2_gpu.toml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| # FSDP-only configuration for a 2-GPU setup. | ||
| # Model is sharded across GPUs. | ||
|
|
||
| [job] | ||
| dump_folder = "./outputs" | ||
| description = "Llama 3 debug training with FSDP on 2 GPUs" | ||
| print_args = false | ||
| use_for_integration_test = true | ||
|
|
||
| [profiling] | ||
| enable_profiling = false | ||
| save_traces_folder = "profile_trace" | ||
| profile_freq = 10 | ||
| enable_memory_snapshot = false | ||
| save_memory_snapshot_folder = "memory_snapshot" | ||
|
|
||
| [metrics] | ||
| log_freq = 1 | ||
| disable_color_printing = false | ||
| enable_tensorboard = false | ||
| save_tb_folder = "tb" | ||
| enable_wandb = false | ||
|
|
||
| [model] | ||
| name = "llama3" | ||
| flavor = "debugmodel" | ||
| tokenizer_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/tokenizer" | ||
|
|
||
| [optimizer] | ||
| name = "AdamW" | ||
| lr = 8e-4 | ||
| eps = 1e-8 | ||
|
|
||
| [lr_scheduler] | ||
| warmup_steps = 2 | ||
| decay_ratio = 0.8 | ||
| decay_type = "linear" | ||
| min_lr_factor = 0.0 | ||
|
|
||
| [training] | ||
| local_batch_size = 8 | ||
| seq_len = 2048 | ||
| max_norm = 1.0 | ||
| steps = 10 | ||
| compile = false | ||
| dataset = "c4_test" | ||
| dataset_path = "/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/tests/assets/c4_test" | ||
|
|
||
| [parallelism] | ||
| data_parallel_replicate_degree = 1 | ||
| data_parallel_shard_degree = 2 | ||
| tensor_parallel_degree = 1 | ||
| pipeline_parallel_degree = 1 | ||
| context_parallel_degree = 1 | ||
| expert_parallel_degree = 1 | ||
|
|
||
| [checkpoint] | ||
| enable_checkpoint = false | ||
|
|
||
| [activation_checkpoint] | ||
| mode = "selective" | ||
| selective_ac_option = '2' | ||
|
|
||
| [validation] | ||
| enabled = false |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.