|
| 1 | +use crate::core::{ColourModel, Image}; |
| 2 | +use crate::core::{PixelBound}; |
| 3 | +use crate::processing::*; |
| 4 | +use ndarray::prelude::*; |
| 5 | +use num_traits::cast::{FromPrimitive}; |
| 6 | +use num_traits::cast::{ToPrimitive}; |
| 7 | +use ndarray_stats::QuantileExt; |
| 8 | +use ndarray_stats::HistogramExt; |
| 9 | +use ndarray_stats::histogram::{Grid, Bins, Edges}; |
| 10 | +use num_traits::{Num, NumAssignOps}; |
| 11 | +use std::marker::PhantomData; |
| 12 | + |
| 13 | +// Development |
| 14 | +#[cfg(test)] |
| 15 | +use assert_approx_eq::assert_approx_eq; |
| 16 | +#[cfg(test)] |
| 17 | +use noisy_float::types::n64; |
| 18 | + |
| 19 | + |
| 20 | +/// Runs the Otsu Thresholding algorithm on a type T |
| 21 | +pub trait ThresholdOtsuExt<T> { |
| 22 | + /// Output type, this is different as Otsu outputs a binary image |
| 23 | + type Output; |
| 24 | + |
| 25 | + /// Run the Otsu threshold detection algorithm with the |
| 26 | + /// given parameters. Due to Otsu being specified as working |
| 27 | + /// on greyscale images all current implementations |
| 28 | + /// assume a single channel image returning an error otherwise. |
| 29 | + fn threshold_otsu(&self) -> Result<Self::Output, Error>; |
| 30 | +} |
| 31 | + |
| 32 | +/// Runs the Mean Thresholding algorithm on a type T |
| 33 | +pub trait ThresholdMeanExt<T> { |
| 34 | + /// Output type, this is different as Otsu outputs a binary image |
| 35 | + type Output; |
| 36 | + |
| 37 | + /// Run the Otsu threshold detection algorithm with the |
| 38 | + /// given parameters. Due to Otsu being specified as working |
| 39 | + /// on greyscale images all current implementations |
| 40 | + /// assume a single channel image returning an error otherwise. |
| 41 | + fn threshold_mean(&self) -> Result<Self::Output, Error>; |
| 42 | +} |
| 43 | + |
| 44 | +impl<T, C> ThresholdOtsuExt<T> for Image<T, C> |
| 45 | +where |
| 46 | + Image<T, C>: Clone, |
| 47 | + T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound, |
| 48 | + C: ColourModel, |
| 49 | +{ |
| 50 | + type Output = Image<bool, C>; |
| 51 | + |
| 52 | + fn threshold_otsu(&self) -> Result<Self::Output, Error> { |
| 53 | + let data = self.data.threshold_otsu()?; |
| 54 | + Ok(Self::Output { |
| 55 | + data, |
| 56 | + model: PhantomData, |
| 57 | + }) |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +impl<T> ThresholdOtsuExt<T> for Array3<T> |
| 62 | +where |
| 63 | + T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive |
| 64 | +{ |
| 65 | + type Output = Array3<bool>; |
| 66 | + |
| 67 | + fn threshold_otsu(&self) -> Result<Self::Output, Error> { |
| 68 | + if self.shape()[2] > 1 { |
| 69 | + Err(Error::ChannelDimensionMismatch) |
| 70 | + } else { |
| 71 | + let value = calculate_threshold_otsu(&self)?; |
| 72 | + let mask = apply_threshold(self, value); |
| 73 | + Ok(mask) |
| 74 | + } |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +/// |
| 79 | +/// Calculates Otsu's threshold |
| 80 | +/// Works per channel, but currently |
| 81 | +/// assumes grayscale (see the error above if number of channels is > 1 |
| 82 | +/// i.e. single channel; otherwise we need to output all 3 threshold values). |
| 83 | +/// Todo: Add optional nbins |
| 84 | +/// |
| 85 | +fn calculate_threshold_otsu<T>(mat: &Array3<T>) -> Result<f64, Error> |
| 86 | +where |
| 87 | + T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive |
| 88 | +{ |
| 89 | + let mut threshold = 0.0; |
| 90 | + let n_bins = 255; |
| 91 | + for c in mat.axis_iter(Axis(2)) { |
| 92 | + let scale_factor = (n_bins) as f64 |
| 93 | + /(c.max().unwrap().to_f64().unwrap()); |
| 94 | + let edges_vec: Vec<u8> = (0..n_bins).collect(); |
| 95 | + let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]); |
| 96 | + |
| 97 | + // get the histogram |
| 98 | + let flat = Array::from_iter(c.iter()).insert_axis(Axis(1)); |
| 99 | + let flat2 = flat.mapv( |
| 100 | + |x| ((*x).to_f64().unwrap() * scale_factor).to_u8().unwrap() |
| 101 | + ); |
| 102 | + let hist = flat2.histogram(grid); |
| 103 | + // Straight out of wikipedia: |
| 104 | + let counts = hist.counts(); |
| 105 | + let total = counts.sum().to_f64().unwrap(); |
| 106 | + let counts = Array::from_iter(counts.iter()); |
| 107 | + // NOTE: Could use the cdf generation for skimage-esque implementation |
| 108 | + // which entails a cumulative sum of the standard histogram |
| 109 | + let mut sum_b = 0.0; |
| 110 | + let mut weight_b = 0.0; |
| 111 | + let mut maximum = 0.0; |
| 112 | + let mut level = 0.0; |
| 113 | + let mut sum_intensity = 0.0; |
| 114 | + for (index, count) in counts.indexed_iter(){ |
| 115 | + sum_intensity += (index as f64) * (*count).to_f64().unwrap(); |
| 116 | + } |
| 117 | + for (index, count) in counts.indexed_iter(){ |
| 118 | + weight_b = weight_b + count.to_f64().unwrap(); |
| 119 | + sum_b = sum_b + (index as f64)* count.to_f64().unwrap(); |
| 120 | + let weight_f = total - weight_b; |
| 121 | + if (weight_b > 0.0) && (weight_f > 0.0){ |
| 122 | + let mean_f = (sum_intensity - sum_b) / weight_f; |
| 123 | + let val = weight_b * weight_f |
| 124 | + * ((sum_b / weight_b) - mean_f) |
| 125 | + * ((sum_b / weight_b) - mean_f); |
| 126 | + if val > maximum{ |
| 127 | + level = 1.0 + (index as f64); |
| 128 | + maximum = val; |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | + threshold = level as f64 / scale_factor; |
| 133 | + } |
| 134 | + Ok(threshold) |
| 135 | +} |
| 136 | + |
| 137 | +impl<T, C> ThresholdMeanExt<T> for Image<T, C> |
| 138 | +where |
| 139 | + Image<T, C>: Clone, |
| 140 | + T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound, |
| 141 | + C: ColourModel, |
| 142 | +{ |
| 143 | + type Output = Image<bool, C>; |
| 144 | + |
| 145 | + fn threshold_mean(&self) -> Result<Self::Output, Error> { |
| 146 | + let data = self.data.threshold_mean()?; |
| 147 | + Ok(Self::Output { |
| 148 | + data, |
| 149 | + model: PhantomData, |
| 150 | + }) |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +impl<T> ThresholdMeanExt<T> for Array3<T> |
| 155 | +where |
| 156 | + T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive |
| 157 | +{ |
| 158 | + type Output = Array3<bool>; |
| 159 | + |
| 160 | + fn threshold_mean(&self) -> Result<Self::Output, Error> { |
| 161 | + if self.shape()[2] > 1 { |
| 162 | + Err(Error::ChannelDimensionMismatch) |
| 163 | + } else { |
| 164 | + let value = calculate_threshold_mean(&self)?; |
| 165 | + let mask = apply_threshold(self, value); |
| 166 | + Ok(mask) |
| 167 | + } |
| 168 | + } |
| 169 | +} |
| 170 | + |
| 171 | +fn calculate_threshold_mean<T>(array: &Array3<T>) -> Result<f64, Error> |
| 172 | +where |
| 173 | + T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive |
| 174 | +{ |
| 175 | + Ok(array.sum().to_f64().unwrap() / array.len() as f64) |
| 176 | +} |
| 177 | + |
| 178 | + |
| 179 | +fn apply_threshold<T>(data: &Array3<T>, threshold: f64) -> Array3<bool> |
| 180 | +where |
| 181 | + T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive, |
| 182 | +{ |
| 183 | + let result = data.mapv(|x| x.to_f64().unwrap() >= threshold); |
| 184 | + result |
| 185 | +} |
| 186 | + |
| 187 | + |
| 188 | +#[cfg(test)] |
| 189 | +mod tests { |
| 190 | + use super::*; |
| 191 | + use ndarray::arr3; |
| 192 | + |
| 193 | + #[test] |
| 194 | + fn threshold_apply_threshold() { |
| 195 | + let data = arr3(&[ |
| 196 | + [[0.2], [0.4], [0.0]], |
| 197 | + [[0.7], [0.5], [0.8]], |
| 198 | + [[0.1], [0.6], [0.0]], |
| 199 | + ]); |
| 200 | + |
| 201 | + let expected = arr3(&[ |
| 202 | + [[false], [false], [false]], |
| 203 | + [[true], [true], [true]], |
| 204 | + [[false], [true], [false]], |
| 205 | + ]); |
| 206 | + |
| 207 | + let result = apply_threshold(&data, 0.5); |
| 208 | + |
| 209 | + assert_eq!(result, expected); |
| 210 | + } |
| 211 | + |
| 212 | + #[test] |
| 213 | + fn threshold_calculate_threshold_otsu_ints() { |
| 214 | + let data = arr3(&[ |
| 215 | + [[2], [4], [0]], |
| 216 | + [[7], [5], [8]], |
| 217 | + [[1], [6], [0]], |
| 218 | + ]); |
| 219 | + let result = calculate_threshold_otsu(&data).unwrap(); |
| 220 | + println!("Done {}", result); |
| 221 | + |
| 222 | + // Calculated using Python's skimage.filters.threshold_otsu |
| 223 | + // on int input array. Float array returns 2.0156... |
| 224 | + let expected = 2.0; |
| 225 | + |
| 226 | + assert_approx_eq!(result, expected, 5e-1); |
| 227 | + } |
| 228 | + |
| 229 | + #[test] |
| 230 | + fn threshold_calculate_threshold_otsu_floats() { |
| 231 | + let data = arr3(&[ |
| 232 | + [[n64(2.0)], [n64(4.0)], [n64(0.0)]], |
| 233 | + [[n64(7.0)], [n64(5.0)], [n64(8.0)]], |
| 234 | + [[n64(1.0)], [n64(6.0)], [n64(0.0)]], |
| 235 | + ]); |
| 236 | + |
| 237 | + let result = calculate_threshold_otsu(&data).unwrap(); |
| 238 | + |
| 239 | + // Calculated using Python's skimage.filters.threshold_otsu |
| 240 | + // on int input array. Float array returns 2.0156... |
| 241 | + let expected = 2.0156; |
| 242 | + |
| 243 | + assert_approx_eq!(result, expected, 5e-1); |
| 244 | + } |
| 245 | + |
| 246 | + #[test] |
| 247 | + fn threshold_calculate_threshold_mean_ints() { |
| 248 | + let data = arr3(&[ |
| 249 | + [[4], [4], [4]], |
| 250 | + [[5], [5], [5]], |
| 251 | + [[6], [6], [6]], |
| 252 | + ]); |
| 253 | + |
| 254 | + let result = calculate_threshold_mean(&data).unwrap(); |
| 255 | + let expected = 5.0; |
| 256 | + |
| 257 | + assert_approx_eq!(result, expected, 1e-16); |
| 258 | + } |
| 259 | + |
| 260 | + #[test] |
| 261 | + fn threshold_calculate_threshold_mean_floats() { |
| 262 | + let data = arr3(&[ |
| 263 | + [[4.0], [4.0], [4.0]], |
| 264 | + [[5.0], [5.0], [5.0]], |
| 265 | + [[6.0], [6.0], [6.0]], |
| 266 | + ]); |
| 267 | + |
| 268 | + let result = calculate_threshold_mean(&data).unwrap(); |
| 269 | + let expected = 5.0; |
| 270 | + |
| 271 | + assert_approx_eq!(result, expected, 1e-16); |
| 272 | + } |
| 273 | +} |
0 commit comments