Skip to content

Commit e862eb1

Browse files
authored
Merge pull request #8 from deliveroo/master
Merge changes from Deliveroo
2 parents 80e0e69 + ba76ff5 commit e862eb1

File tree

1 file changed

+66
-9
lines changed

1 file changed

+66
-9
lines changed

src/lib.rs

+66-9
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,50 @@ use num_traits::Float;
1414

1515
use std::fmt;
1616
use std::ops;
17+
use std::sync::atomic::{AtomicBool, Ordering};
18+
use std::sync::Arc;
1719

1820
pub type Matrix<T> = ndarray::Array2<T>;
1921

2022
pub trait LapJVCost: Float + ops::AddAssign + ops::SubAssign + std::fmt::Debug {}
2123
impl<T> LapJVCost for T where T: Float + ops::AddAssign + ops::SubAssign + std::fmt::Debug {}
2224

25+
#[derive(Debug, Copy, Clone)]
26+
pub enum ErrorKind {
27+
Msg(&'static str),
28+
Cancelled,
29+
}
30+
2331
#[derive(Debug)]
24-
pub struct LapJVError(&'static str);
32+
pub struct LapJVError {
33+
kind: ErrorKind,
34+
}
2535

26-
impl std::fmt::Display for LapJVError {
27-
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
28-
write!(f, "{}", self.0)
36+
impl LapJVError {
37+
pub fn kind(&self) -> ErrorKind {
38+
self.kind
2939
}
3040
}
3141

32-
impl std::error::Error for LapJVError {
33-
fn description(&self) -> &str {
34-
self.0
42+
impl std::fmt::Display for LapJVError {
43+
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
44+
match self.kind {
45+
ErrorKind::Msg(string) => write!(f, "{}", string),
46+
ErrorKind::Cancelled => write!(f, "cancelled"),
47+
}
3548
}
3649
}
3750

51+
impl std::error::Error for LapJVError {}
52+
3853
pub struct LapJV<'a, T: 'a> {
3954
costs: &'a Matrix<T>,
4055
dim: usize,
4156
free_rows: Vec<usize>,
4257
v: Vec<T>,
4358
in_col: Vec<usize>,
4459
in_row: Vec<usize>,
60+
cancellation: Cancellation,
4561
}
4662

4763
/// Solve LAP problem given cost matrix
@@ -65,6 +81,19 @@ where
6581
.fold(T::zero(), |acc, i| acc + input[(i, row[i])])
6682
}
6783

84+
#[derive(Clone)]
85+
pub struct Cancellation(Arc<AtomicBool>);
86+
87+
impl Cancellation {
88+
pub fn cancel(&self) {
89+
self.0.store(true, Ordering::SeqCst)
90+
}
91+
92+
pub fn is_cancelled(&self) -> bool {
93+
self.0.load(Ordering::SeqCst)
94+
}
95+
}
96+
6897
/// Solve LAP problem given cost matrix
6998
/// This is an implementation of the LAPJV algorithm described in:
7099
/// R. Jonker, A. Volgenant. A Shortest Augmenting Path Algorithm for
@@ -80,24 +109,39 @@ where
80109
let v = Vec::with_capacity(dim);
81110
let in_row = vec![0; dim];
82111
let in_col = Vec::with_capacity(dim);
112+
let cancellation = Cancellation(Default::default());
83113
Self {
84114
costs,
85115
dim,
86116
free_rows,
87117
v,
88118
in_col,
89119
in_row,
120+
cancellation
90121
}
91122
}
92123

124+
/// Returns a `Cancellation` token which can be cancelled from another thread.
125+
pub fn cancellation(&self) -> Cancellation {
126+
self.cancellation.clone()
127+
}
128+
129+
fn check_cancelled(&self) -> Result<(), LapJVError> {
130+
if self.cancellation.is_cancelled() {
131+
return Err(LapJVError { kind: ErrorKind::Cancelled });
132+
}
133+
Ok(())
134+
}
135+
93136
pub fn solve(mut self) -> Result<(Vec<usize>, Vec<usize>), LapJVError> {
94137
if self.costs.dim().0 != self.costs.dim().1 {
95-
return Err(LapJVError("Input error: matrix is not square"));
138+
return Err(LapJVError { kind: ErrorKind::Msg("Input error: matrix is not square") } );
96139
}
97140
self.ccrrt_dense();
98141

99142
let mut i = 0;
100143
while !self.free_rows.is_empty() && i < 2 {
144+
self.check_cancelled()?;
101145
self.carr_dense();
102146
i += 1;
103147
}
@@ -228,6 +272,8 @@ where
228272
for freerow in free_rows {
229273
trace!("looking at freerow={}", freerow);
230274

275+
self.check_cancelled()?;
276+
231277
let mut i = std::usize::MAX;
232278
let mut k = 0;
233279
let mut j = self.find_path_dense(freerow, &mut pred);
@@ -238,7 +284,7 @@ where
238284
std::mem::swap(&mut j, &mut self.in_row[i]);
239285
k += 1;
240286
if k > dim {
241-
return Err(LapJVError("Error: ca_dense will not finish"));
287+
return Err(LapJVError { kind: ErrorKind::Msg("Error: ca_dense will not finish") });
242288
}
243289
}
244290
}
@@ -415,6 +461,17 @@ mod tests {
415461
assert_eq!(result.1, vec![1, 2, 0]);
416462
}
417463

464+
#[test]
465+
fn cancellation() {
466+
let m = Matrix::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
467+
.unwrap();
468+
let lapjv = LapJV::new(&m);
469+
let cancellation = lapjv.cancellation();
470+
cancellation.cancel();
471+
let result = lapjv.solve();
472+
assert!(matches!(result, Err(LapJVError { kind: ErrorKind::Cancelled })));
473+
}
474+
418475
#[test]
419476
fn test_solve_random10() {
420477
let (m, result) = solve_random10();

0 commit comments

Comments
 (0)