From 4aeea1198ace02b794d081f9e8d6e927815dd8d5 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 8 May 2024 21:15:16 +0800 Subject: [PATCH] Update attack runner and metric tracker --- src/torchattack/deepfool.py | 6 ++- src/torchattack/runner.py | 81 ++++++++++++++++++++++++++----------- 2 files changed, 63 insertions(+), 24 deletions(-) diff --git a/src/torchattack/deepfool.py b/src/torchattack/deepfool.py index a6dea80..2b0a591 100644 --- a/src/torchattack/deepfool.py +++ b/src/torchattack/deepfool.py @@ -178,4 +178,8 @@ def _atleast_kd(self, x: torch.Tensor, k: int) -> torch.Tensor: if __name__ == '__main__': from torchattack.runner import run_attack - run_attack(DeepFool, attack_cfg={'steps': 50, 'overshoot': 0.02}, model='resnet152') + run_attack( + DeepFool, + attack_cfg={'steps': 50, 'overshoot': 0.02}, + model_name='resnet152', + ) diff --git a/src/torchattack/runner.py b/src/torchattack/runner.py index 699cf8b..2d9cf1c 100644 --- a/src/torchattack/runner.py +++ b/src/torchattack/runner.py @@ -1,4 +1,5 @@ from contextlib import suppress +from typing import Any, Optional import torch import torchvision as tv @@ -6,11 +7,44 @@ from torchattack.dataset import NIPSLoader +class FoolingRateMetric: + """Fooling rate metric tracker.""" + + def __init__(self) -> None: + self.total_count = 0 + self.clean_count = 0 + self.adv_count = 0 + + def update( + self, labels: torch.Tensor, clean_logits: torch.Tensor, adv_logits: torch.Tensor + ) -> None: + """Update metric tracker during attack progress. + + Args: + labels (torch.Tensor): Ground truth labels. + clean_logits (torch.Tensor): Prediction logits for clean samples. + adv_logits (torch.Tensor): Prediction logits for adversarial samples. + """ + + self.total_count += labels.numel() + self.clean_count += (clean_logits.argmax(dim=1) == labels).sum().item() + self.adv_count += (adv_logits.argmax(dim=1) == labels).sum().item() + + def compute_fooling_rate(self) -> torch.Tensor: + return (self.clean_count - self.adv_count) / self.clean_count + + def compute_adv_accuracy(self) -> torch.Tensor: + return self.adv_count / self.total_count + + def compute_clean_accuracy(self) -> torch.Tensor: + return self.clean_count / self.total_count + + def run_attack( - attack, - attack_cfg, - model: str = 'resnet50', - samples: int = 100, + attack: Any, + attack_cfg: Optional[dict] = None, + model_name: str = 'resnet50', + max_samples: int = 100, batch_size: int = 8, ) -> None: """Helper function to run attacks in `__main__`. @@ -24,8 +58,8 @@ def run_attack( Args: attack: The attack class to initialize. attack_cfg: A dict of keyword arguments passed to the attack class. - model: The torchvision model to attack. Defaults to "resnet50". - samples: Max number of samples to attack. Defaults to 100. + model_name: The torchvision model to attack. Defaults to "resnet50". + max_samples: Max number of samples to attack. Defaults to 100. """ # Try to import rich for progress bar @@ -33,14 +67,14 @@ def run_attack( from rich import print from rich.progress import track + if attack_cfg is None: + attack_cfg = {} + # Set up model, transform, and normalize device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = tv.models.get_model(name=model, weights='DEFAULT').to(device).eval() + model = tv.models.get_model(name=model_name, weights='DEFAULT').to(device).eval() transform = tv.transforms.Compose( - [ - tv.transforms.Resize([224]), - tv.transforms.ToTensor(), - ] + [tv.transforms.Resize([224]), tv.transforms.ToTensor()] ) normalize = tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) @@ -49,11 +83,11 @@ def run_attack( root='datasets/nips2017', batch_size=batch_size, transform=transform, - max_samples=samples, + max_samples=max_samples, ) # Set up attack and trackers - total, acc_clean, acc_adv = len(dataloader.dataset), 0, 0 # type: ignore + frm = FoolingRateMetric() attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg) print(attacker) @@ -64,23 +98,24 @@ def run_attack( print('Running attack ... (install `rich` for progress bar)') # Run attack over the dataset (100 images by default) - for i, (images, labels, _) in enumerate(dataloader): - images, labels = images.to(device), labels.to(device) + for i, (x, y, _) in enumerate(dataloader): + x, y = x.to(device), y.to(device) # Adversarial images are created here - adv_images = attacker(images, labels) + advs = attacker(x, y) # Track accuracy - clean_outs = model(normalize(images)).argmax(dim=1) - adv_outs = model(normalize(adv_images)).argmax(dim=1) - - acc_clean += (clean_outs == labels).sum().item() - acc_adv += (adv_outs == labels).sum().item() + cln_outs = model(normalize(x)) + adv_outs = model(normalize(advs)) + frm.update(y, cln_outs, adv_outs) # Save one batch of adversarial images if i == 1: - saved_imgs = adv_images.detach().cpu().mul(255).to(torch.uint8) + saved_imgs = advs.detach().cpu().mul(255).to(torch.uint8) img_grid = tv.utils.make_grid(saved_imgs, nrow=4) tv.io.write_png(img_grid, 'adv_batch.png') - print(f'Accuracy (clean vs adversarial): {acc_clean / total} vs {acc_adv / total}') + # Print results + print(f'Clean accuracy: {frm.compute_clean_accuracy():.2%}') + print(f'Adversarial accuracy: {frm.compute_adv_accuracy():.2%}') + print(f'Fooling rate: {frm.compute_fooling_rate():.2%}')