Skip to content

Commit

Permalink
update CL doc (deepspeedai#1506)
Browse files Browse the repository at this point in the history
* update CL doc

* doc fix
  • Loading branch information
conglongli authored Oct 30, 2021
1 parent 163f568 commit 7f5a3ad
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
* Read more on how to [train large models with DeepSpeed](https://www.deepspeed.ai/tutorials/large-models-w-deepspeed/)
* [2021/08/18] [DeepSpeed powers 8x larger MoE model training with high performance](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/)
* [Mixture of Experts (MoE) tutorial](https://www.deepspeed.ai/tutorials/mixture-of-experts/).
* [2021/08/16] [Curriculum learning: a regularization method for stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate](https://www.deepspeed.ai/tutorials/curriculum-learning/)
* [2021/08/16] [Curriculum learning: a regularization method for stable and 3.3x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate](https://www.deepspeed.ai/tutorials/curriculum-learning/)
* [2021/05/24] [DeepSpeed: Accelerating large-scale model inference and training via system optimizations and compression](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/)
* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)
Expand Down Expand Up @@ -153,7 +153,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* [Simplified Data Loader](https://www.deepspeed.ai/features/#simplified-data-loader)
* [Curriculum Learning](https://www.deepspeed.ai/tutorials/curriculum-learning/)
* A curriculum learning-based data pipeline that presents easier or simpler examples earlier during training
* Stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate while maintaining token-wise convergence speed
* Stable and 3.3x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate while maintaining token-wise convergence speed
* Complementary to many other DeepSpeed features
* [Performance Analysis and Debugging](https://www.deepspeed.ai/features/#performance-analysis-and-debugging)
* [Mixture of Experts (MoE)](https://www.deepspeed.ai/tutorials/mixture-of-experts/)
Expand Down
47 changes: 38 additions & 9 deletions docs/_tutorials/curriculum-learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
title: "Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training"
---

In this tutorial, we introduce DeepSpeed's curriculum learning-based data pipeline, which presents easier or simpler examples earlier during training. By enabling stable training with 8x/4x larger batch size/learning rate (whereas the baseline approach struggles with training divergence), we observe that curriculum learning (based on sequence length) provides stable and 2.6x faster GPT-2 pre-training (tested on 117M and 1.5B parameters), together with better token-wise convergence speed and zero-shot WikiText-103/LAMBADA evaluation results. In addition, since curriculum learning only affect the data pipeline, its benefit is complementary to many DeepSpeed features and other system optimization techniques. For example, curriculum learning is compatible with DeepSpeed's [ZeRO Redundancy Optimizer](/tutorials/zero/) and [ZeRO-Offload](/tutorials/zero-offload/), and Megatron-LM's Model Parallelism.
**Note:**
This tutorial is updated on 10/29/2021. Changes include: 1) A more detailed tuning strategy. 2) Pipeline parallelism support. 3) Token-based learning rate decay. 4) A new GPT-2 example at [github.com/microsoft/Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed). See details below.
{: .notice--info}

In this tutorial, we introduce DeepSpeed's curriculum learning-based data pipeline, which presents easier or simpler examples earlier during training. By enabling stable training with 8x/4x larger batch size/learning rate (whereas the baseline approach struggles with training divergence), we observe that curriculum learning (based on sequence length) provides stable and 3.3x faster GPT-2 pre-training (tested on 117M and 1.5B parameters), together with better token-wise convergence speed and zero-shot WikiText-103/LAMBADA evaluation results. In addition, since curriculum learning only affect the data pipeline, its benefit is complementary to many DeepSpeed features and other system optimization techniques. For example, curriculum learning is compatible with DeepSpeed's [ZeRO Redundancy Optimizer](/tutorials/zero/), [ZeRO-Offload](/tutorials/zero-offload/), and [3D Parallelism](/tutorials/pipeline/).

To illustrate the benefits and usage of curriculum learning, we use the Megatron-LM GPT-2 pre-training task as example. For more details on this task, please refer to the [tutorial](/tutorials/megatron/). In addition, we also have a [paper](https://arxiv.org/abs/2108.06084) which provides the technical details including implementation and evaluations.

Expand Down Expand Up @@ -45,9 +49,9 @@ Curriculum learning can be used by setting the DeepSpeed configuration as the fo
```
To support curriculum learning, we add the following new parameters:

`curriculum_type` is the type of curriculum difficulty metric. Currently we support the `seqlen` metric which presents shorter sequences earlier in training. We implement this type of curriculum learning by passing an additional `curriculum_seqlen` argument to the model's forward function, and performing training data sequence truncation before the actual forward pass. We will describe how to implement this in the Megatron-LM GPT-2 pre-training example below.
`curriculum_type` is the type of curriculum difficulty metric. Currently we support the `seqlen` metric which presents shorter sequences earlier in training. We implement this type of curriculum learning by performing training data sequence truncation before the actual forward pass. We will describe how to implement this in the Megatron-LM GPT-2 pre-training example below.

`min_difficulty` is the starting difficulty level. For `seqlen` metric it means we start with sequence length as `min_difficulty`. We observe that lower `min_difficulty` usually provides better convergence speedup but with two caveats: First, sometimes (especially for large models) starting with too small difficulty level may lead to severe overfitting (e.g., training loss divergence or validation loss keeps jumping up and down) thus hurt the convergence. In such case it is recommended to try increasing the `min_difficulty`. Second, for `seqlen` metric it is recommended to set `min_difficulty` as multiple of 8 (for FP16 data) or 16 (for INT8 data) in order to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/).
`min_difficulty` is the starting difficulty level. For `seqlen` metric it means we start with sequence length as `min_difficulty`. We observe that lower `min_difficulty` usually provides better stability/convergence speed benefit but with two caveats: First, sometimes (especially for large models) starting with too small difficulty level may lead to severe overfitting (e.g., training loss divergence or validation perplexity fluctuations) thus hurt the convergence. Second, for `seqlen` metric it is recommended to set `min_difficulty` as multiple of 8 (for FP16 data) or 16 (for INT8 data) in order to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). To tune this hyperparameter for `seqlen` metric, we recommend to start with `min_difficulty` at 8 (million-scale models) or 64 (billion-scale models), and then increase it if you observe divergence or validation perplexity fluctuations at the very beginning.

`max_difficulty` is the ending difficulty level. For `seqlen` metric it should be set as the full sequence length (e.g., 1024 for Megatron-LM GPT-2 pre-training).

Expand All @@ -65,9 +69,9 @@ For `fixed_linear` schedule there are two configurations:
}
```

The `total_curriculum_step` is the total number of steps for the curriculum learning. For `fixed_linear` schedule the difficulty level will linearly increase from `min_difficulty` to `max_difficulty` during the `total_curriculum_step` duration. This configuration needs to be tuned for each training task. We observe that too small and too large `total_curriculum_step` are both suboptimal: with too small `total_curriculum_step` curriculum learning might not be able to provide enough training stability benefit so the training might still diverge; with too large `total_curriculum_step` the model may overfit too much during curriculum learning on the easier/simpler training data thus hurt the overall convergence. We recommend to first set `total_curriculum_step` as 20% to 40% of the total training steps (note that if you increase the batch size for the curriculum learning-based training, you also need to reduce the total training steps correspondingly), then increase the `total_curriculum_step` if the training is not stable, or reduce the `total_curriculum_step` to test if convergence improves.
The `total_curriculum_step` is the total number of steps for the curriculum learning. For `fixed_linear` schedule the difficulty level will linearly increase from `min_difficulty` to `max_difficulty` during the `total_curriculum_step` duration. This configuration needs to be tuned for each training task. We observe that too small and too large `total_curriculum_step` are both suboptimal: with too small `total_curriculum_step` curriculum learning might not be able to provide enough training stability benefit so the training might still diverge; with too large `total_curriculum_step` the model may overfit too much during curriculum learning on the easier/simpler training data thus hurt the overall convergence. To tune this hyperparameter, we recommend performing a binary search to find the largest `total_curriculum_step` that does not have significant validation perplexity fluctuation during the first few multiples of LR warmup steps. The underlying rationale can be found in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1.

The `difficulty_step` configuration ensures that at anytime the difficulty level must be multiple of `difficulty_step`. We usually set it as 8 (for FP16 data) or 16 (for INT8 data) to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). If this is unrelated to your training experiment, you can set it as 1.
The `difficulty_step` configuration ensures that at anytime the difficulty level must be multiple of `difficulty_step`. A smaller value is preferable since it gives more smooth curriculum and better stability. We usually set it as 8 (for FP16 data) or 16 (for INT8 data) to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). If this is unrelated to your hardware, you can set it as 1.

### 1.2 fixed_root schedule
For `fixed_root` schedule there are three configurations:
Expand Down Expand Up @@ -98,10 +102,35 @@ The `difficulty` is a list of difficulty levels to be used during schedule. The

## 2. Curriculum learning for Megatron-LM GPT-2 pre-training

We provide example scripts under [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning). The `ds_train.sh` is the training script to run and it also includes the actual configurations we used for the experiments in our [paper](https://arxiv.org/abs/2108.06084).
**Watch out!**
After the update on 10/29/2021, now there are two curriculum learning examples for Megatron-LM GPT-2 pre-training. Both of them have some unique features and limitations. See details below.
{: .notice--warning}

We provide two curriculum learning examples for Megatron-LM GPT-2 pre-training:

The first one is at [Megatron-DeepSpeed/tree/main/examples/curriculum_learning](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/curriculum_learning). This integration is based on a newer Megatron-LM fork, and only this curriculum learning example supports pipeline parallelism. However, currently (10/29/2021) we haven't verified ZeRO-2 and ZeRO-3 on this fork. Overall, we highly recommend you to use this example if your model does not require ZeRO-2/3.

The second one is at [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning). This integration is based on an older Megatron-LM hard copy that we will eventually deprecate and this curriculum learning example does not support pipeline parallelism. We recommend you to ONLY use this example if your model requires ZeRO-2/3.

Besides the additional DeepSpeed curriculum learning json configurations described above, there are some other necessary changes on the user side to integrate curriculum learning:

### 2.1 Training data truncation

To enable seqlen-based curriculum learning, we need to add the functionality of training data truncation based on the given curriculum sequence length. For the case without pipeline parallelism, it is necessary to add a `curriculum_seqlen` argument in the model's forward pass and use it to perform training data sequence length truncation. For Megatron-LM GPT-2 pre-training, we implement this in `forward()` in [megatron/model/gpt2_model.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py) and in `forward_step()` in [pretrain_gpt2.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py).

For the case with pipeline parallelism, due to DeepSpeed engine limitations we cannot inject the `curriculum_seqlen` argument in the forward pass. Instead, we create a duplicate of `deepspeed.runtime.data_pipeline.curriculum_scheduler` on the user side, and use it to retrieve the `curriculum_seqlen`. This implementation can be found in [megatron/training.py](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/training.py).

### 2.2 Disable batch size warmup (--rampup-batch-size)
In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique introduced by Open AI GPT-3. So when using CL you need to remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now.

### 2.3 Token-based training termination

Because curriculum learning changes length of each sequence/sample during training, it is very hard/impossible to use number of steps/samples to terminate the training exactly at the desired number of tokens. Thus we add a `--train-tokens` config as an alternative accurate token-based termination. We recommend increase your original `--train-samples` or `--train-iters` to a large enough number (e.g., 3X of what you used for baseline), and set `--train-tokens` at the exact desired number of training tokens.

### 2.4 Token-based LR decay

Besides the additional DeepSpeed configurations, there are some other necessary changes on the user side to enable curriculum learning. First, it is necessary to add a `curriculum_seqlen` argument in the model's forward pass and use it to perform training data sequence length truncation. For Megatron-LM GPT-2 pre-training, we implement this in `forward()` in [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py) and in `forward_step()` in [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py).
Again because curriculum learning changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 1K for GPT-2 and 2K for GPT-3). If previously you were using `--lr-decay-iters`, you can calculate your `--lr-decay-tokens` by multiplying the former by full seqlen and the global batch size. Then you need to replace `--lr-decay-samples` or `--lr-decay-iters` with `--lr-decay-tokens` in your script.

Second, since there will be less tokens per step during curriculum learning, for curriculum-learning based training it requires more steps in order to reach the same number of training tokens as baseline. Thus in Megatron-LM we add a `--train-tokens` argument to terminate the training based on number of tokens. Then we usually set a long enough `--train-iters` (e.g., two times of baseline's total training step), and set the `--train-tokens` the same for baseline and curriculum-learning based training.
### 2.5 LR warmup adjustment

Third, again due to the less tokens per step during curriculum learning, we find that for curriculum-learning based training it is beneficial to increase the learning rate decay steps (otherwise the curriculum learning case will have faster token-wise learning rate decay than baseline). For `fixed_linear` schedule because we start from very short sequence length, the total number of tokens during the curriculum learning is roughly halved. Thus we usually just add half of `fixed_linear` schedule's `total_curriculum_step` to the Megatron-LM's `--lr-decay-iters`.
For LR warmup we don't change it to token-based, because doing so for curriculum learning means slowing down the LR warmup, which is both unnecessary and harmful. However, to avoid too fast warmup you may need to adjust your `--lr-warmup-samples` or `--lr-warmup-iters` from non-CL cases for various reasons (e.g., if you used `--rampup-batch-size` in non-CL case, for CL we don't use it so the number of samples per batch will be different at beginning). Assuming you want to use `X` tokens to warmup the LR (for OpenAI GPT-3 this was 375M tokens), then for curriculum learning case you shall set `--lr-warmup-samples` as `X` divided by the `min_difficulty`, or set `--lr-warmup-iters` as `X` divided by `min_difficulty * --global-batch-size`. This is a rough estimation based on that curriculum learning starts from seqlen `min_difficulty` and it won't increase too much during LR warmup.
Loading

0 comments on commit 7f5a3ad

Please sign in to comment.