Skip to content

Commit

Permalink
Merge pull request #13 from CHAOYUHONG:main
Browse files Browse the repository at this point in the history
SSA DCT transform and num_spectrum not applied
  • Loading branch information
spencerwooo authored Nov 14, 2024
2 parents a49ae05 + 9049f56 commit 6aed62c
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/torchattack/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
The perturbed images if successful. Shape: (N, C, H, W).
"""

g = torch.zeros_like(x)
delta = torch.zeros_like(x, requires_grad=True)

# If alpha is not given, set to eps / steps
Expand All @@ -77,19 +76,31 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

# Perform SSA
for _ in range(self.steps):
# Compute loss
outs = self.model(self.normalize(x + delta))
loss = self.lossfn(outs, y)
g = torch.zeros_like(x)

if self.targeted:
loss = -loss
for _ in range(self.num_spectrum):
# Frequency transformation (dct + idct)
x_adv = self.transform(x + delta)

# Compute gradient
loss.backward()
# Compute loss
outs = self.model(self.normalize(x_adv))
loss = self.lossfn(outs, y)

if self.targeted:
loss = -loss

# Compute gradient
loss.backward()

# Accumulate gradient
g += delta.grad

if delta.grad is None:
continue

# Average gradient over num_spectrum
g /= self.num_spectrum

# Apply momentum term
g = self.decay * g + delta.grad / torch.mean(
torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True
Expand All @@ -106,17 +117,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

return x + delta

def transform(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def transform(self, x: torch.Tensor) -> torch.Tensor:
b = x.shape[0]
gauss = torch.randn(b, 3, 224, 224) * self.eps # must be a multiple of 8

# H and W must be a multiple of 8
gauss = torch.randn(b, 3, 224, 224, device=x.device) * self.eps

x_dct = self._dct_2d(x + gauss)
mask = torch.rand_like(x) * 2 * self.rho + 1 - self.rho
x_idct = self._idct_2d(x_dct * mask)

return x_idct

def _dct(self, x, norm=None):
def _dct(self, x: torch.Tensor, norm: str | None = None) -> torch.Tensor:
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
(This code is copied from https://github.com/yuyang-long/SSA/blob/master/dct.py)
Expand Down

0 comments on commit 6aed62c

Please sign in to comment.