@@ -6,6 +6,7 @@ use rayon::iter::IterBridge;
66use rayon:: prelude:: * ;
77use rayon_cond:: CondIterator ;
88use std:: sync:: atomic:: AtomicBool ;
9+ use std:: sync:: atomic:: AtomicU8 ;
910use std:: sync:: atomic:: Ordering ;
1011
1112// Re-export rayon current_num_threads
@@ -14,35 +15,54 @@ pub use rayon::current_num_threads;
1415pub const ENV_VARIABLE : & str = "TOKENIZERS_PARALLELISM" ;
1516
1617static USED_PARALLELISM : AtomicBool = AtomicBool :: new ( false ) ;
18+ static PARALLELISM : AtomicU8 = AtomicU8 :: new ( 0 ) ;
1719
1820/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
1921pub fn is_parallelism_configured ( ) -> bool {
20- std:: env:: var ( ENV_VARIABLE ) . is_ok ( )
22+ std:: env:: var ( ENV_VARIABLE ) . is_ok ( ) || get_override_parallelism ( ) . is_some ( )
2123}
2224
2325/// Check if at some point we used a parallel iterator
2426pub fn has_parallelism_been_used ( ) -> bool {
2527 USED_PARALLELISM . load ( Ordering :: SeqCst )
2628}
2729
30+ /// Get internally set parallelism
31+ fn get_override_parallelism ( ) -> Option < bool > {
32+ match PARALLELISM . load ( Ordering :: SeqCst ) {
33+ 0 => None ,
34+ 1 => Some ( false ) ,
35+ 2 => Some ( true ) ,
36+ _ => unreachable ! ( ) ,
37+ }
38+ }
39+
2840/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
29- pub fn get_parallelism ( ) -> bool {
41+ fn get_env_parallelism ( ) -> bool {
3042 match std:: env:: var ( ENV_VARIABLE ) {
3143 Ok ( mut v) => {
3244 v. make_ascii_lowercase ( ) ;
3345 !matches ! ( v. as_ref( ) , "" | "off" | "false" | "f" | "no" | "n" | "0" )
3446 }
35- #[ cfg( not( miri) ) ]
3647 Err ( _) => true , // If we couldn't get the variable, we use the default
37- // FIXME: for now turn parallelism off under miri, otherwise complains about crossbeam-epoch
38- #[ cfg( miri) ]
39- Err ( _) => false ,
48+ }
49+ }
50+
51+ pub fn get_parallelism ( ) -> bool {
52+ // FIXME: for now turn parallelism off under miri, otherwise complains about crossbeam-epoch
53+ #[ cfg( miri) ]
54+ return false ;
55+ #[ cfg( not( miri) ) ]
56+ if let Some ( parallel) = get_override_parallelism ( ) {
57+ parallel
58+ } else {
59+ get_env_parallelism ( )
4060 }
4161}
4262
4363/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
4464pub fn set_parallelism ( val : bool ) {
45- std :: env :: set_var ( ENV_VARIABLE , if val { "true" } else { "false" } )
65+ PARALLELISM . store ( if val { 2 } else { 1 } , Ordering :: SeqCst ) ;
4666}
4767
4868/// Allows to convert into an iterator that can be executed either parallelly or serially.
0 commit comments