Skip to content

Commit e6a0e93

Browse files
committed
Make geomspace return None instead of panicking
This helps to make the user more aware of the constraints on `geomspace` and allows recovering in case of bad input.
1 parent 97def32 commit e6a0e93

File tree

2 files changed

+30
-31
lines changed

2 files changed

+30
-31
lines changed

src/geomspace.rs

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,17 @@ impl<F> ExactSizeIterator for Geomspace<F> where Geomspace<F>: Iterator {}
7171
///
7272
/// Iterator element type is `F`, where `F` must be either `f32` or `f64`.
7373
///
74-
/// **Panics** if the interval `[a, b]` contains zero (including the end points).
74+
/// Returns `None` if `start` and `end` have different signs or if either one
75+
/// is zero. Conceptually, this means that in order to obtain a `Some` result,
76+
/// `end / start` must be positive.
7577
#[inline]
76-
pub fn geomspace<F>(a: F, b: F, n: usize) -> Geomspace<F>
78+
pub fn geomspace<F>(a: F, b: F, n: usize) -> Option<Geomspace<F>>
7779
where
7880
F: Float,
7981
{
80-
assert!(
81-
a != F::zero() && b != F::zero(),
82-
"Start and/or end of geomspace cannot be zero.",
83-
);
84-
assert!(
85-
a.is_sign_negative() == b.is_sign_negative(),
86-
"Logarithmic interval cannot cross 0."
87-
);
88-
82+
if a == F::zero() || b == F::zero() || a.is_sign_negative() != b.is_sign_negative() {
83+
return None;
84+
}
8985
let log_a = a.abs().ln();
9086
let log_b = b.abs().ln();
9187
let step = if n > 1 {
@@ -94,13 +90,13 @@ where
9490
} else {
9591
F::zero()
9692
};
97-
Geomspace {
93+
Some(Geomspace {
9894
sign: a.signum(),
9995
start: log_a,
10096
step: step,
10197
index: 0,
10298
len: n,
103-
}
99+
})
104100
}
105101

106102
#[cfg(test)]
@@ -113,22 +109,22 @@ mod tests {
113109
use approx::assert_abs_diff_eq;
114110
use crate::{arr1, Array1};
115111

116-
let array: Array1<_> = geomspace(1e0, 1e3, 4).collect();
112+
let array: Array1<_> = geomspace(1e0, 1e3, 4).unwrap().collect();
117113
assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);
118114

119-
let array: Array1<_> = geomspace(1e3, 1e0, 4).collect();
115+
let array: Array1<_> = geomspace(1e3, 1e0, 4).unwrap().collect();
120116
assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]), epsilon = 1e-12);
121117

122-
let array: Array1<_> = geomspace(-1e3, -1e0, 4).collect();
118+
let array: Array1<_> = geomspace(-1e3, -1e0, 4).unwrap().collect();
123119
assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);
124120

125-
let array: Array1<_> = geomspace(-1e0, -1e3, 4).collect();
121+
let array: Array1<_> = geomspace(-1e0, -1e3, 4).unwrap().collect();
126122
assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]), epsilon = 1e-12);
127123
}
128124

129125
#[test]
130126
fn iter_forward() {
131-
let mut iter = geomspace(1.0f64, 1e3, 4);
127+
let mut iter = geomspace(1.0f64, 1e3, 4).unwrap();
132128

133129
assert!(iter.size_hint() == (4, Some(4)));
134130

@@ -143,7 +139,7 @@ mod tests {
143139

144140
#[test]
145141
fn iter_backward() {
146-
let mut iter = geomspace(1.0f64, 1e3, 4);
142+
let mut iter = geomspace(1.0f64, 1e3, 4).unwrap();
147143

148144
assert!(iter.size_hint() == (4, Some(4)));
149145

@@ -157,20 +153,17 @@ mod tests {
157153
}
158154

159155
#[test]
160-
#[should_panic]
161156
fn zero_lower() {
162-
geomspace(0.0, 1.0, 4);
157+
assert!(geomspace(0.0, 1.0, 4).is_none());
163158
}
164159

165160
#[test]
166-
#[should_panic]
167161
fn zero_upper() {
168-
geomspace(1.0, 0.0, 4);
162+
assert!(geomspace(1.0, 0.0, 4).is_none());
169163
}
170164

171165
#[test]
172-
#[should_panic]
173166
fn zero_included() {
174-
geomspace(-1.0, 1.0, 4);
167+
assert!(geomspace(-1.0, 1.0, 4).is_none());
175168
}
176169
}

src/impl_constructors.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,28 +133,34 @@ impl<S, A> ArrayBase<S, Ix1>
133133
/// end]` with `n` elements geometrically spaced. `A` must be a floating
134134
/// point type.
135135
///
136-
/// The interval can be either all positive or all negative; however, it
137-
/// cannot contain 0 (including the end points).
136+
/// Returns `None` if `start` and `end` have different signs or if either
137+
/// one is zero. Conceptually, this means that in order to obtain a `Some`
138+
/// result, `end / start` must be positive.
138139
///
139140
/// **Panics** if `n` is greater than `isize::MAX`.
140141
///
141142
/// ```rust
142143
/// use approx::assert_abs_diff_eq;
143144
/// use ndarray::{Array, arr1};
144145
///
146+
/// # fn example() -> Option<()> {
145147
/// # #[cfg(feature = "approx")] {
146-
/// let array = Array::geomspace(1e0, 1e3, 4);
148+
/// let array = Array::geomspace(1e0, 1e3, 4)?;
147149
/// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);
148150
///
149-
/// let array = Array::geomspace(-1e3, -1e0, 4);
151+
/// let array = Array::geomspace(-1e3, -1e0, 4)?;
150152
/// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);
151153
/// # }
154+
/// # Some(())
155+
/// # }
156+
/// #
157+
/// # fn main() { example().unwrap() }
152158
/// ```
153-
pub fn geomspace(start: A, end: A, n: usize) -> Self
159+
pub fn geomspace(start: A, end: A, n: usize) -> Option<Self>
154160
where
155161
A: Float,
156162
{
157-
Self::from_vec(to_vec(geomspace::geomspace(start, end, n)))
163+
Some(Self::from_vec(to_vec(geomspace::geomspace(start, end, n)?)))
158164
}
159165
}
160166

0 commit comments

Comments
 (0)