Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support newer versions of PyTorch (v1.1X and v2) #299

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torch_utils/ops/conv2d_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
import contextlib
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -21,6 +22,7 @@

enabled = False # Enable the custom op by setting this to true.
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11

@contextlib.contextmanager
def no_weight_gradients():
Expand Down Expand Up @@ -48,6 +50,9 @@ def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if _use_pytorch_1_11_api:
# The work-around code doesn't work on PyTorch 1.11.0 onwards
return False
if input.device.type != 'cuda':
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
Expand Down
13 changes: 11 additions & 2 deletions torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import warnings
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -21,6 +22,8 @@
#----------------------------------------------------------------------------

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12

#----------------------------------------------------------------------------

Expand All @@ -34,7 +37,7 @@ def grid_sample(input, grid):
def _should_use_custom_op():
if not enabled:
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']):
return True
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
return False
Expand Down Expand Up @@ -62,7 +65,13 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid

Expand Down