Skip to content

Commit

Permalink
Add support for BIA
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 26, 2024
1 parent 066ac8f commit 5351149
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 5 deletions.
4 changes: 4 additions & 0 deletions torchattack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from torchattack.admix import Admix
from torchattack.attack_model import AttackModel
from torchattack.bia import BIA
from torchattack.cda import CDA
from torchattack.create_attack import create_attack
from torchattack.decowa import DeCoWA
from torchattack.deepfool import DeepFool
Expand Down Expand Up @@ -30,6 +32,8 @@
'AttackModel',
# All supported attacks
'Admix',
'BIA',
'CDA',
'DeCoWA',
'DeepFool',
'DIFGSM',
Expand Down
91 changes: 91 additions & 0 deletions torchattack/bia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch

from torchattack.generative._inference import GenerativeAttack
from torchattack.generative._weights import Weights, WeightsEnum
from torchattack.generative.resnet_generator import ResNetGenerator


class BIAWeights(WeightsEnum):
RESNET152 = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_0.pth',
)
RESNET152_RN = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_rn_0.pth',
)
RESNET152_DA = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_da_0.pth',
)
DENSENET169 = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_0.pth',
)
DENSENET169_RN = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_rn_0.pth',
)
DENSENET169_DA = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_da_0.pth',
)
VGG16 = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_0.pth',
)
VGG16_RN = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_rn_0.pth',
)
VGG16_DA = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_da_0.pth',
)
VGG19 = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_0.pth',
)
VGG19_RN = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_rn_0.pth',
)
VGG19_DA = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_da_0.pth',
)
DEFAULT = RESNET152_DA


class BIA(GenerativeAttack):
"""Beyond ImageNet Attack (BIA).
From the paper 'Beyond ImageNet Attack: Towards Crafting Adversarial Examples for
Black-box Domains', https://arxiv.org/abs/2201.11528
Args:
device: Device to use for tensors. Defaults to cuda if available. eps: The
maximum perturbation. Defaults to 10/255. weights: Pretrained weights for the
generator. Either import and use the enum,
or use its name. Defaults to BIAWeights.DEFAULT.
checkpoint_path: Path to a custom checkpoint. Defaults to None. clip_min:
Minimum value for clipping. Defaults to 0.0. clip_max: Maximum value for
clipping. Defaults to 1.0.
"""

def __init__(
self,
device: torch.device | None = None,
eps: float = 10 / 255,
weights: BIAWeights | str | None = BIAWeights.DEFAULT,
checkpoint_path: str | None = None,
clip_min: float = 0.0,
clip_max: float = 1.0,
) -> None:
super().__init__(device, eps, weights, checkpoint_path, clip_min, clip_max)

def _init_generator(self) -> ResNetGenerator:
generator = ResNetGenerator()
# Prioritize checkpoint path over provided weights enum
if self.checkpoint_path is not None:
generator.load_state_dict(torch.load(self.checkpoint_path))
else:
# Verify and load weights from enum if checkpoint path is not provided
self.weights: BIAWeights = BIAWeights.verify(self.weights)
if self.weights is not None:
generator.load_state_dict(self.weights.get_state_dict(check_hash=True))
return generator.eval().to(self.device)


if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attack = BIA(device, weights='VGG19_DA')
print(attack)
6 changes: 4 additions & 2 deletions torchattack/cda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class CDA(GenerativeAttack):
Args:
device: Device to use for tensors. Defaults to cuda if available.
eps: The maximum perturbation. Defaults to 10/255.
weights: Pretrained weights for the generator. Defaults to CDAWeights.DEFAULT.
weights: Pretrained weights for the generator. Either import and use the enum,
or use its name. Defaults to CDAWeights.DEFAULT.
checkpoint_path: Path to a custom checkpoint. Defaults to None.
clip_min: Minimum value for clipping. Defaults to 0.0.
clip_max: Maximum value for clipping. Defaults to 1.0.
"""
Expand Down Expand Up @@ -61,5 +63,5 @@ def _init_generator(self) -> ResNetGenerator:

if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attack = CDA(device, eps=8 / 255, weights='VGG19_IMAGENET1K')
attack = CDA(device, weights='VGG19_IMAGENET1K')
print(attack)
2 changes: 2 additions & 0 deletions torchattack/create_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def create_attack(
if not hasattr(torchattack, attack_name):
raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.")
attacker_cls: Attack = getattr(torchattack, attack_name)
if attack_name in ['BIA', 'CDA']:
return attacker_cls(device, **attack_cfg)
return attacker_cls(model, normalize, device, **attack_cfg)


Expand Down
4 changes: 2 additions & 2 deletions torchattack/eval/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def run_attack(

parser = argparse.ArgumentParser(description='Run an attack on a model.')
parser.add_argument('--attack', type=str, required=True)
parser.add_argument('--eps', type=str, default='16/255')
parser.add_argument('--eps', type=float, default=16)
parser.add_argument('--model-name', type=str, default='resnet50')
parser.add_argument('--victim-model-names', type=str, nargs='+', default=None)
parser.add_argument('--dataset-root', type=str, default='datasets/nips2017')
Expand All @@ -126,7 +126,7 @@ def run_attack(

run_attack(
attack=args.attack,
attack_cfg={'eps': float(args.eps)},
attack_cfg={'eps': args.eps / 255},
model_name=args.model_name,
victim_model_names=args.victim_model_names,
dataset_root=args.dataset_root,
Expand Down
2 changes: 1 addition & 1 deletion torchattack/generative/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
# Initialize the generator and its weights
self.generator = self._init_generator()

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Perform the generative attack via generator inference on a batch of images.
Args:
Expand Down

0 comments on commit 5351149

Please sign in to comment.