Skip to content

Commit

Permalink
Refactor PNA and TGR hook cfg declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 20, 2024
1 parent a028bf0 commit ce1016d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 36 deletions.
46 changes: 13 additions & 33 deletions torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,6 @@ class PNAPatchOut(Attack):
targeted: Targeted attack if True. Defaults to False.
"""

# fmt: off
_supported_vit_cfg = {
'vit_base_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'deit_base_distilled_patch16_224': [
f'blocks.{i}.attn.attn_drop' for i in range(12)
],
'pit_b_224': [
f'transformers.{tid}.blocks.{i}.attn.attn_drop'
for tid, bid in enumerate([3, 6, 4])
for i in range(bid)
],
'cait_s24_224': [
# Regular blocks
f'blocks.{i}.attn.attn_drop' for i in range(24)
] + [
# Token-only block
f'blocks_token_only.{i}.attn.attn_drop' for i in range(2)
],
'visformer_small': [
# Stage 2 blocks
f'stage2.{i}.attn.attn_drop' for i in range(4)
] + [
# Stage 3 blocks
f'stage3.{i}.attn.attn_drop' for i in range(4)
],
}
# fmt: on

def __init__(
self,
model: nn.Module | AttackModel,
Expand Down Expand Up @@ -184,10 +154,20 @@ def attn_drop_mask_grad(

drop_hook_func = partial(attn_drop_mask_grad, gamma=0)

assert self.model_name in self._supported_vit_cfg
# fmt: off
supported_vit_cfg = {
'vit_base_patch16_224': [f'blocks.{i}.attn.attn_drop' for i in range(12)],
'deit_base_distilled_patch16_224': [f'blocks.{i}.attn.attn_drop' for i in range(12)],
'pit_b_224': [f'transformers.{tid}.blocks.{i}.attn.attn_drop' for tid, bid in enumerate([3, 6, 4]) for i in range(bid)],
'cait_s24_224': [f'blocks.{i}.attn.attn_drop' for i in range(24)] + [f'blocks_token_only.{i}.attn.attn_drop' for i in range(2)],
'visformer_small': [f'stage2.{i}.attn.attn_drop' for i in range(4)] + [f'stage3.{i}.attn.attn_drop' for i in range(4)],
}
# fmt: on

assert self.model_name in supported_vit_cfg

# Register backward hook for layers specified in _supported_vit_cfg
for layer in self._supported_vit_cfg[self.model_name]:
# Register backward hook for layers specified in supported_vit_cfg
for layer in supported_vit_cfg[self.model_name]:
module = rgetattr(self.model, layer)
hook = module.register_backward_hook(drop_hook_func)
self.hooks.append(hook)
Expand Down
6 changes: 3 additions & 3 deletions torchattack/tgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def mlp_tgr(
mlp_tgr_hook = partial(mlp_tgr, gamma=0.5)

# fmt: off
_supported_vit_cfg = {
supported_vit_cfg = {
'vit_base_patch16_224': [
(attn_tgr_hook, [f'blocks.{i}.attn.attn_drop' for i in range(12)]),
(v_tgr_hook, [f'blocks.{i}.attn.qkv' for i in range(12)]),
Expand Down Expand Up @@ -341,9 +341,9 @@ def mlp_tgr(
}
# fmt: on

assert self.model_name in _supported_vit_cfg
assert self.model_name in supported_vit_cfg

for hook_func, layers in _supported_vit_cfg[self.model_name]:
for hook_func, layers in supported_vit_cfg[self.model_name]:
for layer in layers:
module = rgetattr(self.model, layer)
hook = module.register_backward_hook(hook_func)
Expand Down

0 comments on commit ce1016d

Please sign in to comment.