Skip to content

Commit b4e0fcc

Browse files
pre-commit-ci[bot]ytl0623
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci Signed-off-by: ytl0623 <[email protected]>
1 parent ad83444 commit b4e0fcc

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import warnings
1515

1616
import torch
17-
import torch.nn.functional as F
1817
from torch.nn.modules.loss import _Loss
1918

2019
from monai.networks import one_hot
@@ -68,8 +67,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6867
y_pred = y_pred[:, 1:]
6968
y_true = y_true[:, 1:]
7069

71-
# Clip predictions
72-
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
7370
axis = list(range(2, len(y_pred.shape)))
7471

7572
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
@@ -169,7 +166,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169166
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
170167
# (B, C-1, H, W)
171168
fore_ce = self.delta * cross_entropy[:, 1:]
172-
169+
173170
loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)
174171

175172
# Apply reduction
@@ -276,21 +273,22 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276273
if y_true.shape[1] != 1:
277274
y_true = y_true.unsqueeze(1)
278275
y_true = one_hot(y_true, num_classes=n_pred_ch)
279-
276+
280277
# Ensure y_true has the same shape as y_pred_act
281278
if y_true.shape != y_pred_act.shape:
282-
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
279+
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283280
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
284-
y_true = y_true.unsqueeze(1) # Add channel dim
285-
286-
if y_true.shape != y_pred_act.shape:
287-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \
288-
f"after activations/one-hot")
281+
y_true = y_true.unsqueeze(1) # Add channel dim
289282

283+
if y_true.shape != y_pred_act.shape:
284+
raise ValueError(
285+
f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) "
286+
f"after activations/one-hot"
287+
)
290288

291289
f_loss = self.asy_focal_loss(y_pred_act, y_true)
292290
t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true)
293291

294292
loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss
295293

296-
return loss
294+
return loss

0 commit comments

Comments
 (0)