diff --git a/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py b/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py index 74cf7ef..4e46b0e 100644 --- a/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py +++ b/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py @@ -69,6 +69,13 @@ import inspect import requests import gc + + +from torch_xla import __version__ +version_tuple = tuple(map(int, __version__.split(".")[:2])) +is_pt21_plus = version_tuple >= (2,1) +is_pt20 = version_tuple == (2,0) + os.environ["NEURON_CC_FLAGS"] = os.environ.get('NEURON_CC_FLAGS', '') + " --model-type=transformer" # For PT autocast. @@ -221,9 +228,12 @@ def sequence_length(self) -> int: def get_model(flags): - base_model = BertForPreTraining.from_pretrained('bert-large-uncased', force_download=True) + base_model = BertForPreTraining.from_pretrained('bert-large-uncased', use_safetensors=False) # medium BERT size L12_A12_H768. Large BERT L24_A16_H1024 causes OOM on GPU V100 my_config = copy.deepcopy(base_model.config) + if flags.disable_dropout or flags.snapshot_step_list: + my_config.hidden_dropout_prob = 0.0 + my_config.attention_probs_dropout_prob = 0.0 if flags.debug: my_config.num_hidden_layers = 1 my_config.num_attention_heads = 2 @@ -231,6 +241,16 @@ def get_model(flags): my_model = BertForPreTraining(my_config) return my_model +def extract_mfu(num_layers, hidden_size, sequence_len, batch_size, average_throughput, world_size): + flops_per_seq = 12 * num_layers * hidden_size * sequence_len * (6 * hidden_size + sequence_len) + tflops_per_seq = flops_per_seq / 10**12 + tflops_per_sec_per_worker = tflops_per_seq * average_throughput/world_size + if '--auto-cast=none' in os.getenv('NEURON_CC_FLAGS', default=''): + hw_tflops_per_worker = 760/32 + else: + hw_tflops_per_worker = 3040/32 + return tflops_per_sec_per_worker/hw_tflops_per_worker * 100 + # fix NVidia checkpoint param names to match HF def fix_ckpt_params(state_dict): keys = [k for k in state_dict.keys() if 'dense_act' in k] @@ -335,6 +355,32 @@ def train_bert_hdf5(flags): } def train_loop_fn(model, optimizer, train_loader, epoch, global_step, training_ustep, running_loss): + + # Add snapshot callback here in order to track total_steps + total_steps = 0 + capture_steps = [] + # Turn off snapshoting for all ranks/steps by default, and select ranks/steps in specified lists + def callback(name, addressable_device_index, execution_count): + return '' + # Enable snapshoting for ranks/steps specified in lists + if flags.snapshot_step_list: + if flags.snapshot_rank_list == "all": + capture_ranks = [] # empty list means all ranks + else: + capture_ranks = [int(i) for i in flags.snapshot_rank_list.split(",")] + if capture_ranks == [] or xm.get_ordinal() in capture_ranks: + capture_steps = [int(i) for i in flags.snapshot_step_list.split(",")] + if is_pt21_plus: + print(f"Enabling snapshotting for rank{xm.get_ordinal()} and steps {capture_steps}") + def callback(name, addressable_device_index, execution_count): + if total_steps in capture_steps: + return 'inputs outputs' + else: + return '' + if is_pt21_plus: + import libneuronxla + libneuronxla.register_hlo_snapshot_callback(callback) + max_grad_norm = 1.0 running_loss_reduced_detached = torch.zeros(1, device=device) for i, data in enumerate(train_loader): @@ -504,6 +550,7 @@ def _print_logs(running_loss_reduced_detached, total_norm): else: chkpt_file = os.path.join(flags.output_dir, "ckpt_{}.pt".format(global_step)) files_info = [f] + files + print('Checkpointing...', flush=True) model_to_save = model.module if hasattr(model, 'module') else model # unwrap model if needed (DDP) if flags.minimal_ckpt: @@ -526,32 +573,37 @@ def _print_logs(running_loss_reduced_detached, total_norm): if os.path.isfile(old_file): print('Keeping only {} checkpoints. Deleting {}'.format(flags.num_ckpts_to_keep, old_file)) os.remove(old_file) + if global_step >= flags.steps_this_run: xm.rendezvous("before_throughput_check") # avoid multi-node hang due to throughput threshold assert by root worker if is_root and not extract_graphs_only: + compile_time = 0.0 + compile_time_file="compile_time.txt" + if os.path.exists(compile_time_file): + with open(compile_time_file, "r") as f: + compile_time = float(f.readline()) # record aggregate & final statistics in the metrics file additional_data = { "Epoch": epoch, "Global step": global_step, "Microstep": training_ustep } average_throughput = round(sum(logger.throughputs)/len(logger.throughputs), 4) + model_flops_utilization = extract_mfu(len(model.bert.encoder.layer), model.bert.config.hidden_size, train_dataloader.dataset.sequence_length, flags.batch_size, average_throughput, world_size) + metric_data = [ + Metric("FinalLoss", final_loss, units="", additional=additional_data), + Metric("TimeToTrain", round(time_diff/60, 4), units="minutes", additional=additional_data), + Metric("Compile Time", compile_time, units="sec", additional=additional_data), + Metric("PeakThroughput", max(logger.throughputs), units="seq/s", additional=additional_data), + Metric("MFU", model_flops_utilization, units="%", additional=additional_data), + ] if(flags.expected_average_throughput > 0): derived_expected_throughput = (0.95*flags.expected_average_throughput) - metric_data = [ - Metric("FinalLoss", final_loss, units="", additional=additional_data), - Metric("TimeToTrain", round(time_diff/60, 4), units="minutes", additional=additional_data), - Metric("AverageThroughput", average_throughput, units="seq/s", expected=flags.expected_average_throughput, derived=(0.95*flags.expected_average_throughput) ,additional=additional_data), - Metric("PeakThroughput", max(logger.throughputs), units="seq/s", additional=additional_data) - ] + metric_data.append( + Metric("AverageThroughput", average_throughput, units="seq/s", expected=flags.expected_average_throughput, derived=derived_expected_throughput, additional=additional_data)) post_metrics(metric_data, parameters=parameters) - assert( average_throughput >= derived_expected_throughput), "Average throughput :{} is below derived expected threshold: {}".format(average_throughput, derived_expected_throughput) + assert(average_throughput >= derived_expected_throughput), "Average throughput :{} is below derived expected threshold: {}".format(average_throughput, derived_expected_throughput) else: - - metric_data = [ - Metric("FinalLoss", final_loss, units="", additional=additional_data), - Metric("TimeToTrain", round(time_diff/60, 4), units="minutes", additional=additional_data), - Metric("AverageThroughput", average_throughput, units="seq/s", additional=additional_data), - Metric("PeakThroughput", max(logger.throughputs), units="seq/s", additional=additional_data) - ] + metric_data.append( + Metric("AverageThroughput", average_throughput, units="seq/s", additional=additional_data)) post_metrics(metric_data, parameters=parameters) return del train_device_loader @@ -611,6 +663,11 @@ def _mp_fn(index, flags): parser.add_argument('--phase2', default=False, action='store_true', help="Whether to train with seq len 512") parser.add_argument('--print_grad_norm', default=False, action='store_true', help="Whether to print grad norm") parser.add_argument('--expected_average_throughput', type=float, default=0.0, help="Expected average throughput") + parser.add_argument('--disable_dropout', default=False, action='store_true', help="Disable dropout") + parser.add_argument("--snapshot_step_list", default=None, help="comma separated list of steps to take snapshot; also used to enable snapshotting with dropout disabled (WARNNG: can take lots of disk space, esp with grad accum.)") + parser.add_argument("--snapshot_rank_list", default="0", help="comma separated list of ranks to take snapshot, or 'all' for all ranks (WARNNG: can take lots of disk space, esp with grad accum.)") + parser.add_argument("--snapshot_dump_dir", default="./snapshots", help="directory to dump snapshots; snapshot_step_list must be specified") + args = parser.parse_args(sys.argv[1:]) if args.steps_this_run < 0: @@ -619,6 +676,18 @@ def _mp_fn(index, flags): if args.enable_pt_autocast: os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1" + + # Enable HLO snapshot dump before device init (first use of 'xla' device) + # Will do more fine-grained enablement in the training function to track global step + if args.snapshot_step_list: + if is_pt21_plus: + print("Enabling snapshotting in dir: ", args.snapshot_dump_dir) + os.environ["XLA_FLAGS"] = f"--xla_dump_hlo_snapshots --xla_dump_to={args.snapshot_dump_dir}" + elif is_pt20: + print("WARNING: Snapshotting is not enabled for torch-neuronx 2.0beta; snapshot options are ignored.") + else: + print("WARNING: For torch-neuronx 1.13, please follow instructions in documentation to enable snapshotting.") + # WORLD_SIZE is set by torchrun if os.environ.get("WORLD_SIZE"): init_process_group() diff --git a/torch-neuronx/training/dp_bert_hf_pretrain/requirements.txt b/torch-neuronx/training/dp_bert_hf_pretrain/requirements.txt index c8159ee..adc6711 100644 --- a/torch-neuronx/training/dp_bert_hf_pretrain/requirements.txt +++ b/torch-neuronx/training/dp_bert_hf_pretrain/requirements.txt @@ -1,12 +1,12 @@ graphviz -tensorboard==2.6 -transformers==4.26.0 +tensorboard==2.14 +transformers==4.44.0 evaluate pillow pytest accelerate -datasets >= 1.8.0 -sentencepiece != 0.1.92 +datasets==2.19.1 +sentencepiece==0.2.0 h5py -requests -huggingface-hub<0.23 +requests==2.31.0 +huggingface-hub==0.24.5 diff --git a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh index 9f30138..9a21d40 100755 --- a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh +++ b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh @@ -52,7 +52,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then CACHE_DIR=$HOME/neuron_cache/bert/`hostname` export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR" fi - export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub + export HF_HOME=/tmp/hf_cache/ + mkdir -p $HF_HOME + if [ -e $HOME/.cache/huggingface ]; then + rsync -av $HOME/.cache/huggingface/ $HF_HOME + fi # HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time python -c "import transformers.utils as utils; utils.move_cache()" fi diff --git a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128_lamb.sh b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128_lamb.sh index be0d0d2..175b76f 100755 --- a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128_lamb.sh +++ b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128_lamb.sh @@ -53,7 +53,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then CACHE_DIR=$HOME/neuron_cache/bert/`hostname` export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR" fi - export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub + export HF_HOME=/tmp/hf_cache/ + mkdir -p $HF_HOME + if [ -e $HOME/.cache/huggingface ]; then + rsync -av $HOME/.cache/huggingface/ $HF_HOME + fi # HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time python -c "import transformers.utils as utils; utils.move_cache()" fi diff --git a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_lamb_phase2.sh b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_lamb_phase2.sh index e324ca8..ce493ae 100755 --- a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_lamb_phase2.sh +++ b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_lamb_phase2.sh @@ -60,7 +60,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then CACHE_DIR=$HOME/neuron_cache/bert/`hostname` export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR" fi - export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub + export HF_HOME=/tmp/hf_cache/ + mkdir -p $HF_HOME + if [ -e $HOME/.cache/huggingface ]; then + rsync -av $HOME/.cache/huggingface/ $HF_HOME + fi # HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time python -c "import transformers.utils as utils; utils.move_cache()" fi diff --git a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_phase2.sh b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_phase2.sh index 3af3eb2..d046089 100755 --- a/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_phase2.sh +++ b/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s512_phase2.sh @@ -60,7 +60,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then CACHE_DIR=$HOME/neuron_cache/bert/`hostname` export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR" fi - export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub + export HF_HOME=/tmp/hf_cache/ + mkdir -p $HF_HOME + if [ -e $HOME/.cache/huggingface ]; then + rsync -av $HOME/.cache/huggingface/ $HF_HOME + fi # HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time python -c "import transformers.utils as utils; utils.move_cache()" fi