Skip to content

Commit

Permalink
Complete tests for create_attack and others
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 25, 2024
1 parent b836831 commit f5db6b0
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 2 deletions.
30 changes: 30 additions & 0 deletions tests/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GeoDA,
PNAPatchOut,
)
from torchattack.attack_model import AttackModel


def run_attack_test(attack_cls, device, model, x, y):
Expand Down Expand Up @@ -72,3 +73,32 @@ def test_cnn_attacks(attack_cls, device, resnet50_model, data):
def test_vit_attacks(attack_cls, device, vitb16_model, data):
x, y = data(vitb16_model.transform)
run_attack_test(attack_cls, device, vitb16_model, x, y)


@pytest.mark.parametrize(
'model_name',
[
'deit_base_distilled_patch16_224',
'pit_b_224',
'cait_s24_224',
'visformer_small',
],
)
def test_tgr_attack_all_supported_models(device, model_name, data):
model = AttackModel.from_pretrained(model_name, device, from_timm=True)
x, y = data(model.transform)
run_attack_test(TGR, device, model, x, y)


@pytest.mark.parametrize(
'model_name',
[
'deit_base_distilled_patch16_224',
'pit_b_224',
'visformer_small',
],
)
def test_vdc_attack_all_supported_models(device, model_name, data):
model = AttackModel.from_pretrained(model_name, device, from_timm=True)
x, y = data(model.transform)
run_attack_test(VDC, device, model, x, y)
71 changes: 71 additions & 0 deletions tests/test_create_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,74 @@ def test_create_vit_attack_same_as_imported(
created_attacker = create_attack(attack_name, vitb16_model)
expected_attacker = expected(vitb16_model)
assert created_attacker == expected_attacker


def test_create_attack_with_eps(device, resnet50_model):
eps = 0.3
attack_cfg = {}
attacker = create_attack(
attack_name='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
eps=eps,
attack_cfg=attack_cfg,
)
assert attacker.eps == eps


def test_create_attack_with_attack_cfg_eps(device, resnet50_model):
attack_cfg = {'eps': 0.1}
attacker = create_attack(
attack_name='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
attack_cfg=attack_cfg,
)
assert attacker.eps == attack_cfg['eps']


def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model):
eps = 0.3
attack_cfg = {'eps': 0.1}
# with pytest.warns(
# UserWarning,
# match="'eps' in 'attack_cfg' (0.1) will be overwritten by the 'eps' argument value (0.3), which MAY NOT be intended.",
# ):
attacker = create_attack(
attack_name='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
eps=eps,
attack_cfg=attack_cfg,
)
assert attacker.eps == eps


def test_create_attack_with_invalid_eps(device, resnet50_model):
eps = 0.3
with pytest.warns(
UserWarning, match="parameter 'eps' is invalid in DeepFool and will be ignored."
):
attacker = create_attack(
attack_name='DeepFool',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
eps=eps,
)
assert 'eps' not in attacker.__dict__


def test_create_attack_with_invalid_attack_name(device, resnet50_model):
with pytest.raises(
ValueError, match="Attack 'InvalidAttack' is not supported within torchattack."
):
create_attack(
attack_name='InvalidAttack',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
)
41 changes: 41 additions & 0 deletions tests/test_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch

from torchattack.eval import FoolingRateMetric


@pytest.fixture()
def metric():
return FoolingRateMetric()


def test_initial_state(metric):
assert metric.total_count.item() == 0
assert metric.clean_count.item() == 0
assert metric.adv_count.item() == 0


def test_update(metric):
labels = torch.tensor([0, 1, 2])
clean_logits = torch.tensor([[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]])
adv_logits = torch.tensor([[0.1, 0.8, 0.1], [0.2, 0.6, 0.2], [0.9, 0.1, 0.0]])

metric.update(labels, clean_logits, adv_logits)

assert metric.total_count.item() == 3
assert metric.clean_count.item() == 3 # all clean samples are correctly classified
assert metric.adv_count.item() == 1 # only the 2nd sample is correctly classified


def test_compute(metric):
labels = torch.tensor([0, 1, 2])
clean_logits = torch.tensor([[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]])
adv_logits = torch.tensor([[0.1, 0.8, 0.1], [0.2, 0.6, 0.2], [0.9, 0.1, 0.0]])

metric.update(labels, clean_logits, adv_logits)
clean_acc, adv_acc, fooling_rate = metric.compute()

assert clean_acc.item() == pytest.approx(3 / 3)
assert adv_acc.item() == pytest.approx(1 / 3)
# fooling_rate = (clean_count - adv_count) / clean_count
assert fooling_rate.item() == pytest.approx((3 - 1) / 3)
38 changes: 36 additions & 2 deletions torchattack/create_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from torchattack.attack_model import AttackModel


def attack_warn(message: str) -> None:
from warnings import warn

warn(message, category=UserWarning, stacklevel=2)


def create_attack(
attack_name: str,
model: nn.Module | AttackModel,
Expand All @@ -16,15 +22,43 @@ def create_attack(
eps: float | None = None,
attack_cfg: dict[str, Any] | None = None,
) -> Attack:
"""Create a torchattack instance based on the provided attack name and config.
Args:
attack_name: The name of the attack to create.
model: The model to be attacked.
normalize: The normalization function specific to the model. Defaults to None.
device: The device on which the attack will be executed. Defaults to None.
eps: The epsilon value for the attack. Defaults to None.
attack_cfg: Additional config parameters for the attack. Defaults to None.
Returns:
Attack: An instance of the specified attack.
Raises:
ValueError: If the specified attack name is not supported within torchattack.
Notes:
- If `eps` is provided and also present in `attack_cfg`, a warning will be
issued and the value in `attack_cfg` will be overwritten.
- For certain attacks like 'GeoDA' and 'DeepFool', the `eps` parameter is
invalid and will be ignored if present in `attack_cfg`.
"""

if attack_cfg is None:
attack_cfg = {}
if eps is not None:
if 'eps' in attack_cfg:
print('Warning: `eps` in `attack_cfg` will be overwritten.')
attack_warn(
f"'eps' in 'attack_cfg' ({attack_cfg['eps']}) will be overwritten "
f"by the 'eps' argument value ({eps}), which MAY NOT be intended."
)
attack_cfg['eps'] = eps
if attack_name in ['GeoDA', 'DeepFool'] and 'eps' in attack_cfg:
print(f'Warning: `eps` is invalid in `{attack_name}` and will be ignored.')
attack_warn(f"parameter 'eps' is invalid in {attack_name} and will be ignored.")
attack_cfg.pop('eps', None)
if not hasattr(torchattack, attack_name):
raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.")
attacker_cls: Attack = getattr(torchattack, attack_name)
return attacker_cls(model, normalize, device, **attack_cfg)

Expand Down

0 comments on commit f5db6b0

Please sign in to comment.