Skip to content

Commit

Permalink
update to solve issue #57
Browse files Browse the repository at this point in the history
  • Loading branch information
kmario23 committed Apr 1, 2019
1 parent b3b8263 commit ca674e1
Showing 1 changed file with 78 additions and 75 deletions.
153 changes: 78 additions & 75 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,81 +245,84 @@ def validate(val_loader, encoder, decoder, criterion):
references = list() # references (true captions) for calculating BLEU-4 score
hypotheses = list() # hypotheses (predictions)

# Batches
for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

# Move to device, if available
imgs = imgs.to(device)
caps = caps.to(device)
caplens = caplens.to(device)

# Forward prop.
if encoder is not None:
imgs = encoder(imgs)
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
targets = caps_sorted[:, 1:]

# Remove timesteps that we didn't decode at, or are pads
# pack_padded_sequence is an easy trick to do this
scores_copy = scores.clone()
scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

# Calculate loss
loss = criterion(scores, targets)

# Add doubly stochastic attention regularization
loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

# Keep track of metrics
losses.update(loss.item(), sum(decode_lengths))
top5 = accuracy(scores, targets, 5)
top5accs.update(top5, sum(decode_lengths))
batch_time.update(time.time() - start)

start = time.time()

if i % print_freq == 0:
print('Validation: [{0}/{1}]\t'
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
loss=losses, top5=top5accs))

# Store references (true captions), and hypothesis (prediction) for each image
# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

# References
allcaps = allcaps[sort_ind] # because images were sorted in the decoder
for j in range(allcaps.shape[0]):
img_caps = allcaps[j].tolist()
img_captions = list(
map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
img_caps)) # remove <start> and pads
references.append(img_captions)

# Hypotheses
_, preds = torch.max(scores_copy, dim=2)
preds = preds.tolist()
temp_preds = list()
for j, p in enumerate(preds):
temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads
preds = temp_preds
hypotheses.extend(preds)

assert len(references) == len(hypotheses)

# Calculate BLEU-4 scores
bleu4 = corpus_bleu(references, hypotheses)

print(
'\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
loss=losses,
top5=top5accs,
bleu=bleu4))
# explicitly disable gradient calculation to avoid CUDA memory error
# solves the issue #57
with torch.no_grad():
# Batches
for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

# Move to device, if available
imgs = imgs.to(device)
caps = caps.to(device)
caplens = caplens.to(device)

# Forward prop.
if encoder is not None:
imgs = encoder(imgs)
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
targets = caps_sorted[:, 1:]

# Remove timesteps that we didn't decode at, or are pads
# pack_padded_sequence is an easy trick to do this
scores_copy = scores.clone()
scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

# Calculate loss
loss = criterion(scores, targets)

# Add doubly stochastic attention regularization
loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

# Keep track of metrics
losses.update(loss.item(), sum(decode_lengths))
top5 = accuracy(scores, targets, 5)
top5accs.update(top5, sum(decode_lengths))
batch_time.update(time.time() - start)

start = time.time()

if i % print_freq == 0:
print('Validation: [{0}/{1}]\t'
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
loss=losses, top5=top5accs))

# Store references (true captions), and hypothesis (prediction) for each image
# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

# References
allcaps = allcaps[sort_ind] # because images were sorted in the decoder
for j in range(allcaps.shape[0]):
img_caps = allcaps[j].tolist()
img_captions = list(
map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
img_caps)) # remove <start> and pads
references.append(img_captions)

# Hypotheses
_, preds = torch.max(scores_copy, dim=2)
preds = preds.tolist()
temp_preds = list()
for j, p in enumerate(preds):
temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads
preds = temp_preds
hypotheses.extend(preds)

assert len(references) == len(hypotheses)

# Calculate BLEU-4 scores
bleu4 = corpus_bleu(references, hypotheses)

print(
'\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
loss=losses,
top5=top5accs,
bleu=bleu4))

return bleu4

Expand Down

0 comments on commit ca674e1

Please sign in to comment.