-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathFocalLabelSmoothing.py
49 lines (42 loc) · 1.57 KB
/
FocalLabelSmoothing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, reduction='none', alpha=1, gamma=2):
super().__init__()
self.reduction = reduction
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-bce_loss)
loss = self.alpha * (1. - pt)**self.gamma * bce_loss
if self.reduction == 'none':
loss = loss
elif self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss
class SmoothFocalLoss(nn.Module):
def __init__(self, reduction='none', alpha=1, gamma=2, smoothing=0.0):
super().__init__()
self.reduction = reduction
self.focal_loss = FocalLoss(reduction='none', alpha=alpha, gamma=gamma)
self.smoothing = smoothing
@staticmethod
def _smooth(targets:torch.Tensor, smoothing=0.0):
assert 0 <= smoothing < 1
with torch.no_grad():
targets = targets * (1.0 - smoothing) + 0.5 * smoothing
return targets
def forward(self, inputs, targets):
targets = SmoothFocalLoss._smooth(targets, self.smoothing)
loss = self.focal_loss(inputs, targets)
if self.reduction == 'none':
loss = loss
elif self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss