Skip to content

Commit

Permalink
Remove FFT from stride incorrect ops (pytorch#145080)
Browse files Browse the repository at this point in the history
I gotta say, the FFT implementation is completely insane, there's gotta be a better way to do this than repeatedly inplace restriding the output tensor. Anyway, this is a faithful translation of both the MKL and cuFFT paths to Python.

Fixes pytorch#135087

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#145080
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: pytorch#145530
  • Loading branch information
ezyang authored and pytorchmergebot committed Jan 27, 2025
1 parent b75afa2 commit 87fdadd
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 74 deletions.
15 changes: 0 additions & 15 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6509,21 +6509,6 @@ def _test_fn(fn, check_backward=True):
"linalg.householder_product",
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
),
# many complex operators incorrect striding, metadata
xfail("fft.fft", ""),
xfail("fft.hfft2", ""),
xfail("fft.hfft", ""),
xfail("fft.hfftn", ""),
xfail("fft.ifft", ""),
xfail("fft.ihfft2", ""),
xfail("fft.ihfft", ""),
xfail("fft.ihfftn", ""),
xfail("fft.irfft2", ""),
xfail("fft.irfft", ""),
xfail("fft.irfftn", ""),
xfail("fft.rfft2", ""),
xfail("fft.rfft", ""),
xfail("fft.rfftn", ""),
xfail("stft", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
}

Expand Down
22 changes: 0 additions & 22 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,24 +2014,6 @@ def f(t):
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition

xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.ihfft2', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
xfail('stft', '')
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
Expand All @@ -2058,10 +2040,6 @@ def f(t):
xfail('angle', ''),
xfail('argmax', ''),
xfail('argmin', ''),
xfail('fft.fft2', ''),
xfail('fft.fftn', ''),
xfail('fft.ifft2', ''),
xfail('fft.ifftn', ''),
xfail('gather', ''),
xfail('linalg.pinv', ''),
xfail('linalg.pinv', 'hermitian'),
Expand Down
152 changes: 123 additions & 29 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,12 @@ def logcumsumexp(self, dim):
return torch.empty_like(self).contiguous()


# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
def _exec_fft(out, self, out_sizes, dim, forward):
# Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
# and aten/src/ATen/cuda/SpectralOps.cpp
#
# Although the actual FFT launch is different, all the permuting code appears
# to be the same
def _exec_fft(out, self, out_sizes, dim, *, forward):
ndim = self.ndim
signal_ndim = len(dim)
batch_dims = ndim - signal_ndim
Expand Down Expand Up @@ -258,12 +262,12 @@ def _exec_fft(out, self, out_sizes, dim, forward):

batch_size = input.size(0)
batched_sizes[0] = batch_size
batched_out_sizes = batched_sizes
batched_out_sizes = list(batched_sizes)
for i in range(len(dim)):
batched_out_sizes[i + 1] = out_sizes[dim[i]]
out = out.reshape(batched_out_sizes)
out.resize_(batched_out_sizes, memory_format=torch.contiguous_format)

# Reshaping to original batch shape and inverting the dimension permutation
# Inplace reshaping to original batch shape and inverting the dimension permutation
out_strides = [0 for _ in range(ndim)]
batch_numel = 1
i = batch_dims - 1
Expand All @@ -273,44 +277,102 @@ def _exec_fft(out, self, out_sizes, dim, forward):
i -= 1
for i in range(batch_dims, ndim):
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
return out.as_strided(out_sizes, out_strides, out.storage_offset())
out.as_strided_(out_sizes, out_strides, out.storage_offset())

return out


def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False):
sorted_dims = list(dim)
self_strides = self.stride()
sorted_dims[: len(sorted_dims) - int(exclude_last)].sort(
key=lambda i: self_strides[i]
)
return sorted_dims


# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
@out_wrapper()
def meta_fft_c2c(self, dim, normalization, forward):
assert self.dtype.is_complex
torch._check(self.dtype.is_complex)
if not dim:
return self.clone()

out_sizes = self.shape
output = self.new_empty(out_sizes)
sorted_dims = _sort_dims(self, dim)
out = self.new_empty(self.size())
return _exec_fft(out, self, self.size(), sorted_dims, forward=forward)

if not dim:
return output

sorted_dims = dim[:]
self_strides = self.stride()
sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
cufft_max_ndim = 3

return output

def use_optimized_cufft_path(dim: list[int]):
if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1):
return False
else:
return True


@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
@out_wrapper()
def meta_fft_r2c(self, dim, normalization, onesided):
assert self.dtype.is_floating_point
output_sizes = list(self.size())
torch._check(self.dtype.is_floating_point)
input_sizes = list(self.size())
out_sizes = list(input_sizes)
last_dim = dim[-1]
last_dim_halfsize = input_sizes[last_dim] // 2 + 1
onesided_sizes = list(input_sizes)
onesided_sizes[last_dim] = last_dim_halfsize

if onesided:
last_dim = dim[-1]
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
output_sizes[last_dim] = last_dim_halfsize
out_sizes[last_dim] = last_dim_halfsize

return self.new_empty(
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
if device_hint(self) == "cuda":
# _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
output = self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)

working_tensor = self
if use_optimized_cufft_path(dim):
_exec_fft(output, working_tensor, out_sizes, dim, forward=True)
else:
# First do the R2C transform on the last dimension
target_sizes = out_sizes if len(dim) == 1 else onesided_sizes
_exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True)
if len(dim) > 1:
working_tensor = self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)

# Then any remaining C2C transforms
sorted_dims = dim[:-1]
while sorted_dims:
output, working_tensor = working_tensor, output
strides = working_tensor.stride()
sorted_dims.sort(
key=lambda i: strides[i], reverse=True
) # NB reverse! Not sure if this is og bug
max_dims = min(cufft_max_ndim, len(sorted_dims))
last_dims = sorted_dims[len(sorted_dims) - max_dims :]
_exec_fft(
output, working_tensor, onesided_sizes, last_dims, forward=True
)
sorted_dims = sorted_dims[: len(sorted_dims) - max_dims]

if not onesided:
if output.size(last_dim) != out_sizes[last_dim]:
working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format)
output = working_tensor

return output

else:
return self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)


@register_meta(aten.randperm.generator_out)
Expand Down Expand Up @@ -375,11 +437,43 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=

@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
@out_wrapper()
def meta_fft_c2r(self, dim, normalization, lastdim):
assert self.dtype.is_complex
output_sizes = list(self.size())
output_sizes[dim[-1]] = lastdim
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int):
# _fft_c2r_mkl
torch._check(self.dtype.is_complex)

if device_hint(self) == "cuda":
out_sizes = list(self.size())
out_sizes[dim[-1]] = lastdim

output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))

if use_optimized_cufft_path(dim):
return _exec_fft(
output,
self.clone(memory_format=torch.contiguous_format),
out_sizes,
dim,
forward=False,
)
else:
# First complete any C2C transforms
if len(dim) > 1:
temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none
else:
temp = self.clone(memory_format=torch.contiguous_format)
return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False)

else:
input = self
if len(dim) > 1:
c2c_dims = dim[:-1]
input = meta_fft_c2c(self, c2c_dims, normalization, forward=False)
dim = dim[-1:]

out_sizes = list(input.size())
out_sizes[dim[-1]] = lastdim
out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
return _exec_fft(out, input, out_sizes, dim, forward=False)


@register_meta(aten.copy_.default)
Expand Down
8 changes: 0 additions & 8 deletions torch/_subclasses/fake_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,6 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):


def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False

op_name = op.name()
if "fft" in op_name:
return True
return False


Expand Down

0 comments on commit 87fdadd

Please sign in to comment.