@@ -6,6 +6,7 @@ use rayon::iter::IterBridge;
6
6
use rayon:: prelude:: * ;
7
7
use rayon_cond:: CondIterator ;
8
8
use std:: sync:: atomic:: AtomicBool ;
9
+ use std:: sync:: atomic:: AtomicU8 ;
9
10
use std:: sync:: atomic:: Ordering ;
10
11
11
12
// Re-export rayon current_num_threads
@@ -14,35 +15,54 @@ pub use rayon::current_num_threads;
14
15
pub const ENV_VARIABLE : & str = "TOKENIZERS_PARALLELISM" ;
15
16
16
17
static USED_PARALLELISM : AtomicBool = AtomicBool :: new ( false ) ;
18
+ static PARALLELISM : AtomicU8 = AtomicU8 :: new ( 0 ) ;
17
19
18
20
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
19
21
pub fn is_parallelism_configured ( ) -> bool {
20
- std:: env:: var ( ENV_VARIABLE ) . is_ok ( )
22
+ std:: env:: var ( ENV_VARIABLE ) . is_ok ( ) || get_override_parallelism ( ) . is_some ( )
21
23
}
22
24
23
25
/// Check if at some point we used a parallel iterator
24
26
pub fn has_parallelism_been_used ( ) -> bool {
25
27
USED_PARALLELISM . load ( Ordering :: SeqCst )
26
28
}
27
29
30
+ /// Get internally set parallelism
31
+ fn get_override_parallelism ( ) -> Option < bool > {
32
+ match PARALLELISM . load ( Ordering :: SeqCst ) {
33
+ 0 => None ,
34
+ 1 => Some ( false ) ,
35
+ 2 => Some ( true ) ,
36
+ _ => unreachable ! ( ) ,
37
+ }
38
+ }
39
+
28
40
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
29
- pub fn get_parallelism ( ) -> bool {
41
+ fn get_env_parallelism ( ) -> bool {
30
42
match std:: env:: var ( ENV_VARIABLE ) {
31
43
Ok ( mut v) => {
32
44
v. make_ascii_lowercase ( ) ;
33
45
!matches ! ( v. as_ref( ) , "" | "off" | "false" | "f" | "no" | "n" | "0" )
34
46
}
35
- #[ cfg( not( miri) ) ]
36
47
Err ( _) => true , // If we couldn't get the variable, we use the default
37
- // FIXME: for now turn parallelism off under miri, otherwise complains about crossbeam-epoch
38
- #[ cfg( miri) ]
39
- Err ( _) => false ,
48
+ }
49
+ }
50
+
51
+ pub fn get_parallelism ( ) -> bool {
52
+ // FIXME: for now turn parallelism off under miri, otherwise complains about crossbeam-epoch
53
+ #[ cfg( miri) ]
54
+ return false ;
55
+ #[ cfg( not( miri) ) ]
56
+ if let Some ( parallel) = get_override_parallelism ( ) {
57
+ parallel
58
+ } else {
59
+ get_env_parallelism ( )
40
60
}
41
61
}
42
62
43
63
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
44
64
pub fn set_parallelism ( val : bool ) {
45
- std :: env :: set_var ( ENV_VARIABLE , if val { "true" } else { "false" } )
65
+ PARALLELISM . store ( if val { 2 } else { 1 } , Ordering :: SeqCst ) ;
46
66
}
47
67
48
68
/// Allows to convert into an iterator that can be executed either parallelly or serially.
0 commit comments