Skip to content

Commit

Permalink
AttackModel model wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Aug 20, 2024
1 parent d9621c3 commit 33c4a8b
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 72 deletions.
2 changes: 2 additions & 0 deletions src/torchattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchattack.nifgsm import NIFGSM
from torchattack.pgd import PGD
from torchattack.pgdl2 import PGDL2
from torchattack.pna_patchout import PNAPatchOut
from torchattack.sinifgsm import SINIFGSM
from torchattack.tifgsm import TIFGSM
from torchattack.vmifgsm import VMIFGSM
Expand All @@ -26,6 +27,7 @@
'NIFGSM',
'PGD',
'PGDL2',
'PNAPatchOut',
'SINIFGSM',
'TIFGSM',
'VMIFGSM',
Expand Down
189 changes: 153 additions & 36 deletions src/torchattack/eval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from contextlib import suppress
from typing import Any
from typing import Any, Callable, Self

import torch
import torchvision as tv

from torchattack.dataset import NIPSLoader
import torch.nn as nn


class FoolingRateMetric:
Expand Down Expand Up @@ -40,12 +37,145 @@ def compute_clean_accuracy(self) -> torch.Tensor:
return self.clean_count / self.total_count


class AttackModel:
"""A wrapper class for a pretrained model used for adversarial attacks.
Intented to be instantiated with `AttackModel.from_pretrained(pretrained_model_name)` from either
`torchvision.models` or `timm`. The model is loaded and attributes including `transform` and `normalize` are
attached based on the model's configuration.
Attributes:
model_name (str): The name of the model.
device (torch.device): The device on which the model is loaded.
model (nn.Module): The pretrained model itself.
transform (Callable): The transformation function applied to input images.
normalize (Callable): The normalization function applied to input images.
Example:
>>> model = AttackModel.from_pretrained('resnet50', device='cuda')
>>> model
AttackModel(model_name=resnet50, device=cuda, transform=Compose(...), normalize=Normalize(...))
>>> model.transform
Compose(
Resize(size=(256, 256), interpolation=PIL.Image.BILINEAR)
CenterCrop(size=(224, 224))
ToTensor()
)
>>> model.normalize
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
>>> model.model
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
...
)
"""

def __init__(
self,
model_name: str,
device: torch.device,
model: nn.Module,
transform: Callable,
normalize: Callable,
) -> None:
self.model_name = model_name
self.device = device
self.model = model
self.transform = transform
self.normalize = normalize

@classmethod
def from_pretrained(
cls,
model_name: str,
device: torch.device,
from_timm: bool = False,
) -> Self:
"""
Loads a pretrained model and initializes an AttackModel instance.
Args:
model_name: The name of the model to load.
device: The device on which to load the model.
from_timm: Whether to load the model from the timm library. Defaults to False.
Returns:
AttackModel: An instance of the AttackModel class initialized with the pretrained model.
"""

import torchvision.transforms as t

if not from_timm:
try:
import torchvision.models as tv_models

model = tv_models.get_model(name=model_name, weights='DEFAULT')

# resolve transforms from vision model weights
weight_id = str(tv_models.get_model_weights(name=model_name)['DEFAULT'])
cfg = tv_models.get_weight(weight_id).transforms()

# construct transform and normalize
transform = t.Compose(
[
t.Resize(
cfg.resize_size,
interpolation=cfg.interpolation,
antialias=True,
),
t.CenterCrop(cfg.crop_size),
t.ToTensor(),
]
)
normalize = t.Normalize(mean=cfg.mean, std=cfg.std)

except ValueError:
print('Model not found in torchvision.models, falling back to timm.')
from_timm = True

else:
import timm

model = timm.create_model(model_name, pretrained=True)
cfg = timm.data.resolve_data_config(model.pretrained_cfg)

# create normalization
normalize = (
t.Normalize(mean=cfg['mean'], std=cfg['std'])
if 'mean' in cfg and 'std' in cfg
# if no mean and std is given, use default ImageNet values
else t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
# create a transform based on the cfg
transform = timm.data.create_transform(**cfg, is_training=False)
# remove the Normalize from the transforms.Compose if there is one
transform.transforms = [
tr for tr in transform.transforms if not isinstance(tr, t.Normalize)
]

model = model.to(device).eval()
return cls(model_name, device, model, transform, normalize)

def forward(self, x):
return self.model(x)

def __call__(self, x):
return self.forward(x)

def __repr__(self):
return (
f'{self.__class__.__name__}(model_name={self.model_name}, device={self.device}, '
f'transform={self.transform}, normalize={self.normalize})'
)


def run_attack(
attack: Any,
attack_cfg: dict | None = None,
model_name: str = 'resnet50',
max_samples: int = 100,
batch_size: int = 16,
from_timm: bool = False,
) -> None:
"""Helper function to run attacks in `__main__`.
Expand All @@ -60,38 +190,23 @@ def run_attack(
attack_cfg: A dict of keyword arguments passed to the attack class.
model_name: The torchvision model to attack. Defaults to "resnet50".
max_samples: Max number of samples to attack. Defaults to 100.
batch_size: Batch size for the dataloader. Defaults to 16.
from_timm: Use timm to load the model. Defaults to True.
"""

# Try to import rich for progress bar
with suppress(ImportError):
from rich import print
from rich.progress import track
import torchvision as tv
from rich import print
from rich.progress import track

from torchattack.dataset import NIPSLoader

if attack_cfg is None:
attack_cfg = {}

# Setup model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
model = (
tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval()
)
except ValueError:
import timm

model = timm.create_model(model_name, pretrained=True).to(device).eval()

# Setup transforms and normalization
transform = tv.transforms.Compose(
[
tv.transforms.Resize([256]),
tv.transforms.CenterCrop(224),
tv.transforms.ToTensor(),
]
)
normalize = tv.transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
model = AttackModel.from_pretrained(model_name, device, from_timm)
transform, normalize = model.transform, model.normalize

# Set up dataloader
dataloader = NIPSLoader(
Expand All @@ -100,18 +215,20 @@ def run_attack(
transform=transform,
max_samples=max_samples,
)
dataloader = track(dataloader, description='Attacking')

# Set up attack and trackers
frm = FoolingRateMetric()
attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg)
attacker = attack(
# Pass the original PyTorch model instead of the wrapped one if the attack
# requires access to the model's internal / intermediate / middle layers.
model=model.model,
normalize=normalize,
device=device,
**attack_cfg,
)
print(attacker)

# Wrap dataloader with rich.progress.track if available
try:
dataloader = track(dataloader, description='Attacking') # type: ignore
except NameError:
print('Running attack ... (install `rich` for progress bar)')

# Run attack over the dataset (100 images by default)
for i, (x, y, _) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
Expand Down
74 changes: 38 additions & 36 deletions src/torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,41 @@
class PNAPatchOut(Attack):
"""PNA-PatchOut attack for ViTs (Pay no attention & PatchOut)."""

# fmt: off
_supported_vit_cfg = {
'vit_base_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'deit_base_distilled_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'pit_b_224': [
# First transformer block
f'transformers.{0}.blocks.{i}.attn.attn_drop' for i in range(3)
] + [
# Second transformer block
f'transformers.{1}.blocks.{i-3}.attn.attn_drop' for i in range(3, 9)
] + [
# Third transformer block
f'transformers.{2}.blocks.{i-9}.attn.attn_drop' for i in range(9, 13)
],
'cait_s24_224': [
# Regular blocks
f'blocks.{i}.attn.attn_drop' for i in range(24)
] + [
# Token-only block
f'blocks_token_only.{i}.attn.attn_drop' for i in range(0, 2)
],
'visformer_small': [
# Stage 2 blocks
f'stage2.{i}.attn.attn_drop' for i in range(4)
] + [
# Stage 3 blocks
f'stage3.{i}.attn.attn_drop' for i in range(4)
],
}
# fmt: on

def __init__(
self,
model: nn.Module,
Expand Down Expand Up @@ -108,44 +143,10 @@ def attn_drop_mask_grad(

drop_hook_func = partial(attn_drop_mask_grad, gamma=0)

# fmt: off
_supported_vit_cfg = {
'vit_base_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'deit_base_distilled_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'pit_b_224': [
# First transformer block
f'transformers.{0}.blocks.{i}.attn.attn_drop' for i in range(3)
] + [
# Second transformer block
f'transformers.{1}.blocks.{i-3}.attn.attn_drop' for i in range(3, 9)
] + [
# Third transformer block
f'transformers.{2}.blocks.{i-9}.attn.attn_drop' for i in range(9, 13)
],
'cait_s24_224': [
# Regular blocks
f'blocks.{i}.attn.attn_drop' for i in range(24)
] + [
# Token-only block
f'blocks_token_only.{i}.attn.attn_drop' for i in range(0, 2)
],
'visformer_small': [
# Stage 2 blocks
f'stage2.{i}.attn.attn_drop' for i in range(4)
] + [
# Stage 3 blocks
f'stage3.{i}.attn.attn_drop' for i in range(4)
],
}
# fmt: on
assert self.model_name in _supported_vit_cfg, f'{self.model_name} not supported'
assert self.model_name in self._supported_vit_cfg

# Register backward hook for layers specified in _supported_vit_cfg
for layer in _supported_vit_cfg[self.model_name]:
for layer in self._supported_vit_cfg[self.model_name]:
module = rgetattr(self.model, layer)
module.register_backward_hook(drop_hook_func)

Expand Down Expand Up @@ -177,4 +178,5 @@ def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor:
PNAPatchOut,
attack_cfg={'model_name': 'vit_base_patch16_224'},
model_name='vit_base_patch16_224',
from_timm=True,
)

0 comments on commit 33c4a8b

Please sign in to comment.