-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathacb_mse.py
More file actions
50 lines (38 loc) · 1.93 KB
/
acb_mse.py
File metadata and controls
50 lines (38 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
class ACBLoss(torch.nn.Module):
def __init__(self, zero_weighting=1, nonzero_weighting=1):
"""
Initializes the ACB-MSE Loss Function class with weighting coefficients.
Args:
- zero_weighting: a scalar weighting coefficient for the MSE loss of zero pixels
- nonzero_weighting: a scalar weighting coefficient for the MSE loss of non-zero pixels
"""
super().__init__()
self.zero_weighting = zero_weighting
self.nonzero_weighting = nonzero_weighting
self.mse_loss = torch.nn.MSELoss(reduction='mean')
def forward(self, reconstructed_image, target_image):
"""
Calculates the weighted mean squared error (MSE) loss between target_image and reconstructed_image.
The loss for zero pixels in the target_image is weighted by zero_weighting, and the loss for non-zero
pixels is weighted by nonzero_weighting.
Args:
- target_image: a tensor of shape (B, C, H, W) containing the target image
- reconstructed_image: a tensor of shape (B, C, H, W) containing the reconstructed image
Returns:
- weighted_mse_loss: a scalar tensor containing the weighted MSE loss
"""
zero_mask = (target_image == 0)
nonzero_mask = ~zero_mask
values_zero = target_image[zero_mask]
values_nonzero = target_image[nonzero_mask]
corresponding_values_zero = reconstructed_image[zero_mask]
corresponding_values_nonzero = reconstructed_image[nonzero_mask]
zero_loss = self.mse_loss(corresponding_values_zero, values_zero)
nonzero_loss = self.mse_loss(corresponding_values_nonzero, values_nonzero)
if torch.isnan(zero_loss):
zero_loss = 0
if torch.isnan(nonzero_loss):
nonzero_loss = 0
weighted_mse_loss = (self.zero_weighting * zero_loss) + (self.nonzero_weighting * nonzero_loss)
return weighted_mse_loss