Skip to content

Conversation

@3outeille
Copy link
Member

@3outeille 3outeille commented Sep 6, 2025

Context

This PR enables:

  • Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP (and the combinations between them). The following models were tested:
    • meta-llama/Llama-3.2-1B
    • microsoft/phi-2
    • Qwen/Qwen2.5-7B
    • mistralai/Mistral-7B-v0.1
    • ByteDance-Seed/Seed-Coder-8B-Instruct
    • Qwen/Qwen3-4B-Instruct-2507
    • arcee-ai/AFM-4.5B
    • ibm-granite/granite-3b-code-base-2k
    • baidu/ERNIE-4.5-0.3B-Base-PT
    • kyutai/helium-1-preview-2b
    • allenai/OLMo-7B-hf
    • mistralai/Ministral-8B-Instruct-2410
  • Patching HF models weights initialisation. Without this, the the loss and grad_norm starts very high

Usage

  • Config: torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml
...
[model]
- name = "llama3"
+ name = "Qwen/Qwen3-4B-Instruct-2507" 
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"
...
  • Train: LOG_RANK=7 ./torchtitan/torchtitan/experiments/transformers_backend/run_train.sh
image

Testing methodology

image
  • Following the converging.md guidelines, I am comparing the baseline FSDP=2 vs FSDP=2 & <other //-ism>
  • More precisely, the test_hf_integration.pyis going to do:
    results/
        |_ meta-llama
            |_ Llama-3.2-1B
                |_ debugmodel/
                    |_ seed_checkpoint/
                        |_ config.toml
                        |_ seed.slurm
                        |_ step-0/
                           |_ ....
                    |_ fsdp2_tp1_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                    |_ fsdp2_tp2_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp1_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log`
                |_ full/
                ...
  • Here is the grid search to test the HF modelling
#!/usr/bin/bash
model_names=(
     "meta-llama/Llama-3.2-1B"
     "microsoft/phi-2" 
     "Qwen/Qwen2.5-7B"
     "mistralai/Mistral-7B-v0.1"
     "ByteDance-Seed/Seed-Coder-8B-Instruct"
     "Qwen/Qwen3-4B-Instruct-2507" 
     "arcee-ai/AFM-4.5B" 
     "ibm-granite/granite-3b-code-base-2k" 
     "baidu/ERNIE-4.5-0.3B-Base-PT" 
     "kyutai/helium-1-preview-2b" 
     "allenai/OLMo-7B-hf"
     "mistralai/Ministral-8B-Instruct-2410" 
)

for model_name in "${model_names[@]}"; do
    rm -rf slurm_results/${model_name}

    python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
    while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
        echo "Waiting for seed checkpoint from ${model_name} to complete ..."
        sleep 1
    done
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
    echo "================"
done

Further tasks

  • Moe (handle in PR Add transformer backend (MoE) clean  #3)
    • Missing build_optimizers_with_moe_load_balancing support for MoE
    • Missing TP/PP/EP supports for MoE
  • When using HF modeling, the test FSDP=2 vs FSDP=2 + PP=2, the loss and grad_norm not bitwise matching (but converging) while it is the case with Torchtitan modeling. (issue is tracked in Fix pp convergence to be bitwise #4)
  • Add convergence tests to CI by doing tiny model + gloo backend (once PP is bitwise matching)
  • the HF modeling has lower MFU than Torchtitan MFU
  • NOTE: import torch._dynamo.config; torch._dynamo.config.cache_size_limit = 128 to avoid recomputation for graph when using torch.compile and activation checkpointing

@3outeille 3outeille force-pushed the 3outeille/transformers_backend branch from 5ae8455 to fe691b8 Compare November 13, 2025 10:54
@3outeille
Copy link
Member Author

@wwwjn addresses all the issues mentioned. There is one last point I want to address (cf here) which I think provide a better user experience

@3outeille 3outeille requested a review from tianyu-l November 14, 2025 10:23
@3outeille 3outeille force-pushed the 3outeille/transformers_backend branch from bcf5355 to c0c273c Compare November 19, 2025 11:25
tianyu-l pushed a commit to pytorch/torchtitan that referenced this pull request Nov 20, 2025
# Context
Reference PR: huggingface#1

This PR enables:
- Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP
(and the combinations between them). The following models were tested:
  - `meta-llama/Llama-3.2-1B`
  - `microsoft/phi-2`
  - `Qwen/Qwen2.5-7B`
  - `mistralai/Mistral-7B-v0.1`
  - `ByteDance-Seed/Seed-Coder-8B-Instruct`
  - `Qwen/Qwen3-4B-Instruct-2507`
  - `arcee-ai/AFM-4.5B`
  - `ibm-granite/granite-3b-code-base-2k`
  - `baidu/ERNIE-4.5-0.3B-Base-PT`
  - `kyutai/helium-1-preview-2b`
  - `allenai/OLMo-7B-hf`
  - `mistralai/Ministral-8B-Instruct-2410`
- Patching HF models weights initialisation. Without this, the the
`loss` and `grad_norm` starts very high

# Usage

- Requirements `transformers==4.57.1`
- Config:
`torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml`
```diff
...
[model]
- name = "llama3"
+ name = "transformers_backend"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
```
- Train: `LOG_RANK=7
CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml
./run_train.sh
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config
--compile.enable`

<img width="1334" height="453" alt="image"
src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c"
/>

# Testing methodology

<img width="2672" height="2018" alt="image"
src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7"
/>

- Following the
[converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md)
guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other
//-ism>`
- More precisely, the `test_hf_integration.py`is going to do:

```bash
    results/
        |_ meta-llama
            |_ Llama-3.2-1B
                |_ debugmodel/
                    |_ seed_checkpoint/
                        |_ config.toml
                        |_ seed.slurm
                        |_ step-0/
                           |_ ....
                    |_ fsdp2_tp1_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                    |_ fsdp2_tp2_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp1_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log`
                |_ full/
                ...
```
- Here is the grid search to test the HF modelling
```shell
#!/usr/bin/bash
model_names=(
     "meta-llama/Llama-3.2-1B"
     "microsoft/phi-2" 
     "Qwen/Qwen2.5-7B"
     "mistralai/Mistral-7B-v0.1"
     "ByteDance-Seed/Seed-Coder-8B-Instruct"
     "Qwen/Qwen3-4B-Instruct-2507" 
     "arcee-ai/AFM-4.5B" 
     "ibm-granite/granite-3b-code-base-2k" 
     "baidu/ERNIE-4.5-0.3B-Base-PT" 
     "kyutai/helium-1-preview-2b" 
     "allenai/OLMo-7B-hf"
     "mistralai/Ministral-8B-Instruct-2410" 
)

for model_name in "${model_names[@]}"; do
    rm -rf slurm_results/${model_name}

    python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
    while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
        echo "Waiting for seed checkpoint from ${model_name} to complete ..."
        sleep 1
    done
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
    echo "================"
done
```

# Further tasks

- Moe (handle in PR huggingface#3)
	- Missing `build_optimizers_with_moe_load_balancing` support for MoE
	- Missing TP/PP/EP supports for MoE 
- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss`
and `grad_norm` not bitwise matching (but converging) while it is the
case with Torchtitan modeling. (issue is tracked in
huggingface#4)
- Add convergence tests to CI by doing tiny model + gloo backend (once
PP is bitwise matching)
- the HF modeling has lower MFU than Torchtitan MFU
- NOTE: `import torch._dynamo.config;
torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for
graph when using `torch.compile` and `activation checkpointing`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants