Skip to content

Commit

Permalink
Attack runner use create_attack function
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 23, 2024
1 parent c2a2d81 commit b836831
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
7 changes: 3 additions & 4 deletions torchattack/create_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ def create_attack(
if 'eps' in attack_cfg:
print('Warning: `eps` in `attack_cfg` will be overwritten.')
attack_cfg['eps'] = eps
if attack_name in ['GeoDA', 'DeepFool']:
if 'eps' in attack_cfg:
print(f'Warning: `eps` is invalid in `{attack_name}` and will be ignored.')
if attack_name in ['GeoDA', 'DeepFool'] and 'eps' in attack_cfg:
print(f'Warning: `eps` is invalid in `{attack_name}` and will be ignored.')
attack_cfg.pop('eps', None)
attacker_cls = getattr(torchattack, attack_name)
attacker_cls: Attack = getattr(torchattack, attack_name)
return attacker_cls(model, normalize, device, **attack_cfg)


Expand Down
53 changes: 43 additions & 10 deletions torchattack/eval/runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import Any

import torch

from torchattack.attack_model import AttackModel
from torchattack.eval.metric import FoolingRateMetric


def run_attack(
attack: Any,
Expand All @@ -13,7 +8,7 @@ def run_attack(
victim_model_names: list[str] | None = None,
dataset_root: str = 'datasets/nips2017',
max_samples: int = 100,
batch_size: int = 16,
batch_size: int = 4,
from_timm: bool = False,
) -> None:
"""Helper function to run attacks in `__main__`.
Expand All @@ -24,7 +19,7 @@ def run_attack(
>>> run_attack(attack=FGSM, attack_cfg=cfg)
Args:
attack: The attack class to initialize.
attack: The attack class to initialize, either by name or class instance.
attack_cfg: A dict of keyword arguments passed to the attack class.
model_name: The surrogate model to attack. Defaults to "resnet50".
victim_model_names: A list of the victim black-box models to attack. Defaults to None.
Expand All @@ -34,10 +29,13 @@ def run_attack(
from_timm: Use timm to load the model. Defaults to True.
"""

import torch
from rich import print
from rich.progress import track

from torchattack import AttackModel, create_attack
from torchattack.eval.dataset import NIPSLoader
from torchattack.eval.metric import FoolingRateMetric

if attack_cfg is None:
attack_cfg = {}
Expand All @@ -58,7 +56,16 @@ def run_attack(

# Set up attack and trackers
frm = FoolingRateMetric()
attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg)
if isinstance(attack, str):
attacker = create_attack(
attack_name=attack,
model=model,
normalize=normalize,
device=device,
attack_cfg=attack_cfg,
)
else:
attacker = attack(model, normalize, device, **attack_cfg)
print(attacker)

# Setup victim models if provided
Expand Down Expand Up @@ -92,8 +99,8 @@ def run_attack(
# Track transfer fooling rates if victim models are provided
if victim_model_names:
for _, (vmodel, vfrm) in enumerate(zip(victim_models, victim_frms)):
v_cln_outs = vmodel(normalize(x))
v_adv_outs = vmodel(normalize(advs))
v_cln_outs = vmodel(vmodel.normalize(x))
v_adv_outs = vmodel(vmodel.normalize(advs))
vfrm.update(y, v_cln_outs, v_adv_outs)

# Print results
Expand All @@ -107,3 +114,29 @@ def run_attack(
f'Victim ({vmodel.model_name}): cln_acc={vcln_acc:.2%}, '
f'adv_acc={vadv_acc:.2%} (fr={vfr:.2%})'
)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Run an attack on a model.')
parser.add_argument('--attack', type=str, required=True)
parser.add_argument('--eps', type=str, default='16/255')
parser.add_argument('--model-name', type=str, default='resnet50')
parser.add_argument('--victim-model-names', type=str, nargs='+', default=None)
parser.add_argument('--dataset-root', type=str, default='datasets/nips2017')
parser.add_argument('--max-samples', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--from-timm', action='store_true')
args = parser.parse_args()

run_attack(
attack=args.attack,
attack_cfg={'eps': float(args.eps)},
model_name=args.model_name,
victim_model_names=args.victim_model_names,
dataset_root=args.dataset_root,
max_samples=args.max_samples,
batch_size=args.batch_size,
from_timm=args.from_timm,
)

0 comments on commit b836831

Please sign in to comment.