diff --git a/.gitignore b/.gitignore index 3c2158a..685f77c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,8 @@ # Dataset files datasets/ -# PDM files -.pdm-python -pdm.lock +# Lockfiles +uv.lock # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/pyproject.toml b/pyproject.toml index 674ac41..72a8b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "scipy>=1.10.1", ] dynamic = ["version"] -optional-dependencies = { dev = ["mypy", "rich"] } +optional-dependencies = { dev = ["mypy", "rich", "timm"] } [build-system] requires = ["setuptools"] diff --git a/src/torchattack/__init__.py b/src/torchattack/__init__.py index 922e5ad..140c7eb 100644 --- a/src/torchattack/__init__.py +++ b/src/torchattack/__init__.py @@ -15,6 +15,7 @@ from torchattack.ssp import SSP from torchattack.tgr import TGR from torchattack.tifgsm import TIFGSM +from torchattack.vdc import VDC from torchattack.vmifgsm import VMIFGSM from torchattack.vnifgsm import VNIFGSM @@ -38,6 +39,7 @@ 'SSP', 'TGR', 'TIFGSM', + 'VDC', 'VMIFGSM', 'VNIFGSM', ] diff --git a/src/torchattack/base.py b/src/torchattack/base.py index 56fc562..7fdbebe 100644 --- a/src/torchattack/base.py +++ b/src/torchattack/base.py @@ -30,7 +30,7 @@ def __repr__(self) -> str: def repr_map(k, v): if isinstance(v, float): return f'{k}={v:.3f}' - if k in ['model', 'normalize', 'feature_layer']: + if k in ['model', 'normalize', 'feature_layer', 'hooks']: return f'{k}={v.__class__.__name__}' if isinstance(v, torch.Tensor): return f'{k}={v.shape}' diff --git a/src/torchattack/eval.py b/src/torchattack/eval.py index 025ba94..06ac657 100644 --- a/src/torchattack/eval.py +++ b/src/torchattack/eval.py @@ -159,7 +159,10 @@ def from_pretrained( return cls(model_name, device, model, transform, normalize) except ValueError: - print('Model not found in torchvision.models, falling back to timm.') + print( + f'Warning: Model `{model_name}` not found in torchvision.models, ' + 'falling back to loading weights from timm.' + ) return cls.from_pretrained(model_name, device, from_timm=True) def forward(self, x): diff --git a/src/torchattack/pna_patchout.py b/src/torchattack/pna_patchout.py index 443ad6f..f592939 100644 --- a/src/torchattack/pna_patchout.py +++ b/src/torchattack/pna_patchout.py @@ -95,7 +95,9 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() + # Register hooks if self.pna_skip: + self.hooks: list[torch.utils.hooks.RemovableHandle] = [] self._register_vit_model_hook() # Set default image size and number of patches for PatchOut @@ -143,6 +145,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: delta.grad.detach_() delta.grad.zero_() + for hook in self.hooks: + hook.remove() + return x + delta def _register_vit_model_hook(self): @@ -162,7 +167,8 @@ def attn_drop_mask_grad( # Register backward hook for layers specified in _supported_vit_cfg for layer in self._supported_vit_cfg[self.model_name]: module = rgetattr(self.model, layer) - module.register_backward_hook(drop_hook_func) + hook = module.register_backward_hook(drop_hook_func) + self.hooks.append(hook) def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor: delta_mask = torch.zeros_like(delta) diff --git a/src/torchattack/tgr.py b/src/torchattack/tgr.py index c928764..70a0031 100644 --- a/src/torchattack/tgr.py +++ b/src/torchattack/tgr.py @@ -61,6 +61,8 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() + # Register hooks + self.hooks: list[torch.utils.hooks.RemovableHandle] = [] self._register_tgr_model_hooks() def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -110,6 +112,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: delta.grad.detach_() delta.grad.zero_() + for hook in self.hooks: + hook.remove() + return x + delta def _register_tgr_model_hooks(self): @@ -348,10 +353,11 @@ def mlp_tgr( assert self.model_name in _supported_vit_cfg - for hook, layers in _supported_vit_cfg[self.model_name]: + for hook_func, layers in _supported_vit_cfg[self.model_name]: for layer in layers: module = rgetattr(self.model, layer) - module.register_backward_hook(hook) + hook = module.register_backward_hook(hook_func) + self.hooks.append(hook) if __name__ == '__main__': diff --git a/src/torchattack/vdc.py b/src/torchattack/vdc.py new file mode 100644 index 0000000..63eadc2 --- /dev/null +++ b/src/torchattack/vdc.py @@ -0,0 +1,709 @@ +import importlib.util +from functools import partial +from typing import Callable + +import numpy as np +import torch +import torch.nn as nn + +from torchattack._rgetattr import rgetattr +from torchattack.base import Attack + + +class VDC(Attack): + """VDC (Virtual Dense Connection) attack for ViTs. + + From the paper: 'Improving the Adversarial Transferability of Vision Transformers + with Virtual Dense Connection' + https://ojs.aaai.org/index.php/AAAI/article/view/28541 + + Args: + model: The model to attack. + model_name: The name of the model. Supported models are: + * 'vit_base_patch16_224' + * 'deit_base_distilled_patch16_224' + * 'pit_b_224' + * 'visformer_small' + normalize: A transform to normalize images. + device: Device to use for tensors. Defaults to cuda if available. + eps: The maximum perturbation. Defaults to 8/255. + steps: Number of steps. Defaults to 10. + alpha: Step size, `eps / steps` if None. Defaults to None. + decay: Momentum decay factor. Defaults to 1.0. + sample_num_batches: Number of batches to sample. Defaults to 130. + lambd: Lambda value for VDC. Defaults to 0.1. + clip_min: Minimum value for clipping. Defaults to 0.0. + clip_max: Maximum value for clipping. Defaults to 1.0. + targeted: Targeted attack if True. Defaults to False. + """ + + def __init__( + self, + model: nn.Module, + model_name: str, + normalize: Callable[[torch.Tensor], torch.Tensor] | None, + device: torch.device | None = None, + eps: float = 8 / 255, + steps: int = 10, + alpha: float | None = None, + decay: float = 1.0, + sample_num_batches: int = 130, + lambd: float = 0.1, + clip_min: float = 0.0, + clip_max: float = 1.0, + targeted: bool = False, + ): + # Check if timm is installed + importlib.util.find_spec('timm') + + super().__init__(normalize, device) + + self.model = model + self.model_name = model_name + self.eps = eps + self.steps = steps + self.alpha = alpha + self.decay = decay + + self.sample_num_batches = sample_num_batches + self.lambd = lambd + + # Default (3, 224, 224) image with ViT-B/16 16x16 patches + self.max_num_batches = int((224 / 16) ** 2) + self.crop_length = 16 + + self.clip_min = clip_min + self.clip_max = clip_max + self.targeted = targeted + self.lossfn = nn.CrossEntropyLoss() + + # Global hooks and attack stage state for VDC + self.stage: list[np.ndarray] = [] + self.hooks: list[torch.utils.hooks.RemovableHandle] = [] + + assert self.sample_num_batches <= self.max_num_batches + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Perform VDC on a batch of images. + + Args: + x: A batch of images. Shape: (N, C, H, W). + y: A batch of labels. Shape: (N). + + Returns: + The perturbed images if successful. Shape: (N, C, H, W). + """ + + g = torch.zeros_like(x) + delta = torch.zeros_like(x, requires_grad=True) + + # If alpha is not given, set to eps / steps + if self.alpha is None: + self.alpha = self.eps / self.steps + + class GradientRecorder: + """Gradient recorder for attention and MLP blocks.""" + + def __init__(self): + self.grad_records = [] + self.grad_additions = [] + + # Perform VDC + for _ in range(self.steps): + # Initialize gradient recorders + self.attn_recorder = GradientRecorder() + self.mlp_recorder = GradientRecorder() + + # Stage 1: Record gradients + self.current_mlp_block = 0 + self.current_attn_block = 0 + self._register_model_hooks(add_grad_mode=False) + + # Compute loss + outs = self.model(self.normalize(x + delta)) + loss = self.lossfn(outs, y) + + if self.targeted: + loss = -loss + + # Compute gradient + loss.backward() + + if delta.grad is None: + continue + + # Zero out gradient + delta.grad.detach_() + delta.grad.zero_() + + for hook in self.hooks: + hook.remove() + + # Stage 2: Update gradients by adding recorded gradients + self.current_mlp_block = 0 + self.current_attn_block = 0 + self._register_model_hooks(add_grad_mode=True) + + # Compute loss 2nd time + outs = self.model(self.normalize(x + delta)) + loss = self.lossfn(outs, y) + + if self.targeted: + loss = -loss + + # Compute gradient 2nd time + loss.backward() + + if delta.grad is None: + continue + + # Apply momentum term + g = self.decay * g + delta.grad / torch.mean( + torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True + ) + + # Update delta + delta.data = delta.data + self.alpha * g.sign() + delta.data = torch.clamp(delta.data, -self.eps, self.eps) + delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x + + # Zero out gradient + delta.grad.detach_() + delta.grad.zero_() + + for hook in self.hooks: + hook.remove() + + return x + delta + + def _register_model_hooks(self, add_grad_mode: bool = False) -> None: + """Register hooks to either record or add gradients during the backward pass. + + Args: + grad_add_hook: If False, register hooks to record gradients. If True, + register hooks to modify the gradients by adding pre-recorded gradients + during the backward pass. + """ + + def mlp_record_vit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + # ablation + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.current_mlp_block)) + ) + # grad_record = grad_in[0].data.cpu().numpy() + if self.current_mlp_block == 0: + grad_add = np.zeros_like(grad_record) + # ablation + grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.1 * (0.5) + # grad_add[:,0,:] = self.norm[:,0,:] + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) + else: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) + self.current_mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def mlp_add_vit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + # grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + # mask_0 = torch.zeros_like(grad_in[0]) + out_grad = mask * grad_in[0][:] + # out_grad = torch.where( + # grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:] + # ) + out_grad += torch.tensor( + self.mlp_recorder.grad_additions[self.current_mlp_block], + device=grad_in[0].device, + ) + self.current_mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def attn_record_vit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.current_attn_block)) + ) + # grad_record = grad_in[0].data.cpu().numpy() + if self.current_attn_block == 0: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) + else: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) + + self.current_attn_block += 1 + return (out_grad,) + + def attn_add_vit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + # grad_record = grad_in[0].data.cpu().numpy() + # mask_0 = torch.zeros_like(grad_in[0]) + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + # out_grad = torch.where( + # grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:] + # ) + out_grad += torch.tensor( + self.attn_recorder.grad_additions[self.current_attn_block], + device=grad_in[0].device, + ) + self.current_attn_block += 1 + return (out_grad,) + + def norm_record_vit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + grad_record = grad_in[0].data.cpu().numpy() + # mask = torch.ones_like(grad_in[0]) * gamma + self.norm_list = grad_record + return grad_in + + # pit + def pool_record_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + grad_add = grad_in[0].data + b, c, h, w = grad_add.shape + grad_add = grad_add.reshape((b, c, h * w)).transpose(1, 2) + self.stage.append(grad_add.cpu().numpy()) + return grad_in + + def mlp_record_pit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.current_mlp_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_mlp_block)) + ) + if self.current_mlp_block == 0: + grad_add = np.zeros_like(grad_record) + grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.03 * (0.5) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) + else: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) + elif self.current_mlp_block < 10: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_mlp_block)) + ) + if self.current_mlp_block == 4: + grad_add = np.zeros_like(grad_record) + grad_add[:, 1:, :] = self.stage[0] * 0.03 * (0.5) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) + else: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + # total_mlp = self.mlp_rec.record[-1] + grad_record + total_mlp = self.mlp_recorder.grad_records[-1] + self.mlp_recorder.grad_records.append(total_mlp) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_mlp_block)) + ) + if self.current_mlp_block == 10: + grad_add = np.zeros_like(grad_record) + grad_add[:, 1:, :] = self.stage[1] * 0.03 * (0.5) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) + else: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + # total_mlp = self.mlp_rec.record[-1] + grad_record + total_mlp = self.mlp_recorder.grad_records[-1] + self.mlp_recorder.grad_records.append(total_mlp) + self.current_mlp_block += 1 + + return (out_grad, grad_in[1], grad_in[2]) + + def mlp_add_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + # grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor( + self.mlp_recorder.grad_additions[self.current_mlp_block], + device=grad_in[0].device, + ) + self.current_mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def attn_record_pit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.current_attn_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_attn_block)) + ) + if self.current_attn_block == 0: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) + else: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) + elif self.current_attn_block < 10: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_attn_block)) + ) + if self.current_attn_block == 4: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) + else: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + # total_attn = self.attn_rec.record[-1] + grad_record + total_attn = self.attn_recorder.grad_records[-1] + self.attn_recorder.grad_records.append(total_attn) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_attn_block)) + ) + if self.current_attn_block == 10: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) + else: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + # total_attn = self.attn_rec.record[-1] + grad_record + total_attn = self.attn_recorder.grad_records[-1] + self.attn_recorder.grad_records.append(total_attn) + self.current_attn_block += 1 + return (out_grad,) + + def attn_add_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + # grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor( + self.attn_recorder.grad_additions[self.current_attn_block], + device=grad_in[0].device, + ) + self.current_attn_block += 1 + return (out_grad,) + + def norm_record_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + grad_record = grad_in[0].data.cpu().numpy() + # mask = torch.ones_like(grad_in[0]) * gamma + self.norm_list = grad_record + return grad_in + + # visformer + def pool_record_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + grad_add = grad_in[0].data + # B,C,H,W = grad_add.shape + # grad_add = grad_add.reshape((B,C,H*W)).transpose(1,2) + self.stage.append(grad_add.cpu().numpy()) + return grad_in + + def mlp_record_vis_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.current_mlp_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_mlp_block)) + ) + if self.current_mlp_block == 0: + grad_add = np.zeros_like(grad_record) + grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.1 * (0.5) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) + else: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_mlp_block)) + ) + if self.current_mlp_block == 4: + grad_add = np.zeros_like(grad_record) + # grad_add[:,1:,:] = self.stage[0]* 0.1*(0.5) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) + else: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) + + self.current_mlp_block += 1 + + return (out_grad, grad_in[1], grad_in[2]) + + def mlp_add_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + # grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor( + self.mlp_recorder.grad_additions[self.current_mlp_block], + device=grad_in[0].device, + ) + self.current_mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def norm_record_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + grad_record = grad_in[0].data.cpu().numpy() + # mask = torch.ones_like(grad_in[0]) * gamma + self.norm_list = grad_record + return grad_in + + def attn_record_vis_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.current_attn_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_attn_block)) + ) + if self.current_attn_block == 0: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) + else: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_attn_block)) + ) + if self.current_attn_block == 4: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) + else: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) + + self.current_attn_block += 1 + return (out_grad,) + + def attn_add_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: + # grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor( + self.attn_recorder.grad_additions[self.current_attn_block], + device=grad_in[0].device, + ) + self.current_attn_block += 1 + return (out_grad,) + + # vit + mlp_record_func_vit = partial(mlp_record_vit_stage, gamma=1.0) + norm_record_func_vit = partial(norm_record_vit, gamma=1.0) + mlp_add_func_vit = partial(mlp_add_vit, gamma=0.5) + attn_record_func_vit = partial(attn_record_vit_stage, gamma=1.0) + attn_add_func_vit = partial(attn_add_vit, gamma=0.25) + + # pit + attn_record_func_pit = partial(attn_record_pit_stage, gamma=1.0) + mlp_record_func_pit = partial(mlp_record_pit_stage, gamma=1.0) + norm_record_func_pit = partial(norm_record_pit, gamma=1.0) + pool_record_func_pit = partial(pool_record_pit, gamma=1.0) + attn_add_func_pit = partial(attn_add_pit, gamma=0.25) + mlp_add_func_pit = partial(mlp_add_pit, gamma=0.5) + # mlp_add_func_pit = partial(mlp_add_pit, gamma=0.75) + + # visformer + attn_record_func_vis = partial(attn_record_vis_stage, gamma=1.0) + mlp_record_func_vis = partial(mlp_record_vis_stage, gamma=1.0) + norm_record_func_vis = partial(norm_record_vis, gamma=1.0) + pool_record_func_vis = partial(pool_record_vis, gamma=1.0) + attn_add_func_vis = partial(attn_add_vis, gamma=0.25) + mlp_add_func_vis = partial(mlp_add_vis, gamma=0.5) + + # fmt: off + # Register hooks for supported models. + # * Gradient RECORD mode hooks: + record_grad_cfg = { + 'vit_base_patch16_224': [ + (norm_record_func_vit, ['norm']), + (mlp_record_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_record_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'deit_base_distilled_patch16_224': [ + (norm_record_func_vit, ['norm']), + (mlp_record_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_record_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'pit_b_224': [ + (norm_record_func_pit, ['norm']), + (attn_record_func_pit, [f'transformers.{tid}.blocks.{i}.attn.attn_drop' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + (mlp_record_func_pit, [f'transformers.{tid}.blocks.{i}.norm2' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + (pool_record_func_pit, ['transformers.1.pool', 'transformers.2.pool']), + ], + 'visformer_small': [ + (norm_record_func_vis, ['norm']), + (attn_record_func_vis, [f'stage2.{i}.attn.attn_drop' for i in range(4)] + [f'stage3.{i}.attn.attn_drop' for i in range(4)]), + (mlp_record_func_vis, [f'stage2.{i}.norm2' for i in range(4)] + [f'stage3.{i}.norm2' for i in range(4)]), + (pool_record_func_vis, ['patch_embed2', 'patch_embed3']), + ], + } + # * Gradient ADD mode hooks: + add_grad_cfg = { + 'vit_base_patch16_224': [ + (mlp_add_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_add_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'deit_base_distilled_patch16_224': [ + (mlp_add_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_add_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'pit_b_224': [ + (attn_add_func_pit, [f'transformers.{tid}.blocks.{i}.attn.attn_drop' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + (mlp_add_func_pit, [f'transformers.{tid}.blocks.{i}.norm2' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + ], + 'visformer_small': [ + (attn_add_func_vis, [f'stage2.{i}.attn.attn_drop' for i in range(4)] + [f'stage3.{i}.attn.attn_drop' for i in range(4)]), + (mlp_add_func_vis, [f'stage2.{i}.norm2' for i in range(4)] + [f'stage3.{i}.norm2' for i in range(4)]), + + ], + } + # fmt: on + + activated_vit_cfg = add_grad_cfg if add_grad_mode else record_grad_cfg + assert self.model_name in activated_vit_cfg + + for hook_func, layers in activated_vit_cfg[self.model_name]: + for layer in layers: + module = rgetattr(self.model, layer) + hook = module.register_backward_hook(hook_func) + self.hooks.append(hook) + + +if __name__ == '__main__': + from torchattack.eval import run_attack + + run_attack( + VDC, + attack_cfg={'model_name': 'pit_b_224'}, + model_name='pit_b_224', + victim_model_names=['cait_s24_224', 'visformer_small'], + from_timm=True, + )