diff --git a/src/torchattack/dataset.py b/src/torchattack/dataset.py index 2f21fae..772c5b2 100644 --- a/src/torchattack/dataset.py +++ b/src/torchattack/dataset.py @@ -107,7 +107,7 @@ def __init__( batch_size: int = 1, shuffle: bool = False, num_workers: int = 4, - normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, + transform: Callable[[torch.Tensor], torch.Tensor] | None = None, max_samples: int | None = None, ): # Specifing a custom image root directory is useful when evaluating @@ -117,7 +117,7 @@ def __init__( dataset=NIPSDataset( image_root=image_root if image_root else f'{path}/images', pairs_path=pairs_path if pairs_path else f'{path}/images.csv', - transform=normalize, + transform=transform, max_samples=max_samples, ), batch_size=batch_size, diff --git a/src/torchattack/utils.py b/src/torchattack/utils.py index 823d08e..8ad2ebe 100644 --- a/src/torchattack/utils.py +++ b/src/torchattack/utils.py @@ -32,7 +32,13 @@ def run_attack(attack, attack_cfg, model='resnet50', samples=100, batch_size=8) dataloader = NIPSLoader( path='data/nips2017', batch_size=batch_size, - transform=tv.transforms.Resize(size=224, antialias=True), + transform=tv.transforms.Compose( + [ + tv.transforms.Resize([232]), + tv.transforms.CenterCrop([224]), + tv.transforms.ToTensor(), + ] + ), max_samples=samples, ) @@ -42,7 +48,7 @@ def run_attack(attack, attack_cfg, model='resnet50', samples=100, batch_size=8) std=[0.229, 0.224, 0.225], ) total, acc_clean, acc_adv = len(dataloader.dataset), 0, 0 # type: ignore - attacker = attack(model=model, transform=normalize, device=device, **attack_cfg) + attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg) print(attacker) # Wrap dataloader with rich.progress.track if available