-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcuda_gridsample.py
123 lines (93 loc) · 5.12 KB
/
cuda_gridsample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from torch.utils.cpp_extension import load
import torch
from pkg_resources import parse_version
gridsample_grad2 = load(name='gridsample_grad2', sources=['gridsample_cuda.cpp', 'gridsample_cuda.cu'], verbose=True)
def grid_sample_2d(input, grid, padding_mode='zeros', align_corners=True):
assert padding_mode in ['zeros', 'border']
return _GridSample2dForward.apply(input, grid, padding_mode, align_corners)
def grid_sample_3d(input, grid, padding_mode='zeros', align_corners=True):
assert padding_mode in ['zeros', 'border']
return _GridSample3dForward.apply(input, grid, padding_mode, align_corners)
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a')
class _GridSample2dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, padding_mode=0, align_corners=True):
assert input.ndim == 4
assert grid.ndim == 4
assert input.shape[0] == grid.shape[0]
assert grid.shape[3] == 2
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
padding_mode=padding_mode, align_corners=align_corners)
ctx.save_for_backward(input, grid)
ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
ctx.align_corners = align_corners
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
return grad_input, grad_grid, None, None
class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
ctx.save_for_backward(grad_output, input, grid)
ctx.padding_mode = padding_mode
ctx.align_corners = align_corners
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
grad_output, input, grid = ctx.saved_tensors
assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
out = gridsample_grad2.grad2_2d(grad2_grad_input, grad2_grad_grid, grad_output,
input, grid, ctx.padding_mode, ctx.align_corners)
grad_grad_output = out[0]
grad_input = out[1]
grad_grid = out[2]
return grad_grad_output, grad_input, grad_grid, None, None
class _GridSample3dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, padding_mode=0, align_corners=True):
assert input.ndim == 5
assert grid.ndim == 5
assert input.shape[0] == grid.shape[0]
assert grid.shape[4] == 3
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
padding_mode=padding_mode, align_corners=align_corners)
ctx.save_for_backward(input, grid)
ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
ctx.align_corners = align_corners
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample3dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
return grad_input, grad_grid, None, None
class _GridSample3dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
op = torch._C._jit_get_operation('aten::grid_sampler_3d_backward')
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
ctx.save_for_backward(grad_output, input, grid)
ctx.padding_mode = padding_mode
ctx.align_corners = align_corners
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
grad_output, input, grid = ctx.saved_tensors
assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
out = gridsample_grad2.grad2_3d(grad2_grad_input, grad2_grad_grid, grad_output,
input, grid, ctx.padding_mode, ctx.align_corners)
grad_grad_output = out[0]
grad_input = out[1]
grad_grid = out[2]
return grad_grad_output, grad_input, grad_grid, None, None