Skip to content

Commit

Permalink
DeCoWA docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Oct 21, 2024
1 parent c6fac5c commit d297edd
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/torchattack/decowa.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + delta

def _vwt(self, x: torch.Tensor, noise_map: torch.Tensor) -> torch.Tensor:
"""
Performs Vanilla Warping Transformation (VWT) on the input images using the
noise map, i.e., computes a grid of control points and applies TPS to warp the
image according to these points.
Args:
x: The input tensor of shape (N, C, H, W).
noise_map: The noise map tensor of shape (H, W, 2) used to warp the grid.
Returns:
The transformed tensor of the same shape as the input tensor `x`.
"""

n, _, w, h = x.size()

xx = self._grid_points_2d(self.mesh_width, self.mesh_height, self.device)
Expand All @@ -171,6 +184,20 @@ def _vwt(self, x: torch.Tensor, noise_map: torch.Tensor) -> torch.Tensor:
def _grid_points_2d(
self, width: int, height: int, device: torch.device
) -> torch.Tensor:
"""Helper function to generate 2d grid points for the TPS transformation.
Creates a regular grid of points in the range [-1, 1] x [-1, 1] with `width`
points in the x-direction and `height` points in the y-direction. The grid is
then reshaped to a 2D tensor of shape (width * height, 2).
Args:
width: The number of points in the x-direction.
height: The number of points in the y-direction. device: The device.
Returns:
The grid of points of shape (width * height, 2).
"""

x = torch.linspace(-1, 1, width, device=device)
y = torch.linspace(-1, 1, height, device=device)

Expand All @@ -182,27 +209,43 @@ def _grid_points_2d(
def _noisy_grid(
self, width: int, height: int, noise_map: torch.Tensor, device: torch.device
) -> torch.Tensor:
"""Creates a perturbed version of the grid by adding random noise.
Args:
width: The number of points in the x-direction.
height: The number of points in the y-direction.
noise_map: The noise map tensor of shape (H, W, 2) used to warp the grid.
device: The device.
Returns:
The grid of points of shape (width * height, 2).
"""

grid = self._grid_points_2d(width, height, device)
mod = torch.zeros([height, width, 2], device=device)
mod[1 : height - 1, 1 : width - 1, :] = noise_map
return grid + mod.reshape(-1, 2)


def k_matrix(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the kernel matrix, which measures the deformation energy."""
eps = 1e-9
d2 = torch.pow(x[:, :, None, :] - y[:, None, :, :], 2).sum(-1)
k_mat = d2 * torch.log(d2 + eps)
return k_mat


def p_matrix(x: torch.Tensor) -> torch.Tensor:
"""Constructs the P matrix, which includes affine transformation components."""
n, k = x.shape[:2]
p_mat = torch.ones(n, k, 3, device=x.device)
p_mat[:, :, 1:] = x
return p_mat


class TPSCoeffs(nn.Module):
"""Computes the coefficients for the TPS transformation based on control points."""

def __init__(self):
super().__init__()

Expand All @@ -226,6 +269,16 @@ def forward(


class TPS(nn.Module):
"""Thin Plate Spline transformation.
Applies the TPS transformation to the image grid, warping it according to the
computed coefficients.
Args:
size: The size of the grid. Defaults to (256, 256).
device: Device to use for tensors. Defaults to cuda if available.
"""

def __init__(self, size: tuple = (256, 256), device: torch.device | None = None):
super().__init__()
h, w = size
Expand Down

0 comments on commit d297edd

Please sign in to comment.