1414import warnings
1515
1616import torch
17- import torch .nn .functional as F
1817from torch .nn .modules .loss import _Loss
1918
2019from monai .networks import one_hot
@@ -68,8 +67,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6867 y_pred = y_pred [:, 1 :]
6968 y_true = y_true [:, 1 :]
7069
71- # Clip predictions
72- y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
7370 axis = list (range (2 , len (y_pred .shape )))
7471
7572 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
@@ -169,7 +166,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169166 back_ce = (1.0 - self .delta ) * torch .pow (1.0 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
170167 # (B, C-1, H, W)
171168 fore_ce = self .delta * cross_entropy [:, 1 :]
172-
169+
173170 loss = torch .cat ([back_ce .unsqueeze (1 ), fore_ce ], dim = 1 ) # (B, C, H, W)
174171
175172 # Apply reduction
@@ -276,21 +273,22 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276273 if y_true .shape [1 ] != 1 :
277274 y_true = y_true .unsqueeze (1 )
278275 y_true = one_hot (y_true , num_classes = n_pred_ch )
279-
276+
280277 # Ensure y_true has the same shape as y_pred_act
281278 if y_true .shape != y_pred_act .shape :
282- # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
279+ # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283280 if y_true .shape [1 ] != y_pred_act .shape [1 ] and y_true .ndim == y_pred_act .ndim - 1 :
284- y_true = y_true .unsqueeze (1 ) # Add channel dim
285-
286- if y_true .shape != y_pred_act .shape :
287- raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred_act .shape } ) " \
288- f"after activations/one-hot" )
281+ y_true = y_true .unsqueeze (1 ) # Add channel dim
289282
283+ if y_true .shape != y_pred_act .shape :
284+ raise ValueError (
285+ f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred_act .shape } ) "
286+ f"after activations/one-hot"
287+ )
290288
291289 f_loss = self .asy_focal_loss (y_pred_act , y_true )
292290 t_loss = self .asy_focal_tversky_loss (y_pred_act , y_true )
293291
294292 loss : torch .Tensor = self .lambda_focal * f_loss + (1 - self .lambda_focal ) * t_loss
295293
296- return loss
294+ return loss
0 commit comments