Skip to content

Commit c6dfbc0

Browse files
committed
Support newer versions of PyTorch (v1.1X and v2)
1 parent 344645e commit c6dfbc0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

torch_utils/ops/conv2d_gradfix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _should_use_custom_op(input):
5555
return False
5656
if input.device.type != 'cuda':
5757
return False
58-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
58+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']):
5959
return True
6060
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
6161
return False

torch_utils/ops/grid_sample_gradfix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def grid_sample(input, grid):
3737
def _should_use_custom_op():
3838
if not enabled:
3939
return False
40-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
40+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']):
4141
return True
4242
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
4343
return False

0 commit comments

Comments
 (0)