Skip to content

Commit

Permalink
Fix SSP loss backprop issue
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Dec 10, 2024
1 parent 3ddb226 commit 95d174b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchattack/ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def __init__(self, ssp_layer: int) -> None:
self.loss_fn = nn.MSELoss()

def forward(self, x: torch.Tensor, xadv: torch.Tensor) -> torch.Tensor:
loss = self.loss_fn(self.perceptual_model(x), self.perceptual_model(xadv))
return torch.tensor(loss, dtype=torch.float32)
x_outs = self.perceptual_model(x)
xadv_outs = self.perceptual_model(xadv)
loss: torch.Tensor = self.loss_fn(x_outs, xadv_outs)
return loss

def __repr__(self) -> str:
return f'{self.__class__.__name__}(ssp_layer={self.ssp_layer})'
Expand Down

0 comments on commit 95d174b

Please sign in to comment.