Skip to content

Commit 88a932a

Browse files
BUGFIX: Issue #21
When computing parameters for multiple equidistant datasets call to MKL's FFT functions, during computation of distances between datasets there was an implicit assumption that strides are all positive. The code was caught off-guards for arrays with a zero stride, associated with a unit shape, like one formed by a[np.newaxis] Added a test.
1 parent 6fb6e0b commit 88a932a

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

mkl_fft/src/mklfft.c.src

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,37 @@ __cached_notinplace_@DftiCompute_MODE@_@MKL_IN_TYPE@_@MKL_OUT_TYPE@(
281281
}
282282
/**end repeat**/
283283

284+
inline npy_intp
285+
compute_distance(npy_intp *x_strides, npy_intp *x_shape, npy_intp x_itemsize, int x_rank, int i1, int i2) {
286+
npy_intp st1, st2;
287+
npy_intp sh1 = x_shape[i1], sh2 = x_shape[i2];
288+
npy_intp min_s;
289+
if (sh1 > 1 && sh2 > 1) {
290+
st1 = x_strides[i1];
291+
st2 = x_strides[i2];
292+
min_s = (st1 > st2) ? st2 : st1;
293+
294+
return min_s;
295+
} else {
296+
int i;
297+
npy_intp max_s;
298+
max_s = x_itemsize;
299+
for(i=0; i < x_rank; i++) {
300+
if (x_shape[i] > 1) {
301+
if (max_s < x_strides[i]) max_s = x_strides[i];
302+
}
303+
}
304+
min_s = max_s;
305+
for(i=i1; i <= i2; i++) {
306+
if (x_shape[i] > 1) {
307+
if (min_s > x_strides[i]) min_s = x_strides[i];
308+
}
309+
}
310+
}
311+
312+
return min_s;
313+
}
314+
284315
static NPY_INLINE int
285316
compute_strides_and_distances(
286317
PyArrayObject *x,
@@ -315,11 +346,9 @@ compute_strides_and_distances(
315346
npy_intp char_dist = 0;
316347
*num_fft_transfs = _to_mkl_long (x_size / x_shape[axis]);
317348
if (axis == 0) {
318-
npy_intp s1 = x_strides[1], s2 = x_strides[x_rank-1];
319-
char_dist = (s1 > s2) ? s2 : s1;
349+
char_dist = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 1, x_rank-1);
320350
} else {
321-
npy_intp s1 = x_strides[0], s2 = x_strides[x_rank-2];
322-
char_dist = (s1 > s2) ? s2 : s1;
351+
char_dist = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 0, x_rank-2);
323352
}
324353

325354
*vec_dist = _to_mkl_long (char_dist / x_itemsize);
@@ -375,17 +404,11 @@ compute_strides_and_distances_inout(
375404
npy_intp char_dist_in = 0, char_dist_out = 0;
376405
*num_fft_transfs = _to_mkl_long (x_size / x_shape[axis]);
377406
if (axis == 0) {
378-
npy_intp s1 = x_strides[1], s2 = x_strides[x_rank-1];
379-
char_dist_in = (s1 > s2) ? s2 : s1;
380-
381-
s1 = y_strides[1]; s2 = y_strides[x_rank-1];
382-
char_dist_out = (s1 > s2) ? s2 : s1;
407+
char_dist_in = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 1, x_rank-1);
408+
char_dist_out = compute_distance(y_strides, y_shape, y_itemsize, x_rank, 1, x_rank-1);
383409
} else {
384-
npy_intp s1 = x_strides[0], s2 = x_strides[x_rank-2];
385-
char_dist_in = (s1 > s2) ? s2 : s1;
386-
387-
s1 = y_strides[0]; s2 = y_strides[x_rank-2];
388-
char_dist_out = (s1 > s2) ? s2 : s1;
410+
char_dist_in = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 0, x_rank-2);
411+
char_dist_out = compute_distance(y_strides, y_shape, y_itemsize, x_rank, 0, x_rank-2);
389412
}
390413
*vec_dist_in = _to_mkl_long (char_dist_in / x_itemsize);
391414
*vec_dist_out = _to_mkl_long (char_dist_out / y_itemsize);

mkl_fft/tests/test_fft1d.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,28 @@ def test_array4(self):
303303
f2 = mkl_fft.fft(f1, axis = ax)
304304
assert_allclose(f2, x, atol=2e-15)
305305

306+
307+
def test_array5(self):
308+
"""Inputs with zero strides are handled correctly"""
309+
z = self.az3
310+
z1 = z[np.newaxis]
311+
f1 = mkl_fft.fft(z1, axis=-1)
312+
f2 = mkl_fft.fft(z1.reshape(z1.shape), axis=-1)
313+
assert_allclose(f1, f2, atol=2e-15)
314+
z1 = z[:, np.newaxis]
315+
f1 = mkl_fft.fft(z1, axis=-1)
316+
f2 = mkl_fft.fft(z1.reshape(z1.shape), axis=-1)
317+
assert_allclose(f1, f2, atol=2e-15)
318+
z1 = z[:, :, np.newaxis]
319+
f1 = mkl_fft.fft(z1, axis=-1)
320+
f2 = mkl_fft.fft(z1.reshape(z1.shape), axis=-1)
321+
assert_allclose(f1, f2, atol=2e-15)
322+
z1 = z[:, :, :, np.newaxis]
323+
f1 = mkl_fft.fft(z1, axis=-1)
324+
f2 = mkl_fft.fft(z1.reshape(z1.shape), axis=-1)
325+
assert_allclose(f1, f2, atol=2e-15)
326+
327+
306328
class Test_mklfft_rfft(TestCase):
307329
def setUp(self):
308330
rnd.seed(1234567)

0 commit comments

Comments
 (0)