diff --git a/Experiments/NLP/distillation/README.md b/Experiments/NLP/distillation/README.md new file mode 100644 index 0000000..b52ebd0 --- /dev/null +++ b/Experiments/NLP/distillation/README.md @@ -0,0 +1,163 @@ +# Distil* + +This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT, DistilRoBERTa and DistilGPT2. + +## What is Distil* + +Distil* is a class of compressed models that started with DistilBERT. DistilBERT stands for Distilled-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production. + +We have applied the same method to other Transformer architectures and released the weights: +- GPT2: on the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 16.3 compared to 21.1 for **DistilGPT2** (after fine-tuning on the train set). +- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base`'s performance on GLUE while being twice faster and 35% smaller. +- German BERT: **German DistilBERT** reaches 99% of `bert-base-german-dbmdz-cased`'s performance on German NER (CoNLL-2003). +- Multilingual BERT: **DistilmBERT** reaches 92% of Multilingual BERT's performance on XNLI while being twice faster and 25% smaller. The model supports 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). + +For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108). + +Here are the results on the dev sets of GLUE: + +| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2| STS-B| WNLI | +| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---: | +| BERT-base-uncased | **79.5** | 56.3 | 84.7 | 88.6 | 91.8 | 89.6 | 69.3 | 92.7 | 89.0 | 53.5 | +| DistilBERT-base-uncased | **77.0** | 51.3 | 82.1 | 87.5 | 89.2 | 88.5 | 59.9 | 91.3 | 86.9 | 56.3 | +| BERT-base-cased | **78.2** | 58.2 | 83.9 | 87.8 | 91.0 | 89.2 | 66.1 | 91.7 | 89.2 | 46.5 | +| DistilBERT-base-cased | **75.9** | 47.2 | 81.5 | 85.6 | 88.2 | 87.8 | 60.6 | 90.4 | 85.5 | 56.3 | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| RoBERTa-base (reported) | **83.2**/**86.4**2 | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.73 | +| DistilRoBERTa1 | **79.0**/**82.3**2 | 59.3 | 84.0 | 86.6 | 90.8 | 89.4 | 67.9 | 92.5 | 88.3 | 52.1 | + +1 We did not use the MNLI checkpoint for fine-tuning but directly perform transfer learning on the pre-trained DistilRoBERTa. + +2 Macro-score computed without WNLI. + +3 We compute this score ourselves for completeness. + +Here are the results on the *test* sets for 6 of the languages available in XNLI. The results are computed in the zero shot setting (trained on the English portion and evaluated on the target language portion): + +| Model | English | Spanish | Chinese | German | Arabic | Urdu | +| :---: | :---: | :---: | :---: | :---: | :---: | :---:| +| mBERT base cased (computed) | 82.1 | 74.6 | 69.1 | 72.3 | 66.4 | 58.5 | +| mBERT base uncased (reported)| 81.4 | 74.3 | 63.8 | 70.5 | 62.1 | 58.3 | +| DistilmBERT | 78.2 | 69.1 | 64.0 | 66.3 | 59.1 | 54.7 | + +## Setup + +This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`. + +**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breaking changes compared to v1.1.0). + + +## How to use DistilBERT + +Transformers includes five pre-trained Distil* models, currently only provided for English and German (we are investigating the possibility to train and release a multilingual version of DistilBERT): + +- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. +- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). +- `distilbert-base-cased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-cased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 65M parameters. +- `distilbert-base-cased-distilled-squad`: A finetuned version of `distilbert-base-cased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 87.1 on the dev set (for comparison, Bert `bert-base-cased` version reaches a 88.7 F1 score). +- `distilbert-base-german-cased`: DistilBERT German language model pretrained on 1/2 of the data used to pretrain Bert using distillation with the supervision of the `bert-base-german-dbmdz-cased` version of German DBMDZ Bert. For NER tasks the model reaches a F1 score of 83.49 on the CoNLL-2003 test set (for comparison, `bert-base-german-dbmdz-cased` reaches a 84.52 F1 score), and a F1 score of 85.23 on the GermEval 2014 test set (`bert-base-german-dbmdz-cased` reaches a 86.89 F1 score). +- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset. The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2. +- `distilroberta-base`: DistilRoBERTa English language model pretrained with the supervision of `roberta-base` solely on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset (it is ~4 times less training data than the teacher RoBERTa). The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 125M parameters for RoBERTa-base). On average DistilRoBERTa is twice as fast as Roberta-base. +- `distilbert-base-multilingual-cased`: DistilmBERT multilingual model pretrained with the supervision of `bert-base-multilingual-cased` on the concatenation of Wikipedia in 104 different languages. The model supports the 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base. + +Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models. + +```python +tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') +model = DistilBertModel.from_pretrained('distilbert-base-cased') + +input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) +outputs = model(input_ids) +last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple +``` + +Similarly, using the other Distil* models simply consists in calling the base classes with a different pretrained checkpoint: +- DistilBERT uncased: `model = DistilBertModel.from_pretrained('distilbert-base-uncased')` +- DistilGPT2: `model = GPT2Model.from_pretrained('distilgpt2')` +- DistilRoBERTa: `model = RobertaModel.from_pretrained('distilroberta-base')` +- DistilmBERT: `model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')` + + +## How to train Distil* + +In the following, we will explain how you can train DistilBERT. + +### A. Preparing the data + +The weights we release are trained using a concatenation of Toronto Book Corpus and English Wikipedia (same training data as the English version of BERT). + +To avoid processing the data several time, we do it once and for all before the training. From now on, will suppose that you have a text file `dump.txt` which contains one sequence per line (a sequence being composed of one of several coherent sentences). + +First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model's vocabulary. + +```bash +python scripts/binarized_data.py \ + --file_path data/dump.txt \ + --tokenizer_type bert \ + --tokenizer_name bert-base-uncased \ + --dump_file data/binarized_text +``` + +Our implementation of masked language modeling loss follows [XLM](https://github.com/facebookresearch/XLM)'s one and smooths the probability of masking with a factor that put more emphasis on rare words. Thus we count the occurrences of each tokens in the data: + +```bash +python scripts/token_counts.py \ + --data_file data/binarized_text.bert-base-uncased.pickle \ + --token_counts_dump data/token_counts.bert-base-uncased.pickle \ + --vocab_size 30522 +``` + +### B. Training + +Training with distillation is really simple once you have pre-processed the data: + +```bash +python train.py \ + --student_type distilbert \ + --student_config training_configs/distilbert-base-uncased.json \ + --teacher_type bert \ + --teacher_name bert-base-uncased \ + --alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \ + --freeze_pos_embs \ + --dump_path serialization_dir/my_first_training \ + --data_file data/binarized_text.bert-base-uncased.pickle \ + --token_counts data/token_counts.bert-base-uncased.pickle \ + --force # overwrites the `dump_path` if it already exists. +``` + +By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them. + +We highly encourage you to use distributed training for training DistilBERT as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs: + +```bash +export NODE_RANK=0 +export N_NODES=1 + +export N_GPU_NODE=4 +export WORLD_SIZE=4 +export MASTER_PORT= +export MASTER_ADDR= + +pkill -f 'python -u train.py' + +python -m torch.distributed.launch \ + --nproc_per_node=$N_GPU_NODE \ + --nnodes=$N_NODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + train.py \ + --force \ + --n_gpu $WORLD_SIZE \ + --student_type distilbert \ + --student_config training_configs/distilbert-base-uncased.json \ + --teacher_type bert \ + --teacher_name bert-base-uncased \ + --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --mlm \ + --freeze_pos_embs \ + --dump_path serialization_dir/my_first_training \ + --data_file data/binarized_text.bert-base-uncased.pickle \ + --token_counts data/token_counts.bert-base-uncased.pickle +``` + +**Tips:** Starting distilled training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract.py` and `scripts/extract_distilbert.py` to create a valid initialization checkpoint and use `--student_pretrained_weights` argument to use this initialization for the distilled training! diff --git a/Experiments/NLP/distillation/distiller.py b/Experiments/NLP/distillation/distiller.py new file mode 100644 index 0000000..963af97 --- /dev/null +++ b/Experiments/NLP/distillation/distiller.py @@ -0,0 +1,601 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The distiller to distil the student. +Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) +""" + +import math +import os +import time + +import psutil +import torch +from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups +from lm_seqs_dataset import LmSeqsDataset +from torch import nn +from torch.optim import AdamW +from torch.utils.data import BatchSampler, DataLoader, RandomSampler +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from transformers import get_linear_schedule_with_warmup +from utils import logger + + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + from tensorboardX import SummaryWriter + + +class Distiller: + def __init__( + self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module + ): + logger.info("Initializing Distiller") + self.params = params + self.dump_path = params.dump_path + self.multi_gpu = params.multi_gpu + self.fp16 = params.fp16 + + self.student = student + self.teacher = teacher + + self.student_config = student.config + self.vocab_size = student.config.vocab_size + + if params.n_gpu <= 1: + sampler = RandomSampler(dataset) + else: + sampler = DistributedSampler(dataset) + + if params.group_by_size: + groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size) + sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size) + else: + sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False) + + self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences) + + self.temperature = params.temperature + assert self.temperature > 0.0 + + self.alpha_ce = params.alpha_ce + self.alpha_mlm = params.alpha_mlm + self.alpha_clm = params.alpha_clm + self.alpha_mse = params.alpha_mse + self.alpha_cos = params.alpha_cos + + self.mlm = params.mlm + if self.mlm: + logger.info("Using MLM loss for LM step.") + self.mlm_mask_prop = params.mlm_mask_prop + assert 0.0 <= self.mlm_mask_prop <= 1.0 + assert params.word_mask + params.word_keep + params.word_rand == 1.0 + self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) + self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs + self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs + if self.fp16: + self.pred_probs = self.pred_probs.half() + self.token_probs = self.token_probs.half() + else: + logger.info("Using CLM loss for LM step.") + + self.epoch = 0 + self.n_iter = 0 + self.n_total_iter = 0 + self.n_sequences_epoch = 0 + self.total_loss_epoch = 0 + self.last_loss = 0 + self.last_loss_ce = 0 + self.last_loss_mlm = 0 + self.last_loss_clm = 0 + if self.alpha_mse > 0.0: + self.last_loss_mse = 0 + if self.alpha_cos > 0.0: + self.last_loss_cos = 0 + self.last_log = 0 + + self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") + self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100) + if self.alpha_mse > 0.0: + self.mse_loss_fct = nn.MSELoss(reduction="sum") + if self.alpha_cos > 0.0: + self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean") + + logger.info("--- Initializing model optimizer") + assert params.gradient_accumulation_steps >= 1 + self.num_steps_epoch = len(self.dataloader) + num_train_optimization_steps = ( + int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 + ) + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad + ], + "weight_decay": params.weight_decay, + }, + { + "params": [ + p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad + ], + "weight_decay": 0.0, + }, + ] + logger.info( + "------ Number of trainable parameters (student): %i" + % sum([p.numel() for p in self.student.parameters() if p.requires_grad]) + ) + logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) + self.optimizer = AdamW( + optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98) + ) + + warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps + ) + + if self.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level") + self.student, self.optimizer = amp.initialize( + self.student, self.optimizer, opt_level=self.params.fp16_opt_level + ) + self.teacher = self.teacher.half() + + if self.multi_gpu: + if self.fp16: + from apex.parallel import DistributedDataParallel + + logger.info("Using apex.parallel.DistributedDataParallel for distributed training.") + self.student = DistributedDataParallel(self.student) + else: + from torch.nn.parallel import DistributedDataParallel + + logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") + self.student = DistributedDataParallel( + self.student, + device_ids=[params.local_rank], + output_device=params.local_rank, + find_unused_parameters=True, + ) + + self.is_master = params.is_master + if self.is_master: + logger.info("--- Initializing Tensorboard") + self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train")) + self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0) + self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0) + + def prepare_batch_mlm(self, batch): + """ + Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM. + + Input: + ------ + batch: `Tuple` + token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded. + lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch. + + Output: + ------- + token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. + attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. + mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict. + """ + token_ids, lengths = batch + token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) + assert token_ids.size(0) == lengths.size(0) + + attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None] + + bs, max_seq_len = token_ids.size() + mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids) + + x_prob = self.token_probs[token_ids.flatten()] + n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) + tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) + pred_mask = torch.zeros( + bs * max_seq_len, dtype=torch.bool, device=token_ids.device + ) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility + pred_mask[tgt_ids] = 1 + pred_mask = pred_mask.view(bs, max_seq_len) + + pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0 + + # mask a number of words == 0 [8] (faster with fp16) + if self.fp16: + n1 = pred_mask.sum().item() + if n1 > 8: + pred_mask = pred_mask.view(-1) + n2 = max(n1 % 8, 8 * (n1 // 8)) + if n2 != n1: + pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0 + pred_mask = pred_mask.view(bs, max_seq_len) + assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() + + _token_ids_real = token_ids[pred_mask] + _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size) + _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"]) + probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) + _token_ids = ( + _token_ids_mask * (probs == 0).long() + + _token_ids_real * (probs == 1).long() + + _token_ids_rand * (probs == 2).long() + ) + token_ids = token_ids.masked_scatter(pred_mask, _token_ids) + + mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility + + # sanity checks + assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size + + return token_ids, attn_mask, mlm_labels + + def prepare_batch_clm(self, batch): + """ + Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM. + + Input: + ------ + batch: `Tuple` + token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded. + lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch. + + Output: + ------- + token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. + attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. + clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict. + """ + token_ids, lengths = batch + token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) + assert token_ids.size(0) == lengths.size(0) + + attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None] + clm_labels = token_ids.new(token_ids.size()).copy_(token_ids) + clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility + + # sanity checks + assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size + + return token_ids, attn_mask, clm_labels + + def round_batch(self, x: torch.tensor, lengths: torch.tensor): + """ + For float16 only. + Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8. + + Input: + ------ + x: `torch.tensor(bs, seq_length)` - The token ids. + lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch. + + Output: + ------- + x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids. + lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths. + """ + if not self.fp16 or len(lengths) < 8: + return x, lengths + + # number of sentences == 0 [8] + bs1 = len(lengths) + bs2 = 8 * (bs1 // 8) + assert bs2 > 0 and bs2 % 8 == 0 + if bs1 != bs2: + idx = torch.randperm(bs1)[:bs2] + lengths = lengths[idx] + slen = lengths.max().item() + x = x[idx, :slen] + else: + idx = None + + # sequence length == 0 [8] + ml1 = x.size(1) + if ml1 % 8 != 0: + pad = 8 - (ml1 % 8) + ml2 = ml1 + pad + if self.mlm: + pad_id = self.params.special_tok_ids["pad_token"] + else: + pad_id = self.params.special_tok_ids["unk_token"] + padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) + x = torch.cat([x, padding_tensor], 1) + assert x.size() == (bs2, ml2) + + assert x.size(0) % 8 == 0 + assert x.size(1) % 8 == 0 + return x, lengths + + def train(self): + """ + The real training loop. + """ + if self.is_master: + logger.info("Starting training") + self.last_log = time.time() + self.student.train() + self.teacher.eval() + + for _ in range(self.params.n_epoch): + if self.is_master: + logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}") + if self.multi_gpu: + torch.distributed.barrier() + + iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) + for batch in iter_bar: + if self.params.n_gpu > 0: + batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch) + + if self.mlm: + token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch) + else: + token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch) + self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels) + + iter_bar.update() + iter_bar.set_postfix( + {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"} + ) + iter_bar.close() + + if self.is_master: + logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}") + self.end_epoch() + + if self.is_master: + logger.info("Save very last checkpoint as `pytorch_model.bin`.") + self.save_checkpoint(checkpoint_name="pytorch_model.bin") + logger.info("Training is finished") + + def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor): + """ + One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), + and possibly a parameter update (depending on the gradient accumulation). + + Input: + ------ + input_ids: `torch.tensor(bs, seq_length)` - The token ids. + attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. + lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM). + """ + if self.mlm: + student_outputs = self.student( + input_ids=input_ids, attention_mask=attention_mask + ) # (bs, seq_length, voc_size) + with torch.no_grad(): + teacher_outputs = self.teacher( + input_ids=input_ids, attention_mask=attention_mask + ) # (bs, seq_length, voc_size) + else: + student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) + with torch.no_grad(): + teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) + s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"] + t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"] + assert s_logits.size() == t_logits.size() + + # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 + if self.params.restrict_ce_to_mask: + mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size) + else: + mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size) + s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask + s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask + t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + assert t_logits_slct.size() == s_logits_slct.size() + + loss_ce = ( + self.ce_loss_fct( + nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1), + nn.functional.softmax(t_logits_slct / self.temperature, dim=-1), + ) + * (self.temperature) ** 2 + ) + loss = self.alpha_ce * loss_ce + + if self.alpha_mlm > 0.0: + loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1)) + loss += self.alpha_mlm * loss_mlm + if self.alpha_clm > 0.0: + shift_logits = s_logits[..., :-1, :].contiguous() + shift_labels = lm_labels[..., 1:].contiguous() + loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss += self.alpha_clm * loss_clm + + if self.alpha_mse > 0.0: + loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size( + 0 + ) # Reproducing batchmean reduction + loss += self.alpha_mse * loss_mse + if self.alpha_cos > 0.0: + s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) + t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) + mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim) + assert s_hidden_states.size() == t_hidden_states.size() + dim = s_hidden_states.size(-1) + + s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim) + s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim) + t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim) + t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim) + + target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,) + loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target) + loss += self.alpha_cos * loss_cos + + self.total_loss_epoch += loss.item() + self.last_loss = loss.item() + self.last_loss_ce = loss_ce.item() + if self.alpha_mlm > 0.0: + self.last_loss_mlm = loss_mlm.item() + if self.alpha_clm > 0.0: + self.last_loss_clm = loss_clm.item() + if self.alpha_mse > 0.0: + self.last_loss_mse = loss_mse.item() + if self.alpha_cos > 0.0: + self.last_loss_cos = loss_cos.item() + + self.optimize(loss) + + self.n_sequences_epoch += input_ids.size(0) + + def optimize(self, loss): + """ + Normalization on the loss (gradient accumulation or distributed training), followed by + backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). + Also update the metrics for tensorboard. + """ + # Check for NaN + if (loss != loss).data.any(): + logger.error("NaN detected") + exit() + + if self.multi_gpu: + loss = loss.mean() + if self.params.gradient_accumulation_steps > 1: + loss = loss / self.params.gradient_accumulation_steps + + if self.fp16: + from apex import amp + + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + self.iter() + if self.n_iter % self.params.gradient_accumulation_steps == 0: + if self.fp16: + nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) + else: + nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + def iter(self): + """ + Update global counts, write to tensorboard and save checkpoint. + """ + self.n_iter += 1 + self.n_total_iter += 1 + + if self.n_total_iter % self.params.log_interval == 0: + self.log_tensorboard() + self.last_log = time.time() + if self.n_total_iter % self.params.checkpoint_interval == 0: + self.save_checkpoint() + + def log_tensorboard(self): + """ + Log into tensorboard. Only by the master process. + """ + if not self.is_master: + return + + for param_name, param in self.student.named_parameters(): + self.tensorboard.add_scalar( + tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter + ) + self.tensorboard.add_scalar( + tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter + ) + if param.grad is None: + continue + self.tensorboard.add_scalar( + tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter + ) + self.tensorboard.add_scalar( + tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter + ) + + self.tensorboard.add_scalar( + tag="losses/cum_avg_loss_epoch", + scalar_value=self.total_loss_epoch / self.n_iter, + global_step=self.n_total_iter, + ) + self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter) + self.tensorboard.add_scalar( + tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter + ) + if self.alpha_mlm > 0.0: + self.tensorboard.add_scalar( + tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter + ) + if self.alpha_clm > 0.0: + self.tensorboard.add_scalar( + tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter + ) + if self.alpha_mse > 0.0: + self.tensorboard.add_scalar( + tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter + ) + if self.alpha_cos > 0.0: + self.tensorboard.add_scalar( + tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter + ) + self.tensorboard.add_scalar( + tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter + ) + + self.tensorboard.add_scalar( + tag="global/memory_usage", + scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000, + global_step=self.n_total_iter, + ) + self.tensorboard.add_scalar( + tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter + ) + + def end_epoch(self): + """ + Finally arrived at the end of epoch (full pass on dataset). + Do some tensorboard logging and checkpoint saving. + """ + logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.") + + if self.is_master: + self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth") + self.tensorboard.add_scalar( + tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch + ) + + self.epoch += 1 + self.n_sequences_epoch = 0 + self.n_iter = 0 + self.total_loss_epoch = 0 + + def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"): + """ + Save the current state. Only by the master process. + """ + if not self.is_master: + return + mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student + mdl_to_save.config.save_pretrained(self.dump_path) + state_dict = mdl_to_save.state_dict() + torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name)) diff --git a/Experiments/NLP/distillation/grouped_batch_sampler.py b/Experiments/NLP/distillation/grouped_batch_sampler.py new file mode 100644 index 0000000..e25def7 --- /dev/null +++ b/Experiments/NLP/distillation/grouped_batch_sampler.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)""" + +import bisect +import copy +from collections import defaultdict + +import numpy as np +from torch.utils.data import BatchSampler, Sampler + +from utils import logger + + +def _quantize(x, bins): + bins = copy.deepcopy(bins) + bins = sorted(bins) + quantized = [bisect.bisect_right(bins, y) for y in x] + return quantized + + +def create_lengths_groups(lengths, k=0): + bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10] + groups = _quantize(lengths, bins) + # count number of elements per group + counts = np.unique(groups, return_counts=True)[1] + fbins = [0] + bins + [np.inf] + logger.info("Using {} as bins for aspect lengths quantization".format(fbins)) + logger.info("Count of instances per bin: {}".format(counts)) + return groups + + +class GroupedBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices. + It enforces that the batch only contain elements from the same group. + It also tries to provide mini-batches which follows an ordering which is + as close as possible to the ordering from the original sampler. + Arguments: + sampler (Sampler): Base sampler. + group_ids (list[int]): If the sampler produces indices in range [0, N), + `group_ids` must be a list of `N` ints which contains the group id of each sample. + The group ids must be a continuous set of integers starting from + 0, i.e. they must be in the range [0, num_groups). + batch_size (int): Size of mini-batch. + """ + + def __init__(self, sampler, group_ids, batch_size): + if not isinstance(sampler, Sampler): + raise TypeError( + "sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler) + ) + self.sampler = sampler + self.group_ids = group_ids + self.batch_size = batch_size + + def __iter__(self): + buffer_per_group = defaultdict(list) + samples_per_group = defaultdict(list) + + num_batches = 0 + for idx in self.sampler: + group_id = self.group_ids[idx] + buffer_per_group[group_id].append(idx) + samples_per_group[group_id].append(idx) + if len(buffer_per_group[group_id]) == self.batch_size: + yield buffer_per_group[group_id] # TODO + num_batches += 1 + del buffer_per_group[group_id] + assert len(buffer_per_group[group_id]) < self.batch_size + + # now we have run out of elements that satisfy + # the group criteria, let's return the remaining + # elements so that the size of the sampler is + # deterministic + expected_num_batches = len(self) + num_remaining = expected_num_batches - num_batches + if num_remaining > 0: + # for the remaining batches, group the batches by similar lengths + batch_idx = [] + for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]): + batch_idx.extend(idxs) + if len(batch_idx) >= self.batch_size: + yield batch_idx[: self.batch_size] + batch_idx = batch_idx[self.batch_size :] + num_remaining -= 1 + if len(batch_idx) > 0: + yield batch_idx + num_remaining -= 1 + assert num_remaining == 0 + + def __len__(self): + """ + Return the number of mini-batches rather than the number of samples. + """ + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/Experiments/NLP/distillation/lm_seqs_dataset.py b/Experiments/NLP/distillation/lm_seqs_dataset.py new file mode 100644 index 0000000..647c8f4 --- /dev/null +++ b/Experiments/NLP/distillation/lm_seqs_dataset.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataset to distilled models +adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) +""" + +import numpy as np +import torch +from torch.utils.data import Dataset + +from utils import logger + + +class LmSeqsDataset(Dataset): + """Custom Dataset wrapping language modeling sequences. + + Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths. + + Input: + ------ + params: `NameSpace` parameters + data: `List[np.array[int]] + """ + + def __init__(self, params, data): + self.params = params + + self.token_ids = np.array(data) + self.lengths = np.array([len(t) for t in data]) + + self.check() + self.remove_long_sequences() + self.remove_empty_sequences() + self.remove_unknown_sequences() + self.check() + self.print_statistics() + + def __getitem__(self, index): + return (self.token_ids[index], self.lengths[index]) + + def __len__(self): + return len(self.lengths) + + def check(self): + """ + Some sanity checks + """ + assert len(self.token_ids) == len(self.lengths) + assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths))) + + def remove_long_sequences(self): + """ + Sequences that are too long are split by chunk of max_model_input_size. + """ + max_len = self.params.max_model_input_size + indices = self.lengths > max_len + logger.info(f"Splitting {sum(indices)} too long sequences.") + + def divide_chunks(l, n): + return [l[i : i + n] for i in range(0, len(l), n)] + + new_tok_ids = [] + new_lengths = [] + if self.params.mlm: + cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"] + else: + cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"] + + for seq_, len_ in zip(self.token_ids, self.lengths): + assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ + if len_ <= max_len: + new_tok_ids.append(seq_) + new_lengths.append(len_) + else: + sub_seqs = [] + for sub_s in divide_chunks(seq_, max_len - 2): + if sub_s[0] != cls_id: + sub_s = np.insert(sub_s, 0, cls_id) + if sub_s[-1] != sep_id: + sub_s = np.insert(sub_s, len(sub_s), sep_id) + assert len(sub_s) <= max_len + assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s + sub_seqs.append(sub_s) + + new_tok_ids.extend(sub_seqs) + new_lengths.extend([len(l) for l in sub_seqs]) + + self.token_ids = np.array(new_tok_ids) + self.lengths = np.array(new_lengths) + + def remove_empty_sequences(self): + """ + Too short sequences are simply removed. This could be tuned. + """ + init_size = len(self) + indices = self.lengths > 11 + self.token_ids = self.token_ids[indices] + self.lengths = self.lengths[indices] + new_size = len(self) + logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.") + + def remove_unknown_sequences(self): + """ + Remove sequences with a (too) high level of unknown tokens. + """ + if "unk_token" not in self.params.special_tok_ids: + return + else: + unk_token_id = self.params.special_tok_ids["unk_token"] + init_size = len(self) + unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) + indices = (unk_occs / self.lengths) < 0.5 + self.token_ids = self.token_ids[indices] + self.lengths = self.lengths[indices] + new_size = len(self) + logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).") + + def print_statistics(self): + """ + Print some statistics on the corpus. Only the master process. + """ + if not self.params.is_master: + return + logger.info(f"{len(self)} sequences") + # data_len = sum(self.lengths) + # nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) + # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') + + # unk_idx = self.params.special_tok_ids['unk_token'] + # nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids]) + # logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)') + + def batch_sequences(self, batch): + """ + Do the padding and transform into torch.tensor. + """ + token_ids = [t[0] for t in batch] + lengths = [t[1] for t in batch] + assert len(token_ids) == len(lengths) + + # Max for paddings + max_seq_len_ = max(lengths) + + # Pad token ids + if self.params.mlm: + pad_idx = self.params.special_tok_ids["pad_token"] + else: + pad_idx = self.params.special_tok_ids["unk_token"] + tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids] + assert len(tk_) == len(token_ids) + assert all(len(t) == max_seq_len_ for t in tk_) + + tk_t = torch.tensor(tk_) # (bs, max_seq_len_) + lg_t = torch.tensor(lengths) # (bs) + return tk_t, lg_t diff --git a/Experiments/NLP/distillation/requirements.txt b/Experiments/NLP/distillation/requirements.txt new file mode 100644 index 0000000..4a2ed78 --- /dev/null +++ b/Experiments/NLP/distillation/requirements.txt @@ -0,0 +1,7 @@ +transformers + +gitpython==3.1.41 +tensorboard>=1.14.0 +tensorboardX==1.8 +psutil==5.6.6 +scipy>=1.4.1 diff --git a/Experiments/NLP/distillation/run_squad_w_distillation.py b/Experiments/NLP/distillation/run_squad_w_distillation.py new file mode 100644 index 0000000..a1150f6 --- /dev/null +++ b/Experiments/NLP/distillation/run_squad_w_distillation.py @@ -0,0 +1,877 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is the exact same script as `examples/question-answering/run_squad.py` (as of 2020, January 8th) with an additional and optional step of distillation.""" + +import argparse +import glob +import logging +import os +import random +import timeit + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm, trange + +import transformers +from transformers import ( + WEIGHTS_NAME, + AdamW, + BertConfig, + BertForQuestionAnswering, + BertTokenizer, + DistilBertConfig, + DistilBertForQuestionAnswering, + DistilBertTokenizer, + RobertaConfig, + RobertaForQuestionAnswering, + RobertaTokenizer, + XLMConfig, + XLMForQuestionAnswering, + XLMTokenizer, + XLNetConfig, + XLNetForQuestionAnswering, + XLNetTokenizer, + get_linear_schedule_with_warmup, + squad_convert_examples_to_features, +) +from transformers.data.metrics.squad_metrics import ( + compute_predictions_log_probs, + compute_predictions_logits, + squad_evaluate, +) +from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor +from transformers.trainer_utils import is_main_process + + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + from tensorboardX import SummaryWriter + + +logger = logging.getLogger(__name__) + + +MODEL_CLASSES = { + "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), + "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), + "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), + "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), + "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer), +} + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def to_list(tensor): + return tensor.detach().cpu().tolist() + + +def train(args, train_dataset, model, tokenizer, teacher=None): + """Train the model""" + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total + ) + + # Check if saved optimizer or scheduler states exist + if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( + os.path.join(args.model_name_or_path, "scheduler.pt") + ): + # Load in optimizer and scheduler states + optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) + scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) + + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = nn.DataParallel(model) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = nn.parallel.DistributedDataParallel( + model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True + ) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size + * args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), + ) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 1 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + # Check if continuing training from a checkpoint + if os.path.exists(args.model_name_or_path): + try: + # set global_step to global_step of last saved checkpoint from model path + checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] + global_step = int(checkpoint_suffix) + epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) + steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(" Continuing training from epoch %d", epochs_trained) + logger.info(" Continuing training from global step %d", global_step) + logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + except ValueError: + logger.info(" Starting fine-tuning.") + + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange( + epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] + ) + # Added here for reproducibility + set_seed(args) + + for _ in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + for step, batch in enumerate(epoch_iterator): + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + continue + + model.train() + if teacher is not None: + teacher.eval() + batch = tuple(t.to(args.device) for t in batch) + + inputs = { + "input_ids": batch[0], + "attention_mask": batch[1], + "start_positions": batch[3], + "end_positions": batch[4], + } + if args.model_type != "distilbert": + inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] + if args.model_type in ["xlnet", "xlm"]: + inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) + if args.version_2_with_negative: + inputs.update({"is_impossible": batch[7]}) + outputs = model(**inputs) + loss, start_logits_stu, end_logits_stu = outputs + + # Distillation loss + if teacher is not None: + if "token_type_ids" not in inputs: + inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2] + with torch.no_grad(): + start_logits_tea, end_logits_tea = teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], + ) + assert start_logits_tea.size() == start_logits_stu.size() + assert end_logits_tea.size() == end_logits_stu.size() + + loss_fct = nn.KLDivLoss(reduction="batchmean") + loss_start = loss_fct( + nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1), + nn.functional.softmax(start_logits_tea / args.temperature, dim=-1), + ) * (args.temperature**2) + loss_end = loss_fct( + nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1), + nn.functional.softmax(end_logits_tea / args.temperature, dim=-1), + ) * (args.temperature**2) + loss_ce = (loss_start + loss_end) / 2.0 + + loss = args.alpha_ce * loss_ce + args.alpha_squad * loss + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + # Log metrics + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + # Only evaluate when single GPU otherwise metrics may not average well + if args.local_rank == -1 and args.evaluate_during_training: + results = evaluate(args, model, tokenizer) + for key, value in results.items(): + tb_writer.add_scalar("eval_{}".format(key), value, global_step) + tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) + tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) + logging_loss = tr_loss + + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + # Save model checkpoint + output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + torch.save(args, os.path.join(output_dir, "training_args.bin")) + logger.info("Saving model checkpoint to %s", output_dir) + + torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + logger.info("Saving optimizer and scheduler states to %s", output_dir) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, model, tokenizer, prefix=""): + dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) + + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(dataset) + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # multi-gpu evaluate + if args.n_gpu > 1 and not isinstance(model, nn.DataParallel): + model = nn.DataParallel(model) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + + all_results = [] + start_time = timeit.default_timer() + + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = tuple(t.to(args.device) for t in batch) + + with torch.no_grad(): + inputs = {"input_ids": batch[0], "attention_mask": batch[1]} + if args.model_type != "distilbert": + inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids + example_indices = batch[3] + if args.model_type in ["xlnet", "xlm"]: + inputs.update({"cls_index": batch[4], "p_mask": batch[5]}) + + outputs = model(**inputs) + + for i, example_index in enumerate(example_indices): + eval_feature = features[example_index.item()] + unique_id = int(eval_feature.unique_id) + + output = [to_list(output[i]) for output in outputs] + + # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler" + # models only use two. + if len(output) >= 5: + start_logits = output[0] + start_top_index = output[1] + end_logits = output[2] + end_top_index = output[3] + cls_logits = output[4] + + result = SquadResult( + unique_id, + start_logits, + end_logits, + start_top_index=start_top_index, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + else: + start_logits, end_logits = output + result = SquadResult(unique_id, start_logits, end_logits) + + all_results.append(result) + + evalTime = timeit.default_timer() - start_time + logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) + + # Compute predictions + output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) + output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) + + if args.version_2_with_negative: + output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) + else: + output_null_log_odds_file = None + + if args.model_type in ["xlnet", "xlm"]: + # XLNet uses a more complex post-processing procedure + predictions = compute_predictions_log_probs( + examples, + features, + all_results, + args.n_best_size, + args.max_answer_length, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + model.config.start_n_top, + model.config.end_n_top, + args.version_2_with_negative, + tokenizer, + args.verbose_logging, + ) + else: + predictions = compute_predictions_logits( + examples, + features, + all_results, + args.n_best_size, + args.max_answer_length, + args.do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + args.verbose_logging, + args.version_2_with_negative, + args.null_score_diff_threshold, + tokenizer, + ) + + # Compute the F1 and exact scores. + results = squad_evaluate(examples, predictions) + return results + + +def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): + if args.local_rank not in [-1, 0] and not evaluate: + # Make sure only the first process in distributed training process the dataset, and the others will use the cache + torch.distributed.barrier() + + # Load data features from cache or dataset file + input_file = args.predict_file if evaluate else args.train_file + cached_features_file = os.path.join( + os.path.dirname(input_file), + "cached_distillation_{}_{}_{}".format( + "dev" if evaluate else "train", + list(filter(None, args.model_name_or_path.split("/"))).pop(), + str(args.max_seq_length), + ), + ) + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + features_and_dataset = torch.load(cached_features_file) + + try: + features, dataset, examples = ( + features_and_dataset["features"], + features_and_dataset["dataset"], + features_and_dataset["examples"], + ) + except KeyError: + raise DeprecationWarning( + "You seem to be loading features from an older version of this script please delete the " + "file %s in order for it to be created again" % cached_features_file + ) + else: + logger.info("Creating features from dataset file at %s", input_file) + processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() + if evaluate: + examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file) + else: + examples = processor.get_train_examples(args.data_dir, filename=args.train_file) + + features, dataset = squad_convert_examples_to_features( + examples=examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=not evaluate, + return_dataset="pt", + threads=args.threads, + ) + + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_features_file) + torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) + + if args.local_rank == 0 and not evaluate: + # Make sure only the first process in distributed training process the dataset, and the others will use the cache + torch.distributed.barrier() + + if output_examples: + return dataset, examples, features + return dataset + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models", + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model checkpoints and predictions will be written.", + ) + + # Distillation parameters (optional) + parser.add_argument( + "--teacher_type", + default=None, + type=str, + help=( + "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for" + " distillation." + ), + ) + parser.add_argument( + "--teacher_name_or_path", + default=None, + type=str, + help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.", + ) + parser.add_argument( + "--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation." + ) + parser.add_argument( + "--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation." + ) + parser.add_argument( + "--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation." + ) + + # Other parameters + parser.add_argument( + "--data_dir", + default=None, + type=str, + help="The input data dir. Should contain the .json files for the task." + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--train_file", + default=None, + type=str, + help="The input training file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--predict_file", + default=None, + type=str, + help="The input evaluation file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="Where do you want to store the pre-trained models downloaded from huggingface.co", + ) + + parser.add_argument( + "--version_2_with_negative", + action="store_true", + help="If true, the SQuAD examples contain some that do not have an answer.", + ) + parser.add_argument( + "--null_score_diff_threshold", + type=float, + default=0.0, + help="If null_score - best_non_null is greater than the threshold predict null.", + ) + + parser.add_argument( + "--max_seq_length", + default=384, + type=int, + help=( + "The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded." + ), + ) + parser.add_argument( + "--doc_stride", + default=128, + type=int, + help="When splitting up a long document into chunks, how much stride to take between chunks.", + ) + parser.add_argument( + "--max_query_length", + default=64, + type=int, + help=( + "The maximum number of tokens for the question. Questions longer than this will " + "be truncated to this length." + ), + ) + parser.add_argument("--do_train", action="store_true", help="Whether to run training.") + parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") + parser.add_argument( + "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step." + ) + parser.add_argument( + "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." + ) + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") + parser.add_argument( + "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." + ) + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform." + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument( + "--n_best_size", + default=20, + type=int, + help="The total number of n-best predictions to generate in the nbest_predictions.json output file.", + ) + parser.add_argument( + "--max_answer_length", + default=30, + type=int, + help=( + "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + ), + ) + parser.add_argument( + "--verbose_logging", + action="store_true", + help=( + "If true, all of the warnings related to data processing will be printed. " + "A number of warnings are expected for a normal SQuAD evaluation." + ), + ) + + parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.") + parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--eval_all_checkpoints", + action="store_true", + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", + ) + parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") + parser.add_argument( + "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O1", + help=( + "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " + "See details at https://nvidia.github.io/apex/amp.html" + ), + ) + parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") + parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") + + parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") + args = parser.parse_args() + + if ( + os.path.exists(args.output_dir) + and os.listdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + raise ValueError( + "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( + args.output_dir + ) + ) + + # Setup distant debugging if needed + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, + device, + args.n_gpu, + bool(args.local_rank != -1), + args.fp16, + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + args.model_type = args.model_type.lower() + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + tokenizer = tokenizer_class.from_pretrained( + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + model = model_class.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + + if args.teacher_type is not None: + assert args.teacher_name_or_path is not None + assert args.alpha_ce > 0.0 + assert args.alpha_ce + args.alpha_squad > 0.0 + assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT." + teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type] + teacher_config = teacher_config_class.from_pretrained( + args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None + ) + teacher = teacher_model_class.from_pretrained( + args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None + ) + teacher.to(args.device) + else: + teacher = None + + if args.local_rank == 0: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + model.to(args.device) + + logger.info("Training/evaluation parameters %s", args) + + # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. + # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will + # remove the need for this code, but it is still valid. + if args.fp16: + try: + import apex + + apex.amp.register_half_function(torch, "einsum") + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + + # Training + if args.do_train: + train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) + global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + # Save the trained model and the tokenizer + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + + # Load a trained model and vocabulary that you have fine-tuned + model = model_class.from_pretrained(args.output_dir) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model.to(args.device) + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + if args.do_train: + logger.info("Loading checkpoints saved during training for evaluation") + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = [ + os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) + ] + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + + for checkpoint in checkpoints: + # Reload the model + global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + + # Evaluate + result = evaluate(args, model, tokenizer, prefix=global_step) + + result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()} + results.update(result) + + logger.info("Results: {}".format(results)) + + return results + + +if __name__ == "__main__": + main() diff --git a/Experiments/NLP/distillation/scripts/binarized_data.py b/Experiments/NLP/distillation/scripts/binarized_data.py new file mode 100644 index 0000000..3fc3214 --- /dev/null +++ b/Experiments/NLP/distillation/scripts/binarized_data.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocessing script before distillation. +""" + +import argparse +import logging +import pickle +import random +import time + +import numpy as np + +from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)." + ) + parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.") + parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"]) + parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.") + parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.") + args = parser.parse_args() + + logger.info(f"Loading Tokenizer ({args.tokenizer_name})") + if args.tokenizer_type == "bert": + tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) + bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]` + sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]` + elif args.tokenizer_type == "roberta": + tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) + bos = tokenizer.special_tokens_map["cls_token"] # `` + sep = tokenizer.special_tokens_map["sep_token"] # `` + elif args.tokenizer_type == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name) + bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>` + sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>` + + logger.info(f"Loading text from {args.file_path}") + with open(args.file_path, "r", encoding="utf8") as fp: + data = fp.readlines() + + logger.info("Start encoding") + logger.info(f"{len(data)} examples to process.") + + rslt = [] + iter = 0 + interval = 10000 + start = time.time() + for text in data: + text = f"{bos} {text.strip()} {sep}" + token_ids = tokenizer.encode(text, add_special_tokens=False) + rslt.append(token_ids) + + iter += 1 + if iter % interval == 0: + end = time.time() + logger.info(f"{iter} examples processed. - {(end-start):.2f}s/{interval}expl") + start = time.time() + logger.info("Finished binarization") + logger.info(f"{len(data)} examples processed.") + + dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle" + vocab_size = tokenizer.vocab_size + if vocab_size < (1 << 16): + rslt_ = [np.uint16(d) for d in rslt] + else: + rslt_ = [np.int32(d) for d in rslt] + random.shuffle(rslt_) + logger.info(f"Dump to {dp_file}") + with open(dp_file, "wb") as handle: + pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + main() diff --git a/Experiments/NLP/distillation/scripts/extract.py b/Experiments/NLP/distillation/scripts/extract.py new file mode 100644 index 0000000..c45821d --- /dev/null +++ b/Experiments/NLP/distillation/scripts/extract.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocessing script before training the distilled model. +Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2. +""" + +import argparse + +import torch + +from transformers import GPT2LMHeadModel, RobertaForMaskedLM + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned" + " Distillation" + ) + ) + parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) + parser.add_argument("--model_name", default="roberta-large", type=str) + parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str) + parser.add_argument("--vocab_transform", action="store_true") + args = parser.parse_args() + + if args.model_type == "roberta": + model = RobertaForMaskedLM.from_pretrained(args.model_name) + prefix = "roberta" + elif args.model_type == "gpt2": + model = GPT2LMHeadModel.from_pretrained(args.model_name) + prefix = "transformer" + + state_dict = model.state_dict() + compressed_sd = {} + + # Embeddings # + if args.model_type == "gpt2": + for param_name in ["wte.weight", "wpe.weight"]: + compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"] + else: + for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]: + param_name = f"{prefix}.embeddings.{w}.weight" + compressed_sd[param_name] = state_dict[param_name] + for w in ["weight", "bias"]: + param_name = f"{prefix}.embeddings.LayerNorm.{w}" + compressed_sd[param_name] = state_dict[param_name] + + # Transformer Blocks # + std_idx = 0 + for teacher_idx in [0, 2, 4, 7, 9, 11]: + if args.model_type == "gpt2": + for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]: + for w in ["weight", "bias"]: + compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[ + f"{prefix}.h.{teacher_idx}.{layer}.{w}" + ] + compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"] + else: + for layer in [ + "attention.self.query", + "attention.self.key", + "attention.self.value", + "attention.output.dense", + "attention.output.LayerNorm", + "intermediate.dense", + "output.dense", + "output.LayerNorm", + ]: + for w in ["weight", "bias"]: + compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}" + ] + std_idx += 1 + + # Language Modeling Head ###s + if args.model_type == "roberta": + for layer in ["lm_head.decoder.weight", "lm_head.bias"]: + compressed_sd[f"{layer}"] = state_dict[f"{layer}"] + if args.vocab_transform: + for w in ["weight", "bias"]: + compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"] + compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"] + elif args.model_type == "gpt2": + for w in ["weight", "bias"]: + compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"] + compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"] + + print(f"N layers selected for distillation: {std_idx}") + print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}") + + print(f"Save transferred checkpoint to {args.dump_checkpoint}.") + torch.save(compressed_sd, args.dump_checkpoint) diff --git a/Experiments/NLP/distillation/scripts/extract_distilbert.py b/Experiments/NLP/distillation/scripts/extract_distilbert.py new file mode 100644 index 0000000..8637970 --- /dev/null +++ b/Experiments/NLP/distillation/scripts/extract_distilbert.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocessing script before training DistilBERT. +Specific to BERT -> DistilBERT. +""" + +import argparse + +import torch + +from transformers import BertForMaskedLM + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned" + " Distillation" + ) + ) + parser.add_argument("--model_type", default="bert", choices=["bert"]) + parser.add_argument("--model_name", default="bert-base-uncased", type=str) + parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str) + parser.add_argument("--vocab_transform", action="store_true") + args = parser.parse_args() + + if args.model_type == "bert": + model = BertForMaskedLM.from_pretrained(args.model_name) + prefix = "bert" + else: + raise ValueError('args.model_type should be "bert".') + + state_dict = model.state_dict() + compressed_sd = {} + + for w in ["word_embeddings", "position_embeddings"]: + compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"] + for w in ["weight", "bias"]: + compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"] + + std_idx = 0 + for teacher_idx in [0, 2, 4, 7, 9, 11]: + for w in ["weight", "bias"]: + compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}" + ] + compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}" + ] + compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}" + ] + + compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}" + ] + compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}" + ] + + compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}" + ] + compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}" + ] + compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[ + f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}" + ] + std_idx += 1 + + compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"] + compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"] + if args.vocab_transform: + for w in ["weight", "bias"]: + compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"] + compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"] + + print(f"N layers selected for distillation: {std_idx}") + print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}") + + print(f"Save transferred checkpoint to {args.dump_checkpoint}.") + torch.save(compressed_sd, args.dump_checkpoint) diff --git a/Experiments/NLP/distillation/scripts/token_counts.py b/Experiments/NLP/distillation/scripts/token_counts.py new file mode 100644 index 0000000..2f80bf3 --- /dev/null +++ b/Experiments/NLP/distillation/scripts/token_counts.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocessing script before training the distilled model. +""" + +import argparse +import logging +import pickle +from collections import Counter + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)" + ) + parser.add_argument( + "--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset." + ) + parser.add_argument( + "--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file." + ) + parser.add_argument("--vocab_size", default=30522, type=int) + args = parser.parse_args() + + logger.info(f"Loading data from {args.data_file}") + with open(args.data_file, "rb") as fp: + data = pickle.load(fp) + + logger.info("Counting occurrences for MLM.") + counter = Counter() + for tk_ids in data: + counter.update(tk_ids) + counts = [0] * args.vocab_size + for k, v in counter.items(): + counts[k] = v + + logger.info(f"Dump to {args.token_counts_dump}") + with open(args.token_counts_dump, "wb") as handle: + pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/Experiments/NLP/distillation/train.py b/Experiments/NLP/distillation/train.py new file mode 100644 index 0000000..15d98ac --- /dev/null +++ b/Experiments/NLP/distillation/train.py @@ -0,0 +1,325 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Training the distilled model. +Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2. +""" + +import argparse +import json +import os +import pickle +import shutil + +import numpy as np +import torch +from distiller import Distiller +from lm_seqs_dataset import LmSeqsDataset + +from transformers import ( + BertConfig, + BertForMaskedLM, + BertTokenizer, + DistilBertConfig, + DistilBertForMaskedLM, + DistilBertTokenizer, + GPT2Config, + GPT2LMHeadModel, + GPT2Tokenizer, + RobertaConfig, + RobertaForMaskedLM, + RobertaTokenizer, +) +from utils import git_log, init_gpu_params, logger, set_seed + + +MODEL_CLASSES = { + "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), + "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), + "bert": (BertConfig, BertForMaskedLM, BertTokenizer), + "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), +} + + +def sanity_checks(args): + """ + A bunch of args sanity checks to perform even starting... + """ + assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0) + assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0) + if args.mlm: + assert os.path.isfile(args.token_counts) + assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"]) + else: + assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"]) + + assert args.teacher_type == args.student_type or ( + args.student_type == "distilbert" and args.teacher_type == "bert" + ) + assert os.path.isfile(args.student_config) + if args.student_pretrained_weights is not None: + assert os.path.isfile(args.student_pretrained_weights) + + if args.freeze_token_type_embds: + assert args.student_type in ["roberta"] + + assert args.alpha_ce >= 0.0 + assert args.alpha_mlm >= 0.0 + assert args.alpha_clm >= 0.0 + assert args.alpha_mse >= 0.0 + assert args.alpha_cos >= 0.0 + assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0 + + +def freeze_pos_embeddings(student, args): + if args.student_type == "roberta": + student.roberta.embeddings.position_embeddings.weight.requires_grad = False + elif args.student_type == "gpt2": + student.transformer.wpe.weight.requires_grad = False + + +def freeze_token_type_embeddings(student, args): + if args.student_type == "roberta": + student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False + + +def main(): + parser = argparse.ArgumentParser(description="Training") + parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.") + + parser.add_argument( + "--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)" + ) + parser.add_argument( + "--data_file", + type=str, + required=True, + help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.", + ) + + parser.add_argument( + "--student_type", + type=str, + choices=["distilbert", "roberta", "gpt2"], + required=True, + help="The student type (DistilBERT, RoBERTa).", + ) + parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.") + parser.add_argument( + "--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint." + ) + + parser.add_argument( + "--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)." + ) + parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.") + + parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.") + parser.add_argument( + "--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0." + ) + parser.add_argument( + "--alpha_mlm", + default=0.0, + type=float, + help="Linear weight for the MLM loss. Must be >=0. Should be used in conjunction with `mlm` flag.", + ) + parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.") + parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.") + parser.add_argument( + "--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0." + ) + + parser.add_argument( + "--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM." + ) + parser.add_argument( + "--mlm_mask_prop", + default=0.15, + type=float, + help="Proportion of tokens for which we need to make a prediction.", + ) + parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.") + parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.") + parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.") + parser.add_argument( + "--mlm_smoothing", + default=0.7, + type=float, + help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).", + ) + parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.") + + parser.add_argument( + "--restrict_ce_to_mask", + action="store_true", + help="If true, compute the distillation loss only the [MLM] prediction distribution.", + ) + parser.add_argument( + "--freeze_pos_embs", + action="store_true", + help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.", + ) + parser.add_argument( + "--freeze_token_type_embds", + action="store_true", + help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.", + ) + + parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.") + parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).") + parser.add_argument( + "--group_by_size", + action="store_false", + help="If true, group sequences that have similar length into the same batch. Default is true.", + ) + + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=50, + help="Gradient accumulation for larger training batches.", + ) + parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.") + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.") + parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.") + + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O1", + help=( + "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " + "See details at https://nvidia.github.io/apex/amp.html" + ), + ) + parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.") + parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank") + parser.add_argument("--seed", type=int, default=56, help="Random seed") + + parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.") + parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.") + args = parser.parse_args() + sanity_checks(args) + + # ARGS # + init_gpu_params(args) + set_seed(args) + if args.is_master: + if os.path.exists(args.dump_path): + if not args.force: + raise ValueError( + f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite" + " itUse `--force` if you want to overwrite it" + ) + else: + shutil.rmtree(args.dump_path) + + if not os.path.exists(args.dump_path): + os.makedirs(args.dump_path) + logger.info(f"Experiment will be dumped and logged in {args.dump_path}") + + # SAVE PARAMS # + logger.info(f"Param: {args}") + with open(os.path.join(args.dump_path, "parameters.json"), "w") as f: + json.dump(vars(args), f, indent=4) + git_log(args.dump_path) + + student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type] + teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type] + + # TOKENIZER # + tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name) + special_tok_ids = {} + for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): + idx = tokenizer.all_special_tokens.index(tok_symbol) + special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] + logger.info(f"Special tokens {special_tok_ids}") + args.special_tok_ids = special_tok_ids + args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name] + + # DATA LOADER # + logger.info(f"Loading data from {args.data_file}") + with open(args.data_file, "rb") as fp: + data = pickle.load(fp) + + if args.mlm: + logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)") + with open(args.token_counts, "rb") as fp: + counts = pickle.load(fp) + + token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing + for idx in special_tok_ids.values(): + token_probs[idx] = 0.0 # do not predict special tokens + token_probs = torch.from_numpy(token_probs) + else: + token_probs = None + + train_lm_seq_dataset = LmSeqsDataset(params=args, data=data) + logger.info("Data loader created.") + + # STUDENT # + logger.info(f"Loading student config from {args.student_config}") + stu_architecture_config = student_config_class.from_pretrained(args.student_config) + stu_architecture_config.output_hidden_states = True + + if args.student_pretrained_weights is not None: + logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}") + student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config) + else: + student = student_model_class(stu_architecture_config) + + if args.n_gpu > 0: + student.to(f"cuda:{args.local_rank}") + logger.info("Student loaded.") + + # TEACHER # + teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True) + if args.n_gpu > 0: + teacher.to(f"cuda:{args.local_rank}") + logger.info(f"Teacher loaded from {args.teacher_name}.") + + # FREEZING # + if args.freeze_pos_embs: + freeze_pos_embeddings(student, args) + if args.freeze_token_type_embds: + freeze_token_type_embeddings(student, args) + + # SANITY CHECKS # + assert student.config.vocab_size == teacher.config.vocab_size + assert student.config.hidden_size == teacher.config.hidden_size + assert student.config.max_position_embeddings == teacher.config.max_position_embeddings + if args.mlm: + assert token_probs.size(0) == stu_architecture_config.vocab_size + + # DISTILLER # + torch.cuda.empty_cache() + distiller = Distiller( + params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher + ) + distiller.train() + logger.info("Let's go get some drinks.") + + +if __name__ == "__main__": + main() diff --git a/Experiments/NLP/distillation/training_configs/distilbert-base-cased.json b/Experiments/NLP/distillation/training_configs/distilbert-base-cased.json new file mode 100644 index 0000000..d4f524d --- /dev/null +++ b/Experiments/NLP/distillation/training_configs/distilbert-base-cased.json @@ -0,0 +1,15 @@ +{ + "activation": "gelu", + "attention_dropout": 0.1, + "dim": 768, + "dropout": 0.1, + "hidden_dim": 3072, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "n_heads": 12, + "n_layers": 6, + "sinusoidal_pos_embds": true, + "tie_weights_": true, + "vocab_size": 28996 + } + \ No newline at end of file diff --git a/Experiments/NLP/distillation/training_configs/distilbert-base-multilingual-cased.json b/Experiments/NLP/distillation/training_configs/distilbert-base-multilingual-cased.json new file mode 100644 index 0000000..f76e7fe --- /dev/null +++ b/Experiments/NLP/distillation/training_configs/distilbert-base-multilingual-cased.json @@ -0,0 +1,15 @@ +{ + "activation": "gelu", + "attention_dropout": 0.1, + "dim": 768, + "dropout": 0.1, + "hidden_dim": 3072, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "n_heads": 12, + "n_layers": 6, + "sinusoidal_pos_embds": true, + "tie_weights_": true, + "vocab_size": 119547 + } + \ No newline at end of file diff --git a/Experiments/NLP/distillation/training_configs/distilbert-base-uncased.json b/Experiments/NLP/distillation/training_configs/distilbert-base-uncased.json new file mode 100644 index 0000000..15d1e7f --- /dev/null +++ b/Experiments/NLP/distillation/training_configs/distilbert-base-uncased.json @@ -0,0 +1,15 @@ +{ + "activation": "gelu", + "attention_dropout": 0.1, + "dim": 768, + "dropout": 0.1, + "hidden_dim": 3072, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "n_heads": 12, + "n_layers": 6, + "sinusoidal_pos_embds": true, + "tie_weights_": true, + "vocab_size": 30522 + } + \ No newline at end of file diff --git a/Experiments/NLP/distillation/training_configs/distilgpt2.json b/Experiments/NLP/distillation/training_configs/distilgpt2.json new file mode 100644 index 0000000..9820ac9 --- /dev/null +++ b/Experiments/NLP/distillation/training_configs/distilgpt2.json @@ -0,0 +1,9 @@ +{ + "initializer_range": 0.02, + "layer_norm_epsilon": 0.00001, + "n_embd": 768, + "n_head": 12, + "n_layer": 6, + "n_positions": 1024, + "vocab_size": 50257 +} \ No newline at end of file diff --git a/Experiments/NLP/distillation/training_configs/distilroberta-base.json b/Experiments/NLP/distillation/training_configs/distilroberta-base.json new file mode 100644 index 0000000..2d90ef6 --- /dev/null +++ b/Experiments/NLP/distillation/training_configs/distilroberta-base.json @@ -0,0 +1,14 @@ +{ + "vocab_size": 50265, + "hidden_size": 768, + "num_hidden_layers": 6, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 514, + "type_vocab_size": 1, + "initializer_range": 0.02, + "layer_norm_eps": 0.00001 +} \ No newline at end of file diff --git a/Experiments/NLP/distillation/utils.py b/Experiments/NLP/distillation/utils.py new file mode 100644 index 0000000..e86d259 --- /dev/null +++ b/Experiments/NLP/distillation/utils.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils to train DistilBERT +adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) +""" + +import json +import logging +import os +import socket + +import git +import numpy as np +import torch + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +def git_log(folder_path: str): + """ + Log commit info. + """ + repo = git.Repo(search_parent_directories=True) + repo_infos = { + "repo_id": str(repo), + "repo_sha": str(repo.head.object.hexsha), + "repo_branch": str(repo.active_branch), + } + + with open(os.path.join(folder_path, "git_log.json"), "w") as f: + json.dump(repo_infos, f, indent=4) + + +def init_gpu_params(params): + """ + Handle single and multi-GPU / multi-node. + """ + if params.n_gpu <= 0: + params.local_rank = 0 + params.master_port = -1 + params.is_master = True + params.multi_gpu = False + return + + assert torch.cuda.is_available() + + logger.info("Initializing GPUs") + if params.n_gpu > 1: + assert params.local_rank != -1 + + params.world_size = int(os.environ["WORLD_SIZE"]) + params.n_gpu_per_node = int(os.environ["N_GPU_NODE"]) + params.global_rank = int(os.environ["RANK"]) + + # number of nodes / node ID + params.n_nodes = params.world_size // params.n_gpu_per_node + params.node_id = params.global_rank // params.n_gpu_per_node + params.multi_gpu = True + + assert params.n_nodes == int(os.environ["N_NODES"]) + assert params.node_id == int(os.environ["NODE_RANK"]) + + # local job (single GPU) + else: + assert params.local_rank == -1 + + params.n_nodes = 1 + params.node_id = 0 + params.local_rank = 0 + params.global_rank = 0 + params.world_size = 1 + params.n_gpu_per_node = 1 + params.multi_gpu = False + + # sanity checks + assert params.n_nodes >= 1 + assert 0 <= params.node_id < params.n_nodes + assert 0 <= params.local_rank <= params.global_rank < params.world_size + assert params.world_size == params.n_nodes * params.n_gpu_per_node + + # define whether this is the master process / if we are in multi-node distributed mode + params.is_master = params.node_id == 0 and params.local_rank == 0 + params.multi_node = params.n_nodes > 1 + + # summary + PREFIX = f"--- Global rank: {params.global_rank} - " + logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes) + logger.info(PREFIX + "Node ID : %i" % params.node_id) + logger.info(PREFIX + "Local rank : %i" % params.local_rank) + logger.info(PREFIX + "World size : %i" % params.world_size) + logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) + logger.info(PREFIX + "Master : %s" % str(params.is_master)) + logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node)) + logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) + logger.info(PREFIX + "Hostname : %s" % socket.gethostname()) + + # set GPU device + torch.cuda.set_device(params.local_rank) + + # initialize multi-GPU + if params.multi_gpu: + logger.info("Initializing PyTorch distributed") + torch.distributed.init_process_group( + init_method="env://", + backend="nccl", + ) + + +def set_seed(args): + """ + Set the random seed. + """ + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed)