Skip to content

Commit

Permalink
Support newer versions of PyTorch (v1.1X and v2)
Browse files Browse the repository at this point in the history
  • Loading branch information
woctezuma committed Jan 10, 2024
1 parent 344645e commit c6dfbc0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torch_utils/ops/conv2d_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _should_use_custom_op(input):
return False
if input.device.type != 'cuda':
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']):
return True
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
return False
Expand Down
2 changes: 1 addition & 1 deletion torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def grid_sample(input, grid):
def _should_use_custom_op():
if not enabled:
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']):
return True
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
return False
Expand Down

0 comments on commit c6dfbc0

Please sign in to comment.