Skip to content

Commit 97def32

Browse files
jturner314LukeMathWalker
authored andcommitted
Implement approx traits for ArrayBase (#581)
* Implement approx traits for ArrayBase * Feature-gate approx trait implementations * Use Zip::all where possible * Mark `all_close` as deprecated * Fix implementation. * Fix issues with conditional execution based on activated features * Remove all_close from all doc tests * Update guide for NumPy users * Replace all_close with abs_diff_eq in tests (currently failing) * Fix typo, pin 0.3.2 to get latest changes * Allow comparison between arrays with different ownership properties * Fix assertions * Fix test * Move tests from all_close to approx * Move tests from all_close to approx * Impl approx traits for differing element types * Fix unused import warning * Remove duplicate type parameter * Fix link in docs * Fix tests * Fix formatting * Remove unnecessary &
1 parent 7733b7c commit 97def32

15 files changed

+274
-68
lines changed

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ itertools = { version = "0.8.0", default-features = false }
3535

3636
rayon = { version = "1.0.3", optional = true }
3737

38+
approx = { version = "0.3.2", optional = true }
39+
3840
# Use via the `blas` crate feature!
3941
cblas-sys = { version = "0.1.4", optional = true, default-features = false }
4042
blas-src = { version = "0.2.0", optional = true, default-features = false }
@@ -47,8 +49,8 @@ serde = { version = "1.0", optional = true }
4749
defmac = "0.2"
4850
quickcheck = { version = "0.8", default-features = false }
4951
rawpointer = "0.1"
52+
approx = "0.3.2"
5053
itertools = { version = "0.8.0", default-features = false, features = ["use_std"] }
51-
approx = "0.3"
5254

5355
[features]
5456
# Enable blas usage
@@ -63,7 +65,7 @@ test-blas-openblas-sys = ["blas"]
6365
test = ["test-blas-openblas-sys"]
6466

6567
# This feature is used for docs
66-
docs = ["serde-1", "rayon"]
68+
docs = ["approx", "serde-1", "rayon"]
6769

6870
[profile.release]
6971
[profile.bench]

src/array_approx.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use crate::imp_prelude::*;
2+
use crate::Zip;
3+
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
4+
5+
/// **Requires crate feature `"approx"`**
6+
impl<A, B, S, S2, D> AbsDiffEq<ArrayBase<S2, D>> for ArrayBase<S, D>
7+
where
8+
A: AbsDiffEq<B>,
9+
A::Epsilon: Clone,
10+
S: Data<Elem = A>,
11+
S2: Data<Elem = B>,
12+
D: Dimension,
13+
{
14+
type Epsilon = A::Epsilon;
15+
16+
fn default_epsilon() -> A::Epsilon {
17+
A::default_epsilon()
18+
}
19+
20+
fn abs_diff_eq(&self, other: &ArrayBase<S2, D>, epsilon: A::Epsilon) -> bool {
21+
if self.shape() != other.shape() {
22+
return false;
23+
}
24+
Zip::from(self)
25+
.and(other)
26+
.all(|a, b| A::abs_diff_eq(a, b, epsilon.clone()))
27+
}
28+
}
29+
30+
/// **Requires crate feature `"approx"`**
31+
impl<A, B, S, S2, D> RelativeEq<ArrayBase<S2, D>> for ArrayBase<S, D>
32+
where
33+
A: RelativeEq<B>,
34+
A::Epsilon: Clone,
35+
S: Data<Elem = A>,
36+
S2: Data<Elem = B>,
37+
D: Dimension,
38+
{
39+
fn default_max_relative() -> A::Epsilon {
40+
A::default_max_relative()
41+
}
42+
43+
fn relative_eq(
44+
&self,
45+
other: &ArrayBase<S2, D>,
46+
epsilon: A::Epsilon,
47+
max_relative: A::Epsilon,
48+
) -> bool {
49+
if self.shape() != other.shape() {
50+
return false;
51+
}
52+
Zip::from(self)
53+
.and(other)
54+
.all(|a, b| A::relative_eq(a, b, epsilon.clone(), max_relative.clone()))
55+
}
56+
}
57+
58+
/// **Requires crate feature `"approx"`**
59+
impl<A, B, S, S2, D> UlpsEq<ArrayBase<S2, D>> for ArrayBase<S, D>
60+
where
61+
A: UlpsEq<B>,
62+
A::Epsilon: Clone,
63+
S: Data<Elem = A>,
64+
S2: Data<Elem = B>,
65+
D: Dimension,
66+
{
67+
fn default_max_ulps() -> u32 {
68+
A::default_max_ulps()
69+
}
70+
71+
fn ulps_eq(&self, other: &ArrayBase<S2, D>, epsilon: A::Epsilon, max_ulps: u32) -> bool {
72+
if self.shape() != other.shape() {
73+
return false;
74+
}
75+
Zip::from(self)
76+
.and(other)
77+
.all(|a, b| A::ulps_eq(a, b, epsilon.clone(), max_ulps))
78+
}
79+
}
80+
81+
#[cfg(test)]
82+
mod tests {
83+
use crate::prelude::*;
84+
use approx::{
85+
assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne,
86+
assert_ulps_eq, assert_ulps_ne,
87+
};
88+
89+
#[test]
90+
fn abs_diff_eq() {
91+
let a: Array2<f32> = array![[0., 2.], [-0.000010001, 100000000.]];
92+
let mut b: Array2<f32> = array![[0., 1.], [-0.000010002, 100000001.]];
93+
assert_abs_diff_ne!(a, b);
94+
b[(0, 1)] = 2.;
95+
assert_abs_diff_eq!(a, b);
96+
97+
// Check epsilon.
98+
assert_abs_diff_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
99+
assert_abs_diff_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);
100+
101+
// Make sure we can compare different shapes without failure.
102+
let c = array![[1., 2.]];
103+
assert_abs_diff_ne!(a, c);
104+
}
105+
106+
#[test]
107+
fn relative_eq() {
108+
let a: Array2<f32> = array![[1., 2.], [-0.000010001, 100000000.]];
109+
let mut b: Array2<f32> = array![[1., 1.], [-0.000010002, 100000001.]];
110+
assert_relative_ne!(a, b);
111+
b[(0, 1)] = 2.;
112+
assert_relative_eq!(a, b);
113+
114+
// Check epsilon.
115+
assert_relative_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
116+
assert_relative_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);
117+
118+
// Make sure we can compare different shapes without failure.
119+
let c = array![[1., 2.]];
120+
assert_relative_ne!(a, c);
121+
}
122+
123+
#[test]
124+
fn ulps_eq() {
125+
let a: Array2<f32> = array![[1., 2.], [-0.000010001, 100000000.]];
126+
let mut b: Array2<f32> = array![[1., 1.], [-0.000010002, 100000001.]];
127+
assert_ulps_ne!(a, b);
128+
b[(0, 1)] = 2.;
129+
assert_ulps_eq!(a, b);
130+
131+
// Check epsilon.
132+
assert_ulps_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
133+
assert_ulps_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);
134+
135+
// Make sure we can compare different shapes without failure.
136+
let c = array![[1., 2.]];
137+
assert_ulps_ne!(a, c);
138+
}
139+
}

src/doc/ndarray_for_numpy_users/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@
473473
//!
474474
//! </td><td>
475475
//!
476-
//! [`a.all_close(&b, 1e-8)`][.all_close()]
476+
//! [`a.abs_diff_eq(&b, 1e-8)`][.abs_diff_eq()]
477477
//!
478478
//! </td><td>
479479
//!
@@ -557,7 +557,7 @@
557557
//! `a[:,4]` | [`a.column(4)`][.column()] or [`a.column_mut(4)`][.column_mut()] | view (or mutable view) of column 4 in a 2-D array
558558
//! `a.shape[0] == a.shape[1]` | [`a.is_square()`][.is_square()] | check if the array is square
559559
//!
560-
//! [.all_close()]: ../../struct.ArrayBase.html#method.all_close
560+
//! [.abs_diff_eq()]: ../../struct.ArrayBase.html#impl-AbsDiffEq<ArrayBase<S2%2C%20D>>
561561
//! [ArcArray]: ../../type.ArcArray.html
562562
//! [arr2()]: ../../fn.arr2.html
563563
//! [array!]: ../../macro.array.html

src/geomspace.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,24 @@ where
106106
#[cfg(test)]
107107
mod tests {
108108
use super::geomspace;
109-
use crate::{arr1, Array1};
110109

111110
#[test]
111+
#[cfg(feature = "approx")]
112112
fn valid() {
113+
use approx::assert_abs_diff_eq;
114+
use crate::{arr1, Array1};
115+
113116
let array: Array1<_> = geomspace(1e0, 1e3, 4).collect();
114-
assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
117+
assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);
115118

116119
let array: Array1<_> = geomspace(1e3, 1e0, 4).collect();
117-
assert!(array.all_close(&arr1(&[1e3, 1e2, 1e1, 1e0]), 1e-5));
120+
assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]), epsilon = 1e-12);
118121

