From 7f01177ee6d1902dbdc191495df4f69dcb7d0fdf Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 11 Feb 2025 15:19:13 +0800 Subject: [PATCH] Refactor model initialization and transform handling in AttackModel --- torchattack/attack_model.py | 64 ++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/torchattack/attack_model.py b/torchattack/attack_model.py index 3ad599e..f361e0b 100644 --- a/torchattack/attack_model.py +++ b/torchattack/attack_model.py @@ -94,8 +94,7 @@ def from_pretrained( if from_timm: import timm - model = timm.create_model(model_name, pretrained=True) - model = model.to(device).eval() + model = timm.create_model(model_name, pretrained=True).eval().to(device) cfg = timm.data.resolve_data_config(model.pretrained_cfg) # Construct normalization @@ -116,28 +115,57 @@ def from_pretrained( import torchvision.models as tv_models import torchvision.transforms.functional as f - model = tv_models.get_model(name=model_name, weights='DEFAULT') - model = model.to(device).eval() + model = tv_models.get_model(model_name, weights='DEFAULT').eval().to(device) # Resolve transforms from vision model weights weight_id = str(tv_models.get_model_weights(name=model_name)['DEFAULT']) cfg = tv_models.get_weight(weight_id).transforms() - # torchvision/transforms/_presets.py::ImageClassification - # Manually construct separated transform and normalize - def transform(x: Image.Image | torch.Tensor) -> torch.Tensor: - x = f.resize( - x, - cfg.resize_size, - interpolation=cfg.interpolation, - antialias=cfg.antialias, - ) - x = f.center_crop(x, cfg.crop_size) - if not isinstance(x, torch.Tensor): - x = f.pil_to_tensor(x) - x = f.convert_image_dtype(x, torch.float) - return x # type: ignore[return-value] + # Source: torchvision/transforms/_presets.py::ImageClassification + # We do not import directly from torchvision.transforms._presets as it is + # declared private API and subject to change without warning. + class TvTransform(nn.Module): + def __init__( # type: ignore[no-any-unimported] + self, + crop_size: list[int], + resize_size: list[int], + interpolation: f.InterpolationMode = f.InterpolationMode.BILINEAR, + antialias: bool | None = True, + ) -> None: + super().__init__() + self.crop_size = crop_size + self.resize_size = resize_size + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, x: Image.Image | torch.Tensor) -> torch.Tensor: + x = f.resize( + x, + self.resize_size, + interpolation=self.interpolation, + antialias=self.antialias, + ) + x = f.center_crop(x, self.crop_size) + if not isinstance(x, torch.Tensor): + x = f.pil_to_tensor(x) + x = f.convert_image_dtype(x, torch.float) + return x # type: ignore[return-value] + + def __repr__(self) -> str: + return ( + f'{self.__class__.__name__}(' + f'crop_size={self.crop_size}, ' + f'resize_size={self.resize_size}, ' + f'interpolation={self.interpolation})' + ) + # Manually construct separated transform and normalize + transform = TvTransform( + crop_size=cfg.crop_size, + resize_size=cfg.resize_size, + interpolation=cfg.interpolation, + antialias=cfg.antialias, + ) normalize = t.Normalize(mean=cfg.mean, std=cfg.std) return cls(model_name, device, model, transform, normalize)