Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,226 @@ def test_forward_scriptability(self):
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))


class TestDeformAttn:
dtype = torch.float64

def expected_fn(
self,
value: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
sampling_loc: Tensor,
attn_weight: Tensor,
) -> Tensor:
assert value.dim() == 4
assert spatial_shapes.dim() == 2 and spatial_shapes.size(-1) == 2
assert level_start_index.dim() == 1
assert sampling_loc.dim() == 6
assert attn_weight.dim() == 5

B, Nv, H, C = value.shape
L = spatial_shapes.shape[0]
Nq = sampling_loc.shape[1]
P = sampling_loc.shape[4]

out = torch.zeros(B, Nq, H, C, device=value.device, dtype=value.dtype)

for lvl in range(L):
Hl = int(spatial_shapes[lvl, 0].item())
Wl = int(spatial_shapes[lvl, 1].item())
start = int(level_start_index[lvl].item())
end = start + Hl * Wl

value_slice = value[:, start:end, :, :].view(B, Hl, Wl, H, C).permute(0, 3, 4, 1, 2).contiguous()
value_lvl = value_slice.view(B * H, C, Hl, Wl)

grid_lvl = sampling_loc[:, :, :, lvl, :, :].permute(0, 2, 1, 3, 4).contiguous().view(B * H, Nq, P, 2)

grid_lvl = 2.0 * grid_lvl - 1.0

sampled = torch.nn.functional.grid_sample(
value_lvl, grid_lvl, mode="bilinear", padding_mode="zeros", align_corners=False
)

w_lvl = attn_weight[:, :, :, lvl, :].permute(0, 2, 1, 3).contiguous().view(B * H, 1, Nq, P)

contrib = (sampled * w_lvl).sum(dim=-1)

out += contrib.view(B, H, C, Nq).permute(0, 3, 1, 2)

return out.reshape(B, Nq, H * C)

@needs_cuda
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (2,))
@pytest.mark.parametrize("num_query", (7,))
def test_forward(self, contiguous, batch_sz, num_query):
device = "cuda"
B = batch_sz
Nq = num_query
H = 3
C = 4
L = 3
P = 4

spatial_shapes = torch.tensor([[8, 6], [4, 4], [2, 3]], dtype=torch.long, device=device)
level_start_index = torch.empty(L, dtype=torch.long, device=device)
level_start_index[0] = 0
for l in range(1, L):
level_start_index[l] = level_start_index[l - 1] + int(spatial_shapes[l - 1, 0] * spatial_shapes[l - 1, 1])
Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item())

value = torch.randn(B, Nv, H, C, device=device, dtype=self.dtype, requires_grad=True)
sampling_loc = torch.rand(B, Nq, H, L, P, 2, device=device, dtype=self.dtype, requires_grad=True)
sampling_loc = sampling_loc.clamp(0.05, 0.95)
attn_weight = torch.rand(B, Nq, H, L, P, device=device, dtype=self.dtype, requires_grad=True)
attn_weight = attn_weight.softmax(-1)
attn_weight = attn_weight.softmax(-2)

if not contiguous:
if value.numel() > 0:
value = value.permute(0, 2, 1, 3).contiguous().permute(0, 2, 1, 3)
if sampling_loc.numel() > 0:
sampling_loc = sampling_loc.permute(0, 2, 1, 3, 4, 5).contiguous().permute(0, 2, 1, 3, 4, 5)
if attn_weight.numel() > 0:
attn_weight = attn_weight.permute(0, 2, 1, 3, 4).contiguous().permute(0, 2, 1, 3, 4)

if not contiguous:
with pytest.raises(RuntimeError):
res = ops.deform_attn(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step=64
)
return
res = ops.deform_attn(value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step=64)

expected = self.expected_fn(value, spatial_shapes, level_start_index, sampling_loc, attn_weight)

torch.testing.assert_close(res.to(expected.dtype), expected, msg=f"\nres: \n{res}\nexpected: \n{expected}")

@needs_cuda
def test_wrong_sizes(self):
device = "cuda"
B = 2
Nq = 5
H = 2
C = 3
L = 2
P = 4

spatial_shapes = torch.tensor([[6, 5], [4, 3]], dtype=torch.long, device=device)
level_start_index = torch.tensor([0, 30], dtype=torch.long, device=device)
Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item())

value = torch.randn(B, Nv, H, C, device=device, dtype=self.dtype, requires_grad=True)
sampling_loc = torch.rand(B, Nq, H, L, P, 2, device=device, dtype=self.dtype, requires_grad=True)
attn_weight = torch.rand(B, Nq, H, L, P, device=device, dtype=self.dtype, requires_grad=True)

with pytest.raises(RuntimeError):
wrong_value = torch.randn(B, Nv + 1, H, C, device=device, dtype=self.dtype)
ops.deform_attn(wrong_value, spatial_shapes, level_start_index, sampling_loc, attn_weight)

with pytest.raises(RuntimeError):
wrong_spatial_shapes = torch.tensor([[6, 5]], dtype=torch.long, device=device)
ops.deform_attn(value, wrong_spatial_shapes, level_start_index[:1], sampling_loc, attn_weight)

with pytest.raises(RuntimeError):
wrong_attn = attn_weight[:, :, :, :, : P - 1]
ops.deform_attn(value, spatial_shapes, level_start_index, sampling_loc, wrong_attn)

with pytest.raises(RuntimeError):
wrong_sampling = sampling_loc[..., 0]
ops.deform_attn(value, spatial_shapes, level_start_index, wrong_sampling, attn_weight)

@needs_cuda
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (1,))
@pytest.mark.parametrize("num_query", (3,))
@pytest.mark.opcheck_only_one()
def test_backward(self, contiguous, batch_sz, num_query):
device = "cuda"
B = batch_sz
Nq = num_query
H = 2
C = 2
L = 2
P = 3

spatial_shapes = torch.tensor([[4, 3], [2, 2]], dtype=torch.long, device=device)
level_start_index = torch.tensor([0, 12], dtype=torch.long, device=device)
Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item())

value = torch.randn(B, Nv, H, C, device=device, dtype=self.dtype, requires_grad=True)
sampling_loc = torch.rand(B, Nq, H, L, P, 2, device=device, dtype=self.dtype, requires_grad=True).clamp(
0.1, 0.9
)
attn_weight = torch.rand(B, Nq, H, L, P, device=device, dtype=self.dtype, requires_grad=True)

if not contiguous:
if value.numel() > 0:
value = value.permute(0, 2, 1, 3).contiguous().permute(0, 2, 1, 3)
if sampling_loc.numel() > 0:
sampling_loc = sampling_loc.permute(0, 2, 1, 3, 4, 5).contiguous().permute(0, 2, 1, 3, 4, 5)
if attn_weight.numel() > 0:
attn_weight = attn_weight.permute(0, 2, 1, 3, 4).contiguous().permute(0, 2, 1, 3, 4)

def func(v_, sl_, lw_, samp_, w_):
return ops.deform_attn(v_, sl_, lw_, samp_, w_, im2col_step=64)

if not contiguous:
with pytest.raises(RuntimeError):
gradcheck(
func,
(value, spatial_shapes, level_start_index, sampling_loc, attn_weight),
nondet_tol=1e-5,
fast_mode=True,
)
return
gradcheck(
func,
(value, spatial_shapes, level_start_index, sampling_loc, attn_weight),
nondet_tol=1e-5,
fast_mode=True,
)

@torch.jit.script
def script_func(v_: Tensor, sl_: Tensor, lw_: Tensor, samp_: Tensor, w_: Tensor, step_: int) -> Tensor:
return ops.deform_attn(v_, sl_, lw_, samp_, w_, im2col_step=step_)

gradcheck(
lambda v, samp, w: script_func(v, spatial_shapes, level_start_index, samp, w, 64),
(value, sampling_loc, attn_weight),
nondet_tol=1e-5,
fast_mode=True,
)

@needs_cuda
def test_forward_scriptability(self):
@torch.jit.script
def script_func(
value: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
sampling_loc: Tensor,
attn_weight: Tensor,
step: int,
) -> Tensor:
return ops.deform_attn(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step=step
)

device = "cuda"
B, Nq, H, C, L, P = 2, 3, 2, 4, 2, 4
spatial_shapes = torch.tensor([[6, 5], [3, 4]], dtype=torch.long, device=device)
level_start_index = torch.tensor([0, 30], dtype=torch.long, device=device)
Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item())
value = torch.randn(B, Nv, H, C, dtype=torch.float, device=device)
sampling_loc = torch.rand(B, Nq, H, L, P, 2, dtype=torch.float, device=device).clamp(0.1, 0.9)
attn_weight = torch.rand(B, Nq, H, L, P, dtype=torch.float, device=device)

out = script_func(value, spatial_shapes, level_start_index, sampling_loc, attn_weight, 64)
assert out.shape == (B, Nq, H * C)


# NS: Remove me once backward is implemented for MPS
def xfail_if_mps(x):
mps_xfail_param = pytest.param("mps", marks=(pytest.mark.needs_mps, pytest.mark.xfail))
Expand Down
33 changes: 33 additions & 0 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,36 @@ def meta_deform_conv2d_backward(
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias


@register_meta("deform_attn")
def meta_deform_attn(
value,
spatial_shapes,
level_start_index,
sampling_loc,
attn_weight,
im2col_step,
):
batch_size = value.shape[0]
num_query = sampling_loc.shape[1]
num_heads = value.shape[2]
channels = value.shape[3]
# Output shape [batch_size, num_query, num_heads * channels]
return value.new_empty((batch_size, num_query, num_heads * channels))


@register_meta("_deform_attn_backward")
def meta_deform_attn_backward(
value,
spatial_shapes,
level_start_index,
sampling_loc,
attn_weight,
grad_output,
im2col_step,
):
grad_value = value.new_empty(value.shape)
grad_sampling_loc = sampling_loc.new_empty(sampling_loc.shape)
grad_attn_weight = attn_weight.new_empty(attn_weight.shape)
return grad_value, grad_sampling_loc, grad_attn_weight
39 changes: 39 additions & 0 deletions torchvision/csrc/ops/autocast/deform_attn_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "../deform_attn.h"

#include <ATen/autocast_mode.h>
#include <torch/library.h>
#include <torch/types.h>

namespace vision {
namespace ops {

namespace {

at::Tensor deform_attn_autocast(
const at::Tensor& value,
const at::Tensor& spatial_shapes,
const at::Tensor& level_start_index,
const at::Tensor& sampling_loc,
const at::Tensor& attn_weight,
const int64_t im2col_step) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_attn(
value,
spatial_shapes,
level_start_index,
sampling_loc,
attn_weight,
im2col_step)
.to(value.scalar_type());
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_attn"),
TORCH_FN(deform_attn_autocast));
}

} // namespace ops
} // namespace vision
Loading