Skip to content

Commit

Permalink
Fix SSP repr, add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed May 9, 2024
1 parent 17c3ab2 commit 962d23d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ attack = FGSM(model, normalize, device)
attack = MIFGSM(model, normalize, device, eps=0.03, steps=10, decay=1.0)
```

Check out [`torchattack.utils.run_attack`](src/torchattack/utils.py) for a simple example.
Check out [`torchattack.runner.run_attack`](src/torchattack/runner.py) for a simple example.

## Attacks

Expand All @@ -66,6 +66,7 @@ Others:
| :------: | :---------------------: | :------------------------------------------------------------------------------------------------------ | :--------------------- |
| DeepFool | $\ell_2$ | [DeepFool: A Simple and Accurate Method to Fool Deep Neural Networks](https://arxiv.org/abs/1511.04599) | `torchattack.DeepFool` |
| GeoDA | $\ell_\infty$, $\ell_2$ | [GeoDA: A Geometric Framework for Black-box Adversarial Attacks](https://arxiv.org/abs/2003.06468) | `torchattack.GeoDA` |
| SSP | $\ell_\infty$ | [A Self-supervised Approach for Adversarial Robustness](https://arxiv.org/abs/2006.04924) | `torchattack.SSP` |

## Development

Expand All @@ -75,6 +76,7 @@ python -m venv .venv
source .venv/bin/activate

# Install deps with dev extras
python -m pip install -r requirements.txt
python -m pip install -e '.[dev]'
```

Expand Down
8 changes: 6 additions & 2 deletions src/torchattack/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,21 @@ def run_attack(
if attack_cfg is None:
attack_cfg = {}

# Set up model, transform, and normalize
# Setup model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval()

# Setup transforms and normalization
transform = tv.transforms.Compose(
[
tv.transforms.Resize([256]),
tv.transforms.CenterCrop(224),
tv.transforms.ToTensor(),
]
)
normalize = tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
normalize = tv.transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)

# Set up dataloader
dataloader = NIPSLoader(
Expand Down
31 changes: 30 additions & 1 deletion src/torchattack/ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class PerceptualCriteria(nn.Module):
def __init__(self, ssp_layer: int) -> None:
super().__init__()
self.ssp_layer = ssp_layer

# Use pretrained VGG16 for perceptual loss
vgg16 = tv.models.vgg16(weights='DEFAULT')
Expand All @@ -27,12 +28,30 @@ def __repr__(self) -> str:


class SSP(Attack):
"""The Self-supervised (SSP) attack.
From the paper 'A Self-supervised Approach for Adversarial Robustness'
https://arxiv.org/abs/2006.04924
Args:
model: The model to attack.
normalize: A transform to normalize images.
device: Device to use for tensors. Defaults to cuda if available.
eps: The maximum perturbation. Defaults to 8/255.
steps: Number of steps. Defaults to 100.
alpha: Step size, `eps / steps` if None. Defaults to None.
ssp_layer: The VGG layer to use for the perceptual loss. Defaults to 16.
clip_min: Minimum value for clipping. Defaults to 0.0.
clip_max: Maximum value for clipping. Defaults to 1.0.
targeted: Targeted attack if True. Defaults to False.
"""

def __init__(
self,
model: nn.Module,
normalize: Callable[[torch.Tensor], torch.Tensor] | None,
device: torch.device | None = None,
eps: float = 16 / 255,
eps: float = 8 / 255,
steps: int = 100,
alpha: float | None = None,
ssp_layer: int = 16,
Expand All @@ -54,6 +73,16 @@ def __init__(
self.perceptual_criteria = PerceptualCriteria(ssp_layer).to(device)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Perform SSP on a batch of images.
Args:
x: A batch of images. Shape: (N, C, H, W).
y: A batch of labels, not required. Shape: (N).
Returns:
The perturbed images if successful. Shape: (N, C, H, W).
"""

delta = torch.randn_like(x, requires_grad=True)

# If alpha is not given, set to eps / steps
Expand Down

0 comments on commit 962d23d

Please sign in to comment.