From 80ca4716a098c7afe6fe4f40bbc84b4c5a61c5cb Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Thu, 14 Nov 2024 13:01:05 +0800 Subject: [PATCH 1/6] VDC with support for vanilla ViT-B/16 --- src/torchattack/eval.py | 5 +- src/torchattack/vdc.py | 301 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 src/torchattack/vdc.py diff --git a/src/torchattack/eval.py b/src/torchattack/eval.py index 025ba94..06ac657 100644 --- a/src/torchattack/eval.py +++ b/src/torchattack/eval.py @@ -159,7 +159,10 @@ def from_pretrained( return cls(model_name, device, model, transform, normalize) except ValueError: - print('Model not found in torchvision.models, falling back to timm.') + print( + f'Warning: Model `{model_name}` not found in torchvision.models, ' + 'falling back to loading weights from timm.' + ) return cls.from_pretrained(model_name, device, from_timm=True) def forward(self, x): diff --git a/src/torchattack/vdc.py b/src/torchattack/vdc.py new file mode 100644 index 0000000..ae0ce03 --- /dev/null +++ b/src/torchattack/vdc.py @@ -0,0 +1,301 @@ +import importlib.util +from functools import partial +from typing import Callable + +import numpy as np +import torch +import torch.nn as nn + +# from torchattack._rgetattr import rgetattr +from torchattack.base import Attack + + +class VDC(Attack): + """VDC (Virtual Dense Connection) attack for ViTs. + + From the paper: 'Improving the Adversarial Transferability of Vision Transformers + with Virtual Dense Connection' + https://ojs.aaai.org/index.php/AAAI/article/view/28541 + + Args: + model: The model to attack. + model_name: The name of the model. + normalize: A transform to normalize images. + device: Device to use for tensors. Defaults to cuda if available. + eps: The maximum perturbation. Defaults to 8/255. + steps: Number of steps. Defaults to 10. + alpha: Step size, `eps / steps` if None. Defaults to None. + decay: Momentum decay factor. Defaults to 1.0. + clip_min: Minimum value for clipping. Defaults to 0.0. + clip_max: Maximum value for clipping. Defaults to 1.0. + targeted: Targeted attack if True. Defaults to False. + """ + + def __init__( + self, + model: nn.Module, + model_name: str, + normalize: Callable[[torch.Tensor], torch.Tensor] | None, + device: torch.device | None = None, + eps: float = 8 / 255, + steps: int = 10, + alpha: float | None = None, + decay: float = 1.0, + sample_num_batches: int = 130, + lambd: float = 0.1, + clip_min: float = 0.0, + clip_max: float = 1.0, + targeted: bool = False, + ): + # Check if timm is installed + importlib.util.find_spec('timm') + + super().__init__(normalize, device) + + self.model = model + self.model_name = model_name + self.eps = eps + self.steps = steps + self.alpha = alpha + self.decay = decay + + self.sample_num_batches = sample_num_batches + self.lambd = lambd + + # Default (3, 224, 224) image with ViT-B/16 16x16 patches + self.max_num_batches = int((224 / 16) ** 2) + self.crop_length = 16 + + self.clip_min = clip_min + self.clip_max = clip_max + self.targeted = targeted + self.lossfn = nn.CrossEntropyLoss() + + self.record_grad = [] + self.record_grad_mlp = [] + ############### + self.attn_record = [] + self.mlp_record = [] + self.attn_add = [] + self.mlp_add = [] + self.norm_list = [] + self.stage = [] + self.attn_block = 0 + self.mlp_block = 0 + self.hooks = [] + self.skip_record = [] + self.skip_add = [] + self.skip_block = 0 + + assert self.sample_num_batches <= self.max_num_batches + + # self._register_model_hooks() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Perform VDC on a batch of images. + + Args: + x: A batch of images. Shape: (N, C, H, W). + y: A batch of labels. Shape: (N). + + Returns: + The perturbed images if successful. Shape: (N, C, H, W). + """ + + g = torch.zeros_like(x) + delta = torch.zeros_like(x, requires_grad=True) + + # If alpha is not given, set to eps / steps + if self.alpha is None: + self.alpha = self.eps / self.steps + + # Perform VDC + for _ in range(self.steps): + self.attn_record = [] + self.attn_add = [] + + self.mlp_record = [] + self.mlp_add = [] + + self.skip_record = [] + self.skip_add = [] + + self.mlp_block = 0 + self.attn_block = 0 + self.skip_block = 0 + self._register_model_hooks(add=False) + + # Compute loss + outs = self.model(self.normalize(x + delta)) + loss = self.lossfn(outs, y) + + if self.targeted: + loss = -loss + + # Compute gradient + loss.backward() + + if delta.grad is None: + continue + + # Zero out gradient + delta.grad.detach_() + delta.grad.zero_() + + for hook in self.hooks: + hook.remove() + + self.mlp_block = 0 + self.attn_block = 0 + self.skip_block = 0 + self._register_model_hooks(add=True) + + # Compute loss 2nd time + outs = self.model(self.normalize(x + delta)) + loss = self.lossfn(outs, y) + + if self.targeted: + loss = -loss + + # Compute gradient 2nd time + loss.backward() + + if delta.grad is None: + continue + + # Apply momentum term + g = self.decay * g + delta.grad / torch.mean( + torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True + ) + + # Update delta + delta.data = delta.data + self.alpha * g.sign() + delta.data = torch.clamp(delta.data, -self.eps, self.eps) + delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x + + # Zero out gradient + delta.grad.detach_() + delta.grad.zero_() + + for hook in self.hooks: + hook.remove() + + return x + delta + + def _register_model_hooks(self, add: bool = False): + def mlp_record_vit_stage(module, grad_in, grad_out, gamma): + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + # ablation + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.mlp_block)) + ) + # grad_record = grad_in[0].data.cpu().numpy() + if self.mlp_block == 0: + grad_add = np.zeros_like(grad_record) + # ablation + grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.1 * (0.5) + # grad_add[:,0,:] = self.norm[:,0,:] + self.mlp_add.append(grad_add) + self.mlp_record.append(grad_record + grad_add) + else: + self.mlp_add.append(self.mlp_record[-1]) + total_mlp = self.mlp_record[-1] + grad_record + self.mlp_record.append(total_mlp) + self.mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def mlp_add_vit(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + # mask_0 = torch.zeros_like(grad_in[0]) + out_grad = mask * grad_in[0][:] + # out_grad = torch.where(grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:]) + out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() + self.mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def attn_record_vit_stage(module, grad_in, grad_out, gamma): + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.attn_block)) + ) + # grad_record = grad_in[0].data.cpu().numpy() + if self.attn_block == 0: + self.attn_add.append(np.zeros_like(grad_record)) + self.attn_record.append(grad_record) + else: + self.attn_add.append(self.attn_record[-1]) + total_attn = self.attn_record[-1] + grad_record + self.attn_record.append(total_attn) + + self.attn_block += 1 + return (out_grad,) + + def attn_add_vit(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + # mask_0 = torch.zeros_like(grad_in[0]) + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + # out_grad = torch.where(grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:]) + out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() + self.attn_block += 1 + return (out_grad,) + + def norm_record_vit(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + self.norm_list = grad_record + return grad_in + + # vit + mlp_record_func_vit = partial(mlp_record_vit_stage, gamma=1.0) + norm_record_func_vit = partial(norm_record_vit, gamma=1.0) + mlp_add_func_vit = partial(mlp_add_vit, gamma=0.5) + attn_record_func_vit = partial(attn_record_vit_stage, gamma=1.0) + attn_add_func_vit = partial(attn_add_vit, gamma=0.25) + + if not add: + if self.model_name in [ + 'vit_base_patch16_224', + 'deit_base_distilled_patch16_224', + ]: + hook = self.model.norm.register_backward_hook(norm_record_func_vit) + self.hooks.append(hook) + for i in range(12): + hook = self.model.blocks[i].norm2.register_backward_hook( + mlp_record_func_vit + ) + self.hooks.append(hook) + hook = self.model.blocks[i].attn.attn_drop.register_backward_hook( + attn_record_func_vit + ) + self.hooks.append(hook) + else: + if self.model_name in [ + 'vit_base_patch16_224', + 'deit_base_distilled_patch16_224', + ]: + for i in range(12): + hook = self.model.blocks[i].norm2.register_backward_hook( + mlp_add_func_vit + ) + self.hooks.append(hook) + hook = self.model.blocks[i].attn.attn_drop.register_backward_hook( + attn_add_func_vit + ) + self.hooks.append(hook) + + +if __name__ == '__main__': + from torchattack.eval import run_attack + + run_attack( + VDC, + attack_cfg={'model_name': 'vit_base_patch16_224'}, + model_name='vit_base_patch16_224', + victim_model_names=['cait_s24_224', 'visformer_small'], + batch_size=4, + from_timm=False, + ) From 2e8375b553221ecd6a342fa617c8c86d0853c8e9 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 19 Nov 2024 12:49:11 +0800 Subject: [PATCH 2/6] Remove unused global variables --- src/torchattack/vdc.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/src/torchattack/vdc.py b/src/torchattack/vdc.py index ae0ce03..3f9da62 100644 --- a/src/torchattack/vdc.py +++ b/src/torchattack/vdc.py @@ -71,26 +71,11 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() - self.record_grad = [] - self.record_grad_mlp = [] - ############### - self.attn_record = [] - self.mlp_record = [] - self.attn_add = [] - self.mlp_add = [] - self.norm_list = [] - self.stage = [] - self.attn_block = 0 - self.mlp_block = 0 + # Global hooks for VDC self.hooks = [] - self.skip_record = [] - self.skip_add = [] - self.skip_block = 0 assert self.sample_num_batches <= self.max_num_batches - # self._register_model_hooks() - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform VDC on a batch of images. @@ -296,6 +281,7 @@ def norm_record_vit(module, grad_in, grad_out, gamma): attack_cfg={'model_name': 'vit_base_patch16_224'}, model_name='vit_base_patch16_224', victim_model_names=['cait_s24_224', 'visformer_small'], + max_samples=12, batch_size=4, - from_timm=False, + from_timm=True, ) From ed6b10eb1742601c6e2fec50c98ed779aff4c77b Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 19 Nov 2024 13:22:27 +0800 Subject: [PATCH 3/6] Add `timm` to dev deps --- .gitignore | 5 ++--- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 3c2158a..685f77c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,8 @@ # Dataset files datasets/ -# PDM files -.pdm-python -pdm.lock +# Lockfiles +uv.lock # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/pyproject.toml b/pyproject.toml index 674ac41..72a8b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "scipy>=1.10.1", ] dynamic = ["version"] -optional-dependencies = { dev = ["mypy", "rich"] } +optional-dependencies = { dev = ["mypy", "rich", "timm"] } [build-system] requires = ["setuptools"] From 0cd927dae19a341cda7c85516623f1f1b3d94e0f Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 19 Nov 2024 17:07:18 +0800 Subject: [PATCH 4/6] Add partial support for PiT and Visformer --- src/torchattack/vdc.py | 367 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 361 insertions(+), 6 deletions(-) diff --git a/src/torchattack/vdc.py b/src/torchattack/vdc.py index 3f9da62..bdfef70 100644 --- a/src/torchattack/vdc.py +++ b/src/torchattack/vdc.py @@ -71,7 +71,8 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() - # Global hooks for VDC + # Global hooks and attack stage state for VDC + self.stage = [] self.hooks = [] assert self.sample_num_batches <= self.max_num_batches @@ -108,7 +109,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.mlp_block = 0 self.attn_block = 0 self.skip_block = 0 - self._register_model_hooks(add=False) + self._register_model_hooks(grad_add_hook=False) # Compute loss outs = self.model(self.normalize(x + delta)) @@ -133,7 +134,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.mlp_block = 0 self.attn_block = 0 self.skip_block = 0 - self._register_model_hooks(add=True) + self._register_model_hooks(grad_add_hook=True) # Compute loss 2nd time outs = self.model(self.normalize(x + delta)) @@ -167,7 +168,15 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + delta - def _register_model_hooks(self, add: bool = False): + def _register_model_hooks(self, grad_add_hook: bool = False): + """Register hooks to either record or add gradients during the backward pass. + + Args: + grad_add_hook: If False, register hooks to record gradients. If True, + register hooks to modify the gradients by adding pre-recorded gradients + during the backward pass. + """ + def mlp_record_vit_stage(module, grad_in, grad_out, gamma): mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] @@ -234,14 +243,246 @@ def norm_record_vit(module, grad_in, grad_out, gamma): self.norm_list = grad_record return grad_in + # pit + def pool_record_pit(module, grad_in, grad_out, gamma): + grad_add = grad_in[0].data + B, C, H, W = grad_add.shape + grad_add = grad_add.reshape((B, C, H * W)).transpose(1, 2) + self.stage.append(grad_add.cpu().numpy()) + return grad_in + + def mlp_record_pit_stage(module, grad_in, grad_out, gamma): + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.mlp_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.mlp_block)) + ) + if self.mlp_block == 0: + grad_add = np.zeros_like(grad_record) + grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.03 * (0.5) + self.mlp_add.append(grad_add) + self.mlp_record.append(grad_record + grad_add) + else: + self.mlp_add.append(self.mlp_record[-1]) + total_mlp = self.mlp_record[-1] + grad_record + self.mlp_record.append(total_mlp) + elif self.mlp_block < 10: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.mlp_block)) + ) + if self.mlp_block == 4: + grad_add = np.zeros_like(grad_record) + grad_add[:, 1:, :] = self.stage[0] * 0.03 * (0.5) + self.mlp_add.append(grad_add) + self.mlp_record.append(grad_record + grad_add) + else: + self.mlp_add.append(self.mlp_record[-1]) + # total_mlp = self.mlp_record[-1] + grad_record + total_mlp = self.mlp_record[-1] + self.mlp_record.append(total_mlp) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.mlp_block)) + ) + if self.mlp_block == 10: + grad_add = np.zeros_like(grad_record) + grad_add[:, 1:, :] = self.stage[1] * 0.03 * (0.5) + self.mlp_add.append(grad_add) + self.mlp_record.append(grad_record + grad_add) + else: + self.mlp_add.append(self.mlp_record[-1]) + # total_mlp = self.mlp_record[-1] + grad_record + total_mlp = self.mlp_record[-1] + self.mlp_record.append(total_mlp) + self.mlp_block += 1 + + return (out_grad, grad_in[1], grad_in[2]) + + def mlp_add_pit(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() + self.mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def attn_record_pit_stage(module, grad_in, grad_out, gamma): + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.attn_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.attn_block)) + ) + if self.attn_block == 0: + self.attn_add.append(np.zeros_like(grad_record)) + self.attn_record.append(grad_record) + else: + self.attn_add.append(self.attn_record[-1]) + total_attn = self.attn_record[-1] + grad_record + self.attn_record.append(total_attn) + elif self.attn_block < 10: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.attn_block)) + ) + if self.attn_block == 4: + self.attn_add.append(np.zeros_like(grad_record)) + self.attn_record.append(grad_record) + else: + self.attn_add.append(self.attn_record[-1]) + # total_attn = self.attn_record[-1] + grad_record + total_attn = self.attn_record[-1] + self.attn_record.append(total_attn) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.attn_block)) + ) + if self.attn_block == 10: + self.attn_add.append(np.zeros_like(grad_record)) + self.attn_record.append(grad_record) + else: + self.attn_add.append(self.attn_record[-1]) + # total_attn = self.attn_record[-1] + grad_record + total_attn = self.attn_record[-1] + self.attn_record.append(total_attn) + self.attn_block += 1 + return (out_grad,) + + def attn_add_pit(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() + self.attn_block += 1 + return (out_grad,) + + def norm_record_pit(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + self.norm_list = grad_record + return grad_in + + #################################################### + # visformer + def pool_record_vis(module, grad_in, grad_out, gamma): + grad_add = grad_in[0].data + # B,C,H,W = grad_add.shape + # grad_add = grad_add.reshape((B,C,H*W)).transpose(1,2) + self.stage.append(grad_add.cpu().numpy()) + return grad_in + + def mlp_record_vis_stage(module, grad_in, grad_out, gamma): + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.mlp_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.mlp_block)) + ) + if self.mlp_block == 0: + grad_add = np.zeros_like(grad_record) + grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.1 * (0.5) + self.mlp_add.append(grad_add) + self.mlp_record.append(grad_record + grad_add) + else: + self.mlp_add.append(self.mlp_record[-1]) + total_mlp = self.mlp_record[-1] + grad_record + self.mlp_record.append(total_mlp) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.mlp_block)) + ) + if self.mlp_block == 4: + grad_add = np.zeros_like(grad_record) + # grad_add[:,1:,:] = self.stage[0]* 0.1*(0.5) + self.mlp_add.append(grad_add) + self.mlp_record.append(grad_record + grad_add) + else: + self.mlp_add.append(self.mlp_record[-1]) + total_mlp = self.mlp_record[-1] + grad_record + self.mlp_record.append(total_mlp) + + self.mlp_block += 1 + + return (out_grad, grad_in[1], grad_in[2]) + + def mlp_add_vis(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() + self.mlp_block += 1 + return (out_grad, grad_in[1], grad_in[2]) + + def norm_record_vis(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + self.norm_list = grad_record + return grad_in + + def attn_record_vis_stage(module, grad_in, grad_out, gamma): + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + if self.attn_block < 4: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.attn_block)) + ) + if self.attn_block == 0: + self.attn_add.append(np.zeros_like(grad_record)) + self.attn_record.append(grad_record) + else: + self.attn_add.append(self.attn_record[-1]) + total_attn = self.attn_record[-1] + grad_record + self.attn_record.append(total_attn) + else: + grad_record = ( + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.attn_block)) + ) + if self.attn_block == 4: + self.attn_add.append(np.zeros_like(grad_record)) + self.attn_record.append(grad_record) + else: + self.attn_add.append(self.attn_record[-1]) + total_attn = self.attn_record[-1] + grad_record + self.attn_record.append(total_attn) + + self.attn_block += 1 + return (out_grad,) + + def attn_add_vis(module, grad_in, grad_out, gamma): + grad_record = grad_in[0].data.cpu().numpy() + mask = torch.ones_like(grad_in[0]) * gamma + out_grad = mask * grad_in[0][:] + out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() + self.attn_block += 1 + return (out_grad,) + + ########### # vit mlp_record_func_vit = partial(mlp_record_vit_stage, gamma=1.0) norm_record_func_vit = partial(norm_record_vit, gamma=1.0) mlp_add_func_vit = partial(mlp_add_vit, gamma=0.5) attn_record_func_vit = partial(attn_record_vit_stage, gamma=1.0) attn_add_func_vit = partial(attn_add_vit, gamma=0.25) - - if not add: + ########### + # pit + attn_record_func_pit = partial(attn_record_pit_stage, gamma=1.0) + mlp_record_func_pit = partial(mlp_record_pit_stage, gamma=1.0) + norm_record_func_pit = partial(norm_record_pit, gamma=1.0) + pool_record_func_pit = partial(pool_record_pit, gamma=1.0) + attn_add_func_pit = partial(attn_add_pit, gamma=0.25) + mlp_add_func_pit = partial(mlp_add_pit, gamma=0.5) + # mlp_add_func_pit = partial(mlp_add_pit, gamma=0.75) + + ########### + # visformer + attn_record_func_vis = partial(attn_record_vis_stage, gamma=1.0) + mlp_record_func_vis = partial(mlp_record_vis_stage, gamma=1.0) + norm_record_func_vis = partial(norm_record_vis, gamma=1.0) + pool_record_func_vis = partial(pool_record_vis, gamma=1.0) + attn_add_func_vis = partial(attn_add_vis, gamma=0.25) + mlp_add_func_vis = partial(mlp_add_vis, gamma=0.5) + + if not grad_add_hook: if self.model_name in [ 'vit_base_patch16_224', 'deit_base_distilled_patch16_224', @@ -257,6 +498,74 @@ def norm_record_vit(module, grad_in, grad_out, gamma): attn_record_func_vit ) self.hooks.append(hook) + elif self.model_name == 'pit_b_224': + hook = self.model.norm.register_backward_hook(norm_record_func_pit) + self.hooks.append(hook) + for block_ind in range(13): + if block_ind < 3: + transformer_ind = 0 + used_block_ind = block_ind + elif block_ind < 9 and block_ind >= 3: + transformer_ind = 1 + used_block_ind = block_ind - 3 + elif block_ind < 13 and block_ind >= 9: + transformer_ind = 2 + used_block_ind = block_ind - 9 + hook = ( + self.model.transformers[transformer_ind] + .blocks[used_block_ind] + .attn.attn_drop.register_backward_hook(attn_record_func_pit) + ) + self.hooks.append(hook) + # hook = self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_record_func_pit) + hook = ( + self.model.transformers[transformer_ind] + .blocks[used_block_ind] + .norm2.register_backward_hook(mlp_record_func_pit) + ) + self.hooks.append(hook) + # TODO: module `pool` is non-existent in pit_b_224, causing the + # following hook register to fail. + hook = self.model.transformers[0].pool.register_backward_hook( + pool_record_func_pit + ) + self.hooks.append(hook) + hook = self.model.transformers[1].pool.register_backward_hook( + pool_record_func_pit + ) + self.hooks.append(hook) + elif self.model_name == 'visformer_small': + hook = self.model.norm.register_backward_hook(norm_record_func_vis) + self.hooks.append(hook) + for block_ind in range(8): + if block_ind < 4: + hook = self.model.stage2[ + block_ind + ].attn.attn_drop.register_backward_hook(attn_record_func_vis) + self.hooks.append(hook) + # hook = self.model.stage2[block_ind].mlp.register_backward_hook(mlp_record_func_vis) + hook = self.model.stage2[ + block_ind + ].norm2.register_backward_hook(mlp_record_func_vis) + self.hooks.append(hook) + elif block_ind >= 4: + hook = self.model.stage3[ + block_ind - 4 + ].attn.attn_drop.register_backward_hook(attn_record_func_vis) + self.hooks.append(hook) + # hook = self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_record_func_vis) + hook = self.model.stage3[ + block_ind - 4 + ].norm2.register_backward_hook(mlp_record_func_vis) + self.hooks.append(hook) + hook = self.model.patch_embed3.register_backward_hook( + pool_record_func_vis + ) + self.hooks.append(hook) + hook = self.model.patch_embed2.register_backward_hook( + pool_record_func_vis + ) + self.hooks.append(hook) else: if self.model_name in [ 'vit_base_patch16_224', @@ -271,6 +580,52 @@ def norm_record_vit(module, grad_in, grad_out, gamma): attn_add_func_vit ) self.hooks.append(hook) + elif self.model_name == 'pit_b_224': + for block_ind in range(13): + if block_ind < 3: + transformer_ind = 0 + used_block_ind = block_ind + elif block_ind < 9 and block_ind >= 3: + transformer_ind = 1 + used_block_ind = block_ind - 3 + elif block_ind < 13 and block_ind >= 9: + transformer_ind = 2 + used_block_ind = block_ind - 9 + hook = ( + self.model.transformers[transformer_ind] + .blocks[used_block_ind] + .attn.attn_drop.register_backward_hook(attn_add_func_pit) + ) + self.hooks.append(hook) + # hook = self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_add_func_pit) + hook = ( + self.model.transformers[transformer_ind] + .blocks[used_block_ind] + .norm2.register_backward_hook(mlp_add_func_pit) + ) + self.hooks.append(hook) + elif self.model_name == 'visformer_small': + for block_ind in range(8): + if block_ind < 4: + hook = self.model.stage2[ + block_ind + ].attn.attn_drop.register_backward_hook(attn_add_func_vis) + self.hooks.append(hook) + # hook = self.model.stage2[block_ind].mlp.register_backward_hook(mlp_add_func_vis) + hook = self.model.stage2[ + block_ind + ].norm2.register_backward_hook(mlp_add_func_vis) + self.hooks.append(hook) + elif block_ind >= 4: + hook = self.model.stage3[ + block_ind - 4 + ].attn.attn_drop.register_backward_hook(attn_add_func_vis) + self.hooks.append(hook) + # hook = self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_add_func_vis) + hook = self.model.stage3[ + block_ind - 4 + ].norm2.register_backward_hook(mlp_add_func_vis) + self.hooks.append(hook) if __name__ == '__main__': From a22e7334b021881a3b182489018ab0162960ed72 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 20 Nov 2024 10:36:25 +0800 Subject: [PATCH 5/6] Fix PiT pooling layer index --- src/torchattack/vdc.py | 38 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/src/torchattack/vdc.py b/src/torchattack/vdc.py index bdfef70..cbc3af6 100644 --- a/src/torchattack/vdc.py +++ b/src/torchattack/vdc.py @@ -200,7 +200,7 @@ def mlp_record_vit_stage(module, grad_in, grad_out, gamma): return (out_grad, grad_in[1], grad_in[2]) def mlp_add_vit(module, grad_in, grad_out, gamma): - grad_record = grad_in[0].data.cpu().numpy() + # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma # mask_0 = torch.zeros_like(grad_in[0]) out_grad = mask * grad_in[0][:] @@ -228,7 +228,7 @@ def attn_record_vit_stage(module, grad_in, grad_out, gamma): return (out_grad,) def attn_add_vit(module, grad_in, grad_out, gamma): - grad_record = grad_in[0].data.cpu().numpy() + # grad_record = grad_in[0].data.cpu().numpy() # mask_0 = torch.zeros_like(grad_in[0]) mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] @@ -239,15 +239,15 @@ def attn_add_vit(module, grad_in, grad_out, gamma): def norm_record_vit(module, grad_in, grad_out, gamma): grad_record = grad_in[0].data.cpu().numpy() - mask = torch.ones_like(grad_in[0]) * gamma + # mask = torch.ones_like(grad_in[0]) * gamma self.norm_list = grad_record return grad_in # pit def pool_record_pit(module, grad_in, grad_out, gamma): grad_add = grad_in[0].data - B, C, H, W = grad_add.shape - grad_add = grad_add.reshape((B, C, H * W)).transpose(1, 2) + b, c, h, w = grad_add.shape + grad_add = grad_add.reshape((b, c, h * w)).transpose(1, 2) self.stage.append(grad_add.cpu().numpy()) return grad_in @@ -300,7 +300,7 @@ def mlp_record_pit_stage(module, grad_in, grad_out, gamma): return (out_grad, grad_in[1], grad_in[2]) def mlp_add_pit(module, grad_in, grad_out, gamma): - grad_record = grad_in[0].data.cpu().numpy() + # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() @@ -349,7 +349,7 @@ def attn_record_pit_stage(module, grad_in, grad_out, gamma): return (out_grad,) def attn_add_pit(module, grad_in, grad_out, gamma): - grad_record = grad_in[0].data.cpu().numpy() + # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() @@ -358,7 +358,7 @@ def attn_add_pit(module, grad_in, grad_out, gamma): def norm_record_pit(module, grad_in, grad_out, gamma): grad_record = grad_in[0].data.cpu().numpy() - mask = torch.ones_like(grad_in[0]) * gamma + # mask = torch.ones_like(grad_in[0]) * gamma self.norm_list = grad_record return grad_in @@ -406,7 +406,7 @@ def mlp_record_vis_stage(module, grad_in, grad_out, gamma): return (out_grad, grad_in[1], grad_in[2]) def mlp_add_vis(module, grad_in, grad_out, gamma): - grad_record = grad_in[0].data.cpu().numpy() + # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() @@ -415,7 +415,7 @@ def mlp_add_vis(module, grad_in, grad_out, gamma): def norm_record_vis(module, grad_in, grad_out, gamma): grad_record = grad_in[0].data.cpu().numpy() - mask = torch.ones_like(grad_in[0]) * gamma + # mask = torch.ones_like(grad_in[0]) * gamma self.norm_list = grad_record return grad_in @@ -449,21 +449,20 @@ def attn_record_vis_stage(module, grad_in, grad_out, gamma): return (out_grad,) def attn_add_vis(module, grad_in, grad_out, gamma): - grad_record = grad_in[0].data.cpu().numpy() + # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() self.attn_block += 1 return (out_grad,) - ########### # vit mlp_record_func_vit = partial(mlp_record_vit_stage, gamma=1.0) norm_record_func_vit = partial(norm_record_vit, gamma=1.0) mlp_add_func_vit = partial(mlp_add_vit, gamma=0.5) attn_record_func_vit = partial(attn_record_vit_stage, gamma=1.0) attn_add_func_vit = partial(attn_add_vit, gamma=0.25) - ########### + # pit attn_record_func_pit = partial(attn_record_pit_stage, gamma=1.0) mlp_record_func_pit = partial(mlp_record_pit_stage, gamma=1.0) @@ -473,7 +472,6 @@ def attn_add_vis(module, grad_in, grad_out, gamma): mlp_add_func_pit = partial(mlp_add_pit, gamma=0.5) # mlp_add_func_pit = partial(mlp_add_pit, gamma=0.75) - ########### # visformer attn_record_func_vis = partial(attn_record_vis_stage, gamma=1.0) mlp_record_func_vis = partial(mlp_record_vis_stage, gamma=1.0) @@ -524,13 +522,11 @@ def attn_add_vis(module, grad_in, grad_out, gamma): .norm2.register_backward_hook(mlp_record_func_pit) ) self.hooks.append(hook) - # TODO: module `pool` is non-existent in pit_b_224, causing the - # following hook register to fail. - hook = self.model.transformers[0].pool.register_backward_hook( + hook = self.model.transformers[1].pool.register_backward_hook( pool_record_func_pit ) self.hooks.append(hook) - hook = self.model.transformers[1].pool.register_backward_hook( + hook = self.model.transformers[2].pool.register_backward_hook( pool_record_func_pit ) self.hooks.append(hook) @@ -633,10 +629,8 @@ def attn_add_vis(module, grad_in, grad_out, gamma): run_attack( VDC, - attack_cfg={'model_name': 'vit_base_patch16_224'}, - model_name='vit_base_patch16_224', + attack_cfg={'model_name': 'pit_b_224'}, + model_name='pit_b_224', victim_model_names=['cait_s24_224', 'visformer_small'], - max_samples=12, - batch_size=4, from_timm=True, ) From 0ddf5e8cc476b54bfafe9e384be4d9ee60dff5bf Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 20 Nov 2024 13:19:54 +0800 Subject: [PATCH 6/6] Refactor VDC attack with rgetattr --- src/torchattack/__init__.py | 2 + src/torchattack/base.py | 2 +- src/torchattack/pna_patchout.py | 8 +- src/torchattack/tgr.py | 10 +- src/torchattack/vdc.py | 661 ++++++++++++++++++-------------- 5 files changed, 385 insertions(+), 298 deletions(-) diff --git a/src/torchattack/__init__.py b/src/torchattack/__init__.py index 922e5ad..140c7eb 100644 --- a/src/torchattack/__init__.py +++ b/src/torchattack/__init__.py @@ -15,6 +15,7 @@ from torchattack.ssp import SSP from torchattack.tgr import TGR from torchattack.tifgsm import TIFGSM +from torchattack.vdc import VDC from torchattack.vmifgsm import VMIFGSM from torchattack.vnifgsm import VNIFGSM @@ -38,6 +39,7 @@ 'SSP', 'TGR', 'TIFGSM', + 'VDC', 'VMIFGSM', 'VNIFGSM', ] diff --git a/src/torchattack/base.py b/src/torchattack/base.py index 56fc562..7fdbebe 100644 --- a/src/torchattack/base.py +++ b/src/torchattack/base.py @@ -30,7 +30,7 @@ def __repr__(self) -> str: def repr_map(k, v): if isinstance(v, float): return f'{k}={v:.3f}' - if k in ['model', 'normalize', 'feature_layer']: + if k in ['model', 'normalize', 'feature_layer', 'hooks']: return f'{k}={v.__class__.__name__}' if isinstance(v, torch.Tensor): return f'{k}={v.shape}' diff --git a/src/torchattack/pna_patchout.py b/src/torchattack/pna_patchout.py index 443ad6f..f592939 100644 --- a/src/torchattack/pna_patchout.py +++ b/src/torchattack/pna_patchout.py @@ -95,7 +95,9 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() + # Register hooks if self.pna_skip: + self.hooks: list[torch.utils.hooks.RemovableHandle] = [] self._register_vit_model_hook() # Set default image size and number of patches for PatchOut @@ -143,6 +145,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: delta.grad.detach_() delta.grad.zero_() + for hook in self.hooks: + hook.remove() + return x + delta def _register_vit_model_hook(self): @@ -162,7 +167,8 @@ def attn_drop_mask_grad( # Register backward hook for layers specified in _supported_vit_cfg for layer in self._supported_vit_cfg[self.model_name]: module = rgetattr(self.model, layer) - module.register_backward_hook(drop_hook_func) + hook = module.register_backward_hook(drop_hook_func) + self.hooks.append(hook) def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor: delta_mask = torch.zeros_like(delta) diff --git a/src/torchattack/tgr.py b/src/torchattack/tgr.py index c928764..70a0031 100644 --- a/src/torchattack/tgr.py +++ b/src/torchattack/tgr.py @@ -61,6 +61,8 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() + # Register hooks + self.hooks: list[torch.utils.hooks.RemovableHandle] = [] self._register_tgr_model_hooks() def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -110,6 +112,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: delta.grad.detach_() delta.grad.zero_() + for hook in self.hooks: + hook.remove() + return x + delta def _register_tgr_model_hooks(self): @@ -348,10 +353,11 @@ def mlp_tgr( assert self.model_name in _supported_vit_cfg - for hook, layers in _supported_vit_cfg[self.model_name]: + for hook_func, layers in _supported_vit_cfg[self.model_name]: for layer in layers: module = rgetattr(self.model, layer) - module.register_backward_hook(hook) + hook = module.register_backward_hook(hook_func) + self.hooks.append(hook) if __name__ == '__main__': diff --git a/src/torchattack/vdc.py b/src/torchattack/vdc.py index cbc3af6..63eadc2 100644 --- a/src/torchattack/vdc.py +++ b/src/torchattack/vdc.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -# from torchattack._rgetattr import rgetattr +from torchattack._rgetattr import rgetattr from torchattack.base import Attack @@ -19,13 +19,19 @@ class VDC(Attack): Args: model: The model to attack. - model_name: The name of the model. + model_name: The name of the model. Supported models are: + * 'vit_base_patch16_224' + * 'deit_base_distilled_patch16_224' + * 'pit_b_224' + * 'visformer_small' normalize: A transform to normalize images. device: Device to use for tensors. Defaults to cuda if available. eps: The maximum perturbation. Defaults to 8/255. steps: Number of steps. Defaults to 10. alpha: Step size, `eps / steps` if None. Defaults to None. decay: Momentum decay factor. Defaults to 1.0. + sample_num_batches: Number of batches to sample. Defaults to 130. + lambd: Lambda value for VDC. Defaults to 0.1. clip_min: Minimum value for clipping. Defaults to 0.0. clip_max: Maximum value for clipping. Defaults to 1.0. targeted: Targeted attack if True. Defaults to False. @@ -72,8 +78,8 @@ def __init__( self.lossfn = nn.CrossEntropyLoss() # Global hooks and attack stage state for VDC - self.stage = [] - self.hooks = [] + self.stage: list[np.ndarray] = [] + self.hooks: list[torch.utils.hooks.RemovableHandle] = [] assert self.sample_num_batches <= self.max_num_batches @@ -95,21 +101,23 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if self.alpha is None: self.alpha = self.eps / self.steps - # Perform VDC - for _ in range(self.steps): - self.attn_record = [] - self.attn_add = [] + class GradientRecorder: + """Gradient recorder for attention and MLP blocks.""" - self.mlp_record = [] - self.mlp_add = [] + def __init__(self): + self.grad_records = [] + self.grad_additions = [] - self.skip_record = [] - self.skip_add = [] + # Perform VDC + for _ in range(self.steps): + # Initialize gradient recorders + self.attn_recorder = GradientRecorder() + self.mlp_recorder = GradientRecorder() - self.mlp_block = 0 - self.attn_block = 0 - self.skip_block = 0 - self._register_model_hooks(grad_add_hook=False) + # Stage 1: Record gradients + self.current_mlp_block = 0 + self.current_attn_block = 0 + self._register_model_hooks(add_grad_mode=False) # Compute loss outs = self.model(self.normalize(x + delta)) @@ -131,10 +139,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for hook in self.hooks: hook.remove() - self.mlp_block = 0 - self.attn_block = 0 - self.skip_block = 0 - self._register_model_hooks(grad_add_hook=True) + # Stage 2: Update gradients by adding recorded gradients + self.current_mlp_block = 0 + self.current_attn_block = 0 + self._register_model_hooks(add_grad_mode=True) # Compute loss 2nd time outs = self.model(self.normalize(x + delta)) @@ -168,7 +176,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + delta - def _register_model_hooks(self, grad_add_hook: bool = False): + def _register_model_hooks(self, add_grad_mode: bool = False) -> None: """Register hooks to either record or add gradients during the backward pass. Args: @@ -177,283 +185,433 @@ def _register_model_hooks(self, grad_add_hook: bool = False): during the backward pass. """ - def mlp_record_vit_stage(module, grad_in, grad_out, gamma): + def mlp_record_vit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] # ablation grad_record = ( - grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.mlp_block)) + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.current_mlp_block)) ) # grad_record = grad_in[0].data.cpu().numpy() - if self.mlp_block == 0: + if self.current_mlp_block == 0: grad_add = np.zeros_like(grad_record) # ablation grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.1 * (0.5) # grad_add[:,0,:] = self.norm[:,0,:] - self.mlp_add.append(grad_add) - self.mlp_record.append(grad_record + grad_add) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) else: - self.mlp_add.append(self.mlp_record[-1]) - total_mlp = self.mlp_record[-1] + grad_record - self.mlp_record.append(total_mlp) - self.mlp_block += 1 + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) + self.current_mlp_block += 1 return (out_grad, grad_in[1], grad_in[2]) - def mlp_add_vit(module, grad_in, grad_out, gamma): + def mlp_add_vit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma # mask_0 = torch.zeros_like(grad_in[0]) out_grad = mask * grad_in[0][:] - # out_grad = torch.where(grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:]) - out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() - self.mlp_block += 1 + # out_grad = torch.where( + # grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:] + # ) + out_grad += torch.tensor( + self.mlp_recorder.grad_additions[self.current_mlp_block], + device=grad_in[0].device, + ) + self.current_mlp_block += 1 return (out_grad, grad_in[1], grad_in[2]) - def attn_record_vit_stage(module, grad_in, grad_out, gamma): + def attn_record_vit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] grad_record = ( - grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.attn_block)) + grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.current_attn_block)) ) # grad_record = grad_in[0].data.cpu().numpy() - if self.attn_block == 0: - self.attn_add.append(np.zeros_like(grad_record)) - self.attn_record.append(grad_record) + if self.current_attn_block == 0: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) else: - self.attn_add.append(self.attn_record[-1]) - total_attn = self.attn_record[-1] + grad_record - self.attn_record.append(total_attn) + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) - self.attn_block += 1 + self.current_attn_block += 1 return (out_grad,) - def attn_add_vit(module, grad_in, grad_out, gamma): + def attn_add_vit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: # grad_record = grad_in[0].data.cpu().numpy() # mask_0 = torch.zeros_like(grad_in[0]) mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - # out_grad = torch.where(grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:]) - out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() - self.attn_block += 1 + # out_grad = torch.where( + # grad_in[0][:] > 0, mask * grad_in[0][:], mask_0 * grad_in[0][:] + # ) + out_grad += torch.tensor( + self.attn_recorder.grad_additions[self.current_attn_block], + device=grad_in[0].device, + ) + self.current_attn_block += 1 return (out_grad,) - def norm_record_vit(module, grad_in, grad_out, gamma): + def norm_record_vit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: grad_record = grad_in[0].data.cpu().numpy() # mask = torch.ones_like(grad_in[0]) * gamma self.norm_list = grad_record return grad_in # pit - def pool_record_pit(module, grad_in, grad_out, gamma): + def pool_record_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: grad_add = grad_in[0].data b, c, h, w = grad_add.shape grad_add = grad_add.reshape((b, c, h * w)).transpose(1, 2) self.stage.append(grad_add.cpu().numpy()) return grad_in - def mlp_record_pit_stage(module, grad_in, grad_out, gamma): + def mlp_record_pit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - if self.mlp_block < 4: + if self.current_mlp_block < 4: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.mlp_block)) + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_mlp_block)) ) - if self.mlp_block == 0: + if self.current_mlp_block == 0: grad_add = np.zeros_like(grad_record) grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.03 * (0.5) - self.mlp_add.append(grad_add) - self.mlp_record.append(grad_record + grad_add) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) else: - self.mlp_add.append(self.mlp_record[-1]) - total_mlp = self.mlp_record[-1] + grad_record - self.mlp_record.append(total_mlp) - elif self.mlp_block < 10: + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) + elif self.current_mlp_block < 10: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.mlp_block)) + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_mlp_block)) ) - if self.mlp_block == 4: + if self.current_mlp_block == 4: grad_add = np.zeros_like(grad_record) grad_add[:, 1:, :] = self.stage[0] * 0.03 * (0.5) - self.mlp_add.append(grad_add) - self.mlp_record.append(grad_record + grad_add) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) else: - self.mlp_add.append(self.mlp_record[-1]) - # total_mlp = self.mlp_record[-1] + grad_record - total_mlp = self.mlp_record[-1] - self.mlp_record.append(total_mlp) + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + # total_mlp = self.mlp_rec.record[-1] + grad_record + total_mlp = self.mlp_recorder.grad_records[-1] + self.mlp_recorder.grad_records.append(total_mlp) else: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.mlp_block)) + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_mlp_block)) ) - if self.mlp_block == 10: + if self.current_mlp_block == 10: grad_add = np.zeros_like(grad_record) grad_add[:, 1:, :] = self.stage[1] * 0.03 * (0.5) - self.mlp_add.append(grad_add) - self.mlp_record.append(grad_record + grad_add) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) else: - self.mlp_add.append(self.mlp_record[-1]) - # total_mlp = self.mlp_record[-1] + grad_record - total_mlp = self.mlp_record[-1] - self.mlp_record.append(total_mlp) - self.mlp_block += 1 + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + # total_mlp = self.mlp_rec.record[-1] + grad_record + total_mlp = self.mlp_recorder.grad_records[-1] + self.mlp_recorder.grad_records.append(total_mlp) + self.current_mlp_block += 1 return (out_grad, grad_in[1], grad_in[2]) - def mlp_add_pit(module, grad_in, grad_out, gamma): + def mlp_add_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() - self.mlp_block += 1 + out_grad += torch.tensor( + self.mlp_recorder.grad_additions[self.current_mlp_block], + device=grad_in[0].device, + ) + self.current_mlp_block += 1 return (out_grad, grad_in[1], grad_in[2]) - def attn_record_pit_stage(module, grad_in, grad_out, gamma): + def attn_record_pit_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - if self.attn_block < 4: + if self.current_attn_block < 4: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.attn_block)) + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_attn_block)) ) - if self.attn_block == 0: - self.attn_add.append(np.zeros_like(grad_record)) - self.attn_record.append(grad_record) + if self.current_attn_block == 0: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) else: - self.attn_add.append(self.attn_record[-1]) - total_attn = self.attn_record[-1] + grad_record - self.attn_record.append(total_attn) - elif self.attn_block < 10: + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) + elif self.current_attn_block < 10: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.attn_block)) + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_attn_block)) ) - if self.attn_block == 4: - self.attn_add.append(np.zeros_like(grad_record)) - self.attn_record.append(grad_record) + if self.current_attn_block == 4: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) else: - self.attn_add.append(self.attn_record[-1]) - # total_attn = self.attn_record[-1] + grad_record - total_attn = self.attn_record[-1] - self.attn_record.append(total_attn) + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + # total_attn = self.attn_rec.record[-1] + grad_record + total_attn = self.attn_recorder.grad_records[-1] + self.attn_recorder.grad_records.append(total_attn) else: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.03 * (0.5 ** (self.attn_block)) + grad_in[0].data.cpu().numpy() + * 0.03 + * (0.5 ** (self.current_attn_block)) ) - if self.attn_block == 10: - self.attn_add.append(np.zeros_like(grad_record)) - self.attn_record.append(grad_record) + if self.current_attn_block == 10: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) else: - self.attn_add.append(self.attn_record[-1]) - # total_attn = self.attn_record[-1] + grad_record - total_attn = self.attn_record[-1] - self.attn_record.append(total_attn) - self.attn_block += 1 + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + # total_attn = self.attn_rec.record[-1] + grad_record + total_attn = self.attn_recorder.grad_records[-1] + self.attn_recorder.grad_records.append(total_attn) + self.current_attn_block += 1 return (out_grad,) - def attn_add_pit(module, grad_in, grad_out, gamma): + def attn_add_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() - self.attn_block += 1 + out_grad += torch.tensor( + self.attn_recorder.grad_additions[self.current_attn_block], + device=grad_in[0].device, + ) + self.current_attn_block += 1 return (out_grad,) - def norm_record_pit(module, grad_in, grad_out, gamma): + def norm_record_pit( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: grad_record = grad_in[0].data.cpu().numpy() # mask = torch.ones_like(grad_in[0]) * gamma self.norm_list = grad_record return grad_in - #################################################### # visformer - def pool_record_vis(module, grad_in, grad_out, gamma): + def pool_record_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: grad_add = grad_in[0].data # B,C,H,W = grad_add.shape # grad_add = grad_add.reshape((B,C,H*W)).transpose(1,2) self.stage.append(grad_add.cpu().numpy()) return grad_in - def mlp_record_vis_stage(module, grad_in, grad_out, gamma): + def mlp_record_vis_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - if self.mlp_block < 4: + if self.current_mlp_block < 4: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.mlp_block)) + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_mlp_block)) ) - if self.mlp_block == 0: + if self.current_mlp_block == 0: grad_add = np.zeros_like(grad_record) grad_add[:, 0, :] = self.norm_list[:, 0, :] * 0.1 * (0.5) - self.mlp_add.append(grad_add) - self.mlp_record.append(grad_record + grad_add) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) else: - self.mlp_add.append(self.mlp_record[-1]) - total_mlp = self.mlp_record[-1] + grad_record - self.mlp_record.append(total_mlp) + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) else: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.mlp_block)) + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_mlp_block)) ) - if self.mlp_block == 4: + if self.current_mlp_block == 4: grad_add = np.zeros_like(grad_record) # grad_add[:,1:,:] = self.stage[0]* 0.1*(0.5) - self.mlp_add.append(grad_add) - self.mlp_record.append(grad_record + grad_add) + self.mlp_recorder.grad_additions.append(grad_add) + self.mlp_recorder.grad_records.append(grad_record + grad_add) else: - self.mlp_add.append(self.mlp_record[-1]) - total_mlp = self.mlp_record[-1] + grad_record - self.mlp_record.append(total_mlp) + self.mlp_recorder.grad_additions.append( + self.mlp_recorder.grad_records[-1] + ) + total_mlp = self.mlp_recorder.grad_records[-1] + grad_record + self.mlp_recorder.grad_records.append(total_mlp) - self.mlp_block += 1 + self.current_mlp_block += 1 return (out_grad, grad_in[1], grad_in[2]) - def mlp_add_vis(module, grad_in, grad_out, gamma): + def mlp_add_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - out_grad += torch.tensor(self.mlp_add[self.mlp_block]).cuda() - self.mlp_block += 1 + out_grad += torch.tensor( + self.mlp_recorder.grad_additions[self.current_mlp_block], + device=grad_in[0].device, + ) + self.current_mlp_block += 1 return (out_grad, grad_in[1], grad_in[2]) - def norm_record_vis(module, grad_in, grad_out, gamma): + def norm_record_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: grad_record = grad_in[0].data.cpu().numpy() # mask = torch.ones_like(grad_in[0]) * gamma self.norm_list = grad_record return grad_in - def attn_record_vis_stage(module, grad_in, grad_out, gamma): + def attn_record_vis_stage( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - if self.attn_block < 4: + if self.current_attn_block < 4: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.attn_block)) + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_attn_block)) ) - if self.attn_block == 0: - self.attn_add.append(np.zeros_like(grad_record)) - self.attn_record.append(grad_record) + if self.current_attn_block == 0: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) else: - self.attn_add.append(self.attn_record[-1]) - total_attn = self.attn_record[-1] + grad_record - self.attn_record.append(total_attn) + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) else: grad_record = ( - grad_in[0].data.cpu().numpy() * 0.1 * (0.5 ** (self.attn_block)) + grad_in[0].data.cpu().numpy() + * 0.1 + * (0.5 ** (self.current_attn_block)) ) - if self.attn_block == 4: - self.attn_add.append(np.zeros_like(grad_record)) - self.attn_record.append(grad_record) + if self.current_attn_block == 4: + self.attn_recorder.grad_additions.append(np.zeros_like(grad_record)) + self.attn_recorder.grad_records.append(grad_record) else: - self.attn_add.append(self.attn_record[-1]) - total_attn = self.attn_record[-1] + grad_record - self.attn_record.append(total_attn) + self.attn_recorder.grad_additions.append( + self.attn_recorder.grad_records[-1] + ) + total_attn = self.attn_recorder.grad_records[-1] + grad_record + self.attn_recorder.grad_records.append(total_attn) - self.attn_block += 1 + self.current_attn_block += 1 return (out_grad,) - def attn_add_vis(module, grad_in, grad_out, gamma): + def attn_add_vis( + module: torch.nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], + gamma: float, + ) -> tuple[torch.Tensor, ...]: # grad_record = grad_in[0].data.cpu().numpy() mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] - out_grad += torch.tensor(self.attn_add[self.attn_block]).cuda() - self.attn_block += 1 + out_grad += torch.tensor( + self.attn_recorder.grad_additions[self.current_attn_block], + device=grad_in[0].device, + ) + self.current_attn_block += 1 return (out_grad,) # vit @@ -480,148 +638,63 @@ def attn_add_vis(module, grad_in, grad_out, gamma): attn_add_func_vis = partial(attn_add_vis, gamma=0.25) mlp_add_func_vis = partial(mlp_add_vis, gamma=0.5) - if not grad_add_hook: - if self.model_name in [ - 'vit_base_patch16_224', - 'deit_base_distilled_patch16_224', - ]: - hook = self.model.norm.register_backward_hook(norm_record_func_vit) - self.hooks.append(hook) - for i in range(12): - hook = self.model.blocks[i].norm2.register_backward_hook( - mlp_record_func_vit - ) - self.hooks.append(hook) - hook = self.model.blocks[i].attn.attn_drop.register_backward_hook( - attn_record_func_vit - ) - self.hooks.append(hook) - elif self.model_name == 'pit_b_224': - hook = self.model.norm.register_backward_hook(norm_record_func_pit) - self.hooks.append(hook) - for block_ind in range(13): - if block_ind < 3: - transformer_ind = 0 - used_block_ind = block_ind - elif block_ind < 9 and block_ind >= 3: - transformer_ind = 1 - used_block_ind = block_ind - 3 - elif block_ind < 13 and block_ind >= 9: - transformer_ind = 2 - used_block_ind = block_ind - 9 - hook = ( - self.model.transformers[transformer_ind] - .blocks[used_block_ind] - .attn.attn_drop.register_backward_hook(attn_record_func_pit) - ) - self.hooks.append(hook) - # hook = self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_record_func_pit) - hook = ( - self.model.transformers[transformer_ind] - .blocks[used_block_ind] - .norm2.register_backward_hook(mlp_record_func_pit) - ) - self.hooks.append(hook) - hook = self.model.transformers[1].pool.register_backward_hook( - pool_record_func_pit - ) - self.hooks.append(hook) - hook = self.model.transformers[2].pool.register_backward_hook( - pool_record_func_pit - ) - self.hooks.append(hook) - elif self.model_name == 'visformer_small': - hook = self.model.norm.register_backward_hook(norm_record_func_vis) - self.hooks.append(hook) - for block_ind in range(8): - if block_ind < 4: - hook = self.model.stage2[ - block_ind - ].attn.attn_drop.register_backward_hook(attn_record_func_vis) - self.hooks.append(hook) - # hook = self.model.stage2[block_ind].mlp.register_backward_hook(mlp_record_func_vis) - hook = self.model.stage2[ - block_ind - ].norm2.register_backward_hook(mlp_record_func_vis) - self.hooks.append(hook) - elif block_ind >= 4: - hook = self.model.stage3[ - block_ind - 4 - ].attn.attn_drop.register_backward_hook(attn_record_func_vis) - self.hooks.append(hook) - # hook = self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_record_func_vis) - hook = self.model.stage3[ - block_ind - 4 - ].norm2.register_backward_hook(mlp_record_func_vis) - self.hooks.append(hook) - hook = self.model.patch_embed3.register_backward_hook( - pool_record_func_vis - ) - self.hooks.append(hook) - hook = self.model.patch_embed2.register_backward_hook( - pool_record_func_vis - ) + # fmt: off + # Register hooks for supported models. + # * Gradient RECORD mode hooks: + record_grad_cfg = { + 'vit_base_patch16_224': [ + (norm_record_func_vit, ['norm']), + (mlp_record_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_record_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'deit_base_distilled_patch16_224': [ + (norm_record_func_vit, ['norm']), + (mlp_record_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_record_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'pit_b_224': [ + (norm_record_func_pit, ['norm']), + (attn_record_func_pit, [f'transformers.{tid}.blocks.{i}.attn.attn_drop' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + (mlp_record_func_pit, [f'transformers.{tid}.blocks.{i}.norm2' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + (pool_record_func_pit, ['transformers.1.pool', 'transformers.2.pool']), + ], + 'visformer_small': [ + (norm_record_func_vis, ['norm']), + (attn_record_func_vis, [f'stage2.{i}.attn.attn_drop' for i in range(4)] + [f'stage3.{i}.attn.attn_drop' for i in range(4)]), + (mlp_record_func_vis, [f'stage2.{i}.norm2' for i in range(4)] + [f'stage3.{i}.norm2' for i in range(4)]), + (pool_record_func_vis, ['patch_embed2', 'patch_embed3']), + ], + } + # * Gradient ADD mode hooks: + add_grad_cfg = { + 'vit_base_patch16_224': [ + (mlp_add_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_add_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'deit_base_distilled_patch16_224': [ + (mlp_add_func_vit, [f'blocks.{i}.norm2' for i in range(12)]), + (attn_add_func_vit, [f'blocks.{i}.attn.attn_drop' for i in range(12)]), + ], + 'pit_b_224': [ + (attn_add_func_pit, [f'transformers.{tid}.blocks.{i}.attn.attn_drop' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + (mlp_add_func_pit, [f'transformers.{tid}.blocks.{i}.norm2' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)]), + ], + 'visformer_small': [ + (attn_add_func_vis, [f'stage2.{i}.attn.attn_drop' for i in range(4)] + [f'stage3.{i}.attn.attn_drop' for i in range(4)]), + (mlp_add_func_vis, [f'stage2.{i}.norm2' for i in range(4)] + [f'stage3.{i}.norm2' for i in range(4)]), + + ], + } + # fmt: on + + activated_vit_cfg = add_grad_cfg if add_grad_mode else record_grad_cfg + assert self.model_name in activated_vit_cfg + + for hook_func, layers in activated_vit_cfg[self.model_name]: + for layer in layers: + module = rgetattr(self.model, layer) + hook = module.register_backward_hook(hook_func) self.hooks.append(hook) - else: - if self.model_name in [ - 'vit_base_patch16_224', - 'deit_base_distilled_patch16_224', - ]: - for i in range(12): - hook = self.model.blocks[i].norm2.register_backward_hook( - mlp_add_func_vit - ) - self.hooks.append(hook) - hook = self.model.blocks[i].attn.attn_drop.register_backward_hook( - attn_add_func_vit - ) - self.hooks.append(hook) - elif self.model_name == 'pit_b_224': - for block_ind in range(13): - if block_ind < 3: - transformer_ind = 0 - used_block_ind = block_ind - elif block_ind < 9 and block_ind >= 3: - transformer_ind = 1 - used_block_ind = block_ind - 3 - elif block_ind < 13 and block_ind >= 9: - transformer_ind = 2 - used_block_ind = block_ind - 9 - hook = ( - self.model.transformers[transformer_ind] - .blocks[used_block_ind] - .attn.attn_drop.register_backward_hook(attn_add_func_pit) - ) - self.hooks.append(hook) - # hook = self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_add_func_pit) - hook = ( - self.model.transformers[transformer_ind] - .blocks[used_block_ind] - .norm2.register_backward_hook(mlp_add_func_pit) - ) - self.hooks.append(hook) - elif self.model_name == 'visformer_small': - for block_ind in range(8): - if block_ind < 4: - hook = self.model.stage2[ - block_ind - ].attn.attn_drop.register_backward_hook(attn_add_func_vis) - self.hooks.append(hook) - # hook = self.model.stage2[block_ind].mlp.register_backward_hook(mlp_add_func_vis) - hook = self.model.stage2[ - block_ind - ].norm2.register_backward_hook(mlp_add_func_vis) - self.hooks.append(hook) - elif block_ind >= 4: - hook = self.model.stage3[ - block_ind - 4 - ].attn.attn_drop.register_backward_hook(attn_add_func_vis) - self.hooks.append(hook) - # hook = self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_add_func_vis) - hook = self.model.stage3[ - block_ind - 4 - ].norm2.register_backward_hook(mlp_add_func_vis) - self.hooks.append(hook) if __name__ == '__main__':