Skip to content

Commit

Permalink
Remove redundant randrot method
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Feb 1, 2025
1 parent 79ee7fe commit fefe882
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
1 change: 1 addition & 0 deletions torchattack/_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __eq__(self, other: Any) -> bool:
'perceptual_criteria', # SSP
'sub_basis', # GeoDA
'generator', # BIA, CDA
'randrot', # BSR
]
for attr in eq_name_attrs:
if not (hasattr(self, attr) and hasattr(other, attr)):
Expand Down
17 changes: 8 additions & 9 deletions torchattack/bsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,19 @@ def __init__(
self.steps = steps
self.alpha = alpha
self.decay = decay

self.num_scale = num_scale
self.num_block = num_block

# Declare random rotation transform
self.randrot = t.RandomRotation(
degrees=(-24, 24), interpolation=t.InterpolationMode.BILINEAR
)

self.clip_min = clip_min
self.clip_max = clip_max
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()
self.randrot = t.RandomRotation(
degrees=(-24, 24), interpolation=t.InterpolationMode.BILINEAR
)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Perform BSR on a batch of images.
Expand Down Expand Up @@ -149,11 +153,6 @@ def _shuffle_single_dim(self, x: torch.Tensor, dim: int) -> list[torch.Tensor]:
random.shuffle(x_strips)
return x_strips

def _image_random_rot(self, x: torch.Tensor) -> torch.Tensor:
"""Apply a random rotation to the input tensor."""

return self.randrot(x)

def _bsr_shuffle(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the BSR (Block Shuffle and Rotate) transformation to the input tensor.
Expand All @@ -177,7 +176,7 @@ def _bsr_shuffle(self, x: torch.Tensor) -> torch.Tensor:
# For each strip, apply random rotation and then shuffle along the second dim
rotated_strips = []
for x_strip in x_strips:
rotated = self._image_random_rot(x_strip)
rotated = self.randrot(x_strip)
shuffled = self._shuffle_single_dim(rotated, dim=d2)
rotated_strips.append(torch.cat(shuffled, dim=d2))

Expand Down

0 comments on commit fefe882

Please sign in to comment.