From ca674e1f6ab81779d97cccc3854556d0e9f138c5 Mon Sep 17 00:00:00 2001 From: Mario Date: Mon, 1 Apr 2019 09:43:03 +0200 Subject: [PATCH] update to solve issue #57 --- train.py | 153 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 78 insertions(+), 75 deletions(-) diff --git a/train.py b/train.py index 4a6ba8c7f..94e4960cd 100644 --- a/train.py +++ b/train.py @@ -245,81 +245,84 @@ def validate(val_loader, encoder, decoder, criterion): references = list() # references (true captions) for calculating BLEU-4 score hypotheses = list() # hypotheses (predictions) - # Batches - for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): - - # Move to device, if available - imgs = imgs.to(device) - caps = caps.to(device) - caplens = caplens.to(device) - - # Forward prop. - if encoder is not None: - imgs = encoder(imgs) - scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) - - # Since we decoded starting with , the targets are all words after , up to - targets = caps_sorted[:, 1:] - - # Remove timesteps that we didn't decode at, or are pads - # pack_padded_sequence is an easy trick to do this - scores_copy = scores.clone() - scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) - targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) - - # Calculate loss - loss = criterion(scores, targets) - - # Add doubly stochastic attention regularization - loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() - - # Keep track of metrics - losses.update(loss.item(), sum(decode_lengths)) - top5 = accuracy(scores, targets, 5) - top5accs.update(top5, sum(decode_lengths)) - batch_time.update(time.time() - start) - - start = time.time() - - if i % print_freq == 0: - print('Validation: [{0}/{1}]\t' - 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time, - loss=losses, top5=top5accs)) - - # Store references (true captions), and hypothesis (prediction) for each image - # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - - # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] - - # References - allcaps = allcaps[sort_ind] # because images were sorted in the decoder - for j in range(allcaps.shape[0]): - img_caps = allcaps[j].tolist() - img_captions = list( - map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], - img_caps)) # remove and pads - references.append(img_captions) - - # Hypotheses - _, preds = torch.max(scores_copy, dim=2) - preds = preds.tolist() - temp_preds = list() - for j, p in enumerate(preds): - temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads - preds = temp_preds - hypotheses.extend(preds) - - assert len(references) == len(hypotheses) - - # Calculate BLEU-4 scores - bleu4 = corpus_bleu(references, hypotheses) - - print( - '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format( - loss=losses, - top5=top5accs, - bleu=bleu4)) + # explicitly disable gradient calculation to avoid CUDA memory error + # solves the issue #57 + with torch.no_grad(): + # Batches + for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): + + # Move to device, if available + imgs = imgs.to(device) + caps = caps.to(device) + caplens = caplens.to(device) + + # Forward prop. + if encoder is not None: + imgs = encoder(imgs) + scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) + + # Since we decoded starting with , the targets are all words after , up to + targets = caps_sorted[:, 1:] + + # Remove timesteps that we didn't decode at, or are pads + # pack_padded_sequence is an easy trick to do this + scores_copy = scores.clone() + scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) + targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) + + # Calculate loss + loss = criterion(scores, targets) + + # Add doubly stochastic attention regularization + loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() + + # Keep track of metrics + losses.update(loss.item(), sum(decode_lengths)) + top5 = accuracy(scores, targets, 5) + top5accs.update(top5, sum(decode_lengths)) + batch_time.update(time.time() - start) + + start = time.time() + + if i % print_freq == 0: + print('Validation: [{0}/{1}]\t' + 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time, + loss=losses, top5=top5accs)) + + # Store references (true captions), and hypothesis (prediction) for each image + # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - + # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] + + # References + allcaps = allcaps[sort_ind] # because images were sorted in the decoder + for j in range(allcaps.shape[0]): + img_caps = allcaps[j].tolist() + img_captions = list( + map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], + img_caps)) # remove and pads + references.append(img_captions) + + # Hypotheses + _, preds = torch.max(scores_copy, dim=2) + preds = preds.tolist() + temp_preds = list() + for j, p in enumerate(preds): + temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads + preds = temp_preds + hypotheses.extend(preds) + + assert len(references) == len(hypotheses) + + # Calculate BLEU-4 scores + bleu4 = corpus_bleu(references, hypotheses) + + print( + '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format( + loss=losses, + top5=top5accs, + bleu=bleu4)) return bleu4