diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index b955731d1..ea2fd331a 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -6,6 +6,7 @@ use rayon::iter::IterBridge; use rayon::prelude::*; use rayon_cond::CondIterator; use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU8; use std::sync::atomic::Ordering; // Re-export rayon current_num_threads @@ -14,10 +15,11 @@ pub use rayon::current_num_threads; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; static USED_PARALLELISM: AtomicBool = AtomicBool::new(false); +static PARALLELISM: AtomicU8 = AtomicU8::new(0); /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set pub fn is_parallelism_configured() -> bool { - std::env::var(ENV_VARIABLE).is_ok() + std::env::var(ENV_VARIABLE).is_ok() || get_override_parallelism().is_some() } /// Check if at some point we used a parallel iterator @@ -25,8 +27,18 @@ pub fn has_parallelism_been_used() -> bool { USED_PARALLELISM.load(Ordering::SeqCst) } +/// Get internally set parallelism +fn get_override_parallelism() -> Option { + match PARALLELISM.load(Ordering::SeqCst) { + 0 => None, + 1 => Some(false), + 2 => Some(true), + _ => unreachable!(), + } +} + /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable -pub fn get_parallelism() -> bool { +fn get_env_parallelism() -> bool { match std::env::var(ENV_VARIABLE) { Ok(mut v) => { v.make_ascii_lowercase(); @@ -36,9 +48,17 @@ pub fn get_parallelism() -> bool { } } +pub fn get_parallelism() -> bool { + if let Some(parallel) = get_override_parallelism() { + parallel + } else { + get_env_parallelism() + } +} + /// Set the value for `TOKENIZERS_PARALLELISM` for the current process pub fn set_parallelism(val: bool) { - std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" }) + PARALLELISM.store(if val { 2 } else { 1 }, Ordering::SeqCst); } /// Allows to convert into an iterator that can be executed either parallelly or serially.