From f31ce2f7c4689c60a95eda5d0bb7442a50b37c46 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 8 May 2024 21:44:14 +0800 Subject: [PATCH] Make mypy happy --- src/torchattack/admix.py | 8 ++++---- src/torchattack/dataset.py | 16 ++++++++-------- src/torchattack/runner.py | 10 +++++----- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/torchattack/admix.py b/src/torchattack/admix.py index 15adc48..860653e 100644 --- a/src/torchattack/admix.py +++ b/src/torchattack/admix.py @@ -103,13 +103,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: grad = torch.autograd.grad(loss, x_admixs)[0] # Split gradients and compute mean - grads = torch.tensor_split(grad, 5, dim=0) - grads = [g * s for g, s in zip(grads, scales, strict=True)] + split_grads = torch.tensor_split(grad, 5, dim=0) + grads = [g * s for g, s in zip(split_grads, scales, strict=True)] grad = torch.mean(torch.stack(grads), dim=0) # Gather gradients - grads = torch.tensor_split(grad, self.size) - grad = torch.sum(torch.stack(grads), dim=0) + split_grads = torch.tensor_split(grad, self.size) + grad = torch.sum(torch.stack(split_grads), dim=0) # Apply momentum term g = self.decay * g + grad / torch.mean( diff --git a/src/torchattack/dataset.py b/src/torchattack/dataset.py index 3f83f53..1d5dcb8 100644 --- a/src/torchattack/dataset.py +++ b/src/torchattack/dataset.py @@ -16,7 +16,7 @@ def __init__( self, image_root: str, pairs_path: str, - transform: Callable[[torch.Tensor], torch.Tensor] | None = None, + transform: Callable[[torch.Tensor | Image.Image], torch.Tensor] | None = None, max_samples: int | None = None, ) -> None: """Initialize the NIPS 2017 Adversarial Learning Challenge dataset. @@ -73,10 +73,10 @@ def __getitem__(self, index: int) -> tuple[Any, int, str]: name = self.names[index] label = int(self.labels[index]) - 1 - image = Image.open(f'{self.image_root}/{name}.png').convert('RGB') - # image = np.array(image, dtype=np.uint8) - # image = torch.from_numpy(image).permute((2, 0, 1)).contiguous().float().div(255) - image = self.transform(image) if self.transform else image + pil_image = Image.open(f'{self.image_root}/{name}.png').convert('RGB') + # np_image = np.array(pil_image, dtype=np.uint8) + # image = torch.from_numpy(np_image).permute((2, 0, 1)).contiguous().float().div(255) + image = self.transform(pil_image) if self.transform else pil_image return image, label, name @@ -89,9 +89,9 @@ class NIPSLoader(DataLoader): >>> from torchvision.transforms import transforms >>> from torchattack.dataset import NIPSLoader - >>> transform = transforms.Resize([224]) + >>> transform = transforms.Compose([transforms.Resize([224]), transforms.ToTensor()]) >>> dataloader = NIPSLoader( - >>> path="data/nips2017", batch_size=16, transform=transform + >>> path="data/nips2017", batch_size=16, transform=transform, max_samples=100 >>> ) You can specify a custom image root directory and CSV file location by @@ -107,7 +107,7 @@ def __init__( batch_size: int = 1, shuffle: bool = False, num_workers: int = 4, - transform: Callable[[torch.Tensor], torch.Tensor] | None = None, + transform: Callable[[torch.Tensor | Image.Image], torch.Tensor] | None = None, max_samples: int | None = None, ): # Specifing a custom image root directory is useful when evaluating diff --git a/src/torchattack/runner.py b/src/torchattack/runner.py index 2d9cf1c..134359b 100644 --- a/src/torchattack/runner.py +++ b/src/torchattack/runner.py @@ -1,5 +1,5 @@ from contextlib import suppress -from typing import Any, Optional +from typing import Any import torch import torchvision as tv @@ -11,9 +11,9 @@ class FoolingRateMetric: """Fooling rate metric tracker.""" def __init__(self) -> None: - self.total_count = 0 - self.clean_count = 0 - self.adv_count = 0 + self.total_count = torch.tensor(0) + self.clean_count = torch.tensor(0) + self.adv_count = torch.tensor(0) def update( self, labels: torch.Tensor, clean_logits: torch.Tensor, adv_logits: torch.Tensor @@ -42,7 +42,7 @@ def compute_clean_accuracy(self) -> torch.Tensor: def run_attack( attack: Any, - attack_cfg: Optional[dict] = None, + attack_cfg: dict | None = None, model_name: str = 'resnet50', max_samples: int = 100, batch_size: int = 8,