diff --git a/Cargo.toml b/Cargo.toml index f4a3c3ff..4a9d8217 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,12 +15,11 @@ source = ["fftw-sys/source"] intel-mkl = ["fftw-sys/intel-mkl"] [dependencies] -num-traits = "0.1.37" -num-complex = "0.1.37" -lazy_static = "0.2.2" -derive-new = "0.4" -ndarray = "0.10" -ndarray-linalg = "0.7" +num-traits = "0.1" +num-complex = "0.1" +lazy_static = "0.2" +derive-new = "0.5" +ndarray = "0.11" procedurals = "0.2" [dependencies.fftw-sys] diff --git a/examples/r2r.rs b/examples/r2r.rs deleted file mode 100644 index 8d7128be..00000000 --- a/examples/r2r.rs +++ /dev/null @@ -1,18 +0,0 @@ -extern crate fftw; -extern crate fftw_sys as ffi; - -use fftw::*; - -fn main() { - let n = 128; - // Create a pair of array for out-place transform of FFTW - let mut pair = r2hc_1d(n).to_pair().unwrap(); - // Initialize to `cos(x)` in bficient space - pair.b.as_view_mut()[1] = 1.0; - // execute rDCT - pair.exec_backward(); - - for val in pair.a.as_slice().iter() { - println!("{}", val); - } -} diff --git a/src/array.rs b/src/array.rs index 8d390ddf..d67d16c9 100644 --- a/src/array.rs +++ b/src/array.rs @@ -1,6 +1,6 @@ -use super::{c32, c64, FFTW_MUTEX}; use error::*; use ffi; +use types::*; use ndarray::*; use num_traits::Zero; @@ -147,9 +147,7 @@ impl DerefMut for AlignedVec { impl Drop for AlignedVec { fn drop(&mut self) { - let lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - unsafe { ffi::fftw_free(self.data as *mut c_void) }; - drop(lock); + excall! { ffi::fftw_free(self.data as *mut c_void) }; } } @@ -159,9 +157,7 @@ where { /// Create array with `fftw_malloc` (`fftw_free` is automatically called when the arrya is `Drop`-ed) pub fn new(n: usize) -> Self { - let lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - let ptr = unsafe { T::alloc(n) }; - drop(lock); + let ptr = excall! { T::alloc(n) }; let mut vec = AlignedVec { n: n, data: ptr }; for v in vec.iter_mut() { *v = T::zero(); diff --git a/src/c2c.rs b/src/c2c.rs deleted file mode 100644 index 3e471bea..00000000 --- a/src/c2c.rs +++ /dev/null @@ -1,42 +0,0 @@ -use super::*; -use super::array::*; -use super::error::*; -use super::pair::{Pair, ToPair}; -use super::plan::*; -use super::traits::*; - -use ndarray::*; -use ndarray_linalg::Scalar; - -/// Setting for 1-dimensional C2C transform -#[derive(Debug, Clone, Copy, new)] -pub struct C2C1D { - n: usize, - sign: Sign, - flag: Flag, -} - -/// Utility function to generage 1-dimensional C2C setting -pub fn c2c_1d(n: usize) -> C2C1D { - C2C1D { - n, - sign: Sign::Forward, - flag: Flag::Measure, - } -} - -impl ToPair for C2C1D { - type Dim = Ix1; - fn to_pair(&self) -> Result> { - let mut a = AlignedVec::new(self.n); - let mut b = AlignedVec::new(self.n); - let forward = unsafe { T::c2c_1d(self.n, &mut a, &mut b, self.sign, self.flag) }; - let backward = unsafe { T::c2c_1d(self.n, &mut b, &mut a, -self.sign, self.flag) }; - Pair { - a: AlignedArray::from_vec(a), - b: AlignedArray::from_vec(b), - forward: Plan::with_factor(forward, Scalar::from_f64(1.0 / self.n as f64)), - backward: Plan::new(backward), - }.null_checked() - } -} diff --git a/src/error.rs b/src/error.rs index a617f45c..bc5111b4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,17 +1,14 @@ use ndarray::ShapeError; -pub use ndarray_linalg::error::{MemoryContError, StrideError}; pub type Result = ::std::result::Result; -use super::nae::NAEInputMismatchError; +use super::plan::InputMismatchError; #[derive(Debug, IntoEnum)] pub enum Error { InvalidPlanError(InvalidPlanError), ShapeError(ShapeError), - StrideError(StrideError), - MemoryContError(MemoryContError), - NAEInputMismatchError(NAEInputMismatchError), + InputMismatchError(InputMismatchError), } #[derive(Debug, new)] diff --git a/src/lib.rs b/src/lib.rs index c59f88e9..cae306c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,94 +5,28 @@ extern crate lazy_static; #[macro_use] extern crate procedurals; +extern crate fftw_sys as ffi; + extern crate ndarray; extern crate num_complex; extern crate num_traits; -// XXX For ndarray_linalg::Scalar -// Will be removed if the following PR to num-complex is merged -// https://github.com/rust-num/num/pull/338 -extern crate ndarray_linalg; - -extern crate fftw_sys as ffi; - -pub mod pair; -pub mod r2r; -pub mod r2c; -pub mod c2c; -pub mod array; -pub mod error; -pub mod plan; -pub mod nae; -pub mod traits; - -pub use ffi::fftw_complex as c64; -pub use ffi::fftwf_complex as c32; - -pub use c2c::*; -pub use pair::*; -pub use r2c::*; -pub use r2r::*; -pub use traits::*; use std::sync::Mutex; + lazy_static! { pub static ref FFTW_MUTEX: Mutex<()> = Mutex::new(()); } -#[repr(i32)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum Sign { - Forward = -1, - Backward = 1, -} - -impl ::std::ops::Neg for Sign { - type Output = Sign; - fn neg(self) -> Self::Output { - match self { - Sign::Forward => Sign::Backward, - Sign::Backward => Sign::Forward, - } - } -} - -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub enum Flag { - Measure, - DestroyInput, - Unaligned, - ConserveMemory, - Exhausive, - PreserveInput, - Patient, - Estimate, - WisdowmOnly, - Mixed(u32), -} - -impl Into for Flag { - fn into(self) -> u32 { - use Flag::*; - match self { - Measure => 0, - DestroyInput => 1 << 0, - Unaligned => 1 << 1, - ConserveMemory => 1 << 2, - Exhausive => 1 << 3, - PreserveInput => 1 << 4, - Patient => 1 << 5, - Estimate => 1 << 6, - WisdowmOnly => 1 << 21, - Mixed(u) => u, - } +/// Exclusive call of FFTW interface. +macro_rules! excall { + ($call:expr) => { + { + let _lock = $crate::FFTW_MUTEX.lock().expect("Cannot get lock"); + unsafe { $call } } -} +}} // excall! -impl ::std::ops::BitOr for Flag { - type Output = Self; - fn bitor(self, rhs: Self) -> Self { - let lhs: u32 = self.into(); - let rhs: u32 = rhs.into(); - Flag::Mixed(lhs | rhs) - } -} +pub mod array; +pub mod error; +pub mod types; +pub mod plan; diff --git a/src/nae.rs b/src/nae.rs deleted file mode 100644 index 12559f94..00000000 --- a/src/nae.rs +++ /dev/null @@ -1,350 +0,0 @@ -use super::*; -use error::*; -use ffi::*; - -#[derive(Debug)] -pub struct C2CPlan { - plan: C::Plan, - alignment: Alignment, -} - -impl Drop for C2CPlan { - fn drop(&mut self) { - C::destroy_plan(self.plan); - } -} - -impl> C2CPlan { - pub fn new( - shape: &[usize], - in_: &mut [C], - out: &mut [C], - sign: Sign, - flag: Flag, - ) -> Result { - Ok(Self { - plan: C::plan_c2c(&shape.to_cint(), in_, out, sign, flag)?, - alignment: Alignment::new(in_, out), - }) - } - pub fn c2c(&mut self, in_: &mut [C], out: &mut [C]) -> Result<()> { - self.alignment.check(in_, out)?; - C::exec_c2c(self.plan, in_, out); - Ok(()) - } -} - -#[derive(Debug)] -pub struct C2RPlan -where - C: FFTW, - R: FFTW, -{ - plan: C::Plan, - alignment: Alignment, -} - -impl Drop for C2RPlan -where - C: FFTW, - R: FFTW, -{ - fn drop(&mut self) { - R::destroy_plan(self.plan); - } -} - -impl C2RPlan -where - C: FFTW, - R: FFTW, -{ - pub fn new(shape: &[usize], in_: &mut [C], out: &mut [R], flag: Flag) -> Result { - Ok(Self { - plan: C::plan_c2r(&shape.to_cint(), in_, out, flag)?, - alignment: Alignment::new(in_, out), - }) - } - pub fn c2r(&mut self, in_: &mut [C], out: &mut [R]) -> Result<()> { - self.alignment.check(in_, out)?; - C::exec_c2r(self.plan, in_, out); - Ok(()) - } -} - -#[derive(Debug)] -pub struct R2CPlan -where - C: FFTW, - R: FFTW, -{ - plan: C::Plan, - alignment: Alignment, -} - -impl Drop for R2CPlan -where - C: FFTW, - R: FFTW, -{ - fn drop(&mut self) { - R::destroy_plan(self.plan); - } -} - -impl R2CPlan -where - C: FFTW, - R: FFTW, -{ - pub fn new(shape: &[usize], in_: &mut [R], out: &mut [C], flag: Flag) -> Result { - Ok(Self { - plan: C::plan_r2c(&shape.to_cint(), in_, out, flag)?, - alignment: Alignment::new(in_, out), - }) - } - pub fn r2c(&mut self, in_: &mut [R], out: &mut [C]) -> Result<()> { - self.alignment.check(in_, out)?; - C::exec_r2c(self.plan, in_, out); - Ok(()) - } -} - -pub type R2RKind = ffi::fftw_r2r_kind; - -#[derive(Debug)] -pub struct R2RPlan { - plan: R::Plan, - alignment: Alignment, -} - -impl Drop for R2RPlan { - fn drop(&mut self) { - R::destroy_plan(self.plan); - } -} - -impl> R2RPlan { - pub fn new( - shape: &[usize], - in_: &mut [R], - out: &mut [R], - kinds: &[R2RKind], - flag: Flag, - ) -> Result { - Ok(Self { - plan: R::plan_r2r(&shape.to_cint(), in_, out, kinds, flag)?, - alignment: Alignment::new(in_, out), - }) - } - pub fn r2c(&mut self, in_: &mut [R], out: &mut [R]) -> Result<()> { - self.alignment.check(in_, out)?; - R::exec_r2r(self.plan, in_, out); - Ok(()) - } -} - -pub trait Plan: Sized { - fn check_null(self) -> Result; -} - -impl Plan for fftw_plan { - fn check_null(self) -> Result { - if self.is_null() { - Err(InvalidPlanError::new().into()) - } else { - Ok(self) - } - } -} - -impl Plan for fftwf_plan { - fn check_null(self) -> Result { - if self.is_null() { - Err(InvalidPlanError::new().into()) - } else { - Ok(self) - } - } -} - -/// Switch `fftw_*` and `fftwf_*` -pub trait FFTW { - type Plan: Plan + Copy; - type Real; - type Complex; - fn destroy_plan(Self::Plan); - fn print_plan(Self::Plan); - fn plan_c2c( - shape: &[i32], - in_: &mut [Self::Complex], - out: &mut [Self::Complex], - Sign, - Flag, - ) -> Result; - fn plan_c2r( - shape: &[i32], - in_: &mut [Self::Complex], - out: &mut [Self::Real], - Flag, - ) -> Result; - fn plan_r2c( - shape: &[i32], - in_: &mut [Self::Real], - out: &mut [Self::Complex], - Flag, - ) -> Result; - fn plan_r2r( - shape: &[i32], - in_: &mut [Self::Real], - out: &mut [Self::Real], - &[R2RKind], - Flag, - ) -> Result; - fn exec_c2c(p: Self::Plan, in_: &mut [Self::Complex], out: &mut [Self::Complex]); - fn exec_c2r(p: Self::Plan, in_: &mut [Self::Complex], out: &mut [Self::Real]); - fn exec_r2c(p: Self::Plan, in_: &mut [Self::Real], out: &mut [Self::Complex]); - fn exec_r2r(p: Self::Plan, in_: &mut [Self::Real], out: &mut [Self::Real]); - fn alignment_of(&[T]) -> i32; -} - -macro_rules! excall { - ($call:expr) => { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - unsafe { $call } - } -} - -macro_rules! impl_fftw { ($scalar:ty) => { -impl FFTW for $scalar { - type Real = f64; - type Complex = c64; - type Plan = fftw_plan; - fn destroy_plan(p: Self::Plan) { - excall!{ fftw_destroy_plan(p) }; - } - fn print_plan(p: Self::Plan) { - excall!{ fftw_print_plan(p) }; - } - fn plan_c2c(shape: &[i32], in_: &mut [Self::Complex], out: &mut [Self::Complex], sign: Sign, flag: Flag) -> Result { - excall!{ fftw_plan_dft(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), sign as i32, flag.into()).check_null() } - } - fn plan_c2r(shape: &[i32], in_: &mut [Self::Complex], out: &mut [Self::Real], flag: Flag) -> Result { - excall!{ fftw_plan_dft_c2r(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), flag.into()).check_null() } - } - fn plan_r2c(shape: &[i32], in_: &mut [Self::Real], out: &mut [Self::Complex], flag: Flag) -> Result { - excall!{ fftw_plan_dft_r2c(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), flag.into()).check_null() } - } - fn plan_r2r(shape: &[i32], in_: &mut [Self::Real], out: &mut [Self::Real], kinds: &[R2RKind], flag: Flag) -> Result { - excall!{ fftw_plan_r2r(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), kinds.as_ptr(), flag.into()).check_null() } - } - fn exec_c2c(p: Self::Plan, in_: &mut [Self::Complex], out: &mut [Self::Complex]) { - unsafe { fftw_execute_dft(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn exec_c2r(p: Self::Plan, in_: &mut [Self::Complex], out: &mut [Self::Real]) { - unsafe { fftw_execute_dft_c2r(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn exec_r2c(p: Self::Plan, in_: &mut [Self::Real], out: &mut [Self::Complex]) { - unsafe { fftw_execute_dft_r2c(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn exec_r2r(p: Self::Plan, in_: &mut [Self::Real], out: &mut [Self::Real]) { - unsafe { fftw_execute_r2r(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn alignment_of(s: &[T]) -> i32 { - unsafe { fftw_alignment_of(s.as_ptr() as *mut _) } - } -} -}} // impl_fftw - -impl_fftw!(f64); -impl_fftw!(c64); - -macro_rules! impl_fftwf { ($scalar:ty) => { -impl FFTW for $scalar { - type Real = f32; - type Complex = c32; - type Plan = fftwf_plan; - fn destroy_plan(p: Self::Plan) { - excall!{ fftwf_destroy_plan(p) }; - } - fn print_plan(p: Self::Plan) { - excall!{ fftwf_print_plan(p) }; - } - fn plan_c2c(shape: &[i32], in_: &mut [Self::Complex], out: &mut [Self::Complex], sign: Sign, flag: Flag) -> Result { - excall!{ fftwf_plan_dft(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), sign as i32, flag.into()).check_null() } - } - fn plan_c2r(shape: &[i32], in_: &mut [Self::Complex], out: &mut [Self::Real], flag: Flag) -> Result { - excall!{ fftwf_plan_dft_c2r(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), flag.into()).check_null() } - } - fn plan_r2c(shape: &[i32], in_: &mut [Self::Real], out: &mut [Self::Complex], flag: Flag) -> Result { - excall!{ fftwf_plan_dft_r2c(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), flag.into()).check_null() } - } - fn plan_r2r(shape: &[i32], in_: &mut [Self::Real], out: &mut [Self::Real], kinds: &[R2RKind], flag: Flag) -> Result { - excall!{ fftwf_plan_r2r(shape.len() as i32, shape.as_ptr(), in_.as_mut_ptr(), out.as_mut_ptr(), kinds.as_ptr(), flag.into()).check_null() } - } - fn exec_c2c(p: Self::Plan, in_: &mut [Self::Complex], out: &mut [Self::Complex]) { - unsafe { fftwf_execute_dft(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn exec_c2r(p: Self::Plan, in_: &mut [Self::Complex], out: &mut [Self::Real]) { - unsafe { fftwf_execute_dft_c2r(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn exec_r2c(p: Self::Plan, in_: &mut [Self::Real], out: &mut [Self::Complex]) { - unsafe { fftwf_execute_dft_r2c(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn exec_r2r(p: Self::Plan, in_: &mut [Self::Real], out: &mut [Self::Real]) { - unsafe { fftwf_execute_r2r(p, in_.as_mut_ptr(), out.as_mut_ptr()) }; - } - fn alignment_of(s: &[T]) -> i32 { - unsafe { fftwf_alignment_of(s.as_ptr() as *mut _) } - } -} -}} // impl_fftwf - -impl_fftwf!(f32); -impl_fftwf!(c32); - -#[derive(Debug)] -pub struct NAEInputMismatchError { - origin: Alignment, - args: Alignment, -} - -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -struct Alignment { - in_: i32, - out: i32, - n_in_: usize, - n_out: usize, -} - -impl Alignment { - fn new(in_: &[A], out: &[B]) -> Self { - Self { - in_: A::alignment_of(in_), - out: B::alignment_of(out), - n_in_: in_.len(), - n_out: out.len(), - } - } - fn check(&self, in_: &[A], out: &[B]) -> Result<()> { - let args = Self::new(in_, out); - if *self != args { - Err(NAEInputMismatchError { - origin: *self, - args, - }.into()) - } else { - Ok(()) - } - } -} - -trait ToCInt { - fn to_cint(&self) -> Vec; -} - -impl ToCInt for [usize] { - fn to_cint(&self) -> Vec { - self.iter().map(|&x| x as i32).collect() - } -} diff --git a/src/pair.rs b/src/pair.rs deleted file mode 100644 index f2c4f5ae..00000000 --- a/src/pair.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! Safe-interface corresponding to out-place transform - -use super::array::*; -use super::error::*; -use super::plan::*; - -use ndarray::*; -use ndarray_linalg::Scalar; - -/// Safe-interface corresponding to out-place transform -/// -/// FFTW interface modifies an array in `fftw_execute` function -/// which does not takes the array as its arguments. -/// It is not compatible to the programing model of safe Rust. -/// `Pair` interface composes the array and plan to manage -/// mutability in the safe Rust way. -#[derive(Debug)] -pub struct Pair -where - A: Scalar + AlignedAllocable, - B: Scalar + AlignedAllocable, -{ - pub a: AlignedArray, - pub b: AlignedArray, - pub(crate) forward: Plan, - pub(crate) backward: Plan, -} - -impl Pair -where - A: Scalar + AlignedAllocable, - B: Scalar + AlignedAllocable, -{ - /// Execute `Pair::forward` with `ndarray::ArrayView` - pub fn forward_array<'a, 'b>( - &'a mut self, - input: ArrayView<'b, A, D>, - ) -> ArrayViewMut<'a, B, D> { - self.a.as_view_mut().assign(&input); - self.exec_forward(); - self.b.as_view_mut() - } - - /// Execute `Pair::backward` with `ndarray::ArrayView` - pub fn backward_array<'a, 'b>( - &'a mut self, - input: ArrayView<'b, B, D>, - ) -> ArrayViewMut<'a, A, D> { - self.b.as_view_mut().assign(&input); - self.exec_backward(); - self.a.as_view_mut() - } - - /// Executes copy the input to `a`, forward transform, - /// and returns the result `b` as a reference - pub fn forward(&mut self, input: &[A]) -> &mut [B] { - self.a.copy_from_slice(input); - self.exec_forward(); - self.b.as_slice_mut() - } - - /// Execute copy to pair, forward transform, - /// and returns a reference of the result. - pub fn backward(&mut self, input: &[B]) -> &mut [A] { - self.b.copy_from_slice(input); - self.exec_backward(); - self.a.as_slice_mut() - } - - /// Execute a forward transform (`a` to `b`) - pub fn exec_forward(&mut self) { - unsafe { self.forward.execute() } - self.forward.normalize(self.b.as_slice_mut()); - } - - /// Execute a backward transform (`b` to `a`) - pub fn exec_backward(&mut self) { - unsafe { self.backward.execute() } - self.backward.normalize(self.a.as_slice_mut()); - } - - pub(crate) fn null_checked(self) -> Result { - self.forward.check_null()?; - self.backward.check_null()?; - Ok(self) - } -} - -/// Create a `Pair` from a setting struct e.g. `R2C1D`. -pub trait ToPair -where - A: Scalar + AlignedAllocable, - B: Scalar + AlignedAllocable, -{ - type Dim: Dimension; - /// Generate `Pair` from a setting struct - fn to_pair(&self) -> Result>; -} diff --git a/src/plan.rs b/src/plan.rs index f0341f2e..dd1f2983 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -1,299 +1,256 @@ -use super::{Flag, R2R_KIND, Sign, c32, c64, FFTW_MUTEX}; -use super::array::AlignedVec; -use super::error::*; -use ffi; +use error::*; +use ffi::*; +use types::*; -use ndarray_linalg::Scalar; -use std::os::raw::c_void; -use std::ptr::null; +use std::marker::PhantomData; -#[derive(Debug)] -pub struct Plan { - p: RawPlan, - factor: Option, -} +pub type Plan64 = fftw_plan; +pub type Plan32 = fftwf_plan; -impl Plan { - pub fn new(p: RawPlan) -> Self { - Self { p, factor: None } - } +pub type C2CPlan64 = Plan; +pub type C2CPlan32 = Plan; +pub type R2CPlan64 = Plan; +pub type R2CPlan32 = Plan; +pub type C2RPlan64 = Plan; +pub type C2RPlan32 = Plan; - pub fn with_factor(p: RawPlan, f: T::Real) -> Self { - Self { p, factor: Some(f) } - } +pub trait PlanSpec: Clone + Copy { + fn validate(self) -> Result; + fn destroy(self); + fn print(self); +} - pub unsafe fn execute(&self) { - self.p.execute() - } +pub struct Plan { + plan: Plan, + alignment: Alignment, + phantom: PhantomData<(A, B)>, +} - fn get_factor(&self) -> Option<&T::Real> { - self.factor.as_ref() +impl Drop for Plan { + fn drop(&mut self) { + self.plan.destroy(); } +} - pub fn check_null(&self) -> Result<()> { - if self.p.is_null() { - Err(InvalidPlanError::new().into()) - } else { - Ok(()) - } - } +pub trait C2CPlan: Sized { + type Complex; - pub fn normalize(&self, array: &mut [T]) { - if let Some(n) = self.get_factor() { - for val in array.iter_mut() { - *val = val.mul_real(*n); - } - } - } -} + /// Create new plan + fn new( + shape: &[usize], + in_: &mut [Self::Complex], + out: &mut [Self::Complex], + sign: Sign, + flag: Flag, + ) -> Result; -#[derive(Debug)] -pub enum RawPlan { - _64(ffi::fftw_plan), - _32(ffi::fftwf_plan), + /// Execute complex-to-complex transform + fn c2c(&mut self, in_: &mut [Self::Complex], out: &mut [Self::Complex]) -> Result<()>; } -impl RawPlan { - /// Execute FFT saved in the plan - /// - /// This is unsafe because rewrite the array saved in the plan. - pub unsafe fn execute(&self) { - if self.is_null() { - panic!("Plan is NULL"); - } - match *self { - RawPlan::_64(p) => ffi::fftw_execute(p), - RawPlan::_32(p) => ffi::fftwf_execute(p), - } - } +pub trait R2CPlan: Sized { + type Real; + type Complex; - /// Check if the plan is NULL - pub fn is_null(&self) -> bool { - let p = match *self { - RawPlan::_64(p) => p as *const c_void, - RawPlan::_32(p) => p as *const c_void, - }; - p == null() - } -} + /// Create new plan + fn new( + shape: &[usize], + in_: &mut [Self::Real], + out: &mut [Self::Complex], + flag: Flag, + ) -> Result; -impl Drop for RawPlan { - fn drop(&mut self) { - if self.is_null() { - // TODO warning - return; - } - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - unsafe { - match *self { - RawPlan::_64(p) => ffi::fftw_destroy_plan(p), - RawPlan::_32(p) => ffi::fftwf_destroy_plan(p), - } - } - } + /// Execute real-to-complex transform + fn r2c(&mut self, in_: &mut [Self::Real], out: &mut [Self::Complex]) -> Result<()>; } -pub trait R2R: Sized { - unsafe fn r2r_1d( - n: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - R2R_KIND, - Flag, - ) -> RawPlan; - unsafe fn r2r_2d( - n0: usize, - n1: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - R2R_KIND, - R2R_KIND, - Flag, - ) -> RawPlan; - unsafe fn r2r_3d( - n0: usize, - n1: usize, - n2: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - R2R_KIND, - R2R_KIND, - R2R_KIND, - Flag, - ) -> RawPlan; -} -pub trait C2C: Sized { - unsafe fn c2c_1d( - n: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Sign, - Flag, - ) -> RawPlan; - unsafe fn c2c_2d( - n0: usize, - n1: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Sign, - Flag, - ) -> RawPlan; - unsafe fn c2c_3d( - n0: usize, - n1: usize, - n2: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Sign, - Flag, - ) -> RawPlan; +pub trait C2RPlan: Sized { + type Real; + type Complex; + + /// Create new plan + fn new( + shape: &[usize], + in_: &mut [Self::Complex], + out: &mut [Self::Real], + flag: Flag, + ) -> Result; + + /// Execute complex-to-real transform + fn c2r(&mut self, in_: &mut [Self::Complex], out: &mut [Self::Real]) -> Result<()>; } -pub trait R2C { - type Real: Sized; - type Complex: Sized; - unsafe fn r2c_1d( - n: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Flag, - ) -> RawPlan; - unsafe fn c2r_1d( - n: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Flag, - ) -> RawPlan; - unsafe fn r2c_2d( - n0: usize, - n1: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Flag, - ) -> RawPlan; - unsafe fn c2r_2d( - n0: usize, - n1: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Flag, - ) -> RawPlan; - unsafe fn r2c_3d( - n0: usize, - n1: usize, - n2: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Flag, - ) -> RawPlan; - unsafe fn c2r_3d( - n0: usize, - n1: usize, - n2: usize, - in_: &mut AlignedVec, - out: &mut AlignedVec, - Flag, - ) -> RawPlan; +macro_rules! impl_c2c { ($C:ty, $Plan:ty; $plan:ident, $exec:ident) => { +impl C2CPlan for Plan<$C, $C, $Plan> { + type Complex = $C; + fn new( + shape: &[usize], + in_: &mut [Self::Complex], + out: &mut [Self::Complex], + sign: Sign, + flag: Flag, + ) -> Result { + let plan = excall!{ $plan( + shape.len() as i32, + shape.to_cint().as_mut_ptr() as *mut _, + in_.as_mut_ptr(), + out.as_mut_ptr(), + sign as i32, flag.into()) + }.validate()?; + Ok(Self { + plan, + alignment: Alignment::new(in_, out), + phantom: PhantomData, + }) + } + fn c2c(&mut self, in_: &mut [Self::Complex], out: &mut [Self::Complex]) -> Result<()> { + self.alignment.check(in_, out)?; + unsafe { $exec(self.plan, in_.as_mut_ptr(), out.as_mut_ptr()) }; + Ok(()) + } } +}} // impl_c2c! -macro_rules! impl_plan_create { - ($bit:ident, $float:ty, $complex:ty, - $r2r_1d:ident, $r2c_1d:ident, $c2r_1d:ident, $c2c_1d:ident, - $r2r_2d:ident, $r2c_2d:ident, $c2r_2d:ident, $c2c_2d:ident, - $r2r_3d:ident, $r2c_3d:ident, $c2r_3d:ident, $c2c_3d:ident) => { +impl_c2c!(c64, Plan64; fftw_plan_dft, fftw_execute_dft); +impl_c2c!(c32, Plan32; fftwf_plan_dft, fftwf_execute_dft); -impl R2R for $float { - unsafe fn r2r_1d(n: usize, in_: &mut AlignedVec, out: &mut AlignedVec, kind: R2R_KIND, flag: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$r2r_1d(n as i32, in_.as_mut_ptr(), out.as_mut_ptr(), kind, flag.into())) - } - unsafe fn r2r_2d(n0: usize, n1: usize, in_: &mut AlignedVec, out: &mut AlignedVec, k0: R2R_KIND, k1: R2R_KIND, flag: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$r2r_2d(n0 as i32, n1 as i32, in_.as_mut_ptr(), out.as_mut_ptr(), k0, k1, flag.into())) +macro_rules! impl_r2c { ($R:ty, $C:ty, $Plan:ty; $plan:ident, $exec:ident) => { +impl R2CPlan for Plan<$R, $C, $Plan> { + type Real = $R; + type Complex = $C; + fn new( + shape: &[usize], + in_: &mut [Self::Real], + out: &mut [Self::Complex], + flag: Flag, + ) -> Result { + let plan = excall!{ $plan( + shape.len() as i32, + shape.to_cint().as_mut_ptr() as *mut _, + in_.as_mut_ptr(), + out.as_mut_ptr(), + flag.into()) + }.validate()?; + Ok(Self { + plan, + alignment: Alignment::new(in_, out), + phantom: PhantomData, + }) } - unsafe fn r2r_3d(n0: usize, n1: usize, n2: usize, in_: &mut AlignedVec, out: &mut AlignedVec, k0: R2R_KIND, k1: R2R_KIND, k2: R2R_KIND, flag: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$r2r_3d(n0 as i32, n1 as i32, n2 as i32, in_.as_mut_ptr(), out.as_mut_ptr(), k0, k1, k2, flag.into())) + fn r2c(&mut self, in_: &mut [Self::Real], out: &mut [Self::Complex]) -> Result<()> { + self.alignment.check(in_, out)?; + unsafe { $exec(self.plan, in_.as_mut_ptr(), out.as_mut_ptr()) }; + Ok(()) } } +}} // impl_r2c! -impl C2C for $complex { - unsafe fn c2c_1d(n: usize, i: &mut AlignedVec, o: &mut AlignedVec, s : Sign, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$c2c_1d(n as i32, i.as_mut_ptr(), o.as_mut_ptr(), s as i32, f.into())) - } - unsafe fn c2c_2d(n0: usize, n1: usize, i: &mut AlignedVec, o: &mut AlignedVec, s : Sign, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$c2c_2d(n0 as i32, n1 as i32, i.as_mut_ptr(), o.as_mut_ptr(), s as i32, f.into())) +impl_r2c!(f64, c64, Plan64; fftw_plan_dft_r2c, fftw_execute_dft_r2c); +impl_r2c!(f32, c32, Plan32; fftwf_plan_dft_r2c, fftwf_execute_dft_r2c); + +macro_rules! impl_c2r { ($R:ty, $C:ty, $Plan:ty; $plan:ident, $exec:ident) => { +impl C2RPlan for Plan<$C, $R, $Plan> { + type Real = $R; + type Complex = $C; + fn new( + shape: &[usize], + in_: &mut [Self::Complex], + out: &mut [Self::Real], + flag: Flag, + ) -> Result { + let plan = excall!{ $plan( + shape.len() as i32, + shape.to_cint().as_mut_ptr() as *mut _, + in_.as_mut_ptr(), + out.as_mut_ptr(), + flag.into()) + }.validate()?; + Ok(Self { + plan, + alignment: Alignment::new(in_, out), + phantom: PhantomData, + }) } - unsafe fn c2c_3d(n0: usize, n1: usize, n2: usize, i: &mut AlignedVec, o: &mut AlignedVec, s : Sign, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$c2c_3d(n0 as i32, n1 as i32, n2 as i32, i.as_mut_ptr(), o.as_mut_ptr(), s as i32, f.into())) + fn c2r(&mut self, in_: &mut [Self::Complex], out: &mut [Self::Real]) -> Result<()> { + self.alignment.check(in_, out)?; + unsafe { $exec(self.plan, in_.as_mut_ptr(), out.as_mut_ptr()) }; + Ok(()) } } +}} // impl_c2r! -impl R2C for ($float, $complex) { - type Real = $float; - type Complex = $complex; - unsafe fn r2c_1d(n: usize, i: &mut AlignedVec, o: &mut AlignedVec, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$r2c_1d(n as i32, i.as_mut_ptr(), o.as_mut_ptr(), f.into())) - } - unsafe fn c2r_1d(n: usize, i: &mut AlignedVec, o: &mut AlignedVec, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$c2r_1d(n as i32, i.as_mut_ptr(), o.as_mut_ptr(), f.into())) +impl_c2r!(f64, c64, Plan64; fftw_plan_dft_c2r, fftw_execute_dft_c2r); +impl_c2r!(f32, c32, Plan32; fftwf_plan_dft_c2r, fftwf_execute_dft_c2r); + +macro_rules! impl_plan_spec { + ($Plan:ty; $destroy_plan:ident, $print_plan:ident) => { +impl PlanSpec for $Plan { + fn validate(self) -> Result { + if self.is_null() { + Err(InvalidPlanError::new().into()) + } else { + Ok(self) + } } - unsafe fn r2c_2d(n0: usize, n1: usize, i: &mut AlignedVec, o: &mut AlignedVec, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$r2c_2d(n0 as i32, n1 as i32, i.as_mut_ptr(), o.as_mut_ptr(), f.into())) + fn destroy(self) { + excall!{ $destroy_plan(self) } } - unsafe fn c2r_2d(n0: usize, n1: usize, i: &mut AlignedVec, o: &mut AlignedVec, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$c2r_2d(n0 as i32, n1 as i32, i.as_mut_ptr(), o.as_mut_ptr(), f.into())) + fn print(self) { + excall!{ $print_plan(self) } } - unsafe fn r2c_3d(n0: usize, n1: usize, n2: usize, i: &mut AlignedVec, o: &mut AlignedVec, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$r2c_3d(n0 as i32, n1 as i32, n2 as i32, i.as_mut_ptr(), o.as_mut_ptr(), f.into())) +} +}} // impl_plan_spec! + +impl_plan_spec!(Plan64; fftw_destroy_plan, fftw_print_plan); +impl_plan_spec!(Plan32; fftwf_destroy_plan, fftwf_print_plan); + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +struct Alignment { + in_: i32, + out: i32, + n_in_: usize, + n_out: usize, +} + +fn alignment_of(a: &[T]) -> i32 { + unsafe { fftw_alignment_of(a.as_ptr() as *mut _) } +} + +impl Alignment { + fn new(in_: &[A], out: &[B]) -> Self { + Self { + in_: alignment_of(in_), + out: alignment_of(out), + n_in_: in_.len(), + n_out: out.len(), + } } - unsafe fn c2r_3d(n0: usize, n1: usize, n2: usize, i: &mut AlignedVec, o: &mut AlignedVec, f: Flag) -> RawPlan { - let _lock = FFTW_MUTEX.lock().expect("Cannot get lock"); - RawPlan::$bit(ffi::$c2r_3d(n0 as i32, n1 as i32, n2 as i32, i.as_mut_ptr(), o.as_mut_ptr(), f.into())) + + fn check(&self, in_: &[A], out: &[B]) -> Result<()> { + let args = Self::new(in_, out); + if *self != args { + Err(InputMismatchError { + origin: *self, + args, + }.into()) + } else { + Ok(()) + } } } -}} // impl_plan_create +#[derive(Debug)] +pub struct InputMismatchError { + origin: Alignment, + args: Alignment, +} -impl_plan_create!( - _64, - f64, - c64, - fftw_plan_r2r_1d, - fftw_plan_dft_r2c_1d, - fftw_plan_dft_c2r_1d, - fftw_plan_dft_1d, - fftw_plan_r2r_2d, - fftw_plan_dft_r2c_2d, - fftw_plan_dft_c2r_2d, - fftw_plan_dft_2d, - fftw_plan_r2r_3d, - fftw_plan_dft_r2c_3d, - fftw_plan_dft_c2r_3d, - fftw_plan_dft_3d -); -impl_plan_create!( - _32, - f32, - c32, - fftwf_plan_r2r_1d, - fftwf_plan_dft_r2c_1d, - fftwf_plan_dft_c2r_1d, - fftwf_plan_dft_1d, - fftwf_plan_r2r_2d, - fftwf_plan_dft_r2c_2d, - fftwf_plan_dft_c2r_2d, - fftwf_plan_dft_2d, - fftwf_plan_r2r_3d, - fftwf_plan_dft_r2c_3d, - fftwf_plan_dft_c2r_3d, - fftwf_plan_dft_3d -); +trait ToCInt { + fn to_cint(&self) -> Vec; +} + +impl ToCInt for [usize] { + fn to_cint(&self) -> Vec { + self.iter().map(|&x| x as i32).collect() + } +} diff --git a/src/r2c.rs b/src/r2c.rs deleted file mode 100644 index 4de3ea31..00000000 --- a/src/r2c.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::Flag; -use super::array::*; -use super::error::*; -use super::pair::{Pair, ToPair}; -use super::plan::*; -use super::traits::*; - -use ndarray::*; -use ndarray_linalg::Scalar; - -/// Setting for 1-dimensional R2C transform -#[derive(Debug, Clone, Copy, new)] -pub struct R2C1D { - n: usize, - flag: Flag, -} - -/// Utility function to generage 1-dimensional R2C setting -pub fn r2c_1d(n: usize) -> R2C1D { - R2C1D { - n, - flag: Flag::Measure, - } -} - -impl ToPair for R2C1D -where - (R, C): R2C, - R: FFTWReal, - C: FFTWComplex, -{ - type Dim = Ix1; - fn to_pair(&self) -> Result> { - let mut a = AlignedVec::::new(self.n); - let mut b = AlignedVec::::new(self.n / 2 + 1); - let forward = unsafe { <(R, C) as R2C>::r2c_1d(self.n, &mut a, &mut b, self.flag) }; - let backward = unsafe { <(R, C) as R2C>::c2r_1d(self.n, &mut b, &mut a, self.flag) }; - Pair { - a: AlignedArray::from_vec(a), - b: AlignedArray::from_vec(b), - forward: Plan::with_factor(forward, Scalar::from_f64(1.0 / self.n as f64)), - backward: Plan::new(backward), - }.null_checked() - } -} diff --git a/src/r2r.rs b/src/r2r.rs deleted file mode 100644 index d07add4e..00000000 --- a/src/r2r.rs +++ /dev/null @@ -1,74 +0,0 @@ -use super::Flag; -use super::array::*; -use super::error::*; -use super::pair::*; -use super::plan::*; -use super::traits::*; - -pub use ffi::fftw_r2r_kind as R2R_KIND; - -use ndarray::*; -use ndarray_linalg::Scalar; - -fn forward(kind: R2R_KIND) -> R2R_KIND { - match kind { - R2R_KIND::FFTW_R2HC => R2R_KIND::FFTW_R2HC, - R2R_KIND::FFTW_HC2R => R2R_KIND::FFTW_R2HC, - R2R_KIND::FFTW_DHT => R2R_KIND::FFTW_DHT, - R2R_KIND::FFTW_REDFT00 => R2R_KIND::FFTW_REDFT00, - R2R_KIND::FFTW_REDFT01 => R2R_KIND::FFTW_REDFT10, - R2R_KIND::FFTW_REDFT10 => R2R_KIND::FFTW_REDFT10, - R2R_KIND::FFTW_REDFT11 => R2R_KIND::FFTW_REDFT11, - R2R_KIND::FFTW_RODFT00 => R2R_KIND::FFTW_RODFT00, - R2R_KIND::FFTW_RODFT01 => R2R_KIND::FFTW_RODFT10, - R2R_KIND::FFTW_RODFT10 => R2R_KIND::FFTW_RODFT10, - R2R_KIND::FFTW_RODFT11 => R2R_KIND::FFTW_RODFT11, - } -} - -fn backward(kind: R2R_KIND) -> R2R_KIND { - match kind { - R2R_KIND::FFTW_R2HC => R2R_KIND::FFTW_HC2R, - R2R_KIND::FFTW_HC2R => R2R_KIND::FFTW_HC2R, - R2R_KIND::FFTW_DHT => R2R_KIND::FFTW_DHT, - R2R_KIND::FFTW_REDFT00 => R2R_KIND::FFTW_REDFT00, - R2R_KIND::FFTW_REDFT01 => R2R_KIND::FFTW_REDFT01, - R2R_KIND::FFTW_REDFT10 => R2R_KIND::FFTW_REDFT01, - R2R_KIND::FFTW_REDFT11 => R2R_KIND::FFTW_REDFT11, - R2R_KIND::FFTW_RODFT00 => R2R_KIND::FFTW_RODFT00, - R2R_KIND::FFTW_RODFT01 => R2R_KIND::FFTW_RODFT01, - R2R_KIND::FFTW_RODFT10 => R2R_KIND::FFTW_RODFT01, - R2R_KIND::FFTW_RODFT11 => R2R_KIND::FFTW_RODFT11, - } -} - -#[derive(Debug, Clone, Copy, new)] -pub struct R2R1D { - n: usize, - kind: R2R_KIND, - flag: Flag, -} - -pub fn r2hc_1d(n: usize) -> R2R1D { - R2R1D { - n: n, - kind: R2R_KIND::FFTW_R2HC, - flag: Flag::Measure, - } -} - -impl ToPair for R2R1D { - type Dim = Ix1; - fn to_pair(&self) -> Result> { - let mut a = AlignedVec::new(self.n); - let mut b = AlignedVec::new(self.n); - let forward = unsafe { T::r2r_1d(self.n, &mut a, &mut b, forward(self.kind), self.flag) }; - let backward = unsafe { T::r2r_1d(self.n, &mut b, &mut a, backward(self.kind), self.flag) }; - Pair { - a: AlignedArray::from_vec(a), - b: AlignedArray::from_vec(b), - forward: Plan::with_factor(forward, Scalar::from_f64(1.0 / self.n as f64)), - backward: Plan::new(backward), - }.null_checked() - } -} diff --git a/src/traits.rs b/src/traits.rs deleted file mode 100644 index 70948d37..00000000 --- a/src/traits.rs +++ /dev/null @@ -1,14 +0,0 @@ -use super::*; -use array::AlignedAllocable; -use plan::*; - -use ndarray_linalg::Scalar; -use num_traits::*; - -pub trait FFTWReal: Scalar + R2R + AlignedAllocable + Zero {} -pub trait FFTWComplex: Scalar + C2C + AlignedAllocable + Zero {} - -impl FFTWReal for f32 {} -impl FFTWReal for f64 {} -impl FFTWComplex for c32 {} -impl FFTWComplex for c64 {} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 00000000..5f22609e --- /dev/null +++ b/src/types.rs @@ -0,0 +1,64 @@ +use ffi; + +pub use ffi::fftw_complex as c64; +pub use ffi::fftwf_complex as c32; + +pub type R2RKind = ffi::fftw_r2r_kind; + +#[repr(i32)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum Sign { + Forward = -1, + Backward = 1, +} + +impl ::std::ops::Neg for Sign { + type Output = Sign; + fn neg(self) -> Self::Output { + match self { + Sign::Forward => Sign::Backward, + Sign::Backward => Sign::Forward, + } + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum Flag { + Measure, + DestroyInput, + Unaligned, + ConserveMemory, + Exhausive, + PreserveInput, + Patient, + Estimate, + WisdowmOnly, + Mixed(u32), +} + +impl Into for Flag { + fn into(self) -> u32 { + use self::Flag::*; + match self { + Measure => 0, + DestroyInput => 1 << 0, + Unaligned => 1 << 1, + ConserveMemory => 1 << 2, + Exhausive => 1 << 3, + PreserveInput => 1 << 4, + Patient => 1 << 5, + Estimate => 1 << 6, + WisdowmOnly => 1 << 21, + Mixed(u) => u, + } + } +} + +impl ::std::ops::BitOr for Flag { + type Output = Self; + fn bitor(self, rhs: Self) -> Self { + let lhs: u32 = self.into(); + let rhs: u32 = rhs.into(); + Flag::Mixed(lhs | rhs) + } +} diff --git a/tests/c2c.rs b/tests/c2c.rs index 50751d47..f218184c 100644 --- a/tests/c2c.rs +++ b/tests/c2c.rs @@ -1,72 +1,57 @@ extern crate fftw; -extern crate ndarray; -#[macro_use] -extern crate ndarray_linalg; extern crate num_traits; -use fftw::*; -use ndarray::*; -use ndarray_linalg::*; +use fftw::plan::*; +use fftw::types::*; +use num_traits::Zero; /// Check successive forward and backward transform equals to the identity -fn test_identity(mut pair: Pair, rtol: C::Real) { - let a: Array1 = random(pair.a.dim()); - println!("a = {:?}", &a); - let b = pair.forward_array(a.view()).to_owned(); - println!("b = {:?}", &b); - let a2 = pair.backward_array(b.view()); - println!("a2 = {:?}", &a2); - assert_close_l2!(&a2, &a, rtol); -} - -/// Check successive forward and backward transform equals to the identity -fn test_forward(mut pair: Pair, rtol: C::Real) { - let n = pair.a.dim(); - let pi = ::std::f64::consts::PI; - let a: Array1 = - Array::from_iter((0..n).map(|i| Scalar::from_f64((2.0 * pi * i as f64 / n as f64).cos()))); - println!("a = {:?}", &a); - let b = pair.forward_array(a.view()).to_owned(); - println!("b = {:?}", &b); - // cos(x) = (exp(ix) + exp(-ix))/2 - let mut ans: Array1 = Array::zeros(b.len()); - ans[1] = Scalar::from_f64(0.5); - ans[n - 1] = Scalar::from_f64(0.5); - assert_close_l2!(&b, &ans, rtol); -} - -mod c2c_64 { - use super::*; - const N: usize = 32; - const RTOL: f64 = 1e-7; - - #[test] - fn identity() { - let pair: Pair = c2c_1d(N).to_pair().unwrap(); - test_identity(pair, RTOL); +#[test] +fn c2c2c_identity() { + let n = 32; + let mut a = vec![c64::zero(); n]; + let mut b = vec![c64::zero(); n]; + let mut plan: C2CPlan64 = + C2CPlan::new(&[n], &mut a, &mut b, Sign::Forward, Flag::Measure).unwrap(); + for i in 0..n { + a[i] = c64::new(1.0, 0.0); } - - #[test] - fn forward() { - let pair: Pair = c2c_1d(N).to_pair().unwrap(); - test_forward(pair, RTOL); + plan.c2c(&mut a, &mut b).unwrap(); + plan.c2c(&mut b, &mut a).unwrap(); + for v in a.iter() { + let ans = c64::new(n as f64, 0.0); + let dif = (v - ans).norm(); + if dif > 1e-7 { + panic!("Large difference: v={}, dif={}", v, dif); + } } } -mod c2c_32 { - use super::*; - const N: usize = 32; - const RTOL: f32 = 1e-4; - - #[test] - fn identity() { - let pair: Pair = c2c_1d(N).to_pair().unwrap(); - test_identity(pair, RTOL); +/// Check cos transform +#[test] +fn c2c_cos() { + let n = 32; + let mut a = vec![c64::zero(); n]; + let mut b = vec![c64::zero(); n]; + let mut plan: C2CPlan64 = + C2CPlan::new(&[n], &mut a, &mut b, Sign::Forward, Flag::Measure).unwrap(); + let pi = ::std::f64::consts::PI; + for i in 0..n { + a[i] = c64::new((2.0 * pi * i as f64 / n as f64).cos(), 0.0); } - - #[test] - fn forward() { - let pair: Pair = c2c_1d(N).to_pair().unwrap(); - test_forward(pair, RTOL); + plan.c2c(&mut a, &mut b).unwrap(); + for (i, v) in b.iter().enumerate() { + let ans = if i == 1 || i == n - 1 { + 0.5 * n as f64 + } else { + 0.0 + }; + let dif = (v - ans).norm(); + if dif > 1e-7 { + panic!( + "Large difference: v={}, ans={}, dif={}, i={}", + v, ans, dif, i + ); + } } } diff --git a/tests/nae.rs b/tests/nae.rs deleted file mode 100644 index b9b5f470..00000000 --- a/tests/nae.rs +++ /dev/null @@ -1,76 +0,0 @@ -extern crate fftw; -extern crate num_traits; - -use fftw::*; -use num_traits::Zero; - -/// Check successive forward and backward transform equals to the identity -#[test] -fn nae_c2c2c_identity() { - let n = 32; - let mut a = vec![c64::zero(); n]; - let mut b = vec![c64::zero(); n]; - let mut plan = nae::C2CPlan::new(&[n], &mut a, &mut b, Sign::Forward, Flag::Measure).unwrap(); - for i in 0..n { - a[i] = c64::new(1.0, 0.0); - } - plan.c2c(&mut a, &mut b).unwrap(); - plan.c2c(&mut b, &mut a).unwrap(); - for v in a.iter() { - let ans = c64::new(n as f64, 0.0); - let dif = (v - ans).norm(); - if dif > 1e-7 { - panic!("Large difference: v={}, dif={}", v, dif); - } - } -} - -/// Check cos transform -#[test] -fn nae_c2c_cos() { - let n = 32; - let mut a = vec![c64::zero(); n]; - let mut b = vec![c64::zero(); n]; - let mut plan = nae::C2CPlan::new(&[n], &mut a, &mut b, Sign::Forward, Flag::Measure).unwrap(); - let pi = ::std::f64::consts::PI; - for i in 0..n { - a[i] = c64::new((2.0 * pi * i as f64 / n as f64).cos(), 0.0); - } - plan.c2c(&mut a, &mut b).unwrap(); - for (i, v) in b.iter().enumerate() { - let ans = if i == 1 || i == n - 1 { - 0.5 * n as f64 - } else { - 0.0 - }; - let dif = (v - ans).norm(); - if dif > 1e-7 { - panic!( - "Large difference: v={}, ans={}, dif={}, i={}", - v, ans, dif, i - ); - } - } -} - -/// Check successive forward and backward transform equals to the identity -#[test] -fn nae_c2r2c_identity() { - let n = 32; - let mut a = vec![c64::zero(); n / 2 + 1]; - let mut b = vec![0.0; n]; - let mut c2r = nae::C2RPlan::new(&[n], &mut a, &mut b, Flag::Measure).unwrap(); - let mut r2c = nae::R2CPlan::new(&[n], &mut b, &mut a, Flag::Measure).unwrap(); - for i in 0..(n / 2 + 1) { - a[i] = c64::new(1.0, 0.0); - } - c2r.c2r(&mut a, &mut b).unwrap(); - r2c.r2c(&mut b, &mut a).unwrap(); - for v in a.iter() { - let ans = c64::new(n as f64, 0.0); - let dif = (v - ans).norm(); - if dif > 1e-7 { - panic!("Large difference: v={}, dif={}", v, dif); - } - } -} diff --git a/tests/r2c.rs b/tests/r2c.rs index 3c48551e..6bf5a234 100644 --- a/tests/r2c.rs +++ b/tests/r2c.rs @@ -1,78 +1,28 @@ extern crate fftw; -extern crate ndarray; -#[macro_use] -extern crate ndarray_linalg; extern crate num_traits; -use fftw::*; -use ndarray::*; -use ndarray_linalg::*; - -/// Check successive forward and backward transformation conserves. -fn test_identity(mut pair: Pair, rtol: R::Real) -where - R: FFTWReal, - C: FFTWComplex, -{ - let a: Array1 = random(pair.a.dim()); - println!("a = {:?}", &a); - let b = pair.forward_array(a.view()).to_owned(); - println!("b = {:?}", &b); - let a2 = pair.backward_array(b.view()); - println!("a2 = {:?}", &a2); - assert_close_l2!(&a2, &a, rtol); -} - -/// Check `cos(k_0 x)` is transformed `b[1] = 1.0 + 0.0i` -fn test_forward(mut pair: Pair, rtol: C::Real) -where - R: FFTWReal, - C: FFTWComplex, -{ - let n = pair.a.dim(); - let pi = ::std::f64::consts::PI; - let a: Array1 = - Array::from_iter((0..n).map(|i| Scalar::from_f64((2.0 * pi * i as f64 / n as f64).cos()))); - println!("a = {:?}", &a); - let b = pair.forward_array(a.view()).to_owned(); - println!("b = {:?}", &b); - let mut ans: Array1 = Array::zeros(b.len()); - ans[1] = Scalar::from_f64(0.5); // cos(x) = 0.5*exp(ix) + c.c. - assert_close_l2!(&b, &ans, rtol); -} - -mod r2c_64 { - use super::*; - const N: usize = 32; - const RTOL: f64 = 1e-7; - - #[test] - fn identity() { - let pair: Pair = r2c_1d(N).to_pair().unwrap(); - test_identity(pair, RTOL); +use fftw::plan::*; +use fftw::types::*; +use num_traits::Zero; + +/// Check successive forward and backward transform equals to the identity +#[test] +fn c2r2c_identity() { + let n = 32; + let mut a = vec![c64::zero(); n / 2 + 1]; + let mut b = vec![0.0; n]; + let mut c2r: C2RPlan64 = C2RPlan::new(&[n], &mut a, &mut b, Flag::Measure).unwrap(); + let mut r2c: R2CPlan64 = R2CPlan::new(&[n], &mut b, &mut a, Flag::Measure).unwrap(); + for i in 0..(n / 2 + 1) { + a[i] = c64::new(1.0, 0.0); } - - #[test] - fn forward() { - let pair: Pair = r2c_1d(N).to_pair().unwrap(); - test_forward(pair, RTOL); - } -} - -mod r2c_32 { - use super::*; - const N: usize = 32; - const RTOL: f32 = 1e-4; - - #[test] - fn identity() { - let pair: Pair = r2c_1d(N).to_pair().unwrap(); - test_identity(pair, RTOL); - } - - #[test] - fn forward() { - let pair: Pair = r2c_1d(N).to_pair().unwrap(); - test_forward(pair, RTOL); + c2r.c2r(&mut a, &mut b).unwrap(); + r2c.r2c(&mut b, &mut a).unwrap(); + for v in a.iter() { + let ans = c64::new(n as f64, 0.0); + let dif = (v - ans).norm(); + if dif > 1e-7 { + panic!("Large difference: v={}, dif={}", v, dif); + } } } diff --git a/tests/r2r.rs b/tests/r2r.rs deleted file mode 100644 index 906abc33..00000000 --- a/tests/r2r.rs +++ /dev/null @@ -1,80 +0,0 @@ -extern crate fftw; -extern crate ndarray; -#[macro_use] -extern crate ndarray_linalg; -extern crate num_traits; - -use fftw::*; -use ndarray::*; -use ndarray_linalg::*; - -/// Check successive forward and backward transformation conserves. -fn test_identity(mut pair: Pair, rtol: R::Real) -where - R: FFTWReal, -{ - let a: Array1 = random(pair.a.dim()); - println!("a = {:?}", &a); - let b = pair.forward_array(a.view()).to_owned(); - println!("b = {:?}", &b); - let a2 = pair.backward_array(b.view()); - println!("a2 = {:?}", &a2); - assert_close_l2!(&a2, &a, rtol); -} - -/// Check `cos(k_0 x)` is transformed `b[1] = 1.0 + 0.0i` -fn test_forward(mut pair: Pair, rtol: R::Real) -where - R: FFTWReal, -{ - let n = pair.a.dim(); - let pi = ::std::f64::consts::PI; - let a: Array1 = - Array::from_iter((0..n).map(|i| Scalar::from_f64((2.0 * pi * i as f64 / n as f64).cos()))); - println!("a = {:?}", &a); - let b = pair.forward_array(a.view()).to_owned(); - println!("b = {:?}", &b); - let mut ans: Array1 = Array::zeros(b.len()); - ans[1] = Scalar::from_f64(0.5); // cos(x) = 0.5*exp(ix) + c.c. - assert_close_l2!(&b, &ans, rtol); -} - -mod r2r_64 { - use super::*; - const N: usize = 32; - const RTOL: f64 = 1e-7; - - #[cfg_attr(feature = "intel-mkl", should_panic)] - #[test] - fn r2hc_identity() { - let pair: Pair = r2hc_1d(N).to_pair().unwrap(); - test_identity(pair, RTOL); - } - - #[cfg_attr(feature = "intel-mkl", should_panic)] - #[test] - fn r2hc_forward() { - let pair: Pair = r2hc_1d(N).to_pair().unwrap(); - test_forward(pair, RTOL); - } -} - -mod r2r_32 { - use super::*; - const N: usize = 32; - const RTOL: f32 = 1e-4; - - #[cfg_attr(feature = "intel-mkl", should_panic)] - #[test] - fn r2hc_identity() { - let pair: Pair = r2hc_1d(N).to_pair().unwrap(); - test_identity(pair, RTOL); - } - - #[cfg_attr(feature = "intel-mkl", should_panic)] - #[test] - fn r2hc_forward() { - let pair: Pair = r2hc_1d(N).to_pair().unwrap(); - test_forward(pair, RTOL); - } -}