From 01f749f503c01e2698fdf77287480892b6d07610 Mon Sep 17 00:00:00 2001 From: sgrvinod Date: Mon, 4 Jun 2018 08:20:03 +0530 Subject: [PATCH] added comments --- README.md | 10 ++--- caption.py | 29 ++++++++++++-- datasets.py | 11 ++++++ models.py | 106 ++++++++++++++++++++++++++++++++++++++++++---------- train.py | 67 ++++++++++++++++++++++++--------- utils.py | 77 +++++++++++++++++++++++++++++++++++++- 6 files changed, 254 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index fd7534f98..eacd6ac31 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -This is a **PyTorch Tutorial to Image Captioning**. +This is a **[PyTorch](https://pytorch.org) Tutorial to Image Captioning**. -This is the first of a series of tutorials I plan to write about _implementing_ cool models on your own with the amazing [PyTorch](https://pytorch.org) library. +This is the first in a series of tutorials I plan to write about _implementing_ cool models on your own with the amazing PyTorch library. Basic knowledge of PyTorch, convolutional and recurrent neural networks is assumed. @@ -24,9 +24,9 @@ I'm using `PyTorch 0.4` in `Python 3.6`. **To build a model that can generate a descriptive caption for an image we provide it.** -In the interest of keeping things simple, let's choose to implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing. +In the interest of keeping things simple, let's implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing. -**This model learns _where_ to look.** +This model learns _where_ to look. As you generate a caption, word by word, you can see the the model's gaze shifting across the image. @@ -78,7 +78,7 @@ There are more examples at the [end of the tutorial](https://github.com/sgrvinod # Overview -In this section, I will present a broad overview of this model. I don't really get into the _minutiae_ here - feel free to skip to the implementation section and commented code for details. +In this section, I will present a broad overview of this model. If you're already familiar with it, you can skip straight to the implementation section or the commented code. ### Encoder diff --git a/caption.py b/caption.py index e3f593b64..145499014 100644 --- a/caption.py +++ b/caption.py @@ -14,6 +14,21 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3): + """ + Reads an image and captions it with beam search. + + :param encoder: encoder model + :param decoder: decoder model + :param image_path: path to image + :param word_map: word map + :param beam_size: number of sequences to consider at each decode-step + :return: caption, weights for visualization + """ + + k = beam_size + vocab_size = len(word_map) + + # Read image and process img = imread(image_path) if len(img.shape) == 2: img = img[:, :, np.newaxis] @@ -27,9 +42,6 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size= transform = transforms.Compose([normalize]) image = transform(img) # (3, 256, 256) - k = beam_size - vocab_size = len(word_map) - # Encode image = image.unsqueeze(0) # (1, 3, 256, 256) encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) @@ -136,6 +148,17 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size= def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True): + """ + Visualizes caption with weights at every word. + + Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb + + :param image_path: path to image that has been captioned + :param seq: caption + :param alphas: weights + :param rev_word_map: reverse word mapping, i.e. ix2word + :param smooth: smooth weights? + """ image = Image.open(image_path) image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) diff --git a/datasets.py b/datasets.py index 3b5b08b01..8ca079fd5 100644 --- a/datasets.py +++ b/datasets.py @@ -6,7 +6,18 @@ class CaptionDataset(Dataset): + """ + A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches. + """ + def __init__(self, data_folder, data_name, split, transform=None): + """ + + :param data_folder: folder where data files are stored + :param data_name: base name of processed datasets + :param split: split, one of 'TRAIN', 'VAL', or 'TEST' + :param transform: image transform pipeline + """ self.split = split assert self.split in {'TRAIN', 'VAL', 'TEST'} diff --git a/models.py b/models.py index 64214337d..4837cd8b8 100644 --- a/models.py +++ b/models.py @@ -6,26 +6,43 @@ class Encoder(nn.Module): + """ + Encoder. + """ + def __init__(self, encoded_image_size=14): super(Encoder, self).__init__() self.enc_image_size = encoded_image_size - resnet = torchvision.models.resnet101(pretrained=True) + resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 + + # Remove linear and pool layers (since we're not doing classification) modules = list(resnet.children())[:-2] self.resnet = nn.Sequential(*modules) + # Resize image to fixed size to allow input images of variable size self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) self.fine_tune() def forward(self, images): - # images.shape = (batch_size, 3, image_size, image_size) + """ + Forward propagation. + + :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) + :return: encoded images + """ out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32) out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size) out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048) return out def fine_tune(self, fine_tune=True): + """ + Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. + + :param fine_tune: Allow? + """ for p in self.resnet.parameters(): p.requires_grad = False # If fine-tuning, only fine-tune convolutional blocks 2 through 4 @@ -35,18 +52,31 @@ def fine_tune(self, fine_tune=True): class Attention(nn.Module): + """ + Attention Network. + """ + def __init__(self, encoder_dim, decoder_dim, attention_dim): + """ + :param encoder_dim: feature size of encoded images + :param decoder_dim: size of decoder's RNN + :param attention_dim: size of the attention network + """ super(Attention, self).__init__() - self.encoder_att = nn.Linear(encoder_dim, attention_dim) - self.decoder_att = nn.Linear(decoder_dim, attention_dim) - self.full_att = nn.Linear(attention_dim, 1) + self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image + self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output + self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed self.relu = nn.ReLU() - self.softmax = nn.Softmax(dim=1) + self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights def forward(self, encoder_out, decoder_hidden): - # encoder_out.shape = (batch_size, num_pixels, encoder_dim) - # decoder_hidden.shape = (batch_size, decoder_dim) + """ + Forward propagation. + :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) + :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) + :return: attention weighted encoding, weights + """ att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) @@ -57,8 +87,21 @@ def forward(self, encoder_out, decoder_hidden): class DecoderWithAttention(nn.Module): + """ + Decoder. + """ + def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_layers=1, encoder_dim=2048, dropout=0.5): + """ + :param attention_dim: size of attention network + :param embed_dim: embedding size + :param decoder_dim: size of decoder's RNN + :param vocab_size: size of vocabulary + :param decoder_layers: number of layers in the decoder + :param encoder_dim: feature size of encoded images + :param dropout: dropout + """ super(DecoderWithAttention, self).__init__() self.encoder_dim = encoder_dim @@ -69,40 +112,65 @@ def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_la self.decoder_layers = decoder_layers self.dropout = dropout - self.attention = Attention(encoder_dim, decoder_dim, attention_dim) + self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network - self.embedding = nn.Embedding(vocab_size, embed_dim) + self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer self.dropout = nn.Dropout(p=self.dropout) - self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, decoder_layers) - self.init_h = nn.Linear(encoder_dim, decoder_dim) - self.init_c = nn.Linear(encoder_dim, decoder_dim) - self.f_beta = nn.Linear(decoder_dim, encoder_dim) + self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, decoder_layers) # decoding LSTMCell + self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell + self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell + self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate self.sigmoid = nn.Sigmoid() - self.fc = nn.Linear(decoder_dim, vocab_size) - self.init_weights() + self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary + self.init_weights() # initialize some layers with the uniform distribution def init_weights(self): + """ + Initializes some parameters with values from the uniform distribution, for easier convergence. + """ self.embedding.weight.data.uniform_(-0.1, 0.1) self.fc.bias.data.fill_(0) self.fc.weight.data.uniform_(-0.1, 0.1) def load_pretrained_embeddings(self, embeddings): + """ + Loads embedding layer with pre-trained embeddings. + + :param embeddings: pre-trained embeddings + """ self.embedding.weight = nn.Parameter(embeddings) def fine_tune_embeddings(self, fine_tune=True): + """ + Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). + + :param fine_tune: Allow? + """ for p in self.embedding.parameters(): p.requires_grad = fine_tune def init_hidden_state(self, encoder_out): + """ + Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. + + :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) + :return: hidden state, cell state + """ mean_encoder_out = encoder_out.mean(dim=1) h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) c = self.init_c(mean_encoder_out) return h, c def forward(self, encoder_out, encoded_captions, caption_lengths): - # encoder_out.shape = (batch_size, image_size, image_size, encoder_dim), image_size being the pixel width/height - # encoded_captions.shape = (batch_size, max_caption_length) - # caption_lengths.shape = (batch_size, 1) + """ + Forward propagation. + + :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) + :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) + :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) + :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices + """ + batch_size = encoder_out.size(0) encoder_dim = encoder_out.size(-1) vocab_size = self.vocab_size diff --git a/train.py b/train.py index 336485127..a07f2df79 100644 --- a/train.py +++ b/train.py @@ -40,12 +40,18 @@ def main(): + """ + Training and validation. + """ + global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map + # Read word map word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json') with open(word_map_file, 'r') as j: word_map = json.load(j) + # Initialize / load checkpoint if checkpoint is None: decoder = DecoderWithAttention(attention_dim=attention_dim, embed_dim=emb_dim, @@ -76,7 +82,7 @@ def main(): # Move to GPU, if available decoder = decoder.to(device) - encoder = encoder.to(device) if encoder is not None else None + encoder = encoder.to(device) # Loss function criterion = nn.CrossEntropyLoss().to(device) @@ -85,15 +91,16 @@ def main(): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader( - hdf5Dataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])), + CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( - hdf5Dataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])), + CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) + # Epochs for epoch in range(start_epoch, epochs): - # Decay learning rate if there is no improvement for 3 consecutive epochs, and terminate training after 8 + # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20 if epochs_since_improvement == 20: break if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0: @@ -101,6 +108,7 @@ def main(): if fine_tune_encoder: adjust_learning_rate(encoder_optimizer, 0.8) + # One epoch's training train(train_loader=train_loader, encoder=encoder, decoder=decoder, @@ -109,32 +117,41 @@ def main(): decoder_optimizer=decoder_optimizer, epoch=epoch) + # One epoch's validation recent_bleu4 = validate(val_loader=val_loader, encoder=encoder, decoder=decoder, criterion=criterion) + # Check if there was an improvement is_best = recent_bleu4 > best_bleu4 best_bleu4 = max(recent_bleu4, best_bleu4) - if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) else: epochs_since_improvement = 0 + # Save checkpoint save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, recent_bleu4, is_best) def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch): """ - Perform one epoch's training. + Performs one epoch's training. + + :param train_loader: DataLoader for training data + :param encoder: encoder model + :param decoder: decoder model + :param criterion: loss layer + :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) + :param decoder_optimizer: optimizer to update decoder's weights + :param epoch: epoch number """ - decoder.train() # train mode - if encoder is not None: - encoder.train() + decoder.train() # train mode (dropout and batchnorm is used) + encoder.train() batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time @@ -143,16 +160,17 @@ def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_ start = time.time() + # Batches for i, (imgs, caps, caplens) in enumerate(train_loader): data_time.update(time.time() - start) + # Move to GPU, if available imgs = imgs.to(device) caps = caps.to(device) caplens = caplens.to(device) # Forward prop. - if encoder is not None: - imgs = encoder(imgs) + imgs = encoder(imgs) scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) # Since we decoded starting with , the targets are all words after , up to @@ -173,14 +191,15 @@ def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_ decoder_optimizer.zero_grad() if encoder_optimizer is not None: encoder_optimizer.zero_grad() - loss.backward() + # Clip gradients if grad_clip is not None: clip_gradient(decoder_optimizer, grad_clip) if encoder_optimizer is not None: clip_gradient(encoder_optimizer, grad_clip) + # Update weights decoder_optimizer.step() if encoder_optimizer is not None: encoder_optimizer.step() @@ -193,6 +212,7 @@ def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_ start = time.time() + # Print status if i % print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' @@ -205,21 +225,32 @@ def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_ def validate(val_loader, encoder, decoder, criterion): - decoder.eval() # eval mode + """ + Performs one epoch's validation. + + :param val_loader: DataLoader for validation data. + :param encoder: encoder model + :param decoder: decoder model + :param criterion: loss layer + :return: BLEU-4 score + """ + decoder.eval() # eval mode (no dropout or batchnorm) if encoder is not None: encoder.eval() - batch_time = AverageMeter() # forward prop. + gradient descent time this batch - losses = AverageMeter() # loss this batch - top5accs = AverageMeter() # top5 accuracy this batch + batch_time = AverageMeter() + losses = AverageMeter() + top5accs = AverageMeter() start = time.time() references = list() # references (true captions) for calculating BLEU-4 score hypotheses = list() # hypotheses (predictions) + # Batches for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): + # Move to device, if available imgs = imgs.to(device) caps = caps.to(device) caplens = caplens.to(device) @@ -269,7 +300,7 @@ def validate(val_loader, encoder, decoder, criterion): img_caps = allcaps[j].tolist() img_captions = list( map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], - img_caps)) # remove , , and pads + img_caps)) # remove and pads references.append(img_captions) # Hypotheses @@ -277,7 +308,7 @@ def validate(val_loader, encoder, decoder, criterion): preds = preds.tolist() temp_preds = list() for j, p in enumerate(preds): - temp_preds.append(preds[j][:decode_lengths[j]]) # remove and pads + temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads preds = temp_preds hypotheses.extend(preds) diff --git a/utils.py b/utils.py index 73e618617..121242d0d 100644 --- a/utils.py +++ b/utils.py @@ -11,11 +11,25 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder, max_len=100): + """ + Creates input files for training, validation, and test data. + + :param dataset: name of dataset, one of 'coco', 'flickr8k', 'flickr30k' + :param karpathy_json_path: path of Karpathy JSON file with splits and captions + :param image_folder: folder with downloaded images + :param captions_per_image: number of captions to sample per image + :param min_word_freq: words occuring less frequently than this threshold are binned as s + :param output_folder: folder to save files + :param max_len: don't sample captions longer than this length + """ + assert dataset in {'coco', 'flickr8k', 'flickr30k'} + # Read Karpathy JSON with open(karpathy_json_path, 'r') as j: data = json.load(j) + # Read image paths and captions for each image train_image_paths = [] train_image_captions = [] val_image_paths = [] @@ -27,6 +41,7 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_i for img in data['images']: captions = [] for c in img['sentences']: + # Update word frequency word_freq.update(c['tokens']) if len(c['tokens']) <= max_len: captions.append(c['tokens']) @@ -47,10 +62,12 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_i test_image_paths.append(path) test_image_captions.append(captions) + # Sanity check assert len(train_image_paths) == len(train_image_captions) assert len(val_image_paths) == len(val_image_captions) assert len(test_image_paths) == len(test_image_captions) + # Create word map words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq] word_map = {k: v + 1 for v, k in enumerate(words)} word_map[''] = len(word_map) + 1 @@ -58,18 +75,24 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_i word_map[''] = len(word_map) + 1 word_map[''] = 0 + # Create a base/root name for all output files base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq' + # Save word map to a JSON with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j: json.dump(word_map, j) + # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files seed(123) for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'), (val_image_paths, val_image_captions, 'VAL'), (test_image_paths, test_image_captions, 'TEST')]: with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h: + # Make a note of the number of captions we are sampling per image h.attrs['captions_per_image'] = captions_per_image + + # Create dataset inside HDF5 file to store images images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8') print("\nReading %s images and captions, storing to file...\n" % split) @@ -78,12 +101,17 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_i caplens = [] for i, path in enumerate(tqdm(impaths)): + + # Sample captions if len(imcaps[i]) < captions_per_image: captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))] else: captions = sample(imcaps[i], k=captions_per_image) + + # Sanity check assert len(captions) == captions_per_image + # Read images img = imread(impaths[i]) if len(img.shape) == 2: img = img[:, :, np.newaxis] @@ -93,19 +121,24 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_i assert img.shape == (3, 256, 256) assert np.max(img) <= 255 + # Save image to HDF5 file images[i] = img for j, c in enumerate(captions): + # Encode captions enc_c = [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [ word_map['']] + [word_map['']] * (max_len - len(c)) + # Find caption lengths c_len = len(c) + 2 enc_captions.append(enc_c) caplens.append(c_len) + # Sanity check assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens) + # Save encoded captions and their lengths to JSON files with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j: json.dump(enc_captions, j) @@ -114,11 +147,25 @@ def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_i def init_embedding(embeddings): + """ + Fills embedding tensor with values from the uniform distribution. + + :param embeddings: embedding tensor + """ bias = np.sqrt(3.0 / embeddings.size(1)) torch.nn.init.uniform_(embeddings, -bias, bias) def load_embeddings(emb_file, word_map): + """ + Creates an embedding tensor for the specified word map, for loading into the model. + + :param emb_file: file containing embeddings (stored in GloVe format) + :param word_map: word map + :return: embeddings in the same order as the words in the word map, dimension of embeddings + """ + + # Find embedding dimension with open(emb_file, 'r') as f: emb_dim = len(f.readline().split(' ')) - 1 @@ -136,7 +183,7 @@ def load_embeddings(emb_file, word_map): emb_word = line[0] embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:]))) - # If embeddings are to be limited to train_vocab, ignore word if not in train_vocab + # Ignore word if not in train_vocab if emb_word not in vocab: continue @@ -146,6 +193,12 @@ def load_embeddings(emb_file, word_map): def clip_gradient(optimizer, grad_clip): + """ + Clips gradients computed during backpropagation to avoid explosion of gradients. + + :param optimizer: optimizer with the gradients to be clipped + :param grad_clip: clip value + """ for group in optimizer.param_groups: for param in group['params']: if param.grad is not None: @@ -154,6 +207,19 @@ def clip_gradient(optimizer, grad_clip): def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, bleu4, is_best): + """ + Saves model checkpoint. + + :param data_name: base name of processed dataset + :param epoch: epoch number + :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score + :param encoder: encoder model + :param decoder: decoder model + :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning + :param decoder_optimizer: optimizer to update decoder's weights + :param bleu4: validation BLEU-4 score for this epoch + :param is_best: is this checkpoint the best so far? + """ state = {'epoch': epoch, 'epochs_since_improvement': epochs_since_improvement, 'bleu-4': bleu4, @@ -163,6 +229,7 @@ def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder 'decoder_optimizer': decoder_optimizer} filename = 'checkpoint_' + data_name + '.pth.tar' torch.save(state, filename) + # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint if is_best: torch.save(state, 'BEST_' + filename) @@ -191,6 +258,9 @@ def update(self, val, n=1): def adjust_learning_rate(optimizer, shrink_factor): """ Shrinks learning rate by a specified factor. + + :param optimizer: optimizer whose learning rate must be shrunk. + :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. """ print("\nDECAYING learning rate.") @@ -202,6 +272,11 @@ def adjust_learning_rate(optimizer, shrink_factor): def accuracy(scores, targets, k): """ Computes top-k accuracy, from predicted and true labels. + + :param scores: scores from the model + :param targets: true labels + :param k: k in top-k accuracy + :return: top-k accuracy """ batch_size = targets.size(0)