Skip to content

Commit

Permalink
Add weights/checkpoint_path args to create_attack
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 26, 2024
1 parent 5351149 commit 169f670
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions torchattack/create_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from torchattack._attack import Attack
from torchattack.attack_model import AttackModel

_generative_attacks = ['BIA', 'CDA']
_non_eps_attacks = ['GeoDA', 'DeepFool']


def attack_warn(message: str) -> None:
from warnings import warn
Expand All @@ -20,13 +23,15 @@ def create_attack(
normalize: Callable[[torch.Tensor], torch.Tensor] | None = None,
device: torch.device | None = None,
eps: float | None = None,
weights: str | None = None,
checkpoint_path: str | None = None,
attack_cfg: dict[str, Any] | None = None,
) -> Attack:
"""Create a torchattack instance based on the provided attack name and config.
Args:
attack_name: The name of the attack to create.
model: The model to be attacked.
model: The model to be attacked. Defaults to None.
normalize: The normalization function specific to the model. Defaults to None.
device: The device on which the attack will be executed. Defaults to None.
eps: The epsilon value for the attack. Defaults to None.
Expand Down Expand Up @@ -54,13 +59,24 @@ def create_attack(
f"by the 'eps' argument value ({eps}), which MAY NOT be intended."
)
attack_cfg['eps'] = eps
if attack_name in ['GeoDA', 'DeepFool'] and 'eps' in attack_cfg:
if attack_name in _non_eps_attacks and 'eps' in attack_cfg:
attack_warn(f"parameter 'eps' is invalid in {attack_name} and will be ignored.")
attack_cfg.pop('eps', None)
if attack_name not in _generative_attacks and (
weights is not None or checkpoint_path is not None
):
attack_warn(
f'weights and checkpoint_path are only used for generative attacks, '
f"and will be ignored for '{attack_name}'."
)
attack_cfg.pop('weights', None)
attack_cfg.pop('checkpoint_path', None)
if not hasattr(torchattack, attack_name):
raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.")
attacker_cls: Attack = getattr(torchattack, attack_name)
if attack_name in ['BIA', 'CDA']:
if attack_name in _generative_attacks:
attack_cfg['weights'] = weights
attack_cfg['checkpoint_path'] = checkpoint_path
return attacker_cls(device, **attack_cfg)
return attacker_cls(model, normalize, device, **attack_cfg)

Expand Down

0 comments on commit 169f670

Please sign in to comment.