@@ -281,6 +281,37 @@ __cached_notinplace_@DftiCompute_MODE@_@MKL_IN_TYPE@_@MKL_OUT_TYPE@(
281
281
}
282
282
/**end repeat**/
283
283
284
+ inline npy_intp
285
+ compute_distance(npy_intp *x_strides, npy_intp *x_shape, npy_intp x_itemsize, int x_rank, int i1, int i2) {
286
+ npy_intp st1, st2;
287
+ npy_intp sh1 = x_shape[i1], sh2 = x_shape[i2];
288
+ npy_intp min_s;
289
+ if (sh1 > 1 && sh2 > 1) {
290
+ st1 = x_strides[i1];
291
+ st2 = x_strides[i2];
292
+ min_s = (st1 > st2) ? st2 : st1;
293
+
294
+ return min_s;
295
+ } else {
296
+ int i;
297
+ npy_intp max_s;
298
+ max_s = x_itemsize;
299
+ for(i=0; i < x_rank; i++) {
300
+ if (x_shape[i] > 1) {
301
+ if (max_s < x_strides[i]) max_s = x_strides[i];
302
+ }
303
+ }
304
+ min_s = max_s;
305
+ for(i=i1; i <= i2; i++) {
306
+ if (x_shape[i] > 1) {
307
+ if (min_s > x_strides[i]) min_s = x_strides[i];
308
+ }
309
+ }
310
+ }
311
+
312
+ return min_s;
313
+ }
314
+
284
315
static NPY_INLINE int
285
316
compute_strides_and_distances(
286
317
PyArrayObject *x,
@@ -315,11 +346,9 @@ compute_strides_and_distances(
315
346
npy_intp char_dist = 0;
316
347
*num_fft_transfs = _to_mkl_long (x_size / x_shape[axis]);
317
348
if (axis == 0) {
318
- npy_intp s1 = x_strides[1], s2 = x_strides[x_rank-1];
319
- char_dist = (s1 > s2) ? s2 : s1;
349
+ char_dist = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 1, x_rank-1);
320
350
} else {
321
- npy_intp s1 = x_strides[0], s2 = x_strides[x_rank-2];
322
- char_dist = (s1 > s2) ? s2 : s1;
351
+ char_dist = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 0, x_rank-2);
323
352
}
324
353
325
354
*vec_dist = _to_mkl_long (char_dist / x_itemsize);
@@ -375,17 +404,11 @@ compute_strides_and_distances_inout(
375
404
npy_intp char_dist_in = 0, char_dist_out = 0;
376
405
*num_fft_transfs = _to_mkl_long (x_size / x_shape[axis]);
377
406
if (axis == 0) {
378
- npy_intp s1 = x_strides[1], s2 = x_strides[x_rank-1];
379
- char_dist_in = (s1 > s2) ? s2 : s1;
380
-
381
- s1 = y_strides[1]; s2 = y_strides[x_rank-1];
382
- char_dist_out = (s1 > s2) ? s2 : s1;
407
+ char_dist_in = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 1, x_rank-1);
408
+ char_dist_out = compute_distance(y_strides, y_shape, y_itemsize, x_rank, 1, x_rank-1);
383
409
} else {
384
- npy_intp s1 = x_strides[0], s2 = x_strides[x_rank-2];
385
- char_dist_in = (s1 > s2) ? s2 : s1;
386
-
387
- s1 = y_strides[0]; s2 = y_strides[x_rank-2];
388
- char_dist_out = (s1 > s2) ? s2 : s1;
410
+ char_dist_in = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 0, x_rank-2);
411
+ char_dist_out = compute_distance(y_strides, y_shape, y_itemsize, x_rank, 0, x_rank-2);
389
412
}
390
413
*vec_dist_in = _to_mkl_long (char_dist_in / x_itemsize);
391
414
*vec_dist_out = _to_mkl_long (char_dist_out / y_itemsize);
0 commit comments