Skip to content

Commit

Permalink
Refactor model initialization and transform handling in AttackModel
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Feb 11, 2025
1 parent e46e9a4 commit 7f01177
Showing 1 changed file with 46 additions and 18 deletions.
64 changes: 46 additions & 18 deletions torchattack/attack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def from_pretrained(
if from_timm:
import timm

model = timm.create_model(model_name, pretrained=True)
model = model.to(device).eval()
model = timm.create_model(model_name, pretrained=True).eval().to(device)
cfg = timm.data.resolve_data_config(model.pretrained_cfg)

# Construct normalization
Expand All @@ -116,28 +115,57 @@ def from_pretrained(
import torchvision.models as tv_models
import torchvision.transforms.functional as f

model = tv_models.get_model(name=model_name, weights='DEFAULT')
model = model.to(device).eval()
model = tv_models.get_model(model_name, weights='DEFAULT').eval().to(device)

# Resolve transforms from vision model weights
weight_id = str(tv_models.get_model_weights(name=model_name)['DEFAULT'])
cfg = tv_models.get_weight(weight_id).transforms()

# torchvision/transforms/_presets.py::ImageClassification
# Manually construct separated transform and normalize
def transform(x: Image.Image | torch.Tensor) -> torch.Tensor:
x = f.resize(
x,
cfg.resize_size,
interpolation=cfg.interpolation,
antialias=cfg.antialias,
)
x = f.center_crop(x, cfg.crop_size)
if not isinstance(x, torch.Tensor):
x = f.pil_to_tensor(x)
x = f.convert_image_dtype(x, torch.float)
return x # type: ignore[return-value]
# Source: torchvision/transforms/_presets.py::ImageClassification
# We do not import directly from torchvision.transforms._presets as it is
# declared private API and subject to change without warning.
class TvTransform(nn.Module):
def __init__( # type: ignore[no-any-unimported]
self,
crop_size: list[int],
resize_size: list[int],
interpolation: f.InterpolationMode = f.InterpolationMode.BILINEAR,
antialias: bool | None = True,
) -> None:
super().__init__()
self.crop_size = crop_size
self.resize_size = resize_size
self.interpolation = interpolation
self.antialias = antialias

def forward(self, x: Image.Image | torch.Tensor) -> torch.Tensor:
x = f.resize(
x,
self.resize_size,
interpolation=self.interpolation,
antialias=self.antialias,
)
x = f.center_crop(x, self.crop_size)
if not isinstance(x, torch.Tensor):
x = f.pil_to_tensor(x)
x = f.convert_image_dtype(x, torch.float)
return x # type: ignore[return-value]

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'crop_size={self.crop_size}, '
f'resize_size={self.resize_size}, '
f'interpolation={self.interpolation})'
)

# Manually construct separated transform and normalize
transform = TvTransform(
crop_size=cfg.crop_size,
resize_size=cfg.resize_size,
interpolation=cfg.interpolation,
antialias=cfg.antialias,
)
normalize = t.Normalize(mean=cfg.mean, std=cfg.std)

return cls(model_name, device, model, transform, normalize)
Expand Down

0 comments on commit 7f01177

Please sign in to comment.