Skip to content

Commit e6426a6

Browse files
phunglesonLukeMathWalker
authored andcommitted
argmin and argmax (#30)
1 parent 8e0e1cf commit e6426a6

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

src/quantile.rs

+84
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,31 @@ where
182182
S: Data<Elem = A>,
183183
D: Dimension,
184184
{
185+
/// Finds the first index of the minimum value of the array.
186+
///
187+
/// Returns `None` if any of the pairwise orderings tested by the function
188+
/// are undefined. (For example, this occurs if there are any
189+
/// floating-point NaN values in the array.)
190+
///
191+
/// Returns `None` if the array is empty.
192+
///
193+
/// # Example
194+
///
195+
/// ```
196+
/// extern crate ndarray;
197+
/// extern crate ndarray_stats;
198+
///
199+
/// use ndarray::array;
200+
/// use ndarray_stats::QuantileExt;
201+
///
202+
/// let a = array![[1., 3., 5.],
203+
/// [2., 0., 6.]];
204+
/// assert_eq!(a.argmin(), Some((1, 1)));
205+
/// ```
206+
fn argmin(&self) -> Option<D::Pattern>
207+
where
208+
A: PartialOrd;
209+
185210
/// Finds the elementwise minimum of the array.
186211
///
187212
/// Returns `None` if any of the pairwise orderings tested by the function
@@ -203,6 +228,31 @@ where
203228
A: MaybeNan,
204229
A::NotNan: Ord;
205230

231+
/// Finds the first index of the maximum value of the array.
232+
///
233+
/// Returns `None` if any of the pairwise orderings tested by the function
234+
/// are undefined. (For example, this occurs if there are any
235+
/// floating-point NaN values in the array.)
236+
///
237+
/// Returns `None` if the array is empty.
238+
///
239+
/// # Example
240+
///
241+
/// ```
242+
/// extern crate ndarray;
243+
/// extern crate ndarray_stats;
244+
///
245+
/// use ndarray::array;
246+
/// use ndarray_stats::QuantileExt;
247+
///
248+
/// let a = array![[1., 3., 7.],
249+
/// [2., 5., 6.]];
250+
/// assert_eq!(a.argmax(), Some((0, 2)));
251+
/// ```
252+
fn argmax(&self) -> Option<D::Pattern>
253+
where
254+
A: PartialOrd;
255+
206256
/// Finds the elementwise maximum of the array.
207257
///
208258
/// Returns `None` if any of the pairwise orderings tested by the function
@@ -278,6 +328,23 @@ where
278328
S: Data<Elem = A>,
279329
D: Dimension,
280330
{
331+
fn argmin(&self) -> Option<D::Pattern>
332+
where
333+
A: PartialOrd,
334+
{
335+
let mut current_min = self.first()?;
336+
let mut current_pattern_min = D::zeros(self.ndim()).into_pattern();
337+
338+
for (pattern, elem) in self.indexed_iter() {
339+
if elem.partial_cmp(current_min)? == cmp::Ordering::Less {
340+
current_pattern_min = pattern;
341+
current_min = elem
342+
}
343+
}
344+
345+
Some(current_pattern_min)
346+
}
347+
281348
fn min(&self) -> Option<&A>
282349
where
283350
A: PartialOrd,
@@ -303,6 +370,23 @@ where
303370
}))
304371
}
305372

373+
fn argmax(&self) -> Option<D::Pattern>
374+
where
375+
A: PartialOrd,
376+
{
377+
let mut current_max = self.first()?;
378+
let mut current_pattern_max = D::zeros(self.ndim()).into_pattern();
379+
380+
for (pattern, elem) in self.indexed_iter() {
381+
if elem.partial_cmp(current_max)? == cmp::Ordering::Greater {
382+
current_pattern_max = pattern;
383+
current_max = elem
384+
}
385+
}
386+
387+
Some(current_pattern_max)
388+
}
389+
306390
fn max(&self) -> Option<&A>
307391
where
308392
A: PartialOrd,

tests/quantile.rs

+36
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ use ndarray_stats::{
88
Quantile1dExt, QuantileExt,
99
};
1010

11+
#[test]
12+
fn test_argmin() {
13+
let a = array![[1, 5, 3], [2, 0, 6]];
14+
assert_eq!(a.argmin(), Some((1, 1)));
15+
16+
let a = array![[1., 5., 3.], [2., 0., 6.]];
17+
assert_eq!(a.argmin(), Some((1, 1)));
18+
19+
let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
20+
assert_eq!(a.argmin(), None);
21+
22+
let a = array![[1, 0, 3], [2, 0, 6]];
23+
assert_eq!(a.argmin(), Some((0, 1)));
24+
25+
let a: Array2<i32> = array![[], []];
26+
assert_eq!(a.argmin(), None);
27+
}
28+
1129
#[test]
1230
fn test_min() {
1331
let a = array![[1, 5, 3], [2, 0, 6]];
@@ -35,6 +53,24 @@ fn test_min_skipnan_all_nan() {
3553
assert!(a.min_skipnan().is_nan());
3654
}
3755

56+
#[test]
57+
fn test_argmax() {
58+
let a = array![[1, 5, 3], [2, 0, 6]];
59+
assert_eq!(a.argmax(), Some((1, 2)));
60+
61+
let a = array![[1., 5., 3.], [2., 0., 6.]];
62+
assert_eq!(a.argmax(), Some((1, 2)));
63+
64+
let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
65+
assert_eq!(a.argmax(), None);
66+
67+
let a = array![[1, 5, 6], [2, 0, 6]];
68+
assert_eq!(a.argmax(), Some((0, 2)));
69+
70+
let a: Array2<i32> = array![[], []];
71+
assert_eq!(a.argmax(), None);
72+
}
73+
3874
#[test]
3975
fn test_max() {
4076
let a = array![[1, 5, 7], [2, 0, 6]];

0 commit comments

Comments
 (0)