Skip to content

Commit

Permalink
Update attack runner and metric tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed May 8, 2024
1 parent c35f8f2 commit 4aeea11
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
6 changes: 5 additions & 1 deletion src/torchattack/deepfool.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,8 @@ def _atleast_kd(self, x: torch.Tensor, k: int) -> torch.Tensor:
if __name__ == '__main__':
from torchattack.runner import run_attack

run_attack(DeepFool, attack_cfg={'steps': 50, 'overshoot': 0.02}, model='resnet152')
run_attack(
DeepFool,
attack_cfg={'steps': 50, 'overshoot': 0.02},
model_name='resnet152',
)
81 changes: 58 additions & 23 deletions src/torchattack/runner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
from contextlib import suppress
from typing import Any, Optional

import torch
import torchvision as tv

from torchattack.dataset import NIPSLoader


class FoolingRateMetric:
"""Fooling rate metric tracker."""

def __init__(self) -> None:
self.total_count = 0
self.clean_count = 0
self.adv_count = 0

def update(
self, labels: torch.Tensor, clean_logits: torch.Tensor, adv_logits: torch.Tensor
) -> None:
"""Update metric tracker during attack progress.
Args:
labels (torch.Tensor): Ground truth labels.
clean_logits (torch.Tensor): Prediction logits for clean samples.
adv_logits (torch.Tensor): Prediction logits for adversarial samples.
"""

self.total_count += labels.numel()
self.clean_count += (clean_logits.argmax(dim=1) == labels).sum().item()
self.adv_count += (adv_logits.argmax(dim=1) == labels).sum().item()

def compute_fooling_rate(self) -> torch.Tensor:
return (self.clean_count - self.adv_count) / self.clean_count

def compute_adv_accuracy(self) -> torch.Tensor:
return self.adv_count / self.total_count

def compute_clean_accuracy(self) -> torch.Tensor:
return self.clean_count / self.total_count


def run_attack(
attack,
attack_cfg,
model: str = 'resnet50',
samples: int = 100,
attack: Any,
attack_cfg: Optional[dict] = None,
model_name: str = 'resnet50',
max_samples: int = 100,
batch_size: int = 8,
) -> None:
"""Helper function to run attacks in `__main__`.
Expand All @@ -24,23 +58,23 @@ def run_attack(
Args:
attack: The attack class to initialize.
attack_cfg: A dict of keyword arguments passed to the attack class.
model: The torchvision model to attack. Defaults to "resnet50".
samples: Max number of samples to attack. Defaults to 100.
model_name: The torchvision model to attack. Defaults to "resnet50".
max_samples: Max number of samples to attack. Defaults to 100.
"""

# Try to import rich for progress bar
with suppress(ImportError):
from rich import print
from rich.progress import track

if attack_cfg is None:
attack_cfg = {}

# Set up model, transform, and normalize
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = tv.models.get_model(name=model, weights='DEFAULT').to(device).eval()
model = tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval()
transform = tv.transforms.Compose(
[
tv.transforms.Resize([224]),
tv.transforms.ToTensor(),
]
[tv.transforms.Resize([224]), tv.transforms.ToTensor()]
)
normalize = tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

Expand All @@ -49,11 +83,11 @@ def run_attack(
root='datasets/nips2017',
batch_size=batch_size,
transform=transform,
max_samples=samples,
max_samples=max_samples,
)

# Set up attack and trackers
total, acc_clean, acc_adv = len(dataloader.dataset), 0, 0 # type: ignore
frm = FoolingRateMetric()
attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg)
print(attacker)

Expand All @@ -64,23 +98,24 @@ def run_attack(
print('Running attack ... (install `rich` for progress bar)')

# Run attack over the dataset (100 images by default)
for i, (images, labels, _) in enumerate(dataloader):
images, labels = images.to(device), labels.to(device)
for i, (x, y, _) in enumerate(dataloader):
x, y = x.to(device), y.to(device)

# Adversarial images are created here
adv_images = attacker(images, labels)
advs = attacker(x, y)

# Track accuracy
clean_outs = model(normalize(images)).argmax(dim=1)
adv_outs = model(normalize(adv_images)).argmax(dim=1)

acc_clean += (clean_outs == labels).sum().item()
acc_adv += (adv_outs == labels).sum().item()
cln_outs = model(normalize(x))
adv_outs = model(normalize(advs))
frm.update(y, cln_outs, adv_outs)

# Save one batch of adversarial images
if i == 1:
saved_imgs = adv_images.detach().cpu().mul(255).to(torch.uint8)
saved_imgs = advs.detach().cpu().mul(255).to(torch.uint8)
img_grid = tv.utils.make_grid(saved_imgs, nrow=4)
tv.io.write_png(img_grid, 'adv_batch.png')

print(f'Accuracy (clean vs adversarial): {acc_clean / total} vs {acc_adv / total}')
# Print results
print(f'Clean accuracy: {frm.compute_clean_accuracy():.2%}')
print(f'Adversarial accuracy: {frm.compute_adv_accuracy():.2%}')
print(f'Fooling rate: {frm.compute_fooling_rate():.2%}')

0 comments on commit 4aeea11

Please sign in to comment.