Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The SSP (self-supervised) attack #11

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Gradient-based attacks:
| :--------: | :-----------: | :------------------------------------------------------------------------------------------------------------------------- | :--------------------- |
| FGSM | $\ell_\infty$ | [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572) | `torchattack.FGSM` |
| PGD | $\ell_\infty$ | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) | `torchattack.PGD` |
| PGD | $\ell_2$ | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) | `torchattack.PGDL2` |
| PGD (L2) | $\ell_2$ | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) | `torchattack.PGDL2` |
| MI-FGSM | $\ell_\infty$ | [Boosting Adversarial Attacks with Momentum](https://arxiv.org/abs/1710.06081) | `torchattack.MIFGSM` |
| DI-FGSM | $\ell_\infty$ | [Improving Transferability of Adversarial Examples with Input Diversity](https://arxiv.org/abs/1803.06978) | `torchattack.DIFGSM` |
| TI-FGSM | $\ell_\infty$ | [Evading Defenses to Transferable Adversarial Examples by Translation-Invariant Attacks](https://arxiv.org/abs/1904.02884) | `torchattack.TIFGSM` |
Expand Down
2 changes: 1 addition & 1 deletion src/torchattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchattack.vmifgsm import VMIFGSM
from torchattack.vnifgsm import VNIFGSM

__version__ = '0.4.0'
__version__ = '0.5.0'

__all__ = [
'Admix',
Expand Down
14 changes: 9 additions & 5 deletions src/torchattack/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_attack(
attack_cfg: dict | None = None,
model_name: str = 'resnet50',
max_samples: int = 100,
batch_size: int = 8,
batch_size: int = 16,
) -> None:
"""Helper function to run attacks in `__main__`.

Expand Down Expand Up @@ -74,7 +74,11 @@ def run_attack(
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval()
transform = tv.transforms.Compose(
[tv.transforms.Resize([224]), tv.transforms.ToTensor()]
[
tv.transforms.Resize([256]),
tv.transforms.CenterCrop(224),
tv.transforms.ToTensor(),
]
)
normalize = tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

Expand Down Expand Up @@ -109,11 +113,11 @@ def run_attack(
adv_outs = model(normalize(advs))
frm.update(y, cln_outs, adv_outs)

# Save one batch of adversarial images
if i == 1:
# Save first batch of adversarial images
if i == 0:
saved_imgs = advs.detach().cpu().mul(255).to(torch.uint8)
img_grid = tv.utils.make_grid(saved_imgs, nrow=4)
tv.io.write_png(img_grid, 'adv_batch.png')
tv.io.write_png(img_grid, 'adv_batch_0.png')

# Print results
print(f'Clean accuracy: {frm.compute_clean_accuracy():.2%}')
Expand Down
88 changes: 88 additions & 0 deletions src/torchattack/ssp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Callable

import torch
import torch.nn as nn
import torchvision as tv

from torchattack.base import Attack


class PerceptualCriteria(nn.Module):
def __init__(self, ssp_layer: int) -> None:
super().__init__()

# Use pretrained VGG16 for perceptual loss
vgg16 = tv.models.vgg16(weights='DEFAULT')

# Initialize perceptual model and loss function
self.perceptual_model = nn.Sequential(*list(vgg16.features))[:ssp_layer]
self.perceptual_model.eval()
self.loss_fn = nn.MSELoss()

def forward(self, x: torch.Tensor, xadv: torch.Tensor) -> torch.Tensor:
return self.loss_fn(self.perceptual_model(x), self.perceptual_model(xadv))

def __repr__(self) -> str:
return f'{self.__class__.__name__}(ssp_layer={self.ssp_layer})'


class SSP(Attack):
def __init__(
self,
model: nn.Module,
normalize: Callable[[torch.Tensor], torch.Tensor] | None,
device: torch.device | None = None,
eps: float = 16 / 255,
steps: int = 100,
alpha: float | None = None,
ssp_layer: int = 16,
clip_min: float = 0.0,
clip_max: float = 1.0,
targeted: bool = False,
) -> None:
super().__init__(normalize, device)

self.model = model
self.eps = eps
self.steps = steps
self.alpha = alpha
self.ssp_layer = ssp_layer
self.clip_min = clip_min
self.clip_max = clip_max
self.targeted = targeted

self.perceptual_criteria = PerceptualCriteria(ssp_layer).to(device)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta = torch.randn_like(x, requires_grad=True)

# If alpha is not given, set to eps / steps
if self.alpha is None:
self.alpha = self.eps / self.steps

for _ in range(self.steps):
xadv = x + delta
loss = self.perceptual_criteria(self.normalize(x), self.normalize(xadv))
loss.backward()

if delta.grad is None:
continue

# Update delta
g = delta.grad.data.sign()

delta.data = delta.data + self.alpha * g
delta.data = torch.clamp(delta.data, -self.eps, self.eps)
delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x

# Zero out gradient
delta.grad.detach_()
delta.grad.zero_()

return x + delta


if __name__ == '__main__':
from torchattack.runner import run_attack

run_attack(SSP, attack_cfg={'eps': 16 / 255, 'ssp_layer': 16})
Loading