Skip to content

Commit

Permalink
Initial support for generative attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 26, 2024
1 parent 6e3838d commit 066ac8f
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ transform, normalize = model.transform, model.normalize
# Additionally, to explicitly specify where to load the pretrained model from (timm or torchvision),
# prepend the model name with 'timm/' or 'tv/' respectively, or use the `from_timm` argument, e.g.
vit_b16 = AttackModel.from_pretrained(model_name='timm/vit_base_patch16_224', device=device)
inception_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device)
inv_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device)
pit_b = AttackModel.from_pretrained(model_name='pit_b_224', device=device, from_timm=True)
```

Expand Down
18 changes: 13 additions & 5 deletions torchattack/_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ class Attack(ABC):

def __init__(
self,
model: nn.Module | AttackModel,
model: nn.Module | AttackModel | None,
normalize: Callable[[torch.Tensor], torch.Tensor] | None,
device: torch.device | None,
) -> None:
super().__init__()

# If model is an AttackModel, use the model attribute
self.model = model.model if isinstance(model, AttackModel) else model
self.model = (
# If model is an AttackModel, use the model attribute
model.model
if isinstance(model, AttackModel)
# If model is a nn.Module, use the model itself
else model
if model is not None
# Otherwise, use an empty nn.Sequential acting as a dummy model
else nn.Sequential()
)

# Set device to given or defaults to cuda if available
is_cuda = torch.cuda.is_available()
Expand All @@ -29,7 +37,7 @@ def __init__(
self.normalize = normalize if normalize else lambda x: x

@abstractmethod
def forward(self, *args: Any, **kwds: Any):
def forward(self, *args: Any, **kwds: Any) -> Any:
pass

def __call__(self, *args: Any, **kwds: Any) -> Any:
Expand All @@ -41,7 +49,7 @@ def __repr__(self) -> str:
def repr_map(k, v):
if isinstance(v, float):
return f'{k}={v:.3f}'
if k in ['model', 'normalize', 'feature_layer', 'hooks']:
if k in ['model', 'normalize', 'feature_layer', 'hooks', 'generator']:
return f'{k}={v.__class__.__name__}'
if isinstance(v, torch.Tensor):
return f'{k}={v.shape}'
Expand Down
65 changes: 65 additions & 0 deletions torchattack/cda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch

from torchattack.generative._inference import GenerativeAttack
from torchattack.generative._weights import Weights, WeightsEnum
from torchattack.generative.resnet_generator import ResNetGenerator


class CDAWeights(WeightsEnum):
RESNET152_IMAGENET1K = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_res152_imagenet_0_rl.pth',
)
INCEPTION_V3_IMAGENET1K = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_incv3_imagenet_0_rl.pth',
)
VGG16_IMAGENET1K = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_vgg16_imagenet_0_rl.pth',
)
VGG19_IMAGENET1K = Weights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_vgg19_imagenet_0_rl.pth',
)
DEFAULT = RESNET152_IMAGENET1K


class CDA(GenerativeAttack):
"""Cross-domain Attack (CDA).
From the paper 'Cross-Domain Transferability of Adversarial Perturbations',
https://arxiv.org/abs/1905.11736
Args:
device: Device to use for tensors. Defaults to cuda if available.
eps: The maximum perturbation. Defaults to 10/255.
weights: Pretrained weights for the generator. Defaults to CDAWeights.DEFAULT.
clip_min: Minimum value for clipping. Defaults to 0.0.
clip_max: Maximum value for clipping. Defaults to 1.0.
"""

def __init__(
self,
device: torch.device | None = None,
eps: float = 10 / 255,
weights: CDAWeights | str | None = CDAWeights.DEFAULT,
checkpoint_path: str | None = None,
clip_min: float = 0.0,
clip_max: float = 1.0,
) -> None:
super().__init__(device, eps, weights, checkpoint_path, clip_min, clip_max)

def _init_generator(self) -> ResNetGenerator:
generator = ResNetGenerator()
# Prioritize checkpoint path over provided weights enum
if self.checkpoint_path is not None:
generator.load_state_dict(torch.load(self.checkpoint_path))
else:
# Verify and load weights from enum if checkpoint path is not provided
self.weights: CDAWeights = CDAWeights.verify(self.weights)
if self.weights is not None:
generator.load_state_dict(self.weights.get_state_dict(check_hash=True))
return generator.eval().to(self.device)


if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attack = CDA(device, eps=8 / 255, weights='VGG19_IMAGENET1K')
print(attack)
2 changes: 1 addition & 1 deletion torchattack/create_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def attack_warn(message: str) -> None:

def create_attack(
attack_name: str,
model: nn.Module | AttackModel,
model: nn.Module | AttackModel | None = None,
normalize: Callable[[torch.Tensor], torch.Tensor] | None = None,
device: torch.device | None = None,
eps: float | None = None,
Expand Down
Empty file.
49 changes: 49 additions & 0 deletions torchattack/generative/_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from abc import abstractmethod
from typing import Any

import torch

from torchattack._attack import Attack
from torchattack.generative._weights import WeightsEnum


class GenerativeAttack(Attack):
def __init__(
self,
device: torch.device | None = None,
eps: float = 10 / 255,
weights: WeightsEnum | str | None = None,
checkpoint_path: str | None = None,
clip_min: float = 0.0,
clip_max: float = 1.0,
) -> None:
# Generative attacks do not require specifying model and normalize.
super().__init__(model=None, normalize=None, device=device)

self.eps = eps
self.weights = weights
self.checkpoint_path = checkpoint_path
self.clip_min = clip_min
self.clip_max = clip_max

# Initialize the generator and its weights
self.generator = self._init_generator()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform the generative attack via generator inference on a batch of images.
Args:
x: A batch of images. Shape: (N, C, H, W).
Returns:
The perturbed images if successful. Shape: (N, C, H, W).
"""

x_unrestricted = self.generator(x)
delta = torch.clamp(x_unrestricted - x, -self.eps, self.eps)
x_adv = torch.clamp(x + delta, self.clip_min, self.clip_max)
return x_adv

@abstractmethod
def _init_generator(self, *args: Any, **kwds: Any) -> Any:
pass
34 changes: 34 additions & 0 deletions torchattack/generative/_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Mapping

from torch.hub import load_state_dict_from_url


@dataclass
class Weights:
url: str


class WeightsEnum(Enum):
@classmethod
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls[obj.replace(cls.__name__ + '.', '')]
elif not isinstance(obj, cls):
raise TypeError(
f'Invalid Weight class provided; expected {cls.__name__} '
f'but received {obj.__class__.__name__}.'
)
return obj

def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, *args, **kwargs)

def __repr__(self) -> str:
return f'{self.__class__.__name__}.{self._name_}'

@property
def url(self):
return self.value.url
144 changes: 144 additions & 0 deletions torchattack/generative/resnet_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import torch.nn as nn

# To control feature map in generator
ngf = 64


class ResNetGenerator(nn.Module):
def __init__(self, inception=False):
"""Generator network (ResNet).
Args:
inception: if True crop layer will be added to go from 3x300x300 to
3x299x299. Defaults to False.
"""

super(ResNetGenerator, self).__init__()

# Input_size = 3, n, n
self.inception = inception
self.block1 = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
)

# Input size = 3, n, n
self.block2 = nn.Sequential(
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
)

# Input size = 3, n/2, n/2
self.block3 = nn.Sequential(
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
)

# Input size = 3, n/4, n/4
# Residual Blocks: 6
self.resblock1 = ResidualBlock(ngf * 4)
self.resblock2 = ResidualBlock(ngf * 4)
self.resblock3 = ResidualBlock(ngf * 4)
self.resblock4 = ResidualBlock(ngf * 4)
self.resblock5 = ResidualBlock(ngf * 4)
self.resblock6 = ResidualBlock(ngf * 4)

# Input size = 3, n/4, n/4
self.upsampl1 = nn.Sequential(
nn.ConvTranspose2d(
ngf * 4,
ngf * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=False,
),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
)

# Input size = 3, n/2, n/2
self.upsampl2 = nn.Sequential(
nn.ConvTranspose2d(
ngf * 2,
ngf,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=False,
),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
)

# Input size = 3, n, n
self.blockf = nn.Sequential(
nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
)

self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)

def forward(self, input):
x = self.block1(input)
x = self.block2(x)
x = self.block3(x)
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = self.resblock5(x)
x = self.resblock6(x)
x = self.upsampl1(x)
x = self.upsampl2(x)
x = self.blockf(x)
if self.inception:
x = self.crop(x)
return (torch.tanh(x) + 1) / 2 # Output range [0 1]


class ResidualBlock(nn.Module):
def __init__(self, num_filters):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(num_filters),
nn.ReLU(True),
nn.Dropout(0.5),
nn.ReflectionPad2d(1),
nn.Conv2d(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(num_filters),
)

def forward(self, x):
residual = self.block(x)
return x + residual


if __name__ == '__main__':
net_g = ResNetGenerator()
test_sample = torch.rand(1, 3, 32, 32)
print('Generator output size:', net_g(test_sample).size())
params = sum(p.numel() for p in net_g.parameters() if p.requires_grad)
print('Generator params:', params)

0 comments on commit 066ac8f

Please sign in to comment.