From 169f6703b081ce4257fcd05899f3ed631353d6dd Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 26 Nov 2024 19:41:45 +0800 Subject: [PATCH] Add weights/checkpoint_path args to create_attack --- torchattack/create_attack.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index a989ae3..18fdee6 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -7,6 +7,9 @@ from torchattack._attack import Attack from torchattack.attack_model import AttackModel +_generative_attacks = ['BIA', 'CDA'] +_non_eps_attacks = ['GeoDA', 'DeepFool'] + def attack_warn(message: str) -> None: from warnings import warn @@ -20,13 +23,15 @@ def create_attack( normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, device: torch.device | None = None, eps: float | None = None, + weights: str | None = None, + checkpoint_path: str | None = None, attack_cfg: dict[str, Any] | None = None, ) -> Attack: """Create a torchattack instance based on the provided attack name and config. Args: attack_name: The name of the attack to create. - model: The model to be attacked. + model: The model to be attacked. Defaults to None. normalize: The normalization function specific to the model. Defaults to None. device: The device on which the attack will be executed. Defaults to None. eps: The epsilon value for the attack. Defaults to None. @@ -54,13 +59,24 @@ def create_attack( f"by the 'eps' argument value ({eps}), which MAY NOT be intended." ) attack_cfg['eps'] = eps - if attack_name in ['GeoDA', 'DeepFool'] and 'eps' in attack_cfg: + if attack_name in _non_eps_attacks and 'eps' in attack_cfg: attack_warn(f"parameter 'eps' is invalid in {attack_name} and will be ignored.") attack_cfg.pop('eps', None) + if attack_name not in _generative_attacks and ( + weights is not None or checkpoint_path is not None + ): + attack_warn( + f'weights and checkpoint_path are only used for generative attacks, ' + f"and will be ignored for '{attack_name}'." + ) + attack_cfg.pop('weights', None) + attack_cfg.pop('checkpoint_path', None) if not hasattr(torchattack, attack_name): raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") attacker_cls: Attack = getattr(torchattack, attack_name) - if attack_name in ['BIA', 'CDA']: + if attack_name in _generative_attacks: + attack_cfg['weights'] = weights + attack_cfg['checkpoint_path'] = checkpoint_path return attacker_cls(device, **attack_cfg) return attacker_cls(model, normalize, device, **attack_cfg)