Skip to content

Commit

Permalink
Add support for transfer eval
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Aug 21, 2024
1 parent b03c16c commit 2644789
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 28 deletions.
87 changes: 61 additions & 26 deletions src/torchattack/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@ def update(
self.clean_count += (clean_logits.argmax(dim=1) == labels).sum().item()
self.adv_count += (adv_logits.argmax(dim=1) == labels).sum().item()

def compute_fooling_rate(self) -> torch.Tensor:
return (self.clean_count - self.adv_count) / self.clean_count
def compute_cln_acc(self) -> torch.Tensor:
"""Compute the model accuracy of clean samples."""
return self.clean_count / self.total_count

def compute_adv_accuracy(self) -> torch.Tensor:
def compute_adv_acc(self) -> torch.Tensor:
"""Compute the model accuracy of adversarial examples."""
return self.adv_count / self.total_count

def compute_clean_accuracy(self) -> torch.Tensor:
return self.clean_count / self.total_count
def compute_fr(self) -> torch.Tensor:
"""Compute the fooling rate."""
return (self.clean_count - self.adv_count) / self.clean_count


class AttackModel:
"""A wrapper class for a pretrained model used for adversarial attacks.
Intented to be instantiated with `AttackModel.from_pretrained(pretrained_model_name)` from either
Intended to be instantiated with `AttackModel.from_pretrained(pretrained_model_name)` from either
`torchvision.models` or `timm`. The model is loaded and attributes including `transform` and `normalize` are
attached based on the model's configuration.
Expand All @@ -57,7 +60,7 @@ class AttackModel:
AttackModel(model_name=resnet50, device=cuda, transform=Compose(...), normalize=Normalize(...))
>>> model.transform
Compose(
Resize(size=(256, 256), interpolation=PIL.Image.BILINEAR)
Resize(size=[256], interpolation=bilinear, max_size=None, antialias=True)
CenterCrop(size=(224, 224))
ToTensor()
)
Expand Down Expand Up @@ -111,17 +114,18 @@ def from_pretrained(

model = tv_models.get_model(name=model_name, weights='DEFAULT')

# resolve transforms from vision model weights
# Resolve transforms from vision model weights
weight_id = str(tv_models.get_model_weights(name=model_name)['DEFAULT'])
cfg = tv_models.get_weight(weight_id).transforms()

# construct transform and normalize
# torchvision/transforms/_presets.py::ImageClassification
# Manually construct separated transform and normalize
transform = t.Compose(
[
t.Resize(
cfg.resize_size,
interpolation=cfg.interpolation,
antialias=True,
antialias=cfg.antialias,
),
t.CenterCrop(cfg.crop_size),
t.ToTensor(),
Expand All @@ -139,16 +143,12 @@ def from_pretrained(
model = timm.create_model(model_name, pretrained=True)
cfg = timm.data.resolve_data_config(model.pretrained_cfg)

# create normalization
normalize = (
t.Normalize(mean=cfg['mean'], std=cfg['std'])
if 'mean' in cfg and 'std' in cfg
# if no mean and std is given, use default ImageNet values
else t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
# create a transform based on the cfg
# Construct normalization
normalize = t.Normalize(mean=cfg['mean'], std=cfg['std'])

# Create a transform based on the model pretrained cfg
transform = timm.data.create_transform(**cfg, is_training=False)
# remove the Normalize from the transforms.Compose if there is one
# Remove the Normalize from composed transform if there is one
transform.transforms = [
tr for tr in transform.transforms if not isinstance(tr, t.Normalize)
]
Expand All @@ -173,6 +173,7 @@ def run_attack(
attack: Any,
attack_cfg: dict | None = None,
model_name: str = 'resnet50',
victim_model_names: list[str] | None = None,
max_samples: int = 100,
batch_size: int = 16,
from_timm: bool = False,
Expand All @@ -187,7 +188,8 @@ def run_attack(
Args:
attack: The attack class to initialize.
attack_cfg: A dict of keyword arguments passed to the attack class.
model_name: The torchvision model to attack. Defaults to "resnet50".
model_name: The surrogate model to attack. Defaults to "resnet50".
victim_model_names: A list of the victim black-box models to attack. Defaults to None.
max_samples: Max number of samples to attack. Defaults to 100.
batch_size: Batch size for the dataloader. Defaults to 16.
from_timm: Use timm to load the model. Defaults to True.
Expand Down Expand Up @@ -219,15 +221,23 @@ def run_attack(
# Set up attack and trackers
frm = FoolingRateMetric()
attacker = attack(
# Pass the original PyTorch model instead of the wrapped one if the attack
# requires access to the model's internal / intermediate / middle layers.
# Pass the original PyTorch model instead of the wrapped one if the attack requires access to the model's
# intermediate layers or other attributes that are not exposed by the AttackModel wrapper.
model=model.model,
normalize=normalize,
device=device,
**attack_cfg,
)
print(attacker)

# Setup victim models if provided
if victim_model_names:
victim_models = [
AttackModel.from_pretrained(name, device, from_timm)
for name in victim_model_names
]
victim_frms = [FoolingRateMetric() for _ in victim_model_names]

# Run attack over the dataset (100 images by default)
for i, (x, y, _) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
Expand All @@ -240,13 +250,38 @@ def run_attack(
adv_outs = model(normalize(advs))
frm.update(y, cln_outs, adv_outs)

# Save first batch of adversarial images
# Save first batch of adversarial examples
if i == 0:
saved_imgs = advs.detach().cpu().mul(255).to(torch.uint8)
img_grid = tv.utils.make_grid(saved_imgs, nrow=4)
tv.io.write_png(img_grid, 'adv_batch_0.png')

# Track transfer fooling rates if victim models are provided
if victim_model_names:
for _, (vmodel, vfrm) in enumerate(
zip(victim_models, victim_frms, strict=True)
):
v_cln_outs = vmodel(normalize(x))
v_adv_outs = vmodel(normalize(advs))
vfrm.update(y, v_cln_outs, v_adv_outs)

# Print results
print(f'Clean accuracy: {frm.compute_clean_accuracy():.2%}')
print(f'Adversarial accuracy: {frm.compute_adv_accuracy():.2%}')
print(f'Fooling rate: {frm.compute_fooling_rate():.2%}')
cln_acc, adv_acc, fr = (
frm.compute_cln_acc(),
frm.compute_adv_acc(),
frm.compute_fr(),
)
print(
f'Surrogate ({model_name}): {cln_acc:.2%} / {adv_acc:.2%} (Fooling rate: {fr:.2%})'
)

if victim_model_names:
for vmodel, vfrm in zip(victim_models, victim_frms, strict=True):
vcln_acc, vadv_acc, vfr = (
vfrm.compute_cln_acc(),
vfrm.compute_adv_acc(),
vfrm.compute_fr(),
)
print(
f'Victim ({vmodel.model_name}): {vcln_acc:.2%} / {vadv_acc:.2%} (Fooling rate: {vfr:.2%})'
)
7 changes: 6 additions & 1 deletion src/torchattack/mifgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if __name__ == '__main__':
from torchattack.eval import run_attack

run_attack(MIFGSM, {'eps': 8 / 255, 'steps': 10})
run_attack(
attack=MIFGSM,
attack_cfg={'eps': 8 / 255, 'steps': 10},
model_name='resnet18',
victim_model_names=['resnet50', 'vgg13', 'densenet121'],
)
7 changes: 6 additions & 1 deletion src/torchattack/pgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if __name__ == '__main__':
from torchattack.eval import run_attack

run_attack(PGD, {'eps': 8 / 255, 'steps': 20, 'random_start': True})
run_attack(
attack=PGD,
attack_cfg={'eps': 8 / 255, 'steps': 20, 'random_start': True},
model_name='resnet18',
victim_model_names=['resnet50', 'vgg13', 'densenet121'],
)

0 comments on commit 2644789

Please sign in to comment.