diff --git a/test/test_ops.py b/test/test_ops.py index d2cf8d29181..85f74550939 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1201,6 +1201,226 @@ def test_forward_scriptability(self): torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3)) +class TestDeformAttn: + dtype = torch.float64 + + def expected_fn( + self, + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + sampling_loc: Tensor, + attn_weight: Tensor, + ) -> Tensor: + assert value.dim() == 4 + assert spatial_shapes.dim() == 2 and spatial_shapes.size(-1) == 2 + assert level_start_index.dim() == 1 + assert sampling_loc.dim() == 6 + assert attn_weight.dim() == 5 + + B, Nv, H, C = value.shape + L = spatial_shapes.shape[0] + Nq = sampling_loc.shape[1] + P = sampling_loc.shape[4] + + out = torch.zeros(B, Nq, H, C, device=value.device, dtype=value.dtype) + + for lvl in range(L): + Hl = int(spatial_shapes[lvl, 0].item()) + Wl = int(spatial_shapes[lvl, 1].item()) + start = int(level_start_index[lvl].item()) + end = start + Hl * Wl + + value_slice = value[:, start:end, :, :].view(B, Hl, Wl, H, C).permute(0, 3, 4, 1, 2).contiguous() + value_lvl = value_slice.view(B * H, C, Hl, Wl) + + grid_lvl = sampling_loc[:, :, :, lvl, :, :].permute(0, 2, 1, 3, 4).contiguous().view(B * H, Nq, P, 2) + + grid_lvl = 2.0 * grid_lvl - 1.0 + + sampled = torch.nn.functional.grid_sample( + value_lvl, grid_lvl, mode="bilinear", padding_mode="zeros", align_corners=False + ) + + w_lvl = attn_weight[:, :, :, lvl, :].permute(0, 2, 1, 3).contiguous().view(B * H, 1, Nq, P) + + contrib = (sampled * w_lvl).sum(dim=-1) + + out += contrib.view(B, H, C, Nq).permute(0, 3, 1, 2) + + return out.reshape(B, Nq, H * C) + + @needs_cuda + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("batch_sz", (2,)) + @pytest.mark.parametrize("num_query", (7,)) + def test_forward(self, contiguous, batch_sz, num_query): + device = "cuda" + B = batch_sz + Nq = num_query + H = 3 + C = 4 + L = 3 + P = 4 + + spatial_shapes = torch.tensor([[8, 6], [4, 4], [2, 3]], dtype=torch.long, device=device) + level_start_index = torch.empty(L, dtype=torch.long, device=device) + level_start_index[0] = 0 + for l in range(1, L): + level_start_index[l] = level_start_index[l - 1] + int(spatial_shapes[l - 1, 0] * spatial_shapes[l - 1, 1]) + Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item()) + + value = torch.randn(B, Nv, H, C, device=device, dtype=self.dtype, requires_grad=True) + sampling_loc = torch.rand(B, Nq, H, L, P, 2, device=device, dtype=self.dtype, requires_grad=True) + sampling_loc = sampling_loc.clamp(0.05, 0.95) + attn_weight = torch.rand(B, Nq, H, L, P, device=device, dtype=self.dtype, requires_grad=True) + attn_weight = attn_weight.softmax(-1) + attn_weight = attn_weight.softmax(-2) + + if not contiguous: + if value.numel() > 0: + value = value.permute(0, 2, 1, 3).contiguous().permute(0, 2, 1, 3) + if sampling_loc.numel() > 0: + sampling_loc = sampling_loc.permute(0, 2, 1, 3, 4, 5).contiguous().permute(0, 2, 1, 3, 4, 5) + if attn_weight.numel() > 0: + attn_weight = attn_weight.permute(0, 2, 1, 3, 4).contiguous().permute(0, 2, 1, 3, 4) + + if not contiguous: + with pytest.raises(RuntimeError): + res = ops.deform_attn( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step=64 + ) + return + res = ops.deform_attn(value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step=64) + + expected = self.expected_fn(value, spatial_shapes, level_start_index, sampling_loc, attn_weight) + + torch.testing.assert_close(res.to(expected.dtype), expected, msg=f"\nres: \n{res}\nexpected: \n{expected}") + + @needs_cuda + def test_wrong_sizes(self): + device = "cuda" + B = 2 + Nq = 5 + H = 2 + C = 3 + L = 2 + P = 4 + + spatial_shapes = torch.tensor([[6, 5], [4, 3]], dtype=torch.long, device=device) + level_start_index = torch.tensor([0, 30], dtype=torch.long, device=device) + Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item()) + + value = torch.randn(B, Nv, H, C, device=device, dtype=self.dtype, requires_grad=True) + sampling_loc = torch.rand(B, Nq, H, L, P, 2, device=device, dtype=self.dtype, requires_grad=True) + attn_weight = torch.rand(B, Nq, H, L, P, device=device, dtype=self.dtype, requires_grad=True) + + with pytest.raises(RuntimeError): + wrong_value = torch.randn(B, Nv + 1, H, C, device=device, dtype=self.dtype) + ops.deform_attn(wrong_value, spatial_shapes, level_start_index, sampling_loc, attn_weight) + + with pytest.raises(RuntimeError): + wrong_spatial_shapes = torch.tensor([[6, 5]], dtype=torch.long, device=device) + ops.deform_attn(value, wrong_spatial_shapes, level_start_index[:1], sampling_loc, attn_weight) + + with pytest.raises(RuntimeError): + wrong_attn = attn_weight[:, :, :, :, : P - 1] + ops.deform_attn(value, spatial_shapes, level_start_index, sampling_loc, wrong_attn) + + with pytest.raises(RuntimeError): + wrong_sampling = sampling_loc[..., 0] + ops.deform_attn(value, spatial_shapes, level_start_index, wrong_sampling, attn_weight) + + @needs_cuda + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("batch_sz", (1,)) + @pytest.mark.parametrize("num_query", (3,)) + @pytest.mark.opcheck_only_one() + def test_backward(self, contiguous, batch_sz, num_query): + device = "cuda" + B = batch_sz + Nq = num_query + H = 2 + C = 2 + L = 2 + P = 3 + + spatial_shapes = torch.tensor([[4, 3], [2, 2]], dtype=torch.long, device=device) + level_start_index = torch.tensor([0, 12], dtype=torch.long, device=device) + Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item()) + + value = torch.randn(B, Nv, H, C, device=device, dtype=self.dtype, requires_grad=True) + sampling_loc = torch.rand(B, Nq, H, L, P, 2, device=device, dtype=self.dtype, requires_grad=True).clamp( + 0.1, 0.9 + ) + attn_weight = torch.rand(B, Nq, H, L, P, device=device, dtype=self.dtype, requires_grad=True) + + if not contiguous: + if value.numel() > 0: + value = value.permute(0, 2, 1, 3).contiguous().permute(0, 2, 1, 3) + if sampling_loc.numel() > 0: + sampling_loc = sampling_loc.permute(0, 2, 1, 3, 4, 5).contiguous().permute(0, 2, 1, 3, 4, 5) + if attn_weight.numel() > 0: + attn_weight = attn_weight.permute(0, 2, 1, 3, 4).contiguous().permute(0, 2, 1, 3, 4) + + def func(v_, sl_, lw_, samp_, w_): + return ops.deform_attn(v_, sl_, lw_, samp_, w_, im2col_step=64) + + if not contiguous: + with pytest.raises(RuntimeError): + gradcheck( + func, + (value, spatial_shapes, level_start_index, sampling_loc, attn_weight), + nondet_tol=1e-5, + fast_mode=True, + ) + return + gradcheck( + func, + (value, spatial_shapes, level_start_index, sampling_loc, attn_weight), + nondet_tol=1e-5, + fast_mode=True, + ) + + @torch.jit.script + def script_func(v_: Tensor, sl_: Tensor, lw_: Tensor, samp_: Tensor, w_: Tensor, step_: int) -> Tensor: + return ops.deform_attn(v_, sl_, lw_, samp_, w_, im2col_step=step_) + + gradcheck( + lambda v, samp, w: script_func(v, spatial_shapes, level_start_index, samp, w, 64), + (value, sampling_loc, attn_weight), + nondet_tol=1e-5, + fast_mode=True, + ) + + @needs_cuda + def test_forward_scriptability(self): + @torch.jit.script + def script_func( + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + sampling_loc: Tensor, + attn_weight: Tensor, + step: int, + ) -> Tensor: + return ops.deform_attn( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step=step + ) + + device = "cuda" + B, Nq, H, C, L, P = 2, 3, 2, 4, 2, 4 + spatial_shapes = torch.tensor([[6, 5], [3, 4]], dtype=torch.long, device=device) + level_start_index = torch.tensor([0, 30], dtype=torch.long, device=device) + Nv = int((spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item()) + value = torch.randn(B, Nv, H, C, dtype=torch.float, device=device) + sampling_loc = torch.rand(B, Nq, H, L, P, 2, dtype=torch.float, device=device).clamp(0.1, 0.9) + attn_weight = torch.rand(B, Nq, H, L, P, dtype=torch.float, device=device) + + out = script_func(value, spatial_shapes, level_start_index, sampling_loc, attn_weight, 64) + assert out.shape == (B, Nq, H * C) + + # NS: Remove me once backward is implemented for MPS def xfail_if_mps(x): mps_xfail_param = pytest.param("mps", marks=(pytest.mark.needs_mps, pytest.mark.xfail)) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index f75bfb77a7f..6264e573272 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -223,3 +223,36 @@ def meta_deform_conv2d_backward( grad_mask = mask.new_empty(mask.shape) grad_bias = bias.new_empty(bias.shape) return grad_input, grad_weight, grad_offset, grad_mask, grad_bias + + +@register_meta("deform_attn") +def meta_deform_attn( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step, +): + batch_size = value.shape[0] + num_query = sampling_loc.shape[1] + num_heads = value.shape[2] + channels = value.shape[3] + # Output shape [batch_size, num_query, num_heads * channels] + return value.new_empty((batch_size, num_query, num_heads * channels)) + + +@register_meta("_deform_attn_backward") +def meta_deform_attn_backward( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + grad_output, + im2col_step, +): + grad_value = value.new_empty(value.shape) + grad_sampling_loc = sampling_loc.new_empty(sampling_loc.shape) + grad_attn_weight = attn_weight.new_empty(attn_weight.shape) + return grad_value, grad_sampling_loc, grad_attn_weight diff --git a/torchvision/csrc/ops/autocast/deform_attn_kernel.cpp b/torchvision/csrc/ops/autocast/deform_attn_kernel.cpp new file mode 100644 index 00000000000..71473a4e2fd --- /dev/null +++ b/torchvision/csrc/ops/autocast/deform_attn_kernel.cpp @@ -0,0 +1,39 @@ +#include "../deform_attn.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor deform_attn_autocast( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const int64_t im2col_step) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + return deform_attn( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step) + .to(value.scalar_type()); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_attn"), + TORCH_FN(deform_attn_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autograd/deform_attn_kernel.cpp b/torchvision/csrc/ops/autograd/deform_attn_kernel.cpp new file mode 100644 index 00000000000..79144782014 --- /dev/null +++ b/torchvision/csrc/ops/autograd/deform_attn_kernel.cpp @@ -0,0 +1,166 @@ +#include "../deform_attn.h" + +#include +#include + +#include + +namespace vision { +namespace ops { + +namespace { + +class DeformAttnFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& value, + const torch::autograd::Variable& spatial_shapes, + const torch::autograd::Variable& level_start_index, + const torch::autograd::Variable& sampling_loc, + const torch::autograd::Variable& attn_weight, + int64_t im2col_step) { + at::AutoDispatchBelowADInplaceOrView g; + auto output = deform_attn( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step); + + ctx->save_for_backward( + {value, spatial_shapes, level_start_index, sampling_loc, attn_weight}); + ctx->saved_data["im2col_step"] = im2col_step; + + return { + output, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + auto saved = ctx->get_saved_variables(); + auto value = saved[0]; + auto spatial_shapes = saved[1]; + auto level_start_index = saved[2]; + auto sampling_loc = saved[3]; + auto attn_weight = saved[4]; + + auto im2col_step = + static_cast(ctx->saved_data["im2col_step"].toInt()); + + auto grads = detail::_deform_attn_backward( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + grad_output[0], + im2col_step); + auto grad_value = std::get<0>(grads); + auto grad_sampling_loc = std::get<1>(grads); + auto grad_attn_weight = std::get<2>(grads); + + return { + grad_value, + torch::autograd::Variable(), + torch::autograd::Variable(), + grad_sampling_loc, + grad_attn_weight, + torch::autograd::Variable(), + }; + } +}; + +class DeformAttnBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& value, + const torch::autograd::Variable& spatial_shapes, + const torch::autograd::Variable& level_start_index, + const torch::autograd::Variable& sampling_loc, + const torch::autograd::Variable& attn_weight, + const torch::autograd::Variable& grad_output, + int64_t im2col_step) { + at::AutoDispatchBelowADInplaceOrView g; + auto result = detail::_deform_attn_backward( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + grad_output, + im2col_step); + + auto grad_value = std::get<0>(result); + auto grad_sampling_loc = std::get<1>(result); + auto grad_attn_weight = std::get<2>(result); + + return { + grad_value, + grad_sampling_loc, + grad_attn_weight, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on deform_attn not supported"); + } +}; + +at::Tensor deform_attn_autograd( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + int64_t im2col_step) { + return DeformAttnFunction::apply( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step)[0]; +} + +std::tuple deform_attn_backward_autograd( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + int64_t im2col_step) { + auto result = DeformAttnBackwardFunction::apply( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + grad_output, + im2col_step); + + return {result[0], result[1], result[2]}; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_attn"), + TORCH_FN(deform_attn_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_attn_backward"), + TORCH_FN(deform_attn_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/cpu/deform_attn_kernel.cpp b/torchvision/csrc/ops/cpu/deform_attn_kernel.cpp new file mode 100644 index 00000000000..92a7c0e42fc --- /dev/null +++ b/torchvision/csrc/ops/cpu/deform_attn_kernel.cpp @@ -0,0 +1,40 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { +at::Tensor ms_deform_attn_forward_kernel( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const int64_t im2col_step) { + TORCH_CHECK(false, "Deformable attention is only supported on CUDA for now."); +} + +std::tuple ms_deform_attn_backward_kernel( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + const int64_t im2col_step) { + TORCH_CHECK(false, "Deformable attention is only supported on CUDA for now."); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_attn"), + TORCH_FN(ms_deform_attn_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_attn_backward"), + TORCH_FN(ms_deform_attn_backward_kernel)); +} +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/cuda/deform_attn_kernel.cu b/torchvision/csrc/ops/cuda/deform_attn_kernel.cu new file mode 100644 index 00000000000..6826f64e4fe --- /dev/null +++ b/torchvision/csrc/ops/cuda/deform_attn_kernel.cu @@ -0,0 +1,1729 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This software may be used and distributed in accordance with +// the terms of the DINOv3 License Agreement. + +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace vision { +namespace ops { + +namespace { + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) { + return (N + num_threads - 1) / num_threads; +} + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear( + const scalar_t*& bottom_data, + const int& height, + const int& width, + const int& nheads, + const int& channels, + const scalar_t& h, + const scalar_t& w, + const int& m, + const int& c) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ void ms_deform_attn_col2im_bilinear( + const scalar_t*& bottom_data, + const int& height, + const int& width, + const int& nheads, + const int& channels, + const scalar_t& h, + const scalar_t& w, + const int& m, + const int& c, + const scalar_t& top_grad, + const scalar_t& attn_weight, + scalar_t*& grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + +template +__device__ void ms_deform_attn_col2im_bilinear_gm( + const scalar_t*& bottom_data, + const int& height, + const int& width, + const int& nheads, + const int& channels, + const scalar_t& h, + const scalar_t& w, + const int& m, + const int& c, + const scalar_t& top_grad, + const scalar_t& attn_weight, + scalar_t*& grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + +template +__global__ void ms_deformable_im2col_gpu_kernel( + const int n, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) { + CUDA_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + scalar_t* data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t* data_value_ptr = data_value + + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + col += ms_deform_attn_im2col_bilinear( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col) * + weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( + const int n, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t* data_value_ptr = data_value + value_ptr_offset; + scalar_t* grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col, + top_grad, + weight, + grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + scalar_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockSize; ++tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( + const int n, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t* data_value_ptr = data_value + value_ptr_offset; + scalar_t* grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col, + top_grad, + weight, + grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( + const int n, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t* data_value_ptr = data_value + value_ptr_offset; + scalar_t* grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col, + top_grad, + weight, + grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + scalar_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( + const int n, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t* data_value_ptr = data_value + value_ptr_offset; + scalar_t* grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col, + top_grad, + weight, + grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( + const int n, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t* data_value_ptr = data_value + value_ptr_offset; + scalar_t* grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col, + top_grad, + weight, + grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm( + const int n, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t* data_value_ptr = data_value + value_ptr_offset; + scalar_t* grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, + spatial_h, + spatial_w, + num_heads, + channels, + h_im, + w_im, + m_col, + c_col, + top_grad, + weight, + grad_value_ptr, + grad_sampling_loc, + grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +void ms_deformable_im2col_cuda( + cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) { + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +template +void ms_deformable_col2im_cuda( + cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) { + const int num_threads = + (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) { + if ((channels & 1023) == 0) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } else { + switch (channels) { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1< + scalar_t, + 1> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1< + scalar_t, + 2> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1< + scalar_t, + 4> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1< + scalar_t, + 8> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1< + scalar_t, + 16> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1< + scalar_t, + 32> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2< + scalar_t, + 64> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2< + scalar_t, + 128> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2< + scalar_t, + 256> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2< + scalar_t, + 512> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2< + scalar_t, + 1024> + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +at::Tensor ms_deform_attn_forward_kernel( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const int64_t im2col_step) { + // Basic tensor properties + TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous"); + TORCH_CHECK( + spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + TORCH_CHECK( + level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + TORCH_CHECK( + sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + TORCH_CHECK( + attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); + TORCH_CHECK(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + TORCH_CHECK( + level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); + TORCH_CHECK(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + TORCH_CHECK(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + + // Rank checks + TORCH_CHECK( + value.dim() == 4, "value must be a 4D tensor of shape [B, Nv, H, C]"); + TORCH_CHECK( + spatial_shapes.dim() == 2 && spatial_shapes.size(1) == 2, + "spatial_shapes must be a 2D tensor of shape [L, 2]"); + TORCH_CHECK( + level_start_index.dim() == 1, + "level_start_index must be a 1D tensor of shape [L]"); + TORCH_CHECK( + sampling_loc.dim() == 6, + "sampling_loc must be a 6D tensor of shape [B, Nq, H, L, P, 2]"); + TORCH_CHECK( + sampling_loc.size(5) == 2, + "sampling_loc last dimension must be 2 (normalized x, y)"); + TORCH_CHECK( + attn_weight.dim() == 5, + "attn_weight must be a 5D tensor of shape [B, Nq, H, L, P]"); + + const auto batch = value.size(0); + const auto spatial_size = value.size(1); + const auto num_heads = value.size(2); + const auto channels = value.size(3); + + const auto num_levels = spatial_shapes.size(0); + + const auto num_query = sampling_loc.size(1); + const auto num_point = sampling_loc.size(4); + + // Cross-shape consistency checks + TORCH_CHECK( + batch == sampling_loc.size(0), + "batch size mismatch between value (B) and sampling_loc (B)"); + TORCH_CHECK( + num_heads == sampling_loc.size(2), + "num_heads mismatch between value (H) and sampling_loc (H)"); + TORCH_CHECK( + num_heads == attn_weight.size(2), + "num_heads mismatch between value (H) and attn_weight (H)"); + + TORCH_CHECK( + num_levels == sampling_loc.size(3), + "L mismatch: spatial_shapes.size(0) must equal sampling_loc.size(3)"); + TORCH_CHECK( + level_start_index.size(0) == num_levels, + "level_start_index.size(0) must equal spatial_shapes.size(0) (L)"); + + TORCH_CHECK( + attn_weight.size(0) == sampling_loc.size(0) && + attn_weight.size(1) == sampling_loc.size(1) && + attn_weight.size(2) == sampling_loc.size(2) && + attn_weight.size(3) == sampling_loc.size(3) && + attn_weight.size(4) == sampling_loc.size(4), + "attn_weight must match sampling_loc in [B, Nq, H, L, P]"); + + // Nv must equal sum_{l}(H_l * W_l) from spatial_shapes + const auto expected_spatial_size = + (spatial_shapes.prod(1)).sum().item(); + TORCH_CHECK( + spatial_size == expected_spatial_size, + "value.size(1) (Nv) must equal sum over levels of (H_l * W_l)"); + + // im2col_step validation + const auto im2col_step_ = std::min(batch, im2col_step); + TORCH_CHECK(im2col_step_ > 0, "im2col_step must be > 0"); + TORCH_CHECK( + batch % im2col_step_ == 0, "batch must be divisible by im2col_step"); + + auto output = + at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const auto batch_n = im2col_step_; + auto output_n = output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda( + at::cuda::getCurrentCUDAStream(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + + n * im2col_step_ * per_attn_weight_size, + batch_n, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + columns.data_ptr()); + })); + } + + output = output.view({batch, num_query, num_heads * channels}); + return output; +} + +std::tuple ms_deform_attn_backward_kernel( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + const int64_t im2col_step) { + // Basic tensor properties + TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous"); + TORCH_CHECK( + spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + TORCH_CHECK( + level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + TORCH_CHECK( + sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + TORCH_CHECK( + attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + TORCH_CHECK( + grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); + TORCH_CHECK(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + TORCH_CHECK( + level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); + TORCH_CHECK(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + TORCH_CHECK(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + TORCH_CHECK(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); + + // Rank checks mirroring forward + TORCH_CHECK( + value.dim() == 4, "value must be a 4D tensor of shape [B, Nv, H, C]"); + TORCH_CHECK( + spatial_shapes.dim() == 2 && spatial_shapes.size(1) == 2, + "spatial_shapes must be a 2D tensor of shape [L, 2]"); + TORCH_CHECK( + level_start_index.dim() == 1, + "level_start_index must be a 1D tensor of shape [L]"); + TORCH_CHECK( + sampling_loc.dim() == 6, + "sampling_loc must be a 6D tensor of shape [B, Nq, H, L, P, 2]"); + TORCH_CHECK( + sampling_loc.size(5) == 2, + "sampling_loc last dimension must be 2 (normalized x, y)"); + TORCH_CHECK( + attn_weight.dim() == 5, + "attn_weight must be a 5D tensor of shape [B, Nq, H, L, P]"); + + const auto batch = value.size(0); + const auto spatial_size = value.size(1); + const auto num_heads = value.size(2); + const auto channels = value.size(3); + + const auto num_levels = spatial_shapes.size(0); + + const auto num_query = sampling_loc.size(1); + const auto num_point = sampling_loc.size(4); + + // Cross-shape consistency checks + TORCH_CHECK( + batch == sampling_loc.size(0), + "batch size mismatch between value (B) and sampling_loc (B)"); + TORCH_CHECK( + num_heads == sampling_loc.size(2), + "num_heads mismatch between value (H) and sampling_loc (H)"); + TORCH_CHECK( + num_heads == attn_weight.size(2), + "num_heads mismatch between value (H) and attn_weight (H)"); + TORCH_CHECK( + num_levels == sampling_loc.size(3), + "L mismatch: spatial_shapes.size(0) must equal sampling_loc.size(3)"); + TORCH_CHECK( + level_start_index.size(0) == num_levels, + "level_start_index.size(0) must equal spatial_shapes.size(0) (L)"); + TORCH_CHECK( + attn_weight.size(0) == sampling_loc.size(0) && + attn_weight.size(1) == sampling_loc.size(1) && + attn_weight.size(2) == sampling_loc.size(2) && + attn_weight.size(3) == sampling_loc.size(3) && + attn_weight.size(4) == sampling_loc.size(4), + "attn_weight must match sampling_loc in [B, Nq, H, L, P]"); + + // Nv must equal sum_{l}(H_l * W_l) from spatial_shapes + const auto expected_spatial_size = + (spatial_shapes.prod(1)).sum().item(); + TORCH_CHECK( + spatial_size == expected_spatial_size, + "value.size(1) (Nv) must equal sum over levels of (H_l * W_l)"); + + // grad_output should have the same total elements as [B, Nq, H, C] + const auto expected_grad_elems = batch * num_query * num_heads * channels; + TORCH_CHECK( + grad_output.numel() == expected_grad_elems, + "grad_output must contain B * Nq * H * C elements"); + + // im2col_step validation + const auto im2col_step_ = std::min(batch, im2col_step); + TORCH_CHECK(im2col_step_ > 0, "im2col_step must be > 0"); + TORCH_CHECK( + batch % im2col_step_ == 0, "batch must be divisible by im2col_step"); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const auto batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda( + at::cuda::getCurrentCUDAStream(), + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + + n * im2col_step_ * per_attn_weight_size, + batch_n, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value.data_ptr() + + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + + n * im2col_step_ * per_attn_weight_size); + })); + } + + return {grad_value, grad_sampling_loc, grad_attn_weight}; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_attn"), + TORCH_FN(ms_deform_attn_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_attn_backward"), + TORCH_FN(ms_deform_attn_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/deform_attn.cpp b/torchvision/csrc/ops/deform_attn.cpp new file mode 100644 index 00000000000..64e29bd4363 --- /dev/null +++ b/torchvision/csrc/ops/deform_attn.cpp @@ -0,0 +1,106 @@ +#include "deform_attn.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor deform_attn( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const int64_t im2col_step) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_attn.deform_attn"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_attn", "") + .typed(); + return op.call( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step); +} + +at::Tensor deform_attn_symint( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const c10::SymInt im2col_step) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_attn.deform_attn"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_attn", "") + .typed(); + return op.call( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step); +} + +namespace detail { + +std::tuple _deform_attn_backward( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + int64_t im2col_step) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_attn_backward", "") + .typed(); + return op.call( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + grad_output, + im2col_step); +} + +std::tuple _deform_attn_backward_symint( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + c10::SymInt im2col_step) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_attn_backward", "") + .typed(); + return op.call( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + grad_output, + im2col_step); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::deform_attn(Tensor value, Tensor spatial_shapes, Tensor level_start_index, Tensor sampling_loc, Tensor attn_weight, SymInt im2col_step) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_deform_attn_backward(Tensor value, Tensor spatial_shapes, Tensor level_start_index, Tensor sampling_loc, Tensor attn_weight, Tensor grad_output, SymInt im2col_step) -> (Tensor, Tensor, Tensor)")); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/deform_attn.h b/torchvision/csrc/ops/deform_attn.h new file mode 100644 index 00000000000..bf284099754 --- /dev/null +++ b/torchvision/csrc/ops/deform_attn.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor deform_attn( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const int64_t im2col_step); + +VISION_API at::Tensor deform_attn_symint( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const c10::SymInt im2col_step); + +namespace detail { + +std::tuple _deform_attn_backward( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + int64_t im2col_step); + +std::tuple _deform_attn_backward_symint( + const at::Tensor& value, + const at::Tensor& spatial_shapes, + const at::Tensor& level_start_index, + const at::Tensor& sampling_loc, + const at::Tensor& attn_weight, + const at::Tensor& grad_output, + c10::SymInt im2col_step); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/ops.h b/torchvision/csrc/ops/ops.h index 77995e44197..39e880a8365 100644 --- a/torchvision/csrc/ops/ops.h +++ b/torchvision/csrc/ops/ops.h @@ -1,5 +1,6 @@ #pragma once +#include "deform_attn.h" #include "deform_conv2d.h" #include "nms.h" #include "ps_roi_align.h" diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..cd883409b4f 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -13,6 +13,7 @@ remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss +from .deform_attn import deform_attn from .deform_conv import deform_conv2d, DeformConv2d from .diou_loss import distance_box_iou_loss from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d @@ -34,6 +35,7 @@ "masks_to_boxes", "deform_conv2d", "DeformConv2d", + "deform_attn", "nms", "batched_nms", "remove_small_boxes", diff --git a/torchvision/ops/deform_attn.py b/torchvision/ops/deform_attn.py new file mode 100644 index 00000000000..6502a2932a2 --- /dev/null +++ b/torchvision/ops/deform_attn.py @@ -0,0 +1,65 @@ +import torch +from torch import Tensor +from torchvision.extension import _assert_has_ops + +from ..utils import _log_api_usage_once + + +def deform_attn( + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + sampling_loc: Tensor, + attn_weight: Tensor, + im2col_step: int = 64, +) -> Tensor: + r""" + Performs Deformable Attention, as described in + `Deformable DETR: Deformable Transformers for End-to-End Object Detection + `__. + + Args: + value (Tensor[batch_size, num_value, num_heads, channels]): input value tensor + spatial_shapes (Tensor[num_levels, 2]): spatial shapes (H, W) for each feature level + level_start_index (Tensor[num_levels]): starting index for each feature level in flattened value + sampling_loc (Tensor[batch_size, num_query, num_heads, num_levels, num_points, 2]): + sampling locations in normalized coordinates [0, 1] + attn_weight (Tensor[batch_size, num_query, num_heads, num_levels, num_points]): + attention weights for each sampling point + im2col_step (int): step size for im2col operation to reduce memory usage. Default: 64 + + Returns: + Tensor[batch_size, num_query, num_heads * channels]: result of deformable attention + + Examples:: + >>> batch_size, num_query, num_heads, channels = 2, 100, 8, 32 + >>> num_levels, num_points = 4, 4 + >>> # Create value tensor (flattened spatial dimensions across all levels) + >>> num_value = 1024 + 256 + 64 + 16 # sum of H*W for each level + >>> value = torch.rand(batch_size, num_value, num_heads, channels, device="cuda") + >>> # Spatial shapes for 4 feature levels + >>> spatial_shapes = torch.tensor([[32, 32], [16, 16], [8, 8], [4, 4]], dtype=torch.long, device="cuda") + >>> # Starting indices for each level in the flattened value tensor + >>> level_start_index = torch.tensor([0, 1024, 1280, 1344], dtype=torch.long, device="cuda") + >>> # Sampling locations (normalized coordinates in [0, 1]) + >>> sampling_loc = torch.rand(batch_size, num_query, num_heads, num_levels, num_points, 2, device="cuda") + >>> # Attention weights (should sum to 1 across num_levels * num_points) + >>> attn_weight = torch.rand(batch_size, num_query, num_heads, num_levels, num_points, device="cuda") + >>> attn_weight = attn_weight.softmax(-1).softmax(-2) # normalize + >>> out = deform_attn(value, spatial_shapes, level_start_index, sampling_loc, attn_weight) + >>> print(out.shape) + >>> # returns + >>> torch.Size([2, 100, 256]) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(deform_attn) + _assert_has_ops() + + return torch.ops.torchvision.deform_attn( + value, + spatial_shapes, + level_start_index, + sampling_loc, + attn_weight, + im2col_step, + )