119122
let array: Array1<_> = geomspace(-1e3, -1e0, 4).collect();
120-
assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
123+
assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);
121124

122125
let array: Array1<_> = geomspace(-1e0, -1e3, 4).collect();
123-
assert!(array.all_close(&arr1(&[-1e0, -1e1, -1e2, -1e3]), 1e-5));
126+
assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]), epsilon = 1e-12);
124127
}
125128

126129
#[test]

src/impl_constructors.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,16 @@ impl<S, A> ArrayBase<S, Ix1>
111111
/// **Panics** if the length is greater than `isize::MAX`.
112112
///
113113
/// ```rust
114+
/// use approx::assert_abs_diff_eq;
114115
/// use ndarray::{Array, arr1};
115116
///
117+
/// # #[cfg(feature = "approx")] {
116118
/// let array = Array::logspace(10.0, 0.0, 3.0, 4);
117-
/// assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
119+
/// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]));
118120
///
119121
/// let array = Array::logspace(-10.0, 3.0, 0.0, 4);
120-
/// assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
122+
/// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]));
123+
/// # }
121124
/// ```
122125
pub fn logspace(base: A, start: A, end: A, n: usize) -> Self
123126
where
@@ -136,13 +139,16 @@ impl<S, A> ArrayBase<S, Ix1>
136139
/// **Panics** if `n` is greater than `isize::MAX`.
137140
///
138141
/// ```rust
142+
/// use approx::assert_abs_diff_eq;
139143
/// use ndarray::{Array, arr1};
140144
///
145+
/// # #[cfg(feature = "approx")] {
141146
/// let array = Array::geomspace(1e0, 1e3, 4);
142-
/// assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
147+
/// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);
143148
///
144149
/// let array = Array::geomspace(-1e3, -1e0, 4);
145-
/// assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
150+
/// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);
151+
/// # }
146152
/// ```
147153
pub fn geomspace(start: A, end: A, n: usize) -> Self
148154
where

src/impl_methods.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,15 +2039,20 @@ where
20392039
/// Elements are visited in arbitrary order.
20402040
///
20412041
/// ```
2042+
/// use approx::assert_abs_diff_eq;
20422043
/// use ndarray::arr2;
20432044
///
2045+
/// # #[cfg(feature = "approx")] {
20442046
/// let mut a = arr2(&[[ 0., 1.],
20452047
/// [-1., 2.]]);
20462048
/// a.mapv_inplace(f32::exp);
2047-
/// assert!(
2048-
/// a.all_close(&arr2(&[[1.00000, 2.71828],
2049-
/// [0.36788, 7.38906]]), 1e-5)
2049+
/// assert_abs_diff_eq!(
2050+
/// a,
2051+
/// arr2(&[[1.00000, 2.71828],
2052+
/// [0.36788, 7.38906]]),
2053+
/// epsilon = 1e-5,
20502054
/// );
2055+
/// # }
20512056
/// ```
20522057
pub fn mapv_inplace<F>(&mut self, mut f: F)
20532058
where S: DataMut,

