Skip to content

Commit fdaf03a

Browse files
authored
Merge pull request #15 from jmetz/otsu
Added working otsu and mean thresholding
2 parents f917fe4 + b96faf3 commit fdaf03a

File tree

4 files changed

+279
-0
lines changed

4 files changed

+279
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66
* Padding strategies (`NoPadding`, `ConstantPadding`, `ZeroPadding`)
7+
* Threshold module with Otsu and Mean threshold algorithms
78

89
### Changed
910
* Integrated Padding strategies into convolutions

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ num-traits = "0.2"
1919
[dev-dependencies]
2020
ndarray-rand = "0.9.0"
2121
rand = "0.6.5"
22+
assert_approx_eq = "1.1.0"
23+
noisy_float = "0.1.11"

src/processing/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ pub mod filter;
88
pub mod kernels;
99
/// Sobel operator for edge detection
1010
pub mod sobel;
11+
/// Thresholding functions
12+
pub mod threshold;
1113

1214
pub use canny::*;
1315
pub use conv::*;
1416
pub use filter::*;
1517
pub use kernels::*;
1618
pub use sobel::*;
19+
pub use threshold::*;
1720

1821
/// Common error type for image processing algorithms
1922
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]

src/processing/threshold.rs

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
use crate::core::{ColourModel, Image};
2+
use crate::core::{PixelBound};
3+
use crate::processing::*;
4+
use ndarray::prelude::*;
5+
use num_traits::cast::{FromPrimitive};
6+
use num_traits::cast::{ToPrimitive};
7+
use ndarray_stats::QuantileExt;
8+
use ndarray_stats::HistogramExt;
9+
use ndarray_stats::histogram::{Grid, Bins, Edges};
10+
use num_traits::{Num, NumAssignOps};
11+
use std::marker::PhantomData;
12+
13+
// Development
14+
#[cfg(test)]
15+
use assert_approx_eq::assert_approx_eq;
16+
#[cfg(test)]
17+
use noisy_float::types::n64;
18+
19+
20+
/// Runs the Otsu Thresholding algorithm on a type T
21+
pub trait ThresholdOtsuExt<T> {
22+
/// Output type, this is different as Otsu outputs a binary image
23+
type Output;
24+
25+
/// Run the Otsu threshold detection algorithm with the
26+
/// given parameters. Due to Otsu being specified as working
27+
/// on greyscale images all current implementations
28+
/// assume a single channel image returning an error otherwise.
29+
fn threshold_otsu(&self) -> Result<Self::Output, Error>;
30+
}
31+
32+
/// Runs the Mean Thresholding algorithm on a type T
33+
pub trait ThresholdMeanExt<T> {
34+
/// Output type, this is different as Otsu outputs a binary image
35+
type Output;
36+
37+
/// Run the Otsu threshold detection algorithm with the
38+
/// given parameters. Due to Otsu being specified as working
39+
/// on greyscale images all current implementations
40+
/// assume a single channel image returning an error otherwise.
41+
fn threshold_mean(&self) -> Result<Self::Output, Error>;
42+
}
43+
44+
impl<T, C> ThresholdOtsuExt<T> for Image<T, C>
45+
where
46+
Image<T, C>: Clone,
47+
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
48+
C: ColourModel,
49+
{
50+
type Output = Image<bool, C>;
51+
52+
fn threshold_otsu(&self) -> Result<Self::Output, Error> {
53+
let data = self.data.threshold_otsu()?;
54+
Ok(Self::Output {
55+
data,
56+
model: PhantomData,
57+
})
58+
}
59+
}
60+
61+
impl<T> ThresholdOtsuExt<T> for Array3<T>
62+
where
63+
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive
64+
{
65+
type Output = Array3<bool>;
66+
67+
fn threshold_otsu(&self) -> Result<Self::Output, Error> {
68+
if self.shape()[2] > 1 {
69+
Err(Error::ChannelDimensionMismatch)
70+
} else {
71+
let value = calculate_threshold_otsu(&self)?;
72+
let mask = apply_threshold(self, value);
73+
Ok(mask)
74+
}
75+
}
76+
}
77+
78+
///
79+
/// Calculates Otsu's threshold
80+
/// Works per channel, but currently
81+
/// assumes grayscale (see the error above if number of channels is > 1
82+
/// i.e. single channel; otherwise we need to output all 3 threshold values).
83+
/// Todo: Add optional nbins
84+
///
85+
fn calculate_threshold_otsu<T>(mat: &Array3<T>) -> Result<f64, Error>
86+
where
87+
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive
88+
{
89+
let mut threshold = 0.0;
90+
let n_bins = 255;
91+
for c in mat.axis_iter(Axis(2)) {
92+
let scale_factor = (n_bins) as f64
93+
/(c.max().unwrap().to_f64().unwrap());
94+
let edges_vec: Vec<u8> = (0..n_bins).collect();
95+
let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]);
96+
97+
// get the histogram
98+
let flat = Array::from_iter(c.iter()).insert_axis(Axis(1));
99+
let flat2 = flat.mapv(
100+
|x| ((*x).to_f64().unwrap() * scale_factor).to_u8().unwrap()
101+
);
102+
let hist = flat2.histogram(grid);
103+
// Straight out of wikipedia:
104+
let counts = hist.counts();
105+
let total = counts.sum().to_f64().unwrap();
106+
let counts = Array::from_iter(counts.iter());
107+
// NOTE: Could use the cdf generation for skimage-esque implementation
108+
// which entails a cumulative sum of the standard histogram
109+
let mut sum_b = 0.0;
110+
let mut weight_b = 0.0;
111+
let mut maximum = 0.0;
112+
let mut level = 0.0;
113+
let mut sum_intensity = 0.0;
114+
for (index, count) in counts.indexed_iter(){
115+
sum_intensity += (index as f64) * (*count).to_f64().unwrap();
116+
}
117+
for (index, count) in counts.indexed_iter(){
118+
weight_b = weight_b + count.to_f64().unwrap();
119+
sum_b = sum_b + (index as f64)* count.to_f64().unwrap();
120+
let weight_f = total - weight_b;
121+
if (weight_b > 0.0) && (weight_f > 0.0){
122+
let mean_f = (sum_intensity - sum_b) / weight_f;
123+
let val = weight_b * weight_f
124+
* ((sum_b / weight_b) - mean_f)
125+
* ((sum_b / weight_b) - mean_f);
126+
if val > maximum{
127+
level = 1.0 + (index as f64);
128+
maximum = val;
129+
}
130+
}
131+
}
132+
threshold = level as f64 / scale_factor;
133+
}
134+
Ok(threshold)
135+
}
136+
137+
impl<T, C> ThresholdMeanExt<T> for Image<T, C>
138+
where
139+
Image<T, C>: Clone,
140+
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
141+
C: ColourModel,
142+
{
143+
type Output = Image<bool, C>;
144+
145+
fn threshold_mean(&self) -> Result<Self::Output, Error> {
146+
let data = self.data.threshold_mean()?;
147+
Ok(Self::Output {
148+
data,
149+
model: PhantomData,
150+
})
151+
}
152+
}
153+
154+
impl<T> ThresholdMeanExt<T> for Array3<T>
155+
where
156+
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive
157+
{
158+
type Output = Array3<bool>;
159+
160+
fn threshold_mean(&self) -> Result<Self::Output, Error> {
161+
if self.shape()[2] > 1 {
162+
Err(Error::ChannelDimensionMismatch)
163+
} else {
164+
let value = calculate_threshold_mean(&self)?;
165+
let mask = apply_threshold(self, value);
166+
Ok(mask)
167+
}
168+
}
169+
}
170+
171+
fn calculate_threshold_mean<T>(array: &Array3<T>) -> Result<f64, Error>
172+
where
173+
T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive
174+
{
175+
Ok(array.sum().to_f64().unwrap() / array.len() as f64)
176+
}
177+
178+
179+
fn apply_threshold<T>(data: &Array3<T>, threshold: f64) -> Array3<bool>
180+
where
181+
T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
182+
{
183+
let result = data.mapv(|x| x.to_f64().unwrap() >= threshold);
184+
result
185+
}
186+
187+
188+
#[cfg(test)]
189+
mod tests {
190+
use super::*;
191+
use ndarray::arr3;
192+
193+
#[test]
194+
fn threshold_apply_threshold() {
195+
let data = arr3(&[
196+
[[0.2], [0.4], [0.0]],
197+
[[0.7], [0.5], [0.8]],
198+
[[0.1], [0.6], [0.0]],
199+
]);
200+
201+
let expected = arr3(&[
202+
[[false], [false], [false]],
203+
[[true], [true], [true]],
204+
[[false], [true], [false]],
205+
]);
206+
207+
let result = apply_threshold(&data, 0.5);
208+
209+
assert_eq!(result, expected);
210+
}
211+
212+
#[test]
213+
fn threshold_calculate_threshold_otsu_ints() {
214+
let data = arr3(&[
215+
[[2], [4], [0]],
216+
[[7], [5], [8]],
217+
[[1], [6], [0]],
218+
]);
219+
let result = calculate_threshold_otsu(&data).unwrap();
220+
println!("Done {}", result);
221+
222+
// Calculated using Python's skimage.filters.threshold_otsu
223+
// on int input array. Float array returns 2.0156...
224+
let expected = 2.0;
225+
226+
assert_approx_eq!(result, expected, 5e-1);
227+
}
228+
229+
#[test]
230+
fn threshold_calculate_threshold_otsu_floats() {
231+
let data = arr3(&[
232+
[[n64(2.0)], [n64(4.0)], [n64(0.0)]],
233+
[[n64(7.0)], [n64(5.0)], [n64(8.0)]],
234+
[[n64(1.0)], [n64(6.0)], [n64(0.0)]],
235+
]);
236+
237+
let result = calculate_threshold_otsu(&data).unwrap();
238+
239+
// Calculated using Python's skimage.filters.threshold_otsu
240+
// on int input array. Float array returns 2.0156...
241+
let expected = 2.0156;
242+
243+
assert_approx_eq!(result, expected, 5e-1);
244+
}
245+
246+
#[test]
247+
fn threshold_calculate_threshold_mean_ints() {
248+
let data = arr3(&[
249+
[[4], [4], [4]],
250+
[[5], [5], [5]],
251+
[[6], [6], [6]],
252+
]);
253+
254+
let result = calculate_threshold_mean(&data).unwrap();
255+
let expected = 5.0;
256+
257+
assert_approx_eq!(result, expected, 1e-16);
258+
}
259+
260+
#[test]
261+
fn threshold_calculate_threshold_mean_floats() {
262+
let data = arr3(&[
263+
[[4.0], [4.0], [4.0]],
264+
[[5.0], [5.0], [5.0]],
265+
[[6.0], [6.0], [6.0]],
266+
]);
267+
268+
let result = calculate_threshold_mean(&data).unwrap();
269+
let expected = 5.0;
270+
271+
assert_approx_eq!(result, expected, 1e-16);
272+
}
273+
}

0 commit comments

Comments
 (0)