Skip to content

Commit

Permalink
Add PyTorch implementation of Label smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Jul 17, 2022
1 parent 6bec5dd commit bcb3cc2
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions MachineLearning/DeepLearning/LossFunction/src/label_smoothing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.loss import _WeightedLoss



class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1, weight = None):
"""if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.weight = weight
self.cls = classes
self.dim = dim

def forward(self, pred, target):
assert 0 <= self.smoothing < 1
pred = pred.log_softmax(dim=self.dim)

if self.weight is not None:
pred = pred * self.weight.unsqueeze(0)

with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))



class LabelSmoothing(nn.Module):
"""NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.0):
"""Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing

def forward(self, x, target):
logprobs = nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()


if __name__=="__main__":
# Label smoothing (1)
crit = LabelSmoothingLoss(smoothing=0.3)
predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
[0, 0.9, 0.2, 0.2, 1],
[1, 0.2, 0.7, 0.9, 1]])

v = crit(Variable(predict),
Variable(torch.LongTensor([2, 1, 0])))
print(v)

# Label smoothing (2)
crit = LabelSmoothing(smoothing=0.3)
predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
[0, 0.9, 0.2, 0.2, 1],
[1, 0.2, 0.7, 0.9, 1]])
v = crit(Variable(predict),
Variable(torch.LongTensor([2, 1, 0])))
print(v)

0 comments on commit bcb3cc2

Please sign in to comment.