@@ -14,34 +14,50 @@ use num_traits::Float;
14
14
15
15
use std:: fmt;
16
16
use std:: ops;
17
+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
18
+ use std:: sync:: Arc ;
17
19
18
20
pub type Matrix < T > = ndarray:: Array2 < T > ;
19
21
20
22
pub trait LapJVCost : Float + ops:: AddAssign + ops:: SubAssign + std:: fmt:: Debug { }
21
23
impl < T > LapJVCost for T where T : Float + ops:: AddAssign + ops:: SubAssign + std:: fmt:: Debug { }
22
24
25
+ #[ derive( Debug , Copy , Clone ) ]
26
+ pub enum ErrorKind {
27
+ Msg ( & ' static str ) ,
28
+ Cancelled ,
29
+ }
30
+
23
31
#[ derive( Debug ) ]
24
- pub struct LapJVError ( & ' static str ) ;
32
+ pub struct LapJVError {
33
+ kind : ErrorKind ,
34
+ }
25
35
26
- impl std :: fmt :: Display for LapJVError {
27
- fn fmt ( & self , f : & mut fmt :: Formatter ) -> Result < ( ) , fmt :: Error > {
28
- write ! ( f , "{}" , self . 0 )
36
+ impl LapJVError {
37
+ pub fn kind ( & self ) -> ErrorKind {
38
+ self . kind
29
39
}
30
40
}
31
41
32
- impl std:: error:: Error for LapJVError {
33
- fn description ( & self ) -> & str {
34
- self . 0
42
+ impl std:: fmt:: Display for LapJVError {
43
+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> Result < ( ) , fmt:: Error > {
44
+ match self . kind {
45
+ ErrorKind :: Msg ( string) => write ! ( f, "{}" , string) ,
46
+ ErrorKind :: Cancelled => write ! ( f, "cancelled" ) ,
47
+ }
35
48
}
36
49
}
37
50
51
+ impl std:: error:: Error for LapJVError { }
52
+
38
53
pub struct LapJV < ' a , T : ' a > {
39
54
costs : & ' a Matrix < T > ,
40
55
dim : usize ,
41
56
free_rows : Vec < usize > ,
42
57
v : Vec < T > ,
43
58
in_col : Vec < usize > ,
44
59
in_row : Vec < usize > ,
60
+ cancellation : Cancellation ,
45
61
}
46
62
47
63
/// Solve LAP problem given cost matrix
65
81
. fold ( T :: zero ( ) , |acc, i| acc + input[ ( i, row[ i] ) ] )
66
82
}
67
83
84
+ #[ derive( Clone ) ]
85
+ pub struct Cancellation ( Arc < AtomicBool > ) ;
86
+
87
+ impl Cancellation {
88
+ pub fn cancel ( & self ) {
89
+ self . 0 . store ( true , Ordering :: SeqCst )
90
+ }
91
+
92
+ pub fn is_cancelled ( & self ) -> bool {
93
+ self . 0 . load ( Ordering :: SeqCst )
94
+ }
95
+ }
96
+
68
97
/// Solve LAP problem given cost matrix
69
98
/// This is an implementation of the LAPJV algorithm described in:
70
99
/// R. Jonker, A. Volgenant. A Shortest Augmenting Path Algorithm for
@@ -80,24 +109,39 @@ where
80
109
let v = Vec :: with_capacity ( dim) ;
81
110
let in_row = vec ! [ 0 ; dim] ;
82
111
let in_col = Vec :: with_capacity ( dim) ;
112
+ let cancellation = Cancellation ( Default :: default ( ) ) ;
83
113
Self {
84
114
costs,
85
115
dim,
86
116
free_rows,
87
117
v,
88
118
in_col,
89
119
in_row,
120
+ cancellation
90
121
}
91
122
}
92
123
124
+ /// Returns a `Cancellation` token which can be cancelled from another thread.
125
+ pub fn cancellation ( & self ) -> Cancellation {
126
+ self . cancellation . clone ( )
127
+ }
128
+
129
+ fn check_cancelled ( & self ) -> Result < ( ) , LapJVError > {
130
+ if self . cancellation . is_cancelled ( ) {
131
+ return Err ( LapJVError { kind : ErrorKind :: Cancelled } ) ;
132
+ }
133
+ Ok ( ( ) )
134
+ }
135
+
93
136
pub fn solve ( mut self ) -> Result < ( Vec < usize > , Vec < usize > ) , LapJVError > {
94
137
if self . costs . dim ( ) . 0 != self . costs . dim ( ) . 1 {
95
- return Err ( LapJVError ( "Input error: matrix is not square" ) ) ;
138
+ return Err ( LapJVError { kind : ErrorKind :: Msg ( "Input error: matrix is not square" ) } ) ;
96
139
}
97
140
self . ccrrt_dense ( ) ;
98
141
99
142
let mut i = 0 ;
100
143
while !self . free_rows . is_empty ( ) && i < 2 {
144
+ self . check_cancelled ( ) ?;
101
145
self . carr_dense ( ) ;
102
146
i += 1 ;
103
147
}
@@ -228,6 +272,8 @@ where
228
272
for freerow in free_rows {
229
273
trace ! ( "looking at freerow={}" , freerow) ;
230
274
275
+ self . check_cancelled ( ) ?;
276
+
231
277
let mut i = std:: usize:: MAX ;
232
278
let mut k = 0 ;
233
279
let mut j = self . find_path_dense ( freerow, & mut pred) ;
@@ -238,7 +284,7 @@ where
238
284
std:: mem:: swap ( & mut j, & mut self . in_row [ i] ) ;
239
285
k += 1 ;
240
286
if k > dim {
241
- return Err ( LapJVError ( "Error: ca_dense will not finish" ) ) ;
287
+ return Err ( LapJVError { kind : ErrorKind :: Msg ( "Error: ca_dense will not finish" ) } ) ;
242
288
}
243
289
}
244
290
}
@@ -415,6 +461,17 @@ mod tests {
415
461
assert_eq ! ( result. 1 , vec![ 1 , 2 , 0 ] ) ;
416
462
}
417
463
464
+ #[ test]
465
+ fn cancellation ( ) {
466
+ let m = Matrix :: from_shape_vec ( ( 3 , 3 ) , vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 , 9.0 ] )
467
+ . unwrap ( ) ;
468
+ let lapjv = LapJV :: new ( & m) ;
469
+ let cancellation = lapjv. cancellation ( ) ;
470
+ cancellation. cancel ( ) ;
471
+ let result = lapjv. solve ( ) ;
472
+ assert ! ( matches!( result, Err ( LapJVError { kind: ErrorKind :: Cancelled } ) ) ) ;
473
+ }
474
+
418
475
#[ test]
419
476
fn test_solve_random10 ( ) {
420
477
let ( m, result) = solve_random10 ( ) ;
0 commit comments