Skip to content

Commit

Permalink
Fix custom ops bug for pytorch 1.12 and onwards
Browse files Browse the repository at this point in the history
Adapt to newer _jit_get_operation API that changed in
pytorch/pytorch#76814

for #188, #193
  • Loading branch information
jannehellsten authored and woctezuma committed Jan 10, 2024
1 parent a71fc60 commit 344645e
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -64,6 +65,8 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
Expand Down

0 comments on commit 344645e

Please sign in to comment.