Skip to content

Commit

Permalink
Initial implementation of PNAPatchOut
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Aug 19, 2024
1 parent 91021eb commit b82325d
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 1 deletion.
180 changes: 180 additions & 0 deletions src/torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import importlib.util
from functools import partial
from typing import Callable

import torch
import torch.nn as nn

from torchattack._rgetattr import rgetattr
from torchattack.base import Attack


class PNAPatchOut(Attack):
"""PNA-PatchOut attack for ViTs (Pay no attention & PatchOut)."""

def __init__(
self,
model: nn.Module,
model_name: str,
normalize: Callable[[torch.Tensor], torch.Tensor] | None,
device: torch.device | None = None,
eps: float = 8 / 255,
steps: int = 10,
alpha: float | None = None,
decay: float = 1.0,
pna_skip: bool = True,
pna_patchout: bool = True,
clip_min: float = 0.0,
clip_max: float = 1.0,
targeted: bool = False,
):
# Check if timm is installed
importlib.util.find_spec('timm')

super().__init__(normalize, device)

self.model = model
self.model_name = model_name
self.eps = eps
self.steps = steps
self.alpha = alpha
self.decay = decay
self.pna_skip = pna_skip
self.pna_patchout = pna_patchout
self.clip_min = clip_min
self.clip_max = clip_max
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()

if self.pna_skip:
self._register_vit_model_hook()

# Set default image size and number of patches for PatchOut
self.image_size = 224
self.crop_len = 16
self.max_num_patches = int((self.image_size / self.crop_len) ** 2)
if self.pna_patchout:
self.sample_num_patches = 130
else:
self.sample_num_patches = self.max_num_patches

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
g = torch.zeros_like(x)
delta = torch.zeros_like(x, requires_grad=True)

# If alpha is not given, set to eps / steps
if self.alpha is None:
self.alpha = self.eps / self.steps

for i in range(self.steps):
patched_out_delta = self._apply_patch_out(delta, seed=i)
outs = self.model(self.normalize(x + patched_out_delta))
loss = self.lossfn(outs, y)

if self.targeted:
loss = -loss

# Compute gradient
loss.backward()

if delta.grad is None:
continue

# Apply momentum term
g = self.decay * g + delta.grad / torch.mean(
torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True
)

# Update delta
delta.data = delta.data + self.alpha * g.sign()
delta.data = torch.clamp(delta.data, -self.eps, self.eps)
delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x

# Zero out gradient
delta.grad.detach_()
delta.grad.zero_()

return x + delta

def _register_vit_model_hook(self):
def attn_drop_mask_grad(
module: torch.nn.Module,
grad_in: tuple[torch.Tensor, ...],
grad_out: tuple[torch.Tensor, ...],
gamma: float,
) -> tuple[torch.Tensor, ...]:
mask = torch.ones_like(grad_in[0]) * gamma
return (mask * grad_in[0][:],)

drop_hook_func = partial(attn_drop_mask_grad, gamma=0)

# fmt: off
_supported_vit_cfg = {
'vit_base_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'deit_base_distilled_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'pit_b_224': [
# First transformer block
f'transformers.{0}.blocks.{i}.attn.attn_drop' for i in range(3)
] + [
# Second transformer block
f'transformers.{1}.blocks.{i-3}.attn.attn_drop' for i in range(3, 9)
] + [
# Third transformer block
f'transformers.{2}.blocks.{i-9}.attn.attn_drop' for i in range(9, 13)
],
'cait_s24_224': [
# Regular blocks
f'blocks.{i}.attn.attn_drop' for i in range(24)
] + [
# Token-only block
f'blocks_token_only.{i}.attn.attn_drop' for i in range(0, 2)
],
'visformer_small': [
# Stage 2 blocks
f'stage2.{i}.attn.attn_drop' for i in range(4)
] + [
# Stage 3 blocks
f'stage3.{i}.attn.attn_drop' for i in range(4)
],
}
# fmt: on
assert self.model_name in _supported_vit_cfg, f'{self.model_name} not supported'

# Register backward hook for layers specified in _supported_vit_cfg
for layer in _supported_vit_cfg[self.model_name]:
module = rgetattr(self.model, layer)
module.register_backward_hook(drop_hook_func)

def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor:
delta_mask = torch.zeros_like(delta)
grid_num_axis = int(self.image_size / self.crop_len)

# Randomly sample patches (unrepeatable)
torch.manual_seed(seed)
ids = torch.randperm(self.max_num_patches)[: self.sample_num_patches]

# Repeatable sampling
rows, cols = ids // grid_num_axis, ids % grid_num_axis
for r, c in zip(rows, cols, strict=True):
delta_mask[
...,
r * self.crop_len : (r + 1) * self.crop_len,
c * self.crop_len : (c + 1) * self.crop_len,
] = 1

# Apply mask to delta
return delta_mask * delta


if __name__ == '__main__':
from torchattack.runner import run_attack

run_attack(
PNAPatchOut,
attack_cfg={'model_name': 'vit_base_patch16_224'},
model_name='vit_base_patch16_224',
)
9 changes: 8 additions & 1 deletion src/torchattack/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@ def run_attack(

# Setup model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval()
try:
model = (
tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval()
)
except ValueError:
import timm

model = timm.create_model(model_name, pretrained=True).to(device).eval()

# Setup transforms and normalization
transform = tv.transforms.Compose(
Expand Down

0 comments on commit b82325d

Please sign in to comment.