Skip to content

Commit

Permalink
Merge pull request #19 from spencerwooo:eval-module
Browse files Browse the repository at this point in the history
Refactor for separate evaluation module
  • Loading branch information
spencerwooo authored Nov 21, 2024
2 parents ce1016d + 2426a2e commit a784446
Show file tree
Hide file tree
Showing 29 changed files with 97 additions and 88 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ python -m pip install git+https://gitee.com/spencerwoo/torchattack

```python
import torch
from torchattack import FGSM, MIFGSM
from torchattack.eval import AttackModel

from torchattack import AttackModel, FGSM, MIFGSM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand All @@ -51,7 +51,7 @@ attack = FGSM(model, normalize, device)
attack = MIFGSM(model, normalize, device, eps=0.03, steps=10, decay=1.0)
```

Check out [`torchattack.eval.run_attack`](src/torchattack/eval.py) for a simple example.
Check out [`torchattack.eval.runner`](torchattack/eval/runner.py) for a quick example.

## Attacks

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ quote-style = "single"
no_implicit_optional = true
check_untyped_defs = true
ignore_missing_imports = true # Used as torchvision does not ship type hints

[tool.uv]
dev-dependencies = [
"pytest>=8.3.3",
]
# disallow_any_unimported = true
# disallow_untyped_defs = true
# warn_return_any = true
4 changes: 4 additions & 0 deletions torchattack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torchattack.admix import Admix
from torchattack.attack_model import AttackModel
from torchattack.decowa import DeCoWA
from torchattack.deepfool import DeepFool
from torchattack.difgsm import DIFGSM
Expand All @@ -22,6 +23,9 @@
__version__ = '1.0.0'

__all__ = [
# Optional but recommended model wrapper
'AttackModel',
# All supported attacks
'Admix',
'DeCoWA',
'DeepFool',
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion torchattack/admix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class Admix(Attack):
Expand Down
18 changes: 8 additions & 10 deletions torchattack/attack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
class AttackModel:
"""A wrapper class for a pretrained model used for adversarial attacks.
Intended to be instantiated with
`AttackModel.from_pretrained(pretrained_model_name)` from either
`torchvision.models` or `timm`. The model is loaded and attributes including
`transform`, `normalize`, and `model_name` are attached based on the model's
configuration.
Intended to be instantiated with `AttackModel.from_pretrained(<MODEL_NAME>)` from
either `torchvision.models` or `timm`. The model is loaded and attributes including
`model_name`, `transform`, and `normalize` are attached based on the model's config.
Attributes:
model_name (str): The name of the model.
device (torch.device): The device on which the model is loaded.
model (nn.Module): The pretrained model itself.
transform (Callable): The transformation function applied to input images.
normalize (Callable): The normalization function applied to input images.
model_name: The name of the model.
device: The device on which the model is loaded.
model: The pretrained model itself.
transform: The transformation function applied to input images.
normalize: The normalization function applied to input images.
Example:
>>> model = AttackModel.from_pretrained('resnet50', device='cuda')
Expand Down
2 changes: 1 addition & 1 deletion torchattack/decowa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class DeCoWA(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/deepfool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class DeepFool(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/difgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn as nn
import torch.nn.functional as f

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class DIFGSM(Attack):
Expand Down
5 changes: 5 additions & 0 deletions torchattack/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from torchattack.eval.dataset import NIPSDataset, NIPSLoader
from torchattack.eval.metric import FoolingRateMetric
from torchattack.eval.runner import run_attack

__all__ = ['run_attack', 'FoolingRateMetric', 'NIPSDataset', 'NIPSLoader']
File renamed without changes.
38 changes: 38 additions & 0 deletions torchattack/eval/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch


class FoolingRateMetric:
"""Fooling rate metric tracker."""

def __init__(self) -> None:
self.total_count = torch.tensor(0)
self.clean_count = torch.tensor(0)
self.adv_count = torch.tensor(0)

def update(
self, labels: torch.Tensor, clean_logits: torch.Tensor, adv_logits: torch.Tensor
) -> None:
"""Update metric tracker during attack progress.
Args:
labels: Ground truth labels.
clean_logits: Prediction logits for clean samples.
adv_logits: Prediction logits for adversarial samples.
"""

self.total_count += labels.numel()
self.clean_count += (clean_logits.argmax(dim=1) == labels).sum().item()
self.adv_count += (adv_logits.argmax(dim=1) == labels).sum().item()

def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the fooling rate and related metrics.
Returns:
A tuple of torch.Tensors containing the clean accuracy, adversarial
accuracy, and fooling rate computed, respectively.
"""
return (
self.clean_count / self.total_count,
self.adv_count / self.total_count,
(self.clean_count - self.adv_count) / self.clean_count,
)
69 changes: 14 additions & 55 deletions torchattack/eval.py → torchattack/eval/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,15 @@
import torch

from torchattack.attack_model import AttackModel


class FoolingRateMetric:
"""Fooling rate metric tracker."""

def __init__(self) -> None:
self.total_count = torch.tensor(0)
self.clean_count = torch.tensor(0)
self.adv_count = torch.tensor(0)

def update(
self, labels: torch.Tensor, clean_logits: torch.Tensor, adv_logits: torch.Tensor
) -> None:
"""Update metric tracker during attack progress.
Args:
labels: Ground truth labels.
clean_logits: Prediction logits for clean samples.
adv_logits: Prediction logits for adversarial samples.
"""

self.total_count += labels.numel()
self.clean_count += (clean_logits.argmax(dim=1) == labels).sum().item()
self.adv_count += (adv_logits.argmax(dim=1) == labels).sum().item()

def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the fooling rate and related metrics.
Returns:
A tuple of torch.Tensors containing the clean accuracy, adversarial
accuracy, and fooling rate computed, respectively.
"""
return (
self.clean_count / self.total_count,
self.adv_count / self.total_count,
(self.clean_count - self.adv_count) / self.clean_count,
)
from torchattack.eval.metric import FoolingRateMetric


def run_attack(
attack: Any,
attack_cfg: dict | None = None,
model_name: str = 'resnet50',
victim_model_names: list[str] | None = None,
dataset_root: str = 'datasets/nips2017',
max_samples: int = 100,
batch_size: int = 16,
from_timm: bool = False,
Expand All @@ -63,16 +28,16 @@ def run_attack(
attack_cfg: A dict of keyword arguments passed to the attack class.
model_name: The surrogate model to attack. Defaults to "resnet50".
victim_model_names: A list of the victim black-box models to attack. Defaults to None.
dataset_root: Root directory of the dataset. Defaults to "datasets/nips2017".
max_samples: Max number of samples to attack. Defaults to 100.
batch_size: Batch size for the dataloader. Defaults to 16.
from_timm: Use timm to load the model. Defaults to True.
"""

import torchvision as tv
from rich import print
from rich.progress import track

from torchattack.dataset import NIPSLoader
from torchattack.eval.dataset import NIPSLoader

if attack_cfg is None:
attack_cfg = {}
Expand All @@ -84,7 +49,7 @@ def run_attack(

# Set up dataloader
dataloader = NIPSLoader(
root='datasets/nips2017',
root=dataset_root,
batch_size=batch_size,
transform=transform,
max_samples=max_samples,
Expand All @@ -93,15 +58,7 @@ def run_attack(

# Set up attack and trackers
frm = FoolingRateMetric()
attacker = attack(
# Pass the original PyTorch model instead of the wrapped one if the attack
# requires access to the model's intermediate layers or other attributes that
# are not exposed by the AttackModel wrapper.
model=model,
normalize=normalize,
device=device,
**attack_cfg,
)
attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg)
print(attacker)

# Setup victim models if provided
Expand All @@ -113,7 +70,7 @@ def run_attack(
victim_frms = [FoolingRateMetric() for _ in victim_model_names]

# Run attack over the dataset (100 images by default)
for i, (x, y, _) in enumerate(dataloader):
for _i, (x, y, _) in enumerate(dataloader):
x, y = x.to(device), y.to(device)

# Adversarial images are created here
Expand All @@ -124,11 +81,13 @@ def run_attack(
adv_outs = model(normalize(advs))
frm.update(y, cln_outs, adv_outs)

# Save first batch of adversarial examples
if i == 0:
saved_imgs = advs.detach().cpu().mul(255).to(torch.uint8)
img_grid = tv.utils.make_grid(saved_imgs, nrow=4)
tv.io.write_png(img_grid, 'adv_batch_0.png')
# *Save first batch of adversarial examples
# if _i == 0:
# import torchvision as tv

# saved_imgs = advs.detach().cpu().mul(255).to(torch.uint8)
# img_grid = tv.utils.make_grid(saved_imgs, nrow=4)
# tv.io.write_png(img_grid, 'adv_batch_0.png')

# Track transfer fooling rates if victim models are provided
if victim_model_names:
Expand Down
2 changes: 1 addition & 1 deletion torchattack/fgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class FGSM(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/fia.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class FIA(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/geoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack

with suppress(ImportError):
from rich import print
Expand Down
2 changes: 1 addition & 1 deletion torchattack/mifgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class MIFGSM(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/nifgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class NIFGSM(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/pgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class PGD(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/pgdl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack

EPS_FOR_DIVISION = 1e-12

Expand Down
2 changes: 1 addition & 1 deletion torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack._rgetattr import rgetattr
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class PNAPatchOut(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/sinifgsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class SINIFGSM(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
import torch.nn as nn

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class SSA(Attack):
Expand Down
2 changes: 1 addition & 1 deletion torchattack/ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn as nn
import torchvision as tv

from torchattack._attack import Attack
from torchattack.attack_model import AttackModel
from torchattack.base import Attack


class PerceptualCriteria(nn.Module):
Expand Down
Loading

0 comments on commit a784446

Please sign in to comment.