Skip to content

Commit

Permalink
Simplify torchattack model creation
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Sep 14, 2024
1 parent 409a44a commit a0703f3
Showing 1 changed file with 44 additions and 40 deletions.
84 changes: 44 additions & 40 deletions src/torchattack/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,39 +110,11 @@ def from_pretrained(

import torchvision.transforms as t

if not from_timm:
try:
import torchvision.models as tv_models

model = tv_models.get_model(name=model_name, weights='DEFAULT')

# Resolve transforms from vision model weights
weight_id = str(tv_models.get_model_weights(name=model_name)['DEFAULT'])
cfg = tv_models.get_weight(weight_id).transforms()

# torchvision/transforms/_presets.py::ImageClassification
# Manually construct separated transform and normalize
transform = t.Compose(
[
t.Resize(
cfg.resize_size,
interpolation=cfg.interpolation,
antialias=cfg.antialias,
),
t.CenterCrop(cfg.crop_size),
t.ToTensor(),
]
)
normalize = t.Normalize(mean=cfg.mean, std=cfg.std)

except ValueError:
print('Model not found in torchvision.models, falling back to timm.')
from_timm = True

else:
if from_timm:
import timm

model = timm.create_model(model_name, pretrained=True)
model = model.to(device).eval()
cfg = timm.data.resolve_data_config(model.pretrained_cfg)

# Construct normalization
Expand All @@ -155,8 +127,40 @@ def from_pretrained(
tr for tr in transform.transforms if not isinstance(tr, t.Normalize)
]

model = model.to(device).eval()
return cls(model_name, device, model, transform, normalize)
return cls(model_name, device, model, transform, normalize)

# If the model is not specified to be load from timm, try loading from
# `torchvision.models` first, then fall back to timm if the model is not found.
try:
import torchvision.models as tv_models

model = tv_models.get_model(name=model_name, weights='DEFAULT')
model = model.to(device).eval()

# Resolve transforms from vision model weights
weight_id = str(tv_models.get_model_weights(name=model_name)['DEFAULT'])
cfg = tv_models.get_weight(weight_id).transforms()

# torchvision/transforms/_presets.py::ImageClassification
# Manually construct separated transform and normalize
transform = t.Compose(
[
t.Resize(
cfg.resize_size,
interpolation=cfg.interpolation,
antialias=cfg.antialias,
),
t.CenterCrop(cfg.crop_size),
t.ToTensor(),
]
)
normalize = t.Normalize(mean=cfg.mean, std=cfg.std)

return cls(model_name, device, model, transform, normalize)

except ValueError:
print('Model not found in torchvision.models, falling back to timm.')
return cls.from_pretrained(model_name, device, from_timm=True)

def forward(self, x):
return self.model(x)
Expand All @@ -166,8 +170,11 @@ def __call__(self, x):

def __repr__(self):
return (
f'{self.__class__.__name__}(model_name={self.model_name}, device={self.device}, '
f'transform={self.transform}, normalize={self.normalize})'
f'{self.__class__.__name__}('
f'model_name={self.model_name}, '
f'device={self.device}, '
f'transform={self.transform}, '
f'normalize={self.normalize})'
)


Expand Down Expand Up @@ -268,15 +275,12 @@ def run_attack(

# Print results
cln_acc, adv_acc, fr = frm.compute()
print(
f'Surrogate ({model_name}): {cln_acc:.2%} / {adv_acc:.2%} '
f'(Fooling rate: {fr:.2%})'
)
print(f'Surrogate ({model_name}): {cln_acc=:.2%}, {adv_acc=:.2%} ({fr=:.2%})')

if victim_model_names:
for vmodel, vfrm in zip(victim_models, victim_frms):
vcln_acc, vadv_acc, vfr = vfrm.compute()
print(
f'Victim ({vmodel.model_name}): {vcln_acc:.2%} / {vadv_acc:.2%} '
f'(Fooling rate: {vfr:.2%})'
f'Victim ({vmodel.model_name}): cln_acc={vcln_acc:.2%}, '
f'adv_acc={vadv_acc:.2%} (fr={vfr:.2%})'
)

0 comments on commit a0703f3

Please sign in to comment.