Skip to content

Commit 9e05d64

Browse files
Due to use of global variable as a cache of Dfti descriptor for 1D transforms,
1D calls need to use lock.
1 parent dc6647e commit 9e05d64

File tree

2 files changed

+76
-66
lines changed

2 files changed

+76
-66
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 76 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ except ModuleNotFoundError:
3535

3636
from libc.string cimport memcpy
3737

38+
from threading import Lock
39+
_lock = Lock()
40+
3841
cdef extern from "Python.h":
3942
ctypedef int size_t
4043

@@ -289,18 +292,19 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
289292
in_place = 1
290293

291294
if in_place:
292-
if x_type is cnp.NPY_CDOUBLE:
293-
if dir_ < 0:
294-
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
295-
else:
296-
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
297-
elif x_type is cnp.NPY_CFLOAT:
298-
if dir_ < 0:
299-
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
295+
with _lock:
296+
if x_type is cnp.NPY_CDOUBLE:
297+
if dir_ < 0:
298+
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
299+
else:
300+
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
301+
elif x_type is cnp.NPY_CFLOAT:
302+
if dir_ < 0:
303+
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
304+
else:
305+
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
300306
else:
301-
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
302-
else:
303-
status = 1
307+
status = 1
304308

305309
if status:
306310
raise ValueError("Internal error, status={}".format(status))
@@ -318,36 +322,37 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
318322
f_arr = __allocate_result(x_arr, n_, axis_, f_type);
319323

320324
# call out-of-place FFT
321-
if f_type is cnp.NPY_CDOUBLE:
322-
if x_type is cnp.NPY_DOUBLE:
323-
if dir_ < 0:
324-
status = double_cdouble_mkl_ifft1d_out(
325-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
326-
else:
327-
status = double_cdouble_mkl_fft1d_out(
328-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
329-
elif x_type is cnp.NPY_CDOUBLE:
330-
if dir_ < 0:
331-
status = cdouble_cdouble_mkl_ifft1d_out(
332-
x_arr, n_, <int> axis_, f_arr)
333-
else:
334-
status = cdouble_cdouble_mkl_fft1d_out(
335-
x_arr, n_, <int> axis_, f_arr)
336-
else:
337-
if x_type is cnp.NPY_FLOAT:
338-
if dir_ < 0:
339-
status = float_cfloat_mkl_ifft1d_out(
340-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
341-
else:
342-
status = float_cfloat_mkl_fft1d_out(
343-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
344-
elif x_type is cnp.NPY_CFLOAT:
345-
if dir_ < 0:
346-
status = cfloat_cfloat_mkl_ifft1d_out(
347-
x_arr, n_, <int> axis_, f_arr)
348-
else:
349-
status = cfloat_cfloat_mkl_fft1d_out(
350-
x_arr, n_, <int> axis_, f_arr)
325+
with _lock:
326+
if f_type is cnp.NPY_CDOUBLE:
327+
if x_type is cnp.NPY_DOUBLE:
328+
if dir_ < 0:
329+
status = double_cdouble_mkl_ifft1d_out(
330+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
331+
else:
332+
status = double_cdouble_mkl_fft1d_out(
333+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
334+
elif x_type is cnp.NPY_CDOUBLE:
335+
if dir_ < 0:
336+
status = cdouble_cdouble_mkl_ifft1d_out(
337+
x_arr, n_, <int> axis_, f_arr)
338+
else:
339+
status = cdouble_cdouble_mkl_fft1d_out(
340+
x_arr, n_, <int> axis_, f_arr)
341+
else:
342+
if x_type is cnp.NPY_FLOAT:
343+
if dir_ < 0:
344+
status = float_cfloat_mkl_ifft1d_out(
345+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
346+
else:
347+
status = float_cfloat_mkl_fft1d_out(
348+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
349+
elif x_type is cnp.NPY_CFLOAT:
350+
if dir_ < 0:
351+
status = cfloat_cfloat_mkl_ifft1d_out(
352+
x_arr, n_, <int> axis_, f_arr)
353+
else:
354+
status = cfloat_cfloat_mkl_fft1d_out(
355+
x_arr, n_, <int> axis_, f_arr)
351356

352357
if (status):
353358
raise ValueError("Internal error occurred, status={}".format(status))
@@ -399,18 +404,19 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
399404
in_place = 1
400405

401406
if in_place:
402-
if x_type is cnp.NPY_DOUBLE:
403-
if dir_ < 0:
404-
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
405-
else:
406-
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
407-
elif x_type is cnp.NPY_FLOAT:
408-
if dir_ < 0:
409-
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
407+
with _lock:
408+
if x_type is cnp.NPY_DOUBLE:
409+
if dir_ < 0:
410+
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
411+
else:
412+
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
413+
elif x_type is cnp.NPY_FLOAT:
414+
if dir_ < 0:
415+
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
416+
else:
417+
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
410418
else:
411-
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
412-
else:
413-
status = 1
419+
status = 1
414420

415421
if status:
416422
raise ValueError("Internal error, status={}".format(status))
@@ -426,16 +432,17 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
426432
f_arr = __allocate_result(x_arr, n_, axis_, x_type);
427433

428434
# call out-of-place FFT
429-
if x_type is cnp.NPY_DOUBLE:
430-
if dir_ < 0:
431-
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
432-
else:
433-
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
434-
else:
435-
if dir_ < 0:
436-
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
435+
with _lock:
436+
if x_type is cnp.NPY_DOUBLE:
437+
if dir_ < 0:
438+
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
439+
else:
440+
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
437441
else:
438-
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
442+
if dir_ < 0:
443+
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
444+
else:
445+
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
439446

440447
if (status):
441448
raise ValueError("Internal error occurred, status={}".format(status))
@@ -487,9 +494,11 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
487494

488495
# call out-of-place FFT
489496
if x_type is cnp.NPY_FLOAT:
490-
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
497+
with _lock:
498+
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
491499
else:
492-
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
500+
with _lock:
501+
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
493502

494503
if (status):
495504
raise ValueError("Internal error occurred, with status={}".format(status))
@@ -563,9 +572,11 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
563572

564573
# call out-of-place FFT
565574
if x_type is cnp.NPY_CFLOAT:
566-
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
575+
with _lock:
576+
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
567577
else:
568-
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
578+
with _lock:
579+
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
569580

570581
if (status):
571582
raise ValueError("Internal error occurred, status={}".format(status))

mkl_fft/setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def configuration(parent_package='',top_path=None):
5151

5252
config.add_extension(
5353
name = '_pydfti',
54-
# module_name = 'mkl_fft._pydfti',
5554
sources = [
5655
join(wdir, 'mklfft.c.src'),
5756
join(wdir, 'multi_iter.c'),

0 commit comments

Comments
 (0)