diff --git a/Experiments/NLP/distillation/README.md b/Experiments/NLP/distillation/README.md
new file mode 100644
index 0000000..b52ebd0
--- /dev/null
+++ b/Experiments/NLP/distillation/README.md
@@ -0,0 +1,163 @@
+# Distil*
+
+This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT, DistilRoBERTa and DistilGPT2.
+
+## What is Distil*
+
+Distil* is a class of compressed models that started with DistilBERT. DistilBERT stands for Distilled-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
+
+We have applied the same method to other Transformer architectures and released the weights:
+- GPT2: on the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 16.3 compared to 21.1 for **DistilGPT2** (after fine-tuning on the train set).
+- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base`'s performance on GLUE while being twice faster and 35% smaller.
+- German BERT: **German DistilBERT** reaches 99% of `bert-base-german-dbmdz-cased`'s performance on German NER (CoNLL-2003).
+- Multilingual BERT: **DistilmBERT** reaches 92% of Multilingual BERT's performance on XNLI while being twice faster and 25% smaller. The model supports 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
+
+For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108).
+
+Here are the results on the dev sets of GLUE:
+
+| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2| STS-B| WNLI |
+| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---: |
+| BERT-base-uncased | **79.5** | 56.3 | 84.7 | 88.6 | 91.8 | 89.6 | 69.3 | 92.7 | 89.0 | 53.5 |
+| DistilBERT-base-uncased | **77.0** | 51.3 | 82.1 | 87.5 | 89.2 | 88.5 | 59.9 | 91.3 | 86.9 | 56.3 |
+| BERT-base-cased | **78.2** | 58.2 | 83.9 | 87.8 | 91.0 | 89.2 | 66.1 | 91.7 | 89.2 | 46.5 |
+| DistilBERT-base-cased | **75.9** | 47.2 | 81.5 | 85.6 | 88.2 | 87.8 | 60.6 | 90.4 | 85.5 | 56.3 |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| RoBERTa-base (reported) | **83.2**/**86.4**2 | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.73 |
+| DistilRoBERTa1 | **79.0**/**82.3**2 | 59.3 | 84.0 | 86.6 | 90.8 | 89.4 | 67.9 | 92.5 | 88.3 | 52.1 |
+
+1 We did not use the MNLI checkpoint for fine-tuning but directly perform transfer learning on the pre-trained DistilRoBERTa.
+
+2 Macro-score computed without WNLI.
+
+3 We compute this score ourselves for completeness.
+
+Here are the results on the *test* sets for 6 of the languages available in XNLI. The results are computed in the zero shot setting (trained on the English portion and evaluated on the target language portion):
+
+| Model | English | Spanish | Chinese | German | Arabic | Urdu |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
+| mBERT base cased (computed) | 82.1 | 74.6 | 69.1 | 72.3 | 66.4 | 58.5 |
+| mBERT base uncased (reported)| 81.4 | 74.3 | 63.8 | 70.5 | 62.1 | 58.3 |
+| DistilmBERT | 78.2 | 69.1 | 64.0 | 66.3 | 59.1 | 54.7 |
+
+## Setup
+
+This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
+
+**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breaking changes compared to v1.1.0).
+
+
+## How to use DistilBERT
+
+Transformers includes five pre-trained Distil* models, currently only provided for English and German (we are investigating the possibility to train and release a multilingual version of DistilBERT):
+
+- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
+- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
+- `distilbert-base-cased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-cased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 65M parameters.
+- `distilbert-base-cased-distilled-squad`: A finetuned version of `distilbert-base-cased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 87.1 on the dev set (for comparison, Bert `bert-base-cased` version reaches a 88.7 F1 score).
+- `distilbert-base-german-cased`: DistilBERT German language model pretrained on 1/2 of the data used to pretrain Bert using distillation with the supervision of the `bert-base-german-dbmdz-cased` version of German DBMDZ Bert. For NER tasks the model reaches a F1 score of 83.49 on the CoNLL-2003 test set (for comparison, `bert-base-german-dbmdz-cased` reaches a 84.52 F1 score), and a F1 score of 85.23 on the GermEval 2014 test set (`bert-base-german-dbmdz-cased` reaches a 86.89 F1 score).
+- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset. The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2.
+- `distilroberta-base`: DistilRoBERTa English language model pretrained with the supervision of `roberta-base` solely on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset (it is ~4 times less training data than the teacher RoBERTa). The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 125M parameters for RoBERTa-base). On average DistilRoBERTa is twice as fast as Roberta-base.
+- `distilbert-base-multilingual-cased`: DistilmBERT multilingual model pretrained with the supervision of `bert-base-multilingual-cased` on the concatenation of Wikipedia in 104 different languages. The model supports the 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base.
+
+Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
+
+```python
+tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
+model = DistilBertModel.from_pretrained('distilbert-base-cased')
+
+input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
+outputs = model(input_ids)
+last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
+```
+
+Similarly, using the other Distil* models simply consists in calling the base classes with a different pretrained checkpoint:
+- DistilBERT uncased: `model = DistilBertModel.from_pretrained('distilbert-base-uncased')`
+- DistilGPT2: `model = GPT2Model.from_pretrained('distilgpt2')`
+- DistilRoBERTa: `model = RobertaModel.from_pretrained('distilroberta-base')`
+- DistilmBERT: `model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')`
+
+
+## How to train Distil*
+
+In the following, we will explain how you can train DistilBERT.
+
+### A. Preparing the data
+
+The weights we release are trained using a concatenation of Toronto Book Corpus and English Wikipedia (same training data as the English version of BERT).
+
+To avoid processing the data several time, we do it once and for all before the training. From now on, will suppose that you have a text file `dump.txt` which contains one sequence per line (a sequence being composed of one of several coherent sentences).
+
+First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model's vocabulary.
+
+```bash
+python scripts/binarized_data.py \
+ --file_path data/dump.txt \
+ --tokenizer_type bert \
+ --tokenizer_name bert-base-uncased \
+ --dump_file data/binarized_text
+```
+
+Our implementation of masked language modeling loss follows [XLM](https://github.com/facebookresearch/XLM)'s one and smooths the probability of masking with a factor that put more emphasis on rare words. Thus we count the occurrences of each tokens in the data:
+
+```bash
+python scripts/token_counts.py \
+ --data_file data/binarized_text.bert-base-uncased.pickle \
+ --token_counts_dump data/token_counts.bert-base-uncased.pickle \
+ --vocab_size 30522
+```
+
+### B. Training
+
+Training with distillation is really simple once you have pre-processed the data:
+
+```bash
+python train.py \
+ --student_type distilbert \
+ --student_config training_configs/distilbert-base-uncased.json \
+ --teacher_type bert \
+ --teacher_name bert-base-uncased \
+ --alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \
+ --freeze_pos_embs \
+ --dump_path serialization_dir/my_first_training \
+ --data_file data/binarized_text.bert-base-uncased.pickle \
+ --token_counts data/token_counts.bert-base-uncased.pickle \
+ --force # overwrites the `dump_path` if it already exists.
+```
+
+By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
+
+We highly encourage you to use distributed training for training DistilBERT as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
+
+```bash
+export NODE_RANK=0
+export N_NODES=1
+
+export N_GPU_NODE=4
+export WORLD_SIZE=4
+export MASTER_PORT=
+export MASTER_ADDR=
+
+pkill -f 'python -u train.py'
+
+python -m torch.distributed.launch \
+ --nproc_per_node=$N_GPU_NODE \
+ --nnodes=$N_NODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT \
+ train.py \
+ --force \
+ --n_gpu $WORLD_SIZE \
+ --student_type distilbert \
+ --student_config training_configs/distilbert-base-uncased.json \
+ --teacher_type bert \
+ --teacher_name bert-base-uncased \
+ --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --mlm \
+ --freeze_pos_embs \
+ --dump_path serialization_dir/my_first_training \
+ --data_file data/binarized_text.bert-base-uncased.pickle \
+ --token_counts data/token_counts.bert-base-uncased.pickle
+```
+
+**Tips:** Starting distilled training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract.py` and `scripts/extract_distilbert.py` to create a valid initialization checkpoint and use `--student_pretrained_weights` argument to use this initialization for the distilled training!
diff --git a/Experiments/NLP/distillation/distiller.py b/Experiments/NLP/distillation/distiller.py
new file mode 100644
index 0000000..963af97
--- /dev/null
+++ b/Experiments/NLP/distillation/distiller.py
@@ -0,0 +1,601 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The distiller to distil the student.
+Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
+"""
+
+import math
+import os
+import time
+
+import psutil
+import torch
+from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
+from lm_seqs_dataset import LmSeqsDataset
+from torch import nn
+from torch.optim import AdamW
+from torch.utils.data import BatchSampler, DataLoader, RandomSampler
+from torch.utils.data.distributed import DistributedSampler
+from tqdm import tqdm
+
+from transformers import get_linear_schedule_with_warmup
+from utils import logger
+
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+except ImportError:
+ from tensorboardX import SummaryWriter
+
+
+class Distiller:
+ def __init__(
+ self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
+ ):
+ logger.info("Initializing Distiller")
+ self.params = params
+ self.dump_path = params.dump_path
+ self.multi_gpu = params.multi_gpu
+ self.fp16 = params.fp16
+
+ self.student = student
+ self.teacher = teacher
+
+ self.student_config = student.config
+ self.vocab_size = student.config.vocab_size
+
+ if params.n_gpu <= 1:
+ sampler = RandomSampler(dataset)
+ else:
+ sampler = DistributedSampler(dataset)
+
+ if params.group_by_size:
+ groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
+ sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
+ else:
+ sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
+
+ self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
+
+ self.temperature = params.temperature
+ assert self.temperature > 0.0
+
+ self.alpha_ce = params.alpha_ce
+ self.alpha_mlm = params.alpha_mlm
+ self.alpha_clm = params.alpha_clm
+ self.alpha_mse = params.alpha_mse
+ self.alpha_cos = params.alpha_cos
+
+ self.mlm = params.mlm
+ if self.mlm:
+ logger.info("Using MLM loss for LM step.")
+ self.mlm_mask_prop = params.mlm_mask_prop
+ assert 0.0 <= self.mlm_mask_prop <= 1.0
+ assert params.word_mask + params.word_keep + params.word_rand == 1.0
+ self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
+ self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
+ self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
+ if self.fp16:
+ self.pred_probs = self.pred_probs.half()
+ self.token_probs = self.token_probs.half()
+ else:
+ logger.info("Using CLM loss for LM step.")
+
+ self.epoch = 0
+ self.n_iter = 0
+ self.n_total_iter = 0
+ self.n_sequences_epoch = 0
+ self.total_loss_epoch = 0
+ self.last_loss = 0
+ self.last_loss_ce = 0
+ self.last_loss_mlm = 0
+ self.last_loss_clm = 0
+ if self.alpha_mse > 0.0:
+ self.last_loss_mse = 0
+ if self.alpha_cos > 0.0:
+ self.last_loss_cos = 0
+ self.last_log = 0
+
+ self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
+ self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
+ if self.alpha_mse > 0.0:
+ self.mse_loss_fct = nn.MSELoss(reduction="sum")
+ if self.alpha_cos > 0.0:
+ self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
+
+ logger.info("--- Initializing model optimizer")
+ assert params.gradient_accumulation_steps >= 1
+ self.num_steps_epoch = len(self.dataloader)
+ num_train_optimization_steps = (
+ int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
+ )
+
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
+ ],
+ "weight_decay": params.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+ logger.info(
+ "------ Number of trainable parameters (student): %i"
+ % sum([p.numel() for p in self.student.parameters() if p.requires_grad])
+ )
+ logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
+ self.optimizer = AdamW(
+ optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
+ )
+
+ warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
+ self.scheduler = get_linear_schedule_with_warmup(
+ self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
+ )
+
+ if self.fp16:
+ try:
+ from apex import amp
+ except ImportError:
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
+ logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
+ self.student, self.optimizer = amp.initialize(
+ self.student, self.optimizer, opt_level=self.params.fp16_opt_level
+ )
+ self.teacher = self.teacher.half()
+
+ if self.multi_gpu:
+ if self.fp16:
+ from apex.parallel import DistributedDataParallel
+
+ logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
+ self.student = DistributedDataParallel(self.student)
+ else:
+ from torch.nn.parallel import DistributedDataParallel
+
+ logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
+ self.student = DistributedDataParallel(
+ self.student,
+ device_ids=[params.local_rank],
+ output_device=params.local_rank,
+ find_unused_parameters=True,
+ )
+
+ self.is_master = params.is_master
+ if self.is_master:
+ logger.info("--- Initializing Tensorboard")
+ self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
+ self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
+ self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
+
+ def prepare_batch_mlm(self, batch):
+ """
+ Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.
+
+ Input:
+ ------
+ batch: `Tuple`
+ token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
+ lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
+
+ Output:
+ -------
+ token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
+ attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
+ mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
+ """
+ token_ids, lengths = batch
+ token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
+ assert token_ids.size(0) == lengths.size(0)
+
+ attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
+
+ bs, max_seq_len = token_ids.size()
+ mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
+
+ x_prob = self.token_probs[token_ids.flatten()]
+ n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
+ tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
+ pred_mask = torch.zeros(
+ bs * max_seq_len, dtype=torch.bool, device=token_ids.device
+ ) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
+ pred_mask[tgt_ids] = 1
+ pred_mask = pred_mask.view(bs, max_seq_len)
+
+ pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
+
+ # mask a number of words == 0 [8] (faster with fp16)
+ if self.fp16:
+ n1 = pred_mask.sum().item()
+ if n1 > 8:
+ pred_mask = pred_mask.view(-1)
+ n2 = max(n1 % 8, 8 * (n1 // 8))
+ if n2 != n1:
+ pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
+ pred_mask = pred_mask.view(bs, max_seq_len)
+ assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
+
+ _token_ids_real = token_ids[pred_mask]
+ _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
+ _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
+ probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
+ _token_ids = (
+ _token_ids_mask * (probs == 0).long()
+ + _token_ids_real * (probs == 1).long()
+ + _token_ids_rand * (probs == 2).long()
+ )
+ token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
+
+ mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
+
+ # sanity checks
+ assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
+
+ return token_ids, attn_mask, mlm_labels
+
+ def prepare_batch_clm(self, batch):
+ """
+ Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.
+
+ Input:
+ ------
+ batch: `Tuple`
+ token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
+ lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
+
+ Output:
+ -------
+ token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
+ attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
+ clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
+ """
+ token_ids, lengths = batch
+ token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
+ assert token_ids.size(0) == lengths.size(0)
+
+ attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
+ clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
+ clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
+
+ # sanity checks
+ assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
+
+ return token_ids, attn_mask, clm_labels
+
+ def round_batch(self, x: torch.tensor, lengths: torch.tensor):
+ """
+ For float16 only.
+ Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
+
+ Input:
+ ------
+ x: `torch.tensor(bs, seq_length)` - The token ids.
+ lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
+
+ Output:
+ -------
+ x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
+ lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
+ """
+ if not self.fp16 or len(lengths) < 8:
+ return x, lengths
+
+ # number of sentences == 0 [8]
+ bs1 = len(lengths)
+ bs2 = 8 * (bs1 // 8)
+ assert bs2 > 0 and bs2 % 8 == 0
+ if bs1 != bs2:
+ idx = torch.randperm(bs1)[:bs2]
+ lengths = lengths[idx]
+ slen = lengths.max().item()
+ x = x[idx, :slen]
+ else:
+ idx = None
+
+ # sequence length == 0 [8]
+ ml1 = x.size(1)
+ if ml1 % 8 != 0:
+ pad = 8 - (ml1 % 8)
+ ml2 = ml1 + pad
+ if self.mlm:
+ pad_id = self.params.special_tok_ids["pad_token"]
+ else:
+ pad_id = self.params.special_tok_ids["unk_token"]
+ padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
+ x = torch.cat([x, padding_tensor], 1)
+ assert x.size() == (bs2, ml2)
+
+ assert x.size(0) % 8 == 0
+ assert x.size(1) % 8 == 0
+ return x, lengths
+
+ def train(self):
+ """
+ The real training loop.
+ """
+ if self.is_master:
+ logger.info("Starting training")
+ self.last_log = time.time()
+ self.student.train()
+ self.teacher.eval()
+
+ for _ in range(self.params.n_epoch):
+ if self.is_master:
+ logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
+ if self.multi_gpu:
+ torch.distributed.barrier()
+
+ iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
+ for batch in iter_bar:
+ if self.params.n_gpu > 0:
+ batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
+
+ if self.mlm:
+ token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
+ else:
+ token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
+ self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
+
+ iter_bar.update()
+ iter_bar.set_postfix(
+ {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
+ )
+ iter_bar.close()
+
+ if self.is_master:
+ logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
+ self.end_epoch()
+
+ if self.is_master:
+ logger.info("Save very last checkpoint as `pytorch_model.bin`.")
+ self.save_checkpoint(checkpoint_name="pytorch_model.bin")
+ logger.info("Training is finished")
+
+ def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
+ """
+ One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
+ and possibly a parameter update (depending on the gradient accumulation).
+
+ Input:
+ ------
+ input_ids: `torch.tensor(bs, seq_length)` - The token ids.
+ attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
+ lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
+ """
+ if self.mlm:
+ student_outputs = self.student(
+ input_ids=input_ids, attention_mask=attention_mask
+ ) # (bs, seq_length, voc_size)
+ with torch.no_grad():
+ teacher_outputs = self.teacher(
+ input_ids=input_ids, attention_mask=attention_mask
+ ) # (bs, seq_length, voc_size)
+ else:
+ student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
+ with torch.no_grad():
+ teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
+ s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
+ t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]
+ assert s_logits.size() == t_logits.size()
+
+ # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
+ # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
+ if self.params.restrict_ce_to_mask:
+ mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
+ else:
+ mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
+ s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
+ s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
+ t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
+ t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
+ assert t_logits_slct.size() == s_logits_slct.size()
+
+ loss_ce = (
+ self.ce_loss_fct(
+ nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
+ nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
+ )
+ * (self.temperature) ** 2
+ )
+ loss = self.alpha_ce * loss_ce
+
+ if self.alpha_mlm > 0.0:
+ loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
+ loss += self.alpha_mlm * loss_mlm
+ if self.alpha_clm > 0.0:
+ shift_logits = s_logits[..., :-1, :].contiguous()
+ shift_labels = lm_labels[..., 1:].contiguous()
+ loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+ loss += self.alpha_clm * loss_clm
+
+ if self.alpha_mse > 0.0:
+ loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
+ 0
+ ) # Reproducing batchmean reduction
+ loss += self.alpha_mse * loss_mse
+ if self.alpha_cos > 0.0:
+ s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
+ t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
+ mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
+ assert s_hidden_states.size() == t_hidden_states.size()
+ dim = s_hidden_states.size(-1)
+
+ s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
+ s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
+ t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
+ t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
+
+ target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
+ loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
+ loss += self.alpha_cos * loss_cos
+
+ self.total_loss_epoch += loss.item()
+ self.last_loss = loss.item()
+ self.last_loss_ce = loss_ce.item()
+ if self.alpha_mlm > 0.0:
+ self.last_loss_mlm = loss_mlm.item()
+ if self.alpha_clm > 0.0:
+ self.last_loss_clm = loss_clm.item()
+ if self.alpha_mse > 0.0:
+ self.last_loss_mse = loss_mse.item()
+ if self.alpha_cos > 0.0:
+ self.last_loss_cos = loss_cos.item()
+
+ self.optimize(loss)
+
+ self.n_sequences_epoch += input_ids.size(0)
+
+ def optimize(self, loss):
+ """
+ Normalization on the loss (gradient accumulation or distributed training), followed by
+ backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
+ Also update the metrics for tensorboard.
+ """
+ # Check for NaN
+ if (loss != loss).data.any():
+ logger.error("NaN detected")
+ exit()
+
+ if self.multi_gpu:
+ loss = loss.mean()
+ if self.params.gradient_accumulation_steps > 1:
+ loss = loss / self.params.gradient_accumulation_steps
+
+ if self.fp16:
+ from apex import amp
+
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss.backward()
+
+ self.iter()
+ if self.n_iter % self.params.gradient_accumulation_steps == 0:
+ if self.fp16:
+ nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
+ else:
+ nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.scheduler.step()
+
+ def iter(self):
+ """
+ Update global counts, write to tensorboard and save checkpoint.
+ """
+ self.n_iter += 1
+ self.n_total_iter += 1
+
+ if self.n_total_iter % self.params.log_interval == 0:
+ self.log_tensorboard()
+ self.last_log = time.time()
+ if self.n_total_iter % self.params.checkpoint_interval == 0:
+ self.save_checkpoint()
+
+ def log_tensorboard(self):
+ """
+ Log into tensorboard. Only by the master process.
+ """
+ if not self.is_master:
+ return
+
+ for param_name, param in self.student.named_parameters():
+ self.tensorboard.add_scalar(
+ tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
+ )
+ self.tensorboard.add_scalar(
+ tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
+ )
+ if param.grad is None:
+ continue
+ self.tensorboard.add_scalar(
+ tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
+ )
+ self.tensorboard.add_scalar(
+ tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
+ )
+
+ self.tensorboard.add_scalar(
+ tag="losses/cum_avg_loss_epoch",
+ scalar_value=self.total_loss_epoch / self.n_iter,
+ global_step=self.n_total_iter,
+ )
+ self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
+ self.tensorboard.add_scalar(
+ tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
+ )
+ if self.alpha_mlm > 0.0:
+ self.tensorboard.add_scalar(
+ tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
+ )
+ if self.alpha_clm > 0.0:
+ self.tensorboard.add_scalar(
+ tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
+ )
+ if self.alpha_mse > 0.0:
+ self.tensorboard.add_scalar(
+ tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
+ )
+ if self.alpha_cos > 0.0:
+ self.tensorboard.add_scalar(
+ tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
+ )
+ self.tensorboard.add_scalar(
+ tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
+ )
+
+ self.tensorboard.add_scalar(
+ tag="global/memory_usage",
+ scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
+ global_step=self.n_total_iter,
+ )
+ self.tensorboard.add_scalar(
+ tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
+ )
+
+ def end_epoch(self):
+ """
+ Finally arrived at the end of epoch (full pass on dataset).
+ Do some tensorboard logging and checkpoint saving.
+ """
+ logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
+
+ if self.is_master:
+ self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
+ self.tensorboard.add_scalar(
+ tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
+ )
+
+ self.epoch += 1
+ self.n_sequences_epoch = 0
+ self.n_iter = 0
+ self.total_loss_epoch = 0
+
+ def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
+ """
+ Save the current state. Only by the master process.
+ """
+ if not self.is_master:
+ return
+ mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
+ mdl_to_save.config.save_pretrained(self.dump_path)
+ state_dict = mdl_to_save.state_dict()
+ torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
diff --git a/Experiments/NLP/distillation/grouped_batch_sampler.py b/Experiments/NLP/distillation/grouped_batch_sampler.py
new file mode 100644
index 0000000..e25def7
--- /dev/null
+++ b/Experiments/NLP/distillation/grouped_batch_sampler.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)"""
+
+import bisect
+import copy
+from collections import defaultdict
+
+import numpy as np
+from torch.utils.data import BatchSampler, Sampler
+
+from utils import logger
+
+
+def _quantize(x, bins):
+ bins = copy.deepcopy(bins)
+ bins = sorted(bins)
+ quantized = [bisect.bisect_right(bins, y) for y in x]
+ return quantized
+
+
+def create_lengths_groups(lengths, k=0):
+ bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
+ groups = _quantize(lengths, bins)
+ # count number of elements per group
+ counts = np.unique(groups, return_counts=True)[1]
+ fbins = [0] + bins + [np.inf]
+ logger.info("Using {} as bins for aspect lengths quantization".format(fbins))
+ logger.info("Count of instances per bin: {}".format(counts))
+ return groups
+
+
+class GroupedBatchSampler(BatchSampler):
+ """
+ Wraps another sampler to yield a mini-batch of indices.
+ It enforces that the batch only contain elements from the same group.
+ It also tries to provide mini-batches which follows an ordering which is
+ as close as possible to the ordering from the original sampler.
+ Arguments:
+ sampler (Sampler): Base sampler.
+ group_ids (list[int]): If the sampler produces indices in range [0, N),
+ `group_ids` must be a list of `N` ints which contains the group id of each sample.
+ The group ids must be a continuous set of integers starting from
+ 0, i.e. they must be in the range [0, num_groups).
+ batch_size (int): Size of mini-batch.
+ """
+
+ def __init__(self, sampler, group_ids, batch_size):
+ if not isinstance(sampler, Sampler):
+ raise TypeError(
+ "sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
+ )
+ self.sampler = sampler
+ self.group_ids = group_ids
+ self.batch_size = batch_size
+
+ def __iter__(self):
+ buffer_per_group = defaultdict(list)
+ samples_per_group = defaultdict(list)
+
+ num_batches = 0
+ for idx in self.sampler:
+ group_id = self.group_ids[idx]
+ buffer_per_group[group_id].append(idx)
+ samples_per_group[group_id].append(idx)
+ if len(buffer_per_group[group_id]) == self.batch_size:
+ yield buffer_per_group[group_id] # TODO
+ num_batches += 1
+ del buffer_per_group[group_id]
+ assert len(buffer_per_group[group_id]) < self.batch_size
+
+ # now we have run out of elements that satisfy
+ # the group criteria, let's return the remaining
+ # elements so that the size of the sampler is
+ # deterministic
+ expected_num_batches = len(self)
+ num_remaining = expected_num_batches - num_batches
+ if num_remaining > 0:
+ # for the remaining batches, group the batches by similar lengths
+ batch_idx = []
+ for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
+ batch_idx.extend(idxs)
+ if len(batch_idx) >= self.batch_size:
+ yield batch_idx[: self.batch_size]
+ batch_idx = batch_idx[self.batch_size :]
+ num_remaining -= 1
+ if len(batch_idx) > 0:
+ yield batch_idx
+ num_remaining -= 1
+ assert num_remaining == 0
+
+ def __len__(self):
+ """
+ Return the number of mini-batches rather than the number of samples.
+ """
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
diff --git a/Experiments/NLP/distillation/lm_seqs_dataset.py b/Experiments/NLP/distillation/lm_seqs_dataset.py
new file mode 100644
index 0000000..647c8f4
--- /dev/null
+++ b/Experiments/NLP/distillation/lm_seqs_dataset.py
@@ -0,0 +1,167 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dataset to distilled models
+adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
+"""
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+from utils import logger
+
+
+class LmSeqsDataset(Dataset):
+ """Custom Dataset wrapping language modeling sequences.
+
+ Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
+
+ Input:
+ ------
+ params: `NameSpace` parameters
+ data: `List[np.array[int]]
+ """
+
+ def __init__(self, params, data):
+ self.params = params
+
+ self.token_ids = np.array(data)
+ self.lengths = np.array([len(t) for t in data])
+
+ self.check()
+ self.remove_long_sequences()
+ self.remove_empty_sequences()
+ self.remove_unknown_sequences()
+ self.check()
+ self.print_statistics()
+
+ def __getitem__(self, index):
+ return (self.token_ids[index], self.lengths[index])
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def check(self):
+ """
+ Some sanity checks
+ """
+ assert len(self.token_ids) == len(self.lengths)
+ assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
+
+ def remove_long_sequences(self):
+ """
+ Sequences that are too long are split by chunk of max_model_input_size.
+ """
+ max_len = self.params.max_model_input_size
+ indices = self.lengths > max_len
+ logger.info(f"Splitting {sum(indices)} too long sequences.")
+
+ def divide_chunks(l, n):
+ return [l[i : i + n] for i in range(0, len(l), n)]
+
+ new_tok_ids = []
+ new_lengths = []
+ if self.params.mlm:
+ cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
+ else:
+ cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
+
+ for seq_, len_ in zip(self.token_ids, self.lengths):
+ assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
+ if len_ <= max_len:
+ new_tok_ids.append(seq_)
+ new_lengths.append(len_)
+ else:
+ sub_seqs = []
+ for sub_s in divide_chunks(seq_, max_len - 2):
+ if sub_s[0] != cls_id:
+ sub_s = np.insert(sub_s, 0, cls_id)
+ if sub_s[-1] != sep_id:
+ sub_s = np.insert(sub_s, len(sub_s), sep_id)
+ assert len(sub_s) <= max_len
+ assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
+ sub_seqs.append(sub_s)
+
+ new_tok_ids.extend(sub_seqs)
+ new_lengths.extend([len(l) for l in sub_seqs])
+
+ self.token_ids = np.array(new_tok_ids)
+ self.lengths = np.array(new_lengths)
+
+ def remove_empty_sequences(self):
+ """
+ Too short sequences are simply removed. This could be tuned.
+ """
+ init_size = len(self)
+ indices = self.lengths > 11
+ self.token_ids = self.token_ids[indices]
+ self.lengths = self.lengths[indices]
+ new_size = len(self)
+ logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
+
+ def remove_unknown_sequences(self):
+ """
+ Remove sequences with a (too) high level of unknown tokens.
+ """
+ if "unk_token" not in self.params.special_tok_ids:
+ return
+ else:
+ unk_token_id = self.params.special_tok_ids["unk_token"]
+ init_size = len(self)
+ unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
+ indices = (unk_occs / self.lengths) < 0.5
+ self.token_ids = self.token_ids[indices]
+ self.lengths = self.lengths[indices]
+ new_size = len(self)
+ logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
+
+ def print_statistics(self):
+ """
+ Print some statistics on the corpus. Only the master process.
+ """
+ if not self.params.is_master:
+ return
+ logger.info(f"{len(self)} sequences")
+ # data_len = sum(self.lengths)
+ # nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
+ # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
+
+ # unk_idx = self.params.special_tok_ids['unk_token']
+ # nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
+ # logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
+
+ def batch_sequences(self, batch):
+ """
+ Do the padding and transform into torch.tensor.
+ """
+ token_ids = [t[0] for t in batch]
+ lengths = [t[1] for t in batch]
+ assert len(token_ids) == len(lengths)
+
+ # Max for paddings
+ max_seq_len_ = max(lengths)
+
+ # Pad token ids
+ if self.params.mlm:
+ pad_idx = self.params.special_tok_ids["pad_token"]
+ else:
+ pad_idx = self.params.special_tok_ids["unk_token"]
+ tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
+ assert len(tk_) == len(token_ids)
+ assert all(len(t) == max_seq_len_ for t in tk_)
+
+ tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
+ lg_t = torch.tensor(lengths) # (bs)
+ return tk_t, lg_t
diff --git a/Experiments/NLP/distillation/requirements.txt b/Experiments/NLP/distillation/requirements.txt
new file mode 100644
index 0000000..4a2ed78
--- /dev/null
+++ b/Experiments/NLP/distillation/requirements.txt
@@ -0,0 +1,7 @@
+transformers
+
+gitpython==3.1.41
+tensorboard>=1.14.0
+tensorboardX==1.8
+psutil==5.6.6
+scipy>=1.4.1
diff --git a/Experiments/NLP/distillation/run_squad_w_distillation.py b/Experiments/NLP/distillation/run_squad_w_distillation.py
new file mode 100644
index 0000000..a1150f6
--- /dev/null
+++ b/Experiments/NLP/distillation/run_squad_w_distillation.py
@@ -0,0 +1,877 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This is the exact same script as `examples/question-answering/run_squad.py` (as of 2020, January 8th) with an additional and optional step of distillation."""
+
+import argparse
+import glob
+import logging
+import os
+import random
+import timeit
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+from tqdm import tqdm, trange
+
+import transformers
+from transformers import (
+ WEIGHTS_NAME,
+ AdamW,
+ BertConfig,
+ BertForQuestionAnswering,
+ BertTokenizer,
+ DistilBertConfig,
+ DistilBertForQuestionAnswering,
+ DistilBertTokenizer,
+ RobertaConfig,
+ RobertaForQuestionAnswering,
+ RobertaTokenizer,
+ XLMConfig,
+ XLMForQuestionAnswering,
+ XLMTokenizer,
+ XLNetConfig,
+ XLNetForQuestionAnswering,
+ XLNetTokenizer,
+ get_linear_schedule_with_warmup,
+ squad_convert_examples_to_features,
+)
+from transformers.data.metrics.squad_metrics import (
+ compute_predictions_log_probs,
+ compute_predictions_logits,
+ squad_evaluate,
+)
+from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
+from transformers.trainer_utils import is_main_process
+
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+except ImportError:
+ from tensorboardX import SummaryWriter
+
+
+logger = logging.getLogger(__name__)
+
+
+MODEL_CLASSES = {
+ "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
+ "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
+ "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
+ "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
+ "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
+}
+
+
+def set_seed(args):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.n_gpu > 0:
+ torch.cuda.manual_seed_all(args.seed)
+
+
+def to_list(tensor):
+ return tensor.detach().cpu().tolist()
+
+
+def train(args, train_dataset, model, tokenizer, teacher=None):
+ """Train the model"""
+ if args.local_rank in [-1, 0]:
+ tb_writer = SummaryWriter()
+
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
+
+ if args.max_steps > 0:
+ t_total = args.max_steps
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
+ else:
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
+
+ # Prepare optimizer and schedule (linear warmup and decay)
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": args.weight_decay,
+ },
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
+ ]
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
+ scheduler = get_linear_schedule_with_warmup(
+ optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
+ )
+
+ # Check if saved optimizer or scheduler states exist
+ if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
+ os.path.join(args.model_name_or_path, "scheduler.pt")
+ ):
+ # Load in optimizer and scheduler states
+ optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
+ scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
+
+ if args.fp16:
+ try:
+ from apex import amp
+ except ImportError:
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
+
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
+
+ # multi-gpu training (should be after apex fp16 initialization)
+ if args.n_gpu > 1:
+ model = nn.DataParallel(model)
+
+ # Distributed training (should be after apex fp16 initialization)
+ if args.local_rank != -1:
+ model = nn.parallel.DistributedDataParallel(
+ model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
+ )
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = %d", len(train_dataset))
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
+ logger.info(
+ " Total train batch size (w. parallel, distributed & accumulation) = %d",
+ args.train_batch_size
+ * args.gradient_accumulation_steps
+ * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
+ )
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
+ logger.info(" Total optimization steps = %d", t_total)
+
+ global_step = 1
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ # Check if continuing training from a checkpoint
+ if os.path.exists(args.model_name_or_path):
+ try:
+ # set global_step to global_step of last saved checkpoint from model path
+ checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
+ global_step = int(checkpoint_suffix)
+ epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
+ steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
+
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
+ logger.info(" Continuing training from epoch %d", epochs_trained)
+ logger.info(" Continuing training from global step %d", global_step)
+ logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
+ except ValueError:
+ logger.info(" Starting fine-tuning.")
+
+ tr_loss, logging_loss = 0.0, 0.0
+ model.zero_grad()
+ train_iterator = trange(
+ epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
+ )
+ # Added here for reproducibility
+ set_seed(args)
+
+ for _ in train_iterator:
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
+ for step, batch in enumerate(epoch_iterator):
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ steps_trained_in_current_epoch -= 1
+ continue
+
+ model.train()
+ if teacher is not None:
+ teacher.eval()
+ batch = tuple(t.to(args.device) for t in batch)
+
+ inputs = {
+ "input_ids": batch[0],
+ "attention_mask": batch[1],
+ "start_positions": batch[3],
+ "end_positions": batch[4],
+ }
+ if args.model_type != "distilbert":
+ inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
+ if args.model_type in ["xlnet", "xlm"]:
+ inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
+ if args.version_2_with_negative:
+ inputs.update({"is_impossible": batch[7]})
+ outputs = model(**inputs)
+ loss, start_logits_stu, end_logits_stu = outputs
+
+ # Distillation loss
+ if teacher is not None:
+ if "token_type_ids" not in inputs:
+ inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
+ with torch.no_grad():
+ start_logits_tea, end_logits_tea = teacher(
+ input_ids=inputs["input_ids"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"],
+ )
+ assert start_logits_tea.size() == start_logits_stu.size()
+ assert end_logits_tea.size() == end_logits_stu.size()
+
+ loss_fct = nn.KLDivLoss(reduction="batchmean")
+ loss_start = loss_fct(
+ nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
+ nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
+ ) * (args.temperature**2)
+ loss_end = loss_fct(
+ nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
+ nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
+ ) * (args.temperature**2)
+ loss_ce = (loss_start + loss_end) / 2.0
+
+ loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
+
+ if args.n_gpu > 1:
+ loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
+ if args.gradient_accumulation_steps > 1:
+ loss = loss / args.gradient_accumulation_steps
+
+ if args.fp16:
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss.backward()
+
+ tr_loss += loss.item()
+ if (step + 1) % args.gradient_accumulation_steps == 0:
+ if args.fp16:
+ nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
+ else:
+ nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+ optimizer.step()
+ scheduler.step() # Update learning rate schedule
+ model.zero_grad()
+ global_step += 1
+
+ # Log metrics
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
+ # Only evaluate when single GPU otherwise metrics may not average well
+ if args.local_rank == -1 and args.evaluate_during_training:
+ results = evaluate(args, model, tokenizer)
+ for key, value in results.items():
+ tb_writer.add_scalar("eval_{}".format(key), value, global_step)
+ tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
+ tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
+ logging_loss = tr_loss
+
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ model_to_save = (
+ model.module if hasattr(model, "module") else model
+ ) # Take care of distributed/parallel training
+ model_to_save.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+
+ torch.save(args, os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s", output_dir)
+
+ torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
+ torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
+ logger.info("Saving optimizer and scheduler states to %s", output_dir)
+
+ if args.max_steps > 0 and global_step > args.max_steps:
+ epoch_iterator.close()
+ break
+ if args.max_steps > 0 and global_step > args.max_steps:
+ train_iterator.close()
+ break
+
+ if args.local_rank in [-1, 0]:
+ tb_writer.close()
+
+ return global_step, tr_loss / global_step
+
+
+def evaluate(args, model, tokenizer, prefix=""):
+ dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
+
+ if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
+ os.makedirs(args.output_dir)
+
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
+
+ # Note that DistributedSampler samples randomly
+ eval_sampler = SequentialSampler(dataset)
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
+
+ # multi-gpu evaluate
+ if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
+ model = nn.DataParallel(model)
+
+ # Eval!
+ logger.info("***** Running evaluation {} *****".format(prefix))
+ logger.info(" Num examples = %d", len(dataset))
+ logger.info(" Batch size = %d", args.eval_batch_size)
+
+ all_results = []
+ start_time = timeit.default_timer()
+
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
+ model.eval()
+ batch = tuple(t.to(args.device) for t in batch)
+
+ with torch.no_grad():
+ inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
+ if args.model_type != "distilbert":
+ inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
+ example_indices = batch[3]
+ if args.model_type in ["xlnet", "xlm"]:
+ inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
+
+ outputs = model(**inputs)
+
+ for i, example_index in enumerate(example_indices):
+ eval_feature = features[example_index.item()]
+ unique_id = int(eval_feature.unique_id)
+
+ output = [to_list(output[i]) for output in outputs]
+
+ # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
+ # models only use two.
+ if len(output) >= 5:
+ start_logits = output[0]
+ start_top_index = output[1]
+ end_logits = output[2]
+ end_top_index = output[3]
+ cls_logits = output[4]
+
+ result = SquadResult(
+ unique_id,
+ start_logits,
+ end_logits,
+ start_top_index=start_top_index,
+ end_top_index=end_top_index,
+ cls_logits=cls_logits,
+ )
+
+ else:
+ start_logits, end_logits = output
+ result = SquadResult(unique_id, start_logits, end_logits)
+
+ all_results.append(result)
+
+ evalTime = timeit.default_timer() - start_time
+ logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
+
+ # Compute predictions
+ output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
+ output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
+
+ if args.version_2_with_negative:
+ output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
+ else:
+ output_null_log_odds_file = None
+
+ if args.model_type in ["xlnet", "xlm"]:
+ # XLNet uses a more complex post-processing procedure
+ predictions = compute_predictions_log_probs(
+ examples,
+ features,
+ all_results,
+ args.n_best_size,
+ args.max_answer_length,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ model.config.start_n_top,
+ model.config.end_n_top,
+ args.version_2_with_negative,
+ tokenizer,
+ args.verbose_logging,
+ )
+ else:
+ predictions = compute_predictions_logits(
+ examples,
+ features,
+ all_results,
+ args.n_best_size,
+ args.max_answer_length,
+ args.do_lower_case,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ args.verbose_logging,
+ args.version_2_with_negative,
+ args.null_score_diff_threshold,
+ tokenizer,
+ )
+
+ # Compute the F1 and exact scores.
+ results = squad_evaluate(examples, predictions)
+ return results
+
+
+def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
+ if args.local_rank not in [-1, 0] and not evaluate:
+ # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+ torch.distributed.barrier()
+
+ # Load data features from cache or dataset file
+ input_file = args.predict_file if evaluate else args.train_file
+ cached_features_file = os.path.join(
+ os.path.dirname(input_file),
+ "cached_distillation_{}_{}_{}".format(
+ "dev" if evaluate else "train",
+ list(filter(None, args.model_name_or_path.split("/"))).pop(),
+ str(args.max_seq_length),
+ ),
+ )
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
+ logger.info("Loading features from cached file %s", cached_features_file)
+ features_and_dataset = torch.load(cached_features_file)
+
+ try:
+ features, dataset, examples = (
+ features_and_dataset["features"],
+ features_and_dataset["dataset"],
+ features_and_dataset["examples"],
+ )
+ except KeyError:
+ raise DeprecationWarning(
+ "You seem to be loading features from an older version of this script please delete the "
+ "file %s in order for it to be created again" % cached_features_file
+ )
+ else:
+ logger.info("Creating features from dataset file at %s", input_file)
+ processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
+ if evaluate:
+ examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
+ else:
+ examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
+
+ features, dataset = squad_convert_examples_to_features(
+ examples=examples,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length,
+ doc_stride=args.doc_stride,
+ max_query_length=args.max_query_length,
+ is_training=not evaluate,
+ return_dataset="pt",
+ threads=args.threads,
+ )
+
+ if args.local_rank in [-1, 0]:
+ logger.info("Saving features into cached file %s", cached_features_file)
+ torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
+
+ if args.local_rank == 0 and not evaluate:
+ # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+ torch.distributed.barrier()
+
+ if output_examples:
+ return dataset, examples, features
+ return dataset
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--model_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
+ )
+ parser.add_argument(
+ "--model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The output directory where the model checkpoints and predictions will be written.",
+ )
+
+ # Distillation parameters (optional)
+ parser.add_argument(
+ "--teacher_type",
+ default=None,
+ type=str,
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
+ )
+ parser.add_argument(
+ "--teacher_name_or_path",
+ default=None,
+ type=str,
+ help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
+ )
+ parser.add_argument(
+ "--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
+ )
+ parser.add_argument(
+ "--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
+ )
+ parser.add_argument(
+ "--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
+ )
+
+ # Other parameters
+ parser.add_argument(
+ "--data_dir",
+ default=None,
+ type=str,
+ help="The input data dir. Should contain the .json files for the task."
+ + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
+ )
+ parser.add_argument(
+ "--train_file",
+ default=None,
+ type=str,
+ help="The input training file. If a data dir is specified, will look for the file there"
+ + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
+ )
+ parser.add_argument(
+ "--predict_file",
+ default=None,
+ type=str,
+ help="The input evaluation file. If a data dir is specified, will look for the file there"
+ + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
+ )
+ parser.add_argument(
+ "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ default="",
+ type=str,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ default="",
+ type=str,
+ help="Where do you want to store the pre-trained models downloaded from huggingface.co",
+ )
+
+ parser.add_argument(
+ "--version_2_with_negative",
+ action="store_true",
+ help="If true, the SQuAD examples contain some that do not have an answer.",
+ )
+ parser.add_argument(
+ "--null_score_diff_threshold",
+ type=float,
+ default=0.0,
+ help="If null_score - best_non_null is greater than the threshold predict null.",
+ )
+
+ parser.add_argument(
+ "--max_seq_length",
+ default=384,
+ type=int,
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
+ )
+ parser.add_argument(
+ "--doc_stride",
+ default=128,
+ type=int,
+ help="When splitting up a long document into chunks, how much stride to take between chunks.",
+ )
+ parser.add_argument(
+ "--max_query_length",
+ default=64,
+ type=int,
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
+ )
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
+ parser.add_argument(
+ "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
+ )
+ parser.add_argument(
+ "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
+ )
+
+ parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
+ parser.add_argument(
+ "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
+ )
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument(
+ "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
+ )
+ parser.add_argument(
+ "--max_steps",
+ default=-1,
+ type=int,
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
+ )
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
+ parser.add_argument(
+ "--n_best_size",
+ default=20,
+ type=int,
+ help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
+ )
+ parser.add_argument(
+ "--max_answer_length",
+ default=30,
+ type=int,
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
+ )
+ parser.add_argument(
+ "--verbose_logging",
+ action="store_true",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
+ )
+
+ parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
+ parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--eval_all_checkpoints",
+ action="store_true",
+ help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
+ )
+ parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
+ parser.add_argument(
+ "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
+ )
+ parser.add_argument(
+ "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
+ )
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
+
+ parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
+ parser.add_argument(
+ "--fp16",
+ action="store_true",
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
+ )
+ parser.add_argument(
+ "--fp16_opt_level",
+ type=str,
+ default="O1",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
+ )
+ parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
+ parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
+
+ parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
+ args = parser.parse_args()
+
+ if (
+ os.path.exists(args.output_dir)
+ and os.listdir(args.output_dir)
+ and args.do_train
+ and not args.overwrite_output_dir
+ ):
+ raise ValueError(
+ "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
+ args.output_dir
+ )
+ )
+
+ # Setup distant debugging if needed
+ if args.server_ip and args.server_port:
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
+ import ptvsd
+
+ print("Waiting for debugger attach")
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
+ ptvsd.wait_for_attach()
+
+ # Setup CUDA, GPU & distributed training
+ if args.local_rank == -1 or args.no_cuda:
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
+ args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
+ else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
+ torch.cuda.set_device(args.local_rank)
+ device = torch.device("cuda", args.local_rank)
+ torch.distributed.init_process_group(backend="nccl")
+ args.n_gpu = 1
+ args.device = device
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
+ )
+ logger.warning(
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+ args.local_rank,
+ device,
+ args.n_gpu,
+ bool(args.local_rank != -1),
+ args.fp16,
+ )
+ # Set the verbosity to info of the Transformers logger (on main process only):
+ if is_main_process(args.local_rank):
+ transformers.utils.logging.set_verbosity_info()
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+ # Set seed
+ set_seed(args)
+
+ # Load pretrained model and tokenizer
+ if args.local_rank not in [-1, 0]:
+ # Make sure only the first process in distributed training will download model & vocab
+ torch.distributed.barrier()
+
+ args.model_type = args.model_type.lower()
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+ config = config_class.from_pretrained(
+ args.config_name if args.config_name else args.model_name_or_path,
+ cache_dir=args.cache_dir if args.cache_dir else None,
+ )
+ tokenizer = tokenizer_class.from_pretrained(
+ args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
+ do_lower_case=args.do_lower_case,
+ cache_dir=args.cache_dir if args.cache_dir else None,
+ )
+ model = model_class.from_pretrained(
+ args.model_name_or_path,
+ from_tf=bool(".ckpt" in args.model_name_or_path),
+ config=config,
+ cache_dir=args.cache_dir if args.cache_dir else None,
+ )
+
+ if args.teacher_type is not None:
+ assert args.teacher_name_or_path is not None
+ assert args.alpha_ce > 0.0
+ assert args.alpha_ce + args.alpha_squad > 0.0
+ assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
+ teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
+ teacher_config = teacher_config_class.from_pretrained(
+ args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
+ )
+ teacher = teacher_model_class.from_pretrained(
+ args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
+ )
+ teacher.to(args.device)
+ else:
+ teacher = None
+
+ if args.local_rank == 0:
+ # Make sure only the first process in distributed training will download model & vocab
+ torch.distributed.barrier()
+
+ model.to(args.device)
+
+ logger.info("Training/evaluation parameters %s", args)
+
+ # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
+ # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
+ # remove the need for this code, but it is still valid.
+ if args.fp16:
+ try:
+ import apex
+
+ apex.amp.register_half_function(torch, "einsum")
+ except ImportError:
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
+
+ # Training
+ if args.do_train:
+ train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
+ global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
+
+ # Save the trained model and the tokenizer
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
+ logger.info("Saving model checkpoint to %s", args.output_dir)
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ model_to_save = (
+ model.module if hasattr(model, "module") else model
+ ) # Take care of distributed/parallel training
+ model_to_save.save_pretrained(args.output_dir)
+ tokenizer.save_pretrained(args.output_dir)
+
+ # Good practice: save your training arguments together with the trained model
+ torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
+
+ # Load a trained model and vocabulary that you have fine-tuned
+ model = model_class.from_pretrained(args.output_dir)
+ tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
+ model.to(args.device)
+
+ # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
+ results = {}
+ if args.do_eval and args.local_rank in [-1, 0]:
+ if args.do_train:
+ logger.info("Loading checkpoints saved during training for evaluation")
+ checkpoints = [args.output_dir]
+ if args.eval_all_checkpoints:
+ checkpoints = [
+ os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
+ ]
+
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
+
+ for checkpoint in checkpoints:
+ # Reload the model
+ global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
+ model = model_class.from_pretrained(checkpoint)
+ model.to(args.device)
+
+ # Evaluate
+ result = evaluate(args, model, tokenizer, prefix=global_step)
+
+ result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
+ results.update(result)
+
+ logger.info("Results: {}".format(results))
+
+ return results
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Experiments/NLP/distillation/scripts/binarized_data.py b/Experiments/NLP/distillation/scripts/binarized_data.py
new file mode 100644
index 0000000..3fc3214
--- /dev/null
+++ b/Experiments/NLP/distillation/scripts/binarized_data.py
@@ -0,0 +1,97 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocessing script before distillation.
+"""
+
+import argparse
+import logging
+import pickle
+import random
+import time
+
+import numpy as np
+
+from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
+
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
+ )
+ parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
+ parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
+ parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
+ parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
+ args = parser.parse_args()
+
+ logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
+ if args.tokenizer_type == "bert":
+ tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
+ bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
+ sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
+ elif args.tokenizer_type == "roberta":
+ tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
+ bos = tokenizer.special_tokens_map["cls_token"] # ``
+ sep = tokenizer.special_tokens_map["sep_token"] # ``
+ elif args.tokenizer_type == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
+ bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
+ sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
+
+ logger.info(f"Loading text from {args.file_path}")
+ with open(args.file_path, "r", encoding="utf8") as fp:
+ data = fp.readlines()
+
+ logger.info("Start encoding")
+ logger.info(f"{len(data)} examples to process.")
+
+ rslt = []
+ iter = 0
+ interval = 10000
+ start = time.time()
+ for text in data:
+ text = f"{bos} {text.strip()} {sep}"
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
+ rslt.append(token_ids)
+
+ iter += 1
+ if iter % interval == 0:
+ end = time.time()
+ logger.info(f"{iter} examples processed. - {(end-start):.2f}s/{interval}expl")
+ start = time.time()
+ logger.info("Finished binarization")
+ logger.info(f"{len(data)} examples processed.")
+
+ dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
+ vocab_size = tokenizer.vocab_size
+ if vocab_size < (1 << 16):
+ rslt_ = [np.uint16(d) for d in rslt]
+ else:
+ rslt_ = [np.int32(d) for d in rslt]
+ random.shuffle(rslt_)
+ logger.info(f"Dump to {dp_file}")
+ with open(dp_file, "wb") as handle:
+ pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Experiments/NLP/distillation/scripts/extract.py b/Experiments/NLP/distillation/scripts/extract.py
new file mode 100644
index 0000000..c45821d
--- /dev/null
+++ b/Experiments/NLP/distillation/scripts/extract.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocessing script before training the distilled model.
+Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
+"""
+
+import argparse
+
+import torch
+
+from transformers import GPT2LMHeadModel, RobertaForMaskedLM
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=(
+ "Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
+ " Distillation"
+ )
+ )
+ parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
+ parser.add_argument("--model_name", default="roberta-large", type=str)
+ parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
+ parser.add_argument("--vocab_transform", action="store_true")
+ args = parser.parse_args()
+
+ if args.model_type == "roberta":
+ model = RobertaForMaskedLM.from_pretrained(args.model_name)
+ prefix = "roberta"
+ elif args.model_type == "gpt2":
+ model = GPT2LMHeadModel.from_pretrained(args.model_name)
+ prefix = "transformer"
+
+ state_dict = model.state_dict()
+ compressed_sd = {}
+
+ # Embeddings #
+ if args.model_type == "gpt2":
+ for param_name in ["wte.weight", "wpe.weight"]:
+ compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
+ else:
+ for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
+ param_name = f"{prefix}.embeddings.{w}.weight"
+ compressed_sd[param_name] = state_dict[param_name]
+ for w in ["weight", "bias"]:
+ param_name = f"{prefix}.embeddings.LayerNorm.{w}"
+ compressed_sd[param_name] = state_dict[param_name]
+
+ # Transformer Blocks #
+ std_idx = 0
+ for teacher_idx in [0, 2, 4, 7, 9, 11]:
+ if args.model_type == "gpt2":
+ for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
+ for w in ["weight", "bias"]:
+ compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
+ f"{prefix}.h.{teacher_idx}.{layer}.{w}"
+ ]
+ compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
+ else:
+ for layer in [
+ "attention.self.query",
+ "attention.self.key",
+ "attention.self.value",
+ "attention.output.dense",
+ "attention.output.LayerNorm",
+ "intermediate.dense",
+ "output.dense",
+ "output.LayerNorm",
+ ]:
+ for w in ["weight", "bias"]:
+ compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
+ ]
+ std_idx += 1
+
+ # Language Modeling Head ###s
+ if args.model_type == "roberta":
+ for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
+ compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
+ if args.vocab_transform:
+ for w in ["weight", "bias"]:
+ compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
+ compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
+ elif args.model_type == "gpt2":
+ for w in ["weight", "bias"]:
+ compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
+ compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
+
+ print(f"N layers selected for distillation: {std_idx}")
+ print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
+
+ print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
+ torch.save(compressed_sd, args.dump_checkpoint)
diff --git a/Experiments/NLP/distillation/scripts/extract_distilbert.py b/Experiments/NLP/distillation/scripts/extract_distilbert.py
new file mode 100644
index 0000000..8637970
--- /dev/null
+++ b/Experiments/NLP/distillation/scripts/extract_distilbert.py
@@ -0,0 +1,96 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocessing script before training DistilBERT.
+Specific to BERT -> DistilBERT.
+"""
+
+import argparse
+
+import torch
+
+from transformers import BertForMaskedLM
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=(
+ "Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
+ " Distillation"
+ )
+ )
+ parser.add_argument("--model_type", default="bert", choices=["bert"])
+ parser.add_argument("--model_name", default="bert-base-uncased", type=str)
+ parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
+ parser.add_argument("--vocab_transform", action="store_true")
+ args = parser.parse_args()
+
+ if args.model_type == "bert":
+ model = BertForMaskedLM.from_pretrained(args.model_name)
+ prefix = "bert"
+ else:
+ raise ValueError('args.model_type should be "bert".')
+
+ state_dict = model.state_dict()
+ compressed_sd = {}
+
+ for w in ["word_embeddings", "position_embeddings"]:
+ compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
+ for w in ["weight", "bias"]:
+ compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
+
+ std_idx = 0
+ for teacher_idx in [0, 2, 4, 7, 9, 11]:
+ for w in ["weight", "bias"]:
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
+ ]
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
+ ]
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
+ ]
+
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
+ ]
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
+ ]
+
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
+ ]
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
+ ]
+ compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
+ f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
+ ]
+ std_idx += 1
+
+ compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
+ compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
+ if args.vocab_transform:
+ for w in ["weight", "bias"]:
+ compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
+ compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
+
+ print(f"N layers selected for distillation: {std_idx}")
+ print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
+
+ print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
+ torch.save(compressed_sd, args.dump_checkpoint)
diff --git a/Experiments/NLP/distillation/scripts/token_counts.py b/Experiments/NLP/distillation/scripts/token_counts.py
new file mode 100644
index 0000000..2f80bf3
--- /dev/null
+++ b/Experiments/NLP/distillation/scripts/token_counts.py
@@ -0,0 +1,57 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocessing script before training the distilled model.
+"""
+
+import argparse
+import logging
+import pickle
+from collections import Counter
+
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
+ )
+ parser.add_argument(
+ "--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
+ )
+ parser.add_argument(
+ "--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
+ )
+ parser.add_argument("--vocab_size", default=30522, type=int)
+ args = parser.parse_args()
+
+ logger.info(f"Loading data from {args.data_file}")
+ with open(args.data_file, "rb") as fp:
+ data = pickle.load(fp)
+
+ logger.info("Counting occurrences for MLM.")
+ counter = Counter()
+ for tk_ids in data:
+ counter.update(tk_ids)
+ counts = [0] * args.vocab_size
+ for k, v in counter.items():
+ counts[k] = v
+
+ logger.info(f"Dump to {args.token_counts_dump}")
+ with open(args.token_counts_dump, "wb") as handle:
+ pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
diff --git a/Experiments/NLP/distillation/train.py b/Experiments/NLP/distillation/train.py
new file mode 100644
index 0000000..15d98ac
--- /dev/null
+++ b/Experiments/NLP/distillation/train.py
@@ -0,0 +1,325 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Training the distilled model.
+Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
+"""
+
+import argparse
+import json
+import os
+import pickle
+import shutil
+
+import numpy as np
+import torch
+from distiller import Distiller
+from lm_seqs_dataset import LmSeqsDataset
+
+from transformers import (
+ BertConfig,
+ BertForMaskedLM,
+ BertTokenizer,
+ DistilBertConfig,
+ DistilBertForMaskedLM,
+ DistilBertTokenizer,
+ GPT2Config,
+ GPT2LMHeadModel,
+ GPT2Tokenizer,
+ RobertaConfig,
+ RobertaForMaskedLM,
+ RobertaTokenizer,
+)
+from utils import git_log, init_gpu_params, logger, set_seed
+
+
+MODEL_CLASSES = {
+ "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
+ "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
+ "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
+ "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
+}
+
+
+def sanity_checks(args):
+ """
+ A bunch of args sanity checks to perform even starting...
+ """
+ assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
+ assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
+ if args.mlm:
+ assert os.path.isfile(args.token_counts)
+ assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
+ else:
+ assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
+
+ assert args.teacher_type == args.student_type or (
+ args.student_type == "distilbert" and args.teacher_type == "bert"
+ )
+ assert os.path.isfile(args.student_config)
+ if args.student_pretrained_weights is not None:
+ assert os.path.isfile(args.student_pretrained_weights)
+
+ if args.freeze_token_type_embds:
+ assert args.student_type in ["roberta"]
+
+ assert args.alpha_ce >= 0.0
+ assert args.alpha_mlm >= 0.0
+ assert args.alpha_clm >= 0.0
+ assert args.alpha_mse >= 0.0
+ assert args.alpha_cos >= 0.0
+ assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
+
+
+def freeze_pos_embeddings(student, args):
+ if args.student_type == "roberta":
+ student.roberta.embeddings.position_embeddings.weight.requires_grad = False
+ elif args.student_type == "gpt2":
+ student.transformer.wpe.weight.requires_grad = False
+
+
+def freeze_token_type_embeddings(student, args):
+ if args.student_type == "roberta":
+ student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Training")
+ parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
+
+ parser.add_argument(
+ "--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
+ )
+ parser.add_argument(
+ "--data_file",
+ type=str,
+ required=True,
+ help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
+ )
+
+ parser.add_argument(
+ "--student_type",
+ type=str,
+ choices=["distilbert", "roberta", "gpt2"],
+ required=True,
+ help="The student type (DistilBERT, RoBERTa).",
+ )
+ parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
+ parser.add_argument(
+ "--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
+ )
+
+ parser.add_argument(
+ "--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
+ )
+ parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
+
+ parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
+ parser.add_argument(
+ "--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
+ )
+ parser.add_argument(
+ "--alpha_mlm",
+ default=0.0,
+ type=float,
+ help="Linear weight for the MLM loss. Must be >=0. Should be used in conjunction with `mlm` flag.",
+ )
+ parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
+ parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
+ parser.add_argument(
+ "--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
+ )
+
+ parser.add_argument(
+ "--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
+ )
+ parser.add_argument(
+ "--mlm_mask_prop",
+ default=0.15,
+ type=float,
+ help="Proportion of tokens for which we need to make a prediction.",
+ )
+ parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
+ parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
+ parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
+ parser.add_argument(
+ "--mlm_smoothing",
+ default=0.7,
+ type=float,
+ help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
+ )
+ parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
+
+ parser.add_argument(
+ "--restrict_ce_to_mask",
+ action="store_true",
+ help="If true, compute the distillation loss only the [MLM] prediction distribution.",
+ )
+ parser.add_argument(
+ "--freeze_pos_embs",
+ action="store_true",
+ help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
+ )
+ parser.add_argument(
+ "--freeze_token_type_embds",
+ action="store_true",
+ help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
+ )
+
+ parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
+ parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
+ parser.add_argument(
+ "--group_by_size",
+ action="store_false",
+ help="If true, group sequences that have similar length into the same batch. Default is true.",
+ )
+
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=50,
+ help="Gradient accumulation for larger training batches.",
+ )
+ parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
+ parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
+ parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
+ parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
+
+ parser.add_argument(
+ "--fp16",
+ action="store_true",
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
+ )
+ parser.add_argument(
+ "--fp16_opt_level",
+ type=str,
+ default="O1",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
+ )
+ parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
+ parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
+ parser.add_argument("--seed", type=int, default=56, help="Random seed")
+
+ parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
+ parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
+ args = parser.parse_args()
+ sanity_checks(args)
+
+ # ARGS #
+ init_gpu_params(args)
+ set_seed(args)
+ if args.is_master:
+ if os.path.exists(args.dump_path):
+ if not args.force:
+ raise ValueError(
+ f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
+ " itUse `--force` if you want to overwrite it"
+ )
+ else:
+ shutil.rmtree(args.dump_path)
+
+ if not os.path.exists(args.dump_path):
+ os.makedirs(args.dump_path)
+ logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
+
+ # SAVE PARAMS #
+ logger.info(f"Param: {args}")
+ with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
+ json.dump(vars(args), f, indent=4)
+ git_log(args.dump_path)
+
+ student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
+ teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
+
+ # TOKENIZER #
+ tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
+ special_tok_ids = {}
+ for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
+ idx = tokenizer.all_special_tokens.index(tok_symbol)
+ special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
+ logger.info(f"Special tokens {special_tok_ids}")
+ args.special_tok_ids = special_tok_ids
+ args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
+
+ # DATA LOADER #
+ logger.info(f"Loading data from {args.data_file}")
+ with open(args.data_file, "rb") as fp:
+ data = pickle.load(fp)
+
+ if args.mlm:
+ logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
+ with open(args.token_counts, "rb") as fp:
+ counts = pickle.load(fp)
+
+ token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
+ for idx in special_tok_ids.values():
+ token_probs[idx] = 0.0 # do not predict special tokens
+ token_probs = torch.from_numpy(token_probs)
+ else:
+ token_probs = None
+
+ train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
+ logger.info("Data loader created.")
+
+ # STUDENT #
+ logger.info(f"Loading student config from {args.student_config}")
+ stu_architecture_config = student_config_class.from_pretrained(args.student_config)
+ stu_architecture_config.output_hidden_states = True
+
+ if args.student_pretrained_weights is not None:
+ logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
+ student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
+ else:
+ student = student_model_class(stu_architecture_config)
+
+ if args.n_gpu > 0:
+ student.to(f"cuda:{args.local_rank}")
+ logger.info("Student loaded.")
+
+ # TEACHER #
+ teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
+ if args.n_gpu > 0:
+ teacher.to(f"cuda:{args.local_rank}")
+ logger.info(f"Teacher loaded from {args.teacher_name}.")
+
+ # FREEZING #
+ if args.freeze_pos_embs:
+ freeze_pos_embeddings(student, args)
+ if args.freeze_token_type_embds:
+ freeze_token_type_embeddings(student, args)
+
+ # SANITY CHECKS #
+ assert student.config.vocab_size == teacher.config.vocab_size
+ assert student.config.hidden_size == teacher.config.hidden_size
+ assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
+ if args.mlm:
+ assert token_probs.size(0) == stu_architecture_config.vocab_size
+
+ # DISTILLER #
+ torch.cuda.empty_cache()
+ distiller = Distiller(
+ params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
+ )
+ distiller.train()
+ logger.info("Let's go get some drinks.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Experiments/NLP/distillation/training_configs/distilbert-base-cased.json b/Experiments/NLP/distillation/training_configs/distilbert-base-cased.json
new file mode 100644
index 0000000..d4f524d
--- /dev/null
+++ b/Experiments/NLP/distillation/training_configs/distilbert-base-cased.json
@@ -0,0 +1,15 @@
+{
+ "activation": "gelu",
+ "attention_dropout": 0.1,
+ "dim": 768,
+ "dropout": 0.1,
+ "hidden_dim": 3072,
+ "initializer_range": 0.02,
+ "max_position_embeddings": 512,
+ "n_heads": 12,
+ "n_layers": 6,
+ "sinusoidal_pos_embds": true,
+ "tie_weights_": true,
+ "vocab_size": 28996
+ }
+
\ No newline at end of file
diff --git a/Experiments/NLP/distillation/training_configs/distilbert-base-multilingual-cased.json b/Experiments/NLP/distillation/training_configs/distilbert-base-multilingual-cased.json
new file mode 100644
index 0000000..f76e7fe
--- /dev/null
+++ b/Experiments/NLP/distillation/training_configs/distilbert-base-multilingual-cased.json
@@ -0,0 +1,15 @@
+{
+ "activation": "gelu",
+ "attention_dropout": 0.1,
+ "dim": 768,
+ "dropout": 0.1,
+ "hidden_dim": 3072,
+ "initializer_range": 0.02,
+ "max_position_embeddings": 512,
+ "n_heads": 12,
+ "n_layers": 6,
+ "sinusoidal_pos_embds": true,
+ "tie_weights_": true,
+ "vocab_size": 119547
+ }
+
\ No newline at end of file
diff --git a/Experiments/NLP/distillation/training_configs/distilbert-base-uncased.json b/Experiments/NLP/distillation/training_configs/distilbert-base-uncased.json
new file mode 100644
index 0000000..15d1e7f
--- /dev/null
+++ b/Experiments/NLP/distillation/training_configs/distilbert-base-uncased.json
@@ -0,0 +1,15 @@
+{
+ "activation": "gelu",
+ "attention_dropout": 0.1,
+ "dim": 768,
+ "dropout": 0.1,
+ "hidden_dim": 3072,
+ "initializer_range": 0.02,
+ "max_position_embeddings": 512,
+ "n_heads": 12,
+ "n_layers": 6,
+ "sinusoidal_pos_embds": true,
+ "tie_weights_": true,
+ "vocab_size": 30522
+ }
+
\ No newline at end of file
diff --git a/Experiments/NLP/distillation/training_configs/distilgpt2.json b/Experiments/NLP/distillation/training_configs/distilgpt2.json
new file mode 100644
index 0000000..9820ac9
--- /dev/null
+++ b/Experiments/NLP/distillation/training_configs/distilgpt2.json
@@ -0,0 +1,9 @@
+{
+ "initializer_range": 0.02,
+ "layer_norm_epsilon": 0.00001,
+ "n_embd": 768,
+ "n_head": 12,
+ "n_layer": 6,
+ "n_positions": 1024,
+ "vocab_size": 50257
+}
\ No newline at end of file
diff --git a/Experiments/NLP/distillation/training_configs/distilroberta-base.json b/Experiments/NLP/distillation/training_configs/distilroberta-base.json
new file mode 100644
index 0000000..2d90ef6
--- /dev/null
+++ b/Experiments/NLP/distillation/training_configs/distilroberta-base.json
@@ -0,0 +1,14 @@
+{
+ "vocab_size": 50265,
+ "hidden_size": 768,
+ "num_hidden_layers": 6,
+ "num_attention_heads": 12,
+ "intermediate_size": 3072,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "attention_probs_dropout_prob": 0.1,
+ "max_position_embeddings": 514,
+ "type_vocab_size": 1,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 0.00001
+}
\ No newline at end of file
diff --git a/Experiments/NLP/distillation/utils.py b/Experiments/NLP/distillation/utils.py
new file mode 100644
index 0000000..e86d259
--- /dev/null
+++ b/Experiments/NLP/distillation/utils.py
@@ -0,0 +1,134 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utils to train DistilBERT
+adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
+"""
+
+import json
+import logging
+import os
+import socket
+
+import git
+import numpy as np
+import torch
+
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+
+def git_log(folder_path: str):
+ """
+ Log commit info.
+ """
+ repo = git.Repo(search_parent_directories=True)
+ repo_infos = {
+ "repo_id": str(repo),
+ "repo_sha": str(repo.head.object.hexsha),
+ "repo_branch": str(repo.active_branch),
+ }
+
+ with open(os.path.join(folder_path, "git_log.json"), "w") as f:
+ json.dump(repo_infos, f, indent=4)
+
+
+def init_gpu_params(params):
+ """
+ Handle single and multi-GPU / multi-node.
+ """
+ if params.n_gpu <= 0:
+ params.local_rank = 0
+ params.master_port = -1
+ params.is_master = True
+ params.multi_gpu = False
+ return
+
+ assert torch.cuda.is_available()
+
+ logger.info("Initializing GPUs")
+ if params.n_gpu > 1:
+ assert params.local_rank != -1
+
+ params.world_size = int(os.environ["WORLD_SIZE"])
+ params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
+ params.global_rank = int(os.environ["RANK"])
+
+ # number of nodes / node ID
+ params.n_nodes = params.world_size // params.n_gpu_per_node
+ params.node_id = params.global_rank // params.n_gpu_per_node
+ params.multi_gpu = True
+
+ assert params.n_nodes == int(os.environ["N_NODES"])
+ assert params.node_id == int(os.environ["NODE_RANK"])
+
+ # local job (single GPU)
+ else:
+ assert params.local_rank == -1
+
+ params.n_nodes = 1
+ params.node_id = 0
+ params.local_rank = 0
+ params.global_rank = 0
+ params.world_size = 1
+ params.n_gpu_per_node = 1
+ params.multi_gpu = False
+
+ # sanity checks
+ assert params.n_nodes >= 1
+ assert 0 <= params.node_id < params.n_nodes
+ assert 0 <= params.local_rank <= params.global_rank < params.world_size
+ assert params.world_size == params.n_nodes * params.n_gpu_per_node
+
+ # define whether this is the master process / if we are in multi-node distributed mode
+ params.is_master = params.node_id == 0 and params.local_rank == 0
+ params.multi_node = params.n_nodes > 1
+
+ # summary
+ PREFIX = f"--- Global rank: {params.global_rank} - "
+ logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
+ logger.info(PREFIX + "Node ID : %i" % params.node_id)
+ logger.info(PREFIX + "Local rank : %i" % params.local_rank)
+ logger.info(PREFIX + "World size : %i" % params.world_size)
+ logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
+ logger.info(PREFIX + "Master : %s" % str(params.is_master))
+ logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
+ logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
+ logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
+
+ # set GPU device
+ torch.cuda.set_device(params.local_rank)
+
+ # initialize multi-GPU
+ if params.multi_gpu:
+ logger.info("Initializing PyTorch distributed")
+ torch.distributed.init_process_group(
+ init_method="env://",
+ backend="nccl",
+ )
+
+
+def set_seed(args):
+ """
+ Set the random seed.
+ """
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.n_gpu > 0:
+ torch.cuda.manual_seed_all(args.seed)