src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
//! - `rayon`
6969
//! - Optional, compatible with Rust stable
7070
//! - Enables parallel iterators, parallelized methods and [`par_azip!`].
71+
//! - `approx`
72+
//! - Optional, compatible with Rust stable
73+
//! - Enables implementations of traits from the [`approx`] crate.
7174
//! - `blas`
7275
//! - Optional and experimental, compatible with Rust stable
7376
//! - Enable transparent BLAS support for matrix multiplication.
@@ -90,6 +93,9 @@ extern crate serde;
9093
#[cfg(feature="rayon")]
9194
extern crate rayon;
9295

96+
#[cfg(feature="approx")]
97+
extern crate approx;
98+
9399
#[cfg(feature="blas")]
94100
extern crate cblas_sys;
95101
#[cfg(feature="blas")]
@@ -146,6 +152,8 @@ mod aliases;
146152
mod arraytraits;
147153
#[cfg(feature = "serde-1")]
148154
mod array_serde;
155+
#[cfg(feature = "approx")]
156+
mod array_approx;
149157
mod arrayformat;
150158
mod data_traits;
151159

src/logspace.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,24 @@ where
9696
#[cfg(test)]
9797
mod tests {
9898
use super::logspace;
99-
use crate::{arr1, Array1};
10099

101100
#[test]
101+
#[cfg(feature = "approx")]
102102
fn valid() {
103+
use approx::assert_abs_diff_eq;
104+
use crate::{arr1, Array1};
105+
103106
let array: Array1<_> = logspace(10.0, 0.0, 3.0, 4).collect();
104-
assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
107+
assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]));
105108

106109
let array: Array1<_> = logspace(10.0, 3.0, 0.0, 4).collect();
107-
assert!(array.all_close(&arr1(&[1e3, 1e2, 1e1, 1e0]), 1e-5));
110+
assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]));
108111

109112
let array: Array1<_> = logspace(-10.0, 3.0, 0.0, 4).collect();
110-
assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
113+
assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]));
111114

112115
let array: Array1<_> = logspace(-10.0, 0.0, 3.0, 4).collect();
113-
assert!(array.all_close(&arr1(&[-1e0, -1e1, -1e2, -1e3]), 1e-5));
116+
assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]));
114117
}
115118

116119
#[test]

src/numeric/impl_numeric.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ impl<A, S, D> ArrayBase<S, D>
306306
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
307307
///
308308
/// **Panics** if broadcasting to the same shape isn’t possible.
309+
#[deprecated(note="Use `abs_diff_eq` - it requires the `approx` crate feature", since="0.13")]
309310
pub fn all_close<S2, E>(&self, rhs: &ArrayBase<S2, E>, tol: A) -> bool
310311
where A: Float,
311312
S2: Data<Elem=A>,

tests/array-construct.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ fn test_dimension_zero() {
2121
}
2222

2323
#[test]
24+
#[cfg(feature = "approx")]
2425
fn test_arc_into_owned() {
26+
use approx::assert_abs_diff_ne;
27+
2528
let a = Array2::from_elem((5, 5), 1.).into_shared();
2629
let mut b = a.clone();
2730
b.fill(0.);
2831
let mut c = b.into_owned();
2932
c.fill(2.);
3033
// test that they are unshared
31-
assert!(!a.all_close(&c, 0.01));
34+
assert_abs_diff_ne!(a, c, epsilon = 0.01);
3235
}
3336

3437
#[test]

0 commit comments

Comments
 (0)