Skip to content

Commit

Permalink
Merge pull request #18 from spencerwooo:vdc-attack
Browse files Browse the repository at this point in the history
[New Attack] VDC Attack (Virtual Dense Connection)
  • Loading branch information
spencerwooo authored Nov 20, 2024
2 parents 6aed62c + 0ddf5e8 commit 1acabbd
Show file tree
Hide file tree
Showing 8 changed files with 734 additions and 9 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Dataset files
datasets/

# PDM files
.pdm-python
pdm.lock
# Lockfiles
uv.lock

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"scipy>=1.10.1",
]
dynamic = ["version"]
optional-dependencies = { dev = ["mypy", "rich"] }
optional-dependencies = { dev = ["mypy", "rich", "timm"] }

[build-system]
requires = ["setuptools"]
Expand Down
2 changes: 2 additions & 0 deletions src/torchattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchattack.ssp import SSP
from torchattack.tgr import TGR
from torchattack.tifgsm import TIFGSM
from torchattack.vdc import VDC
from torchattack.vmifgsm import VMIFGSM
from torchattack.vnifgsm import VNIFGSM

Expand All @@ -38,6 +39,7 @@
'SSP',
'TGR',
'TIFGSM',
'VDC',
'VMIFGSM',
'VNIFGSM',
]
2 changes: 1 addition & 1 deletion src/torchattack/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __repr__(self) -> str:
def repr_map(k, v):
if isinstance(v, float):
return f'{k}={v:.3f}'
if k in ['model', 'normalize', 'feature_layer']:
if k in ['model', 'normalize', 'feature_layer', 'hooks']:
return f'{k}={v.__class__.__name__}'
if isinstance(v, torch.Tensor):
return f'{k}={v.shape}'
Expand Down
5 changes: 4 additions & 1 deletion src/torchattack/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def from_pretrained(
return cls(model_name, device, model, transform, normalize)

except ValueError:
print('Model not found in torchvision.models, falling back to timm.')
print(
f'Warning: Model `{model_name}` not found in torchvision.models, '
'falling back to loading weights from timm.'
)
return cls.from_pretrained(model_name, device, from_timm=True)

def forward(self, x):
Expand Down
8 changes: 7 additions & 1 deletion src/torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()

# Register hooks
if self.pna_skip:
self.hooks: list[torch.utils.hooks.RemovableHandle] = []
self._register_vit_model_hook()

# Set default image size and number of patches for PatchOut
Expand Down Expand Up @@ -143,6 +145,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta.grad.detach_()
delta.grad.zero_()

for hook in self.hooks:
hook.remove()

return x + delta

def _register_vit_model_hook(self):
Expand All @@ -162,7 +167,8 @@ def attn_drop_mask_grad(
# Register backward hook for layers specified in _supported_vit_cfg
for layer in self._supported_vit_cfg[self.model_name]:
module = rgetattr(self.model, layer)
module.register_backward_hook(drop_hook_func)
hook = module.register_backward_hook(drop_hook_func)
self.hooks.append(hook)

def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor:
delta_mask = torch.zeros_like(delta)
Expand Down
10 changes: 8 additions & 2 deletions src/torchattack/tgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()

# Register hooks
self.hooks: list[torch.utils.hooks.RemovableHandle] = []
self._register_tgr_model_hooks()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -110,6 +112,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta.grad.detach_()
delta.grad.zero_()

for hook in self.hooks:
hook.remove()

return x + delta

def _register_tgr_model_hooks(self):
Expand Down Expand Up @@ -348,10 +353,11 @@ def mlp_tgr(

assert self.model_name in _supported_vit_cfg

for hook, layers in _supported_vit_cfg[self.model_name]:
for hook_func, layers in _supported_vit_cfg[self.model_name]:
for layer in layers:
module = rgetattr(self.model, layer)
module.register_backward_hook(hook)
hook = module.register_backward_hook(hook_func)
self.hooks.append(hook)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 1acabbd

Please sign in to comment.