Skip to content

Commit

Permalink
Ignore <pad> when computing cross-entropy loss and sequence accuracy.
Browse files Browse the repository at this point in the history
Fixes clay-lab/transductions clay-lab#64
  • Loading branch information
bdusell committed Jun 10, 2022
1 parent 5e650f4 commit 565423d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
16 changes: 10 additions & 6 deletions core/metrics/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ class SequenceAccuracy(BaseMetric):
correct, the sequence scores 1.0; otherwise, it scores 0.0.
"""

def __init__(self, pad_token_id=None):
super().__init__()
self._pad = pad_token_id

def compute(self, prediction: Tensor, target: Tensor):
prediction = prediction.argmax(1)

correct = (prediction == target).prod(axis=1)
total = correct.shape[0]
correct = correct.sum()

correct_tokens = prediction == target
if self._pad is not None:
correct_tokens.logical_or_(target == self._pad)
correct_sequences = torch.all(correct_tokens, dim=1)
correct = correct_sequences.sum()
total = target.size(0)
return correct, total


Expand Down Expand Up @@ -91,7 +96,6 @@ def compute(self, prediction: Tensor, target: Tensor):
correct = prediction[:, self.n] == target[:, self.n]
total = correct.shape[0]
correct = correct.sum()

return correct, total


Expand Down
35 changes: 21 additions & 14 deletions core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,14 @@ def train(self):

early_stoping = EarlyStopping(self._cfg.experiment.hyperparameters)

pad_idx = self._model._decoder.vocab.stoi["<pad>"]

# Metrics
seq_acc = SequenceAccuracy()
tok_acc = TokenAccuracy(self._dataset.target_field.vocab.stoi["<pad>"])
len_acc = LengthAccuracy(self._dataset.target_field.vocab.stoi["<pad>"])
seq_acc = SequenceAccuracy(pad_idx)
tok_acc = TokenAccuracy(pad_idx)
len_acc = LengthAccuracy(pad_idx)
first_acc = NthTokenAccuracy(n=1)
avg_loss = LossMetric(F.cross_entropy)
avg_loss = LossMetric(lambda p, t: F.cross_entropy(p, t, ignore_index=pad_idx))

meter = Meter([seq_acc, tok_acc, len_acc, first_acc, avg_loss])

Expand Down Expand Up @@ -226,7 +228,7 @@ def train(self):
meter(output, target)

# Compute average validation loss
val_loss = F.cross_entropy(output, target)
val_loss = F.cross_entropy(output, target, ignore_index=pad_idx)
V.set_postfix(val_loss="{:4.3f}".format(val_loss.item()))

meter.log(stage="val", step=epoch)
Expand Down Expand Up @@ -288,12 +290,14 @@ def eval(self, eval_cfg: DictConfig):
# Load checkpoint data
self._load_checkpoint(eval_cfg.checkpoint_dir)

pad_idx = self._dataset.target_field.vocab.stoi["<pad>"]

# Create meter
seq_acc = SequenceAccuracy()
tok_acc = TokenAccuracy(self._dataset.target_field.vocab.stoi["<pad>"])
len_acc = LengthAccuracy(self._dataset.target_field.vocab.stoi["<pad>"])
seq_acc = SequenceAccuracy(pad_idx)
tok_acc = TokenAccuracy(pad_idx)
len_acc = LengthAccuracy(pad_idx)
first_acc = NthTokenAccuracy(n=1)
avg_loss = LossMetric(F.cross_entropy)
avg_loss = LossMetric(lambda p, t: F.cross_entropy(p, t, ignore_index=pad_idx))

meter = Meter([seq_acc, tok_acc, len_acc, first_acc, avg_loss])

Expand Down Expand Up @@ -339,9 +343,11 @@ def arith_eval(self, eval_cfg: DictConfig):
# Load checkpoint data
self._load_checkpoint(eval_cfg.checkpoint_dir)

pad_idx = self._dataset.target_field.vocab.stoi["<pad>"]

# Create meter
seq_acc = SequenceAccuracy()
len_acc = LengthAccuracy(self._dataset.target_field.vocab.stoi["<pad>"])
seq_acc = SequenceAccuracy(pad_idx)
len_acc = LengthAccuracy(pad_idx)
object_acc = NthTokenAccuracy(n=5)

meter = Meter([seq_acc, len_acc, object_acc])
Expand Down Expand Up @@ -609,11 +615,12 @@ def fit_tpdn(self, tpdn_cfg: DictConfig):
tpdn.eval()
disp_loss = nn.CrossEntropyLoss()

pad_idx = self._dataset.target_field.vocab.stoi["<pad>"]
meter = Meter(
[
SequenceAccuracy(),
TokenAccuracy(self._dataset.target_field.vocab.stoi["<pad>"]),
LengthAccuracy(self._dataset.target_field.vocab.stoi["<pad>"]),
SequenceAccuracy(pad_idx),
TokenAccuracy(pad_idx),
LengthAccuracy(pad_idx),
NthTokenAccuracy(n=1),
NthTokenAccuracy(n=3),
NthTokenAccuracy(n=5),
Expand Down

0 comments on commit 565423d

Please sign in to comment.