diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py index e95e10d0b..bb70f1ea6 100755 --- a/torch_utils/ops/conv2d_gradfix.py +++ b/torch_utils/ops/conv2d_gradfix.py @@ -12,6 +12,7 @@ import warnings import contextlib import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -21,6 +22,7 @@ enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 @contextlib.contextmanager def no_weight_gradients(): @@ -48,6 +50,9 @@ def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False + if _use_pytorch_1_11_api: + # The work-around code doesn't work on PyTorch 1.11.0 onwards + return False if input.device.type != 'cuda': return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index ca6b3413e..98b5b97b6 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -13,6 +13,7 @@ import warnings import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -21,6 +22,8 @@ #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 +_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12 #---------------------------------------------------------------------------- @@ -34,7 +37,7 @@ def grid_sample(input, grid): def _should_use_custom_op(): if not enabled: return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']): return True warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') return False @@ -62,7 +65,13 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + if _use_pytorch_1_12_api: + op = op[0] + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid