diff --git a/python/cufinufft/cufinufft/_plan.py b/python/cufinufft/cufinufft/_plan.py index 76eb63471..c9bc3a326 100644 --- a/python/cufinufft/cufinufft/_plan.py +++ b/python/cufinufft/cufinufft/_plan.py @@ -83,20 +83,20 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None, self._plan = None # Setup type bound methods - self.dtype = np.dtype(dtype) + self._dtype = np.dtype(dtype) - if self.dtype == np.complex128: + if self._dtype == np.complex128: self._make_plan = _make_plan self._setpts = _set_pts self._exec_plan = _exec_plan self._destroy_plan = _destroy_plan - self.real_dtype = np.float64 - elif self.dtype == np.complex64: + self._real_dtype = np.float64 + elif self._dtype == np.complex64: self._make_plan = _make_planf self._setpts = _set_ptsf self._exec_plan = _exec_planf self._destroy_plan = _destroy_planf - self.real_dtype = np.float32 + self._real_dtype = np.float32 else: raise TypeError("Expected complex64 or complex128.") @@ -118,12 +118,12 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None, if dim not in [1, 2, 3]: raise ValueError("Only dimensions 1, 2, and 3 supported") - self.dim = dim - self.type = nufft_type - self.isign = isign - self.eps = float(eps) - self.n_modes = n_modes - self.n_trans = n_trans + self._dim = dim + self._type = nufft_type + self._isign = isign + self._eps = float(eps) + self._n_modes = n_modes + self._n_trans = n_trans self._maxbatch = 1 # TODO: optimize this one day # Get the default option values. @@ -146,6 +146,26 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None, # we want to keep around for life of instance. self._references = [] + @property + def type(self): + return self._type + + @property + def dtype(self): + return self._dtype + + @property + def dim(self): + return self._dim + + @property + def n_modes(self): + return self._n_modes + + @property + def n_trans(self): + return self._n_trans + @staticmethod def _default_opts(): """ @@ -174,15 +194,15 @@ def _init_plan(self): # We extend the mode tuple to 3D as needed, # and reorder from C/python ndarray.shape style input (nZ, nY, nX) # to the (F) order expected by the low level library (nX, nY, nZ). - _n_modes = self.n_modes[::-1] + (1,) * (3 - self.dim) + _n_modes = self._n_modes[::-1] + (1,) * (3 - self._dim) _n_modes = (c_int64 * 3)(*_n_modes) - ier = self._make_plan(self.type, - self.dim, + ier = self._make_plan(self._type, + self._dim, _n_modes, - self.isign, - self.n_trans, - self.eps, + self._isign, + self._n_trans, + self._eps, byref(self._plan), self._opts) @@ -209,20 +229,20 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None): points (source for type 1, target for type 2). """ - _x = _ensure_array_type(x, "x", self.real_dtype) - _y = _ensure_array_type(y, "y", self.real_dtype) - _z = _ensure_array_type(z, "z", self.real_dtype) + _x = _ensure_array_type(x, "x", self._real_dtype) + _y = _ensure_array_type(y, "y", self._real_dtype) + _z = _ensure_array_type(z, "z", self._real_dtype) - _x, _y, _z = _ensure_valid_pts(_x, _y, _z, self.dim) + _x, _y, _z = _ensure_valid_pts(_x, _y, _z, self._dim) M = _compat.get_array_size(_x) - if self.type == 3: - _s = _ensure_array_type(s, "s", self.real_dtype) - _t = _ensure_array_type(t, "t", self.real_dtype) - _u = _ensure_array_type(u, "u", self.real_dtype) + if self._type == 3: + _s = _ensure_array_type(s, "s", self._real_dtype) + _t = _ensure_array_type(t, "t", self._real_dtype) + _u = _ensure_array_type(u, "u", self._real_dtype) - _s, _t, _u = _ensure_valid_pts(_s, _t, _u, self.dim) + _s, _t, _u = _ensure_valid_pts(_s, _t, _u, self._dim) N = _compat.get_array_size(_s) else: @@ -242,22 +262,22 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None): # We will also store references to these arrays. # This keeps python from prematurely cleaning them up. self._references.append(_x) - if self.dim >= 2: + if self._dim >= 2: fpts_axes.insert(0, _compat.get_array_ptr(_y)) self._references.append(_y) - if self.dim >= 3: + if self._dim >= 3: fpts_axes.insert(0, _compat.get_array_ptr(_z)) self._references.append(_z) # Do the same for type 3 - if self.type == 3: + if self._type == 3: fpts_axes_t3 = [_compat.get_array_ptr(_s), None, None] self._references.append(_s) - if self.dim >= 2: + if self._dim >= 2: fpts_axes_t3.insert(0, _compat.get_array_ptr(_t)) self._references.append(_t) - if self.dim >= 3: + if self._dim >= 3: fpts_axes_t3.insert(0, _compat.get_array_ptr(_u)) self._references.append(_u) else: @@ -268,8 +288,8 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None): M, *fpts_axes[:3], N, *fpts_axes_t3[:3]) - self.nj = M - self.nk = N + self._nj = M + self._nk = N if ier != 0: raise RuntimeError('Error setting non-uniform points.') @@ -297,37 +317,37 @@ def execute(self, data, out=None): The output array of the transform(s). """ - _data = _ensure_array_type(data, "data", self.dtype) - _out = _ensure_array_type(out, "out", self.dtype, output=True) + _data = _ensure_array_type(data, "data", self._dtype) + _out = _ensure_array_type(out, "out", self._dtype, output=True) - if self.type == 1: - req_data_shape = (self.n_trans, self.nj) - req_out_shape = self.n_modes - elif self.type == 2: - req_data_shape = (self.n_trans, *self.n_modes) - req_out_shape = (self.nj,) - elif self.type == 3: - req_data_shape = (self.n_trans, self.nj) - req_out_shape = (self.nk,) + if self._type == 1: + req_data_shape = (self._n_trans, self._nj) + req_out_shape = self._n_modes + elif self._type == 2: + req_data_shape = (self._n_trans, *self._n_modes) + req_out_shape = (self._nj,) + elif self._type == 3: + req_data_shape = (self._n_trans, self._nj) + req_out_shape = (self._nk,) _data, data_shape = _ensure_array_shape(_data, "data", req_data_shape, allow_reshape=True) - if self.type == 1: + if self._type == 1: batch_shape = data_shape[:-1] else: - batch_shape = data_shape[:-self.dim] + batch_shape = data_shape[:-self._dim] req_out_shape = batch_shape + req_out_shape if out is None: - _out = _compat.array_empty_like(_data, req_out_shape, dtype=self.dtype) + _out = _compat.array_empty_like(_data, req_out_shape, dtype=self._dtype) else: _out = _ensure_array_shape(_out, "out", req_out_shape) - if self.type in [1, 3]: + if self._type in [1, 3]: ier = self._exec_plan(self._plan, _compat.get_array_ptr(_data), _compat.get_array_ptr(_out)) - elif self.type == 2: + elif self._type == 2: ier = self._exec_plan(self._plan, _compat.get_array_ptr(_out), _compat.get_array_ptr(_data)) diff --git a/python/cufinufft/tests/test_basic.py b/python/cufinufft/tests/test_basic.py index f826f6ef2..01645781b 100644 --- a/python/cufinufft/tests/test_basic.py +++ b/python/cufinufft/tests/test_basic.py @@ -166,3 +166,33 @@ def test_opts(to_gpu, to_cpu, shape=(8, 8, 8), M=32, tol=1e-3): fk = to_cpu(fk_gpu) utils.verify_type1(k, c, fk, tol) + + +def test_cufinufft_plan_properties(): + nufft_type = 2 + n_modes = (8, 8) + n_trans = 2 + dtype = np.complex64 + + plan = Plan(nufft_type, n_modes, n_trans, dtype=dtype) + + assert plan.type == nufft_type + assert tuple(plan.n_modes) == n_modes + assert plan.dim == len(n_modes) + assert plan.n_trans == n_trans + assert plan.dtype == dtype + + with pytest.raises(AttributeError): + plan.type = 1 + + with pytest.raises(AttributeError): + plan.n_modes = (4, 4) + + with pytest.raises(AttributeError): + plan.dim = 1 + + with pytest.raises(AttributeError): + plan.n_trans = 1 + + with pytest.raises(AttributeError): + plan.dtype = np.float64 diff --git a/python/finufft/finufft/_interfaces.py b/python/finufft/finufft/_interfaces.py index bd6c17564..23d5ef9e8 100644 --- a/python/finufft/finufft/_interfaces.py +++ b/python/finufft/finufft/_interfaces.py @@ -144,19 +144,38 @@ def __init__(self,nufft_type,n_modes_or_dim,n_trans=1,eps=1e-6,isign=None,dtype= err_handler(ier) # set C++ side plan as inner_plan - self.inner_plan = plan + self._inner_plan = plan # set properties - self.type = nufft_type - self.dim = dim - self.n_modes = n_modes - self.n_trans = n_trans + self._type = nufft_type + self._dim = dim + self._n_modes = n_modes + self._n_trans = n_trans if is_single: - self.dtype = np.dtype("complex64") + self._dtype = np.dtype("complex64") else: - self.dtype = np.dtype("complex128") + self._dtype = np.dtype("complex128") + @property + def type(self): + return self._type + + @property + def dtype(self): + return self._dtype + + @property + def dim(self): + return self._dim + + @property + def n_modes(self): + return self._n_modes + + @property + def n_trans(self): + return self._n_trans ### setpts def setpts(self,x=None,y=None,z=None,s=None,t=None,u=None): @@ -187,7 +206,7 @@ def setpts(self,x=None,y=None,z=None,s=None,t=None,u=None): points (target for type 3). """ - real_dtype = _get_real_dtype(self.dtype) + real_dtype = _get_real_dtype(self._dtype) self._xj = _ensure_array_type(x, "x", real_dtype) self._yj = _ensure_array_type(y, "y", real_dtype) @@ -197,17 +216,17 @@ def setpts(self,x=None,y=None,z=None,s=None,t=None,u=None): self._u = _ensure_array_type(u, "u", real_dtype) # valid sizes - dim = self.dim - tp = self.type - (self.nj, self.nk) = valid_setpts(tp, dim, self._xj, self._yj, self._zj, self._s, self._t, self._u) + dim = self._dim + tp = self._type + (self._nj, self._nk) = valid_setpts(tp, dim, self._xj, self._yj, self._zj, self._s, self._t, self._u) # call set pts for single prec plan - if self.dim == 1: - ier = self._setpts(self.inner_plan, self.nj, self._xj, self._yj, self._zj, self.nk, self._s, self._t, self._u) - elif self.dim == 2: - ier = self._setpts(self.inner_plan, self.nj, self._yj, self._xj, self._zj, self.nk, self._t, self._s, self._u) - elif self.dim == 3: - ier = self._setpts(self.inner_plan, self.nj, self._zj, self._yj, self._xj, self.nk, self._u, self._t, self._s) + if self._dim == 1: + ier = self._setpts(self._inner_plan, self._nj, self._xj, self._yj, self._zj, self._nk, self._s, self._t, self._u) + elif self._dim == 2: + ier = self._setpts(self._inner_plan, self._nj, self._yj, self._xj, self._zj, self._nk, self._t, self._s, self._u) + elif self._dim == 3: + ier = self._setpts(self._inner_plan, self._nj, self._zj, self._yj, self._xj, self._nk, self._u, self._t, self._s) if ier != 0: err_handler(ier) @@ -234,17 +253,17 @@ def execute(self,data,out=None): complex[n_modes], complex[n_tr, n_modes], complex[M], or complex[n_tr, M]: The output array of the transform(s). """ - _data = _ensure_array_type(data, "data", self.dtype) - _out = _ensure_array_type(out, "out", self.dtype, output=True) + _data = _ensure_array_type(data, "data", self._dtype) + _out = _ensure_array_type(out, "out", self._dtype, output=True) - tp = self.type - n_trans = self.n_trans - nj = self.nj - nk = self.nk - dim = self.dim + tp = self._type + n_trans = self._n_trans + nj = self._nj + nk = self._nk + dim = self._dim if tp==1 or tp==2: - ms, mt, mu = [*self.n_modes, *([1]*(3-len(self.n_modes)))] + ms, mt, mu = [*self._n_modes, *([1]*(3-len(self._n_modes)))] # input shape and size check if tp==2: @@ -264,19 +283,19 @@ def execute(self,data,out=None): # allocate out if None if out is None: if tp==1: - _out = np.zeros([*data.shape[:-1], *self.n_modes[::-1]], dtype=self.dtype, order='C') + _out = np.zeros([*data.shape[:-1], *self._n_modes[::-1]], dtype=self._dtype, order='C') if tp==2: - _out = np.zeros([*data.shape[:-dim], nj], dtype=self.dtype, order='C') + _out = np.zeros([*data.shape[:-dim], nj], dtype=self._dtype, order='C') if tp==3: - _out = np.zeros([*data.shape[:-1], nk], dtype=self.dtype, order='C') + _out = np.zeros([*data.shape[:-1], nk], dtype=self._dtype, order='C') # call execute based on type and precision type if tp==1 or tp==3: - ier = self._execute(self.inner_plan, + ier = self._execute(self._inner_plan, _data.ctypes.data_as(c_void_p), _out.ctypes.data_as(c_void_p)) elif tp==2: - ier = self._execute(self.inner_plan, + ier = self._execute(self._inner_plan, _out.ctypes.data_as(c_void_p), _data.ctypes.data_as(c_void_p)) @@ -289,7 +308,7 @@ def execute(self,data,out=None): def __del__(self): destroy(self) - self.inner_plan = None + self._inner_plan = None ### End of Plan class definition @@ -491,8 +510,8 @@ def setkwopts(opt,**kwargs): ### destroy def destroy(plan): - if hasattr(plan, "inner_plan"): - ier = plan._destroy(plan.inner_plan) + if hasattr(plan, "_inner_plan"): + ier = plan._destroy(plan._inner_plan) if ier != 0: err_handler(ier) diff --git a/python/finufft/test/test_finufft_plan.py b/python/finufft/test/test_finufft_plan.py index 952a79753..a062c57e3 100644 --- a/python/finufft/test/test_finufft_plan.py +++ b/python/finufft/test/test_finufft_plan.py @@ -238,3 +238,34 @@ def test_finufft_plan_errors(): with pytest.raises(RuntimeError, match="transform type invalid"): plan = Plan(4, (8,)) + + +def test_finufft_plan_properties(): + # Make sure properties work properly + nufft_type = 2 + n_modes = (8, 8) + n_trans = 2 + dtype = np.complex64 + + plan = Plan(nufft_type, n_modes, n_trans, dtype=dtype) + + assert plan.type == nufft_type + assert tuple(plan.n_modes) == n_modes + assert plan.dim == len(n_modes) + assert plan.n_trans == n_trans + assert plan.dtype == dtype + + with pytest.raises(AttributeError): + plan.type = 1 + + with pytest.raises(AttributeError): + plan.n_modes = (4, 4) + + with pytest.raises(AttributeError): + plan.dim = 1 + + with pytest.raises(AttributeError): + plan.n_trans = 1 + + with pytest.raises(AttributeError): + plan.dtype = np.float64