Skip to content

Commit

Permalink
Refactor VDC attack with rgetattr
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 20, 2024
1 parent a22e733 commit 0ddf5e8
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 298 deletions.
2 changes: 2 additions & 0 deletions src/torchattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchattack.ssp import SSP
from torchattack.tgr import TGR
from torchattack.tifgsm import TIFGSM
from torchattack.vdc import VDC
from torchattack.vmifgsm import VMIFGSM
from torchattack.vnifgsm import VNIFGSM

Expand All @@ -38,6 +39,7 @@
'SSP',
'TGR',
'TIFGSM',
'VDC',
'VMIFGSM',
'VNIFGSM',
]
2 changes: 1 addition & 1 deletion src/torchattack/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __repr__(self) -> str:
def repr_map(k, v):
if isinstance(v, float):
return f'{k}={v:.3f}'
if k in ['model', 'normalize', 'feature_layer']:
if k in ['model', 'normalize', 'feature_layer', 'hooks']:
return f'{k}={v.__class__.__name__}'
if isinstance(v, torch.Tensor):
return f'{k}={v.shape}'
Expand Down
8 changes: 7 additions & 1 deletion src/torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()

# Register hooks
if self.pna_skip:
self.hooks: list[torch.utils.hooks.RemovableHandle] = []
self._register_vit_model_hook()

# Set default image size and number of patches for PatchOut
Expand Down Expand Up @@ -143,6 +145,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta.grad.detach_()
delta.grad.zero_()

for hook in self.hooks:
hook.remove()

return x + delta

def _register_vit_model_hook(self):
Expand All @@ -162,7 +167,8 @@ def attn_drop_mask_grad(
# Register backward hook for layers specified in _supported_vit_cfg
for layer in self._supported_vit_cfg[self.model_name]:
module = rgetattr(self.model, layer)
module.register_backward_hook(drop_hook_func)
hook = module.register_backward_hook(drop_hook_func)
self.hooks.append(hook)

def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor:
delta_mask = torch.zeros_like(delta)
Expand Down
10 changes: 8 additions & 2 deletions src/torchattack/tgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()

# Register hooks
self.hooks: list[torch.utils.hooks.RemovableHandle] = []
self._register_tgr_model_hooks()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -110,6 +112,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta.grad.detach_()
delta.grad.zero_()

for hook in self.hooks:
hook.remove()

return x + delta

def _register_tgr_model_hooks(self):
Expand Down Expand Up @@ -348,10 +353,11 @@ def mlp_tgr(

assert self.model_name in _supported_vit_cfg

for hook, layers in _supported_vit_cfg[self.model_name]:
for hook_func, layers in _supported_vit_cfg[self.model_name]:
for layer in layers:
module = rgetattr(self.model, layer)
module.register_backward_hook(hook)
hook = module.register_backward_hook(hook_func)
self.hooks.append(hook)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 0ddf5e8

Please sign in to comment.