Skip to content

Commit

Permalink
Make the script more friendly to outside users and add docs. (#528)
Browse files Browse the repository at this point in the history
* Add mini script

* update documentation

* quick docs change
  • Loading branch information
vwxyzjn authored Jan 28, 2025
1 parent bc536c8 commit b12816f
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 76 deletions.
82 changes: 6 additions & 76 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ You can run the following commands for getting started:

```bash
# quick debugging run using 1 GPU
sh scripts/finetune_with_accelerate_config.sh 1 configs/train_configs/sft/mini.yaml
bash scripts/train/finetune/mini.sh
# train an 8B tulu3 model using 8 GPU
sh scripts/finetune_with_accelerate_config.sh 8 configs/train_configs/tulu3/tulu3_sft.yaml
bash scripts/train/finetune/tulu_finetune_mix.sh
```


### Preference Tuning

```bash
# quick debugging run using 1 GPU
sh scripts/dpo_train_with_accelerate_config.sh 1 configs/train_configs/dpo/mini.yaml
bash scripts/train/dpo/mini.sh
# train an 8B tulu3 model using 8 GPU
sh scripts/dpo_train_with_accelerate_config.sh 8 configs/train_configs/tulu3/tulu3_dpo_8b.yaml
bash scripts/train/dpo/tulu_preference_mix.sh
```


Expand All @@ -127,79 +127,9 @@ sh scripts/dpo_train_with_accelerate_config.sh 8 configs/train_configs/tulu3/tul
# quick debugging run using 2 GPU (1 for inference, 1 for training)
# here we are using `HuggingFaceTB/SmolLM2-360M-Instruct`; it's prob not
# gonna work, but it's easy to test run and print stuff.
python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--dataset_mixer '{"ai2-adapt-dev/gsm8k_math_ifeval_ground_truth_mixed": 1.0}' \
--dataset_train_splits train \
--dataset_eval_mixer '{"ai2-adapt-dev/gsm8k_math_ground_truth": 1.0}' \
--dataset_eval_splits test \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--model_name_or_path HuggingFaceTB/SmolLM2-360M-Instruct \
--reward_model_path HuggingFaceTB/SmolLM2-360M-Instruct \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000 \
--penalty_reward_value -10.0 \
--deepspeed_stage 3 \
--per_device_train_batch_size 2 \
--local_rollout_forward_batch_size 2 \
--local_mini_batch_size 32 \
--local_rollout_batch_size 32 \
--num_epochs 1 \
--actor_num_gpus_per_node 1 \
--vllm_tensor_parallel_size 1 \
--beta 0.05 \
--apply_verifiable_reward true \
--output_dir output/rlvr_1b \
--seed 3 \
--num_evals 3 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--with_tracking

bash scripts/train/rlvr/mini.sh
# train an 8B tulu3 model using 8 GPU (1 for inference, 7 for training)
python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--dataset_mixer '{"ai2-adapt-dev/gsm8k_math_ifeval_ground_truth_mixed": 1.0}' \
--dataset_train_splits train \
--dataset_eval_mixer '{"ai2-adapt-dev/gsm8k_math_ground_truth": 1.0}' \
--dataset_eval_splits test \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
--reward_model_path allenai/Llama-3.1-Tulu-3-8B-RM \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000000 \
--penalty_reward_value -10.0 \
--deepspeed_stage 3 \
--per_device_train_batch_size 2 \
--local_rollout_forward_batch_size 2 \
--local_mini_batch_size 32 \
--local_rollout_batch_size 32 \
--actor_num_gpus_per_node 7 \
--vllm_tensor_parallel_size 1 \
--beta 0.05 \
--apply_verifiable_reward true \
--output_dir output/rlvr_8b \
--seed 3 \
--num_evals 3 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--with_tracking
bash scripts/train/rlvr/tulu_rlvr.sh
```


Expand Down
1 change: 1 addition & 0 deletions open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ def load_model():
if (
args.try_auto_save_to_beaker
and accelerator.is_main_process
and is_beaker_job()
and len(beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
Expand Down
1 change: 1 addition & 0 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ def main(args: FlatArguments):
if (
args.try_auto_save_to_beaker
and accelerator.is_main_process
and is_beaker_job()
and len(beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
Expand Down
1 change: 1 addition & 0 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,7 @@ def vllm_generate(
if (
args.try_auto_save_to_beaker
and self.rank == 0
and is_beaker_job()
and len(self.beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
Expand Down
1 change: 1 addition & 0 deletions open_instruct/ppo_vllm_thread_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,7 @@ def vllm_generate(
if (
args.try_auto_save_to_beaker
and self.rank == 0
and is_beaker_job()
and len(self.beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
Expand Down
1 change: 1 addition & 0 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,7 @@ def vllm_generate(
if (
args.try_auto_save_to_beaker
and self.rank == 0
and is_beaker_job()
and len(self.beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
Expand Down
23 changes: 23 additions & 0 deletions scripts/train/dpo/mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
accelerate launch \
--mixed_precision bf16 \
--num_processes 1 \
open_instruct/dpo_tune_cache.py \
--model_name_or_path EleutherAI/pythia-14m \
--tokenizer_name EleutherAI/pythia-14m \
--use_slow_tokenizer False \
--use_flash_attn False \
--max_seq_length 1024 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--learning_rate 5e-07 \
--lr_scheduler_type linear \
--warmup_ratio 0.1 \
--weight_decay 0.0 \
--num_train_epochs 3 \
--output_dir output/dpo_pythia_14m/ \
--with_tracking \
--report_to wandb \
--logging_steps 1 \
--dataset_mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 100 \
--add_bos \
--seed 123
33 changes: 33 additions & 0 deletions scripts/train/dpo/tulu_preference_mix.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# per https://arxiv.org/pdf/2411.15124,
# the effective batch size is 32 using 8 GPUs, since we simulate it with:
# gradient_accumulation_steps 4 * per_device_train_batch_size 1 * num_processes 8 = 32
accelerate launch \
--mixed_precision bf16 \
--num_processes 8 \
--use_deepspeed \
--deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \
--deepspeed_multinode_launcher standard \
open_instruct/dpo_tune_cache.py \
--model_name_or_path allenai/Llama-3.1-Tulu-3-8B-SFT \
--use_flash_attn \
--tokenizer_name allenai/Llama-3.1-Tulu-3-8B-SFT \
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--learning_rate 5e-07 \
--lr_scheduler_type linear \
--warmup_ratio 0.1 \
--weight_decay 0.0 \
--num_train_epochs 1 \
--output_dir /output \
--with_tracking \
--report_to wandb \
--logging_steps 1 \
--model_revision main \
--gradient_checkpointing \
--dataset_mixer_list allenai/llama-3.1-tulu-3-8b-preference-mixture 1.0 \
--use_slow_tokenizer \
--use_lora False \
--dpo_loss_type dpo_norm \
--dpo_beta 5 \
--exp_name tulu-3-8b-dpo
25 changes: 25 additions & 0 deletions scripts/train/finetune/mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
accelerate launch \
--mixed_precision bf16 \
--num_processes 1 \
open_instruct/finetune.py \
--model_name_or_path EleutherAI/pythia-14m \
--tokenizer_name EleutherAI/pythia-14m \
--use_slow_tokenizer False \
--use_flash_attn False \
--max_seq_length 1024 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--learning_rate 5e-06 \
--lr_scheduler_type linear \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--output_dir output/sft_pythia_14m/ \
--with_tracking \
--report_to wandb \
--logging_steps 1 \
--reduce_loss sum \
--model_revision main \
--dataset_mixer_list allenai/tulu-3-sft-personas-algebra 100 \
--add_bos \
--seed 123
32 changes: 32 additions & 0 deletions scripts/train/finetune/tulu_finetune_mix.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# per https://arxiv.org/pdf/2411.15124,
# the effective batch size is 128 using a single GPU, since we simulate it with:
# gradient_accumulation_steps 16 * per_device_train_batch_size 1 * num_processes 8 = 128
accelerate launch \
--mixed_precision bf16 \
--num_processes 8 \
--use_deepspeed \
--deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \
--deepspeed_multinode_launcher standard \
open_instruct/finetune.py \
--model_name_or_path meta-llama/Llama-3.1-8B \
--tokenizer_name meta-llama/Llama-3.1-8B \
--use_slow_tokenizer \
--use_flash_attn \
--max_seq_length 4096 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 5e-06 \
--lr_scheduler_type linear \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--output_dir output/sft_8b \
--with_tracking \
--report_to wandb \
--logging_steps 1 \
--reduce_loss sum \
--model_revision main \
--dataset_mixer_list allenai/tulu-3-sft-mixture 100 \
--dataset_mix_dir output/sft_8b \
--exp_name tulu-3-8b-sft \
--seed 123
36 changes: 36 additions & 0 deletions scripts/train/rlvr/mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--dataset_mixer '{"ai2-adapt-dev/gsm8k_math_ifeval_ground_truth_mixed": 1.0}' \
--dataset_train_splits train \
--dataset_eval_mixer '{"ai2-adapt-dev/gsm8k_math_ground_truth": 1.0}' \
--dataset_eval_splits test \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--model_name_or_path HuggingFaceTB/SmolLM2-360M-Instruct \
--reward_model_path HuggingFaceTB/SmolLM2-360M-Instruct \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 1000 \
--penalty_reward_value -10.0 \
--deepspeed_stage 3 \
--per_device_train_batch_size 2 \
--local_rollout_forward_batch_size 2 \
--local_mini_batch_size 32 \
--local_rollout_batch_size 32 \
--num_epochs 1 \
--actor_num_gpus_per_node 1 \
--vllm_tensor_parallel_size 1 \
--beta 0.05 \
--apply_verifiable_reward true \
--output_dir output/rlvr_1b \
--seed 3 \
--num_evals 3 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--with_tracking
35 changes: 35 additions & 0 deletions scripts/train/rlvr/tulu_rlvr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--dataset_mixer '{"ai2-adapt-dev/gsm8k_math_ifeval_ground_truth_mixed": 1.0}' \
--dataset_train_splits train \
--dataset_eval_mixer '{"ai2-adapt-dev/gsm8k_math_ground_truth": 1.0}' \
--dataset_eval_splits test \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
--reward_model_path allenai/Llama-3.1-Tulu-3-8B-RM \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000000 \
--penalty_reward_value -10.0 \
--deepspeed_stage 3 \
--per_device_train_batch_size 2 \
--local_rollout_forward_batch_size 2 \
--local_mini_batch_size 32 \
--local_rollout_batch_size 32 \
--actor_num_gpus_per_node 7 \
--vllm_tensor_parallel_size 1 \
--beta 0.05 \
--apply_verifiable_reward true \
--output_dir output/rlvr_8b \
--seed 3 \
--num_evals 3 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--with_tracking

0 comments on commit b12816f

Please sign in to comment.