From e005064f461c0f8f8bfc67ea6be05ad1dd3076f8 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Mon, 26 Feb 2024 11:52:21 +0800 Subject: [PATCH] refactor: dataset paths and transform --- .gitignore | 2 +- src/torchattack/dataset.py | 6 +++--- src/torchattack/utils.py | 25 ++++++++++++------------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 22b8fa4..3c2158a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Dataset files -data/ +datasets/ # PDM files .pdm-python diff --git a/src/torchattack/dataset.py b/src/torchattack/dataset.py index 772c5b2..3f83f53 100644 --- a/src/torchattack/dataset.py +++ b/src/torchattack/dataset.py @@ -101,7 +101,7 @@ class NIPSLoader(DataLoader): def __init__( self, - path: str | None, + root: str | None, image_root: str | None = None, pairs_path: str | None = None, batch_size: int = 1, @@ -115,8 +115,8 @@ def __init__( super().__init__( dataset=NIPSDataset( - image_root=image_root if image_root else f'{path}/images', - pairs_path=pairs_path if pairs_path else f'{path}/images.csv', + image_root=image_root if image_root else f'{root}/images', + pairs_path=pairs_path if pairs_path else f'{root}/images.csv', transform=transform, max_samples=max_samples, ), diff --git a/src/torchattack/utils.py b/src/torchattack/utils.py index 8ad2ebe..20959be 100644 --- a/src/torchattack/utils.py +++ b/src/torchattack/utils.py @@ -26,27 +26,26 @@ def run_attack(attack, attack_cfg, model='resnet50', samples=100, batch_size=8) from rich import print from rich.progress import track - # Set up model and dataloader + # Set up model, transform, and normalize device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = tv.models.get_model(name=model, weights='DEFAULT').to(device).eval() + transform = tv.transforms.Compose( + [ + tv.transforms.Resize([224]), + tv.transforms.ToTensor(), + ] + ) + normalize = tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + # Set up dataloader dataloader = NIPSLoader( - path='data/nips2017', + root='datasets/nips2017', batch_size=batch_size, - transform=tv.transforms.Compose( - [ - tv.transforms.Resize([232]), - tv.transforms.CenterCrop([224]), - tv.transforms.ToTensor(), - ] - ), + transform=transform, max_samples=samples, ) # Set up attack and trackers - normalize = tv.transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - ) total, acc_clean, acc_adv = len(dataloader.dataset), 0, 0 # type: ignore attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg) print(attacker)