diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index a989ae3..18fdee6 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -7,6 +7,9 @@ from torchattack._attack import Attack from torchattack.attack_model import AttackModel +_generative_attacks = ['BIA', 'CDA'] +_non_eps_attacks = ['GeoDA', 'DeepFool'] + def attack_warn(message: str) -> None: from warnings import warn @@ -20,13 +23,15 @@ def create_attack( normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, device: torch.device | None = None, eps: float | None = None, + weights: str | None = None, + checkpoint_path: str | None = None, attack_cfg: dict[str, Any] | None = None, ) -> Attack: """Create a torchattack instance based on the provided attack name and config. Args: attack_name: The name of the attack to create. - model: The model to be attacked. + model: The model to be attacked. Defaults to None. normalize: The normalization function specific to the model. Defaults to None. device: The device on which the attack will be executed. Defaults to None. eps: The epsilon value for the attack. Defaults to None. @@ -54,13 +59,24 @@ def create_attack( f"by the 'eps' argument value ({eps}), which MAY NOT be intended." ) attack_cfg['eps'] = eps - if attack_name in ['GeoDA', 'DeepFool'] and 'eps' in attack_cfg: + if attack_name in _non_eps_attacks and 'eps' in attack_cfg: attack_warn(f"parameter 'eps' is invalid in {attack_name} and will be ignored.") attack_cfg.pop('eps', None) + if attack_name not in _generative_attacks and ( + weights is not None or checkpoint_path is not None + ): + attack_warn( + f'weights and checkpoint_path are only used for generative attacks, ' + f"and will be ignored for '{attack_name}'." + ) + attack_cfg.pop('weights', None) + attack_cfg.pop('checkpoint_path', None) if not hasattr(torchattack, attack_name): raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") attacker_cls: Attack = getattr(torchattack, attack_name) - if attack_name in ['BIA', 'CDA']: + if attack_name in _generative_attacks: + attack_cfg['weights'] = weights + attack_cfg['checkpoint_path'] = checkpoint_path return attacker_cls(device, **attack_cfg) return attacker_cls(model, normalize, device, **attack_cfg)