Skip to content

Commit

Permalink
Complete implementation of DeCoWA
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Sep 18, 2024
1 parent b12f97d commit 7694144
Showing 1 changed file with 51 additions and 26 deletions.
77 changes: 51 additions & 26 deletions src/torchattack/decowa.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,28 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Perform warping
for _ in range(self.num_warping):
# Apply warping to perturbation
noise_map_hat = self._update_noise_map(x + delta, y)
noise_map = torch.rand([self.mesh_height - 2, self.mesh_width - 2, 2])
noise_map_hat = (noise_map - 0.5) * self.noise_scale

# Iterate for a round only
for _ in range(1):
noise_map_hat.requires_grad_()

vwt_x = self._vwt(x + delta, noise_map_hat)
outs = self.model(self.normalize(vwt_x))
loss = self.lossfn(outs, y)

if self.targeted:
loss = -loss

loss.backward()

if delta.grad is None:
continue

noise_map_hat.detach_()
noise_map_hat -= self.rho * noise_map_hat.grad

vwt_x = self._vwt(x + delta, noise_map_hat)

# Compute loss
Expand Down Expand Up @@ -127,17 +148,21 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

return x + delta

def _update_noise_map(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
noise_map = torch.rand([self.mesh_height - 2, self.mesh_width - 2, 2]) - 0.5
noise_map = noise_map * self.noise_scale

noise_map.requires_grad_()
vwt_x = self._vwt(x, noise_map)

def _vwt(self, x: torch.Tensor, noise_map_hat: torch.Tensor) -> torch.Tensor:
def _vwt(self, x: torch.Tensor, noise_map: torch.Tensor) -> torch.Tensor:
n, c, w, h = x.size()
grid_x = self._grid_points_2d(w, h, x.device)
xx = self._grid_points_2d(self.mesh_width, self.mesh_height, self.device)
yy = self._noisy_grid(self.mesh_width, self.mesh_height, noise_map, self.device)
tpbs = TPS(size=(h, w), device=self.device)
warped_grid_b = tpbs(xx[None, ...], yy[None, ...])
warped_grid_b = warped_grid_b.repeat(n, 1, 1, 1)
vwt_x = torch.grid_sampler_2d(
input=x,
grid=warped_grid_b,
interpolation_mode=0,
padding_mode=0,
align_corners=False,
)
return vwt_x

def _grid_points_2d(
self, width: int, height: int, device: torch.device
Expand All @@ -152,7 +177,7 @@ def _noisy_grid(
self, width: int, height: int, noise_map: torch.Tensor, device: torch.device
) -> torch.Tensor:
grid = self._grid_points_2d(width, height, device)
mod = torch.zeros((width * height, 2), device=device)
mod = torch.zeros([height, width, 2], device=device)
mod[1 : height - 1, 1 : width - 1, :] = noise_map
return grid + mod.reshape(-1, 2)

Expand All @@ -178,20 +203,20 @@ def __init__(self):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
n, k_mat = x.shape[:2]
z_mat = torch.zeros(1, k_mat + 3, 2, device=x.device)
p_mat = torch.ones(n, k_mat, 3, device=x.device)
l_mat = torch.zeros(n, k_mat + 3, k_mat + 3, device=x.device)
n, k = x.shape[:2]
z_mat = torch.zeros(1, k + 3, 2, device=x.device)
p_mat = torch.ones(n, k, 3, device=x.device)
l_mat = torch.zeros(n, k + 3, k + 3, device=x.device)
k_mat = k_matrix(x, x)

p_mat[:, :, 1:] = x
z_mat[:, :k_mat, :] = y
l_mat[:, :k_mat, :k_mat] = k_mat
l_mat[:, :k_mat, k_mat:] = p_mat
l_mat[:, k_mat:, :k_mat] = p_mat.permute(0, 2, 1)
z_mat[:, :k, :] = y
l_mat[:, :k, :k] = k_mat
l_mat[:, :k, k:] = p_mat
l_mat[:, k:, :k] = p_mat.permute(0, 2, 1)

q_mat = torch.linalg.solve(l_mat, z_mat)
return q_mat[:, :k_mat], q_mat[:, k_mat:]
return q_mat[:, :k], q_mat[:, k:]


class TPS(nn.Module):
Expand All @@ -210,10 +235,10 @@ def __init__(self, size: tuple = (256, 256), device: torch.device | None = None)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
h, w = self.size
w, a = self.tps(x, y)
w_mat, a_mat = self.tps(x, y)
u_mat = k_matrix(self.grid, x)
p_mat = p_matrix(self.grid)
grid = p_mat @ a + u_mat @ w
grid = p_mat @ a_mat + u_mat @ w_mat
return grid.view(-1, h, w, 2)


Expand All @@ -222,7 +247,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

run_attack(
attack=DeCoWA,
attack_cfg={'eps': 8 / 255, 'steps': 10},
model_name='resnet18',
victim_model_names=['resnet50', 'vgg13', 'densenet121'],
attack_cfg={'eps': 16 / 255, 'steps': 10},
model_name='resnet50',
victim_model_names=['resnet18', 'vit_b_16'],
)

0 comments on commit 7694144

Please sign in to comment.