diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index e5b5cb324..feff9811c 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -553,7 +553,7 @@ impl tk::tokenizer::Normalizer for CustomNormalizer { Python::with_gil(|py| { let normalized = PyNormalizedStringRefMut::new(normalized); let py_normalized = self.inner.bind(py); - py_normalized.call_method("normalize", (normalized.get(),), None)?; + py_normalized.call_method("normalize", (normalized.get().clone(),), None)?; Ok(()) }) } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index bac3284ad..1c43f7eb8 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -634,7 +634,7 @@ impl tk::tokenizer::PreTokenizer for CustomPreTokenizer { Python::with_gil(|py| { let pretok = PyPreTokenizedStringRefMut::new(sentence); let py_pretok = self.inner.bind(py); - py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?; + py_pretok.call_method("pre_tokenize", (pretok.get().clone(),), None)?; Ok(()) }) } diff --git a/bindings/python/src/utils/mod.rs b/bindings/python/src/utils/mod.rs index 43352a7fa..21b3fc1e1 100644 --- a/bindings/python/src/utils/mod.rs +++ b/bindings/python/src/utils/mod.rs @@ -18,11 +18,11 @@ pub trait DestroyPtr { fn destroy(&mut self); } -pub struct RefMutGuard<'r, T: DestroyPtr + Clone> { +pub struct RefMutGuard<'r, T: DestroyPtr> { content: T, r: PhantomData<&'r mut T>, } -impl RefMutGuard<'_, T> { +impl RefMutGuard<'_, T> { pub fn new(content: T) -> Self { Self { content, @@ -30,12 +30,12 @@ impl RefMutGuard<'_, T> { } } - pub fn get(&self) -> T { - self.content.clone() + pub fn get(&self) -> &T { + &self.content } } -impl Drop for RefMutGuard<'_, T> { +impl Drop for RefMutGuard<'_, T> { fn drop(&mut self) { self.content.destroy() } diff --git a/bindings/python/src/utils/normalization.rs b/bindings/python/src/utils/normalization.rs index 4cb3c7ce8..b67dcff9c 100644 --- a/bindings/python/src/utils/normalization.rs +++ b/bindings/python/src/utils/normalization.rs @@ -396,7 +396,7 @@ impl DestroyPtr for PyNormalizedStringRefMut { } impl PyNormalizedStringRefMut { - pub fn new(normalized: &mut NormalizedString) -> RefMutGuard { + pub fn new(normalized: &mut NormalizedString) -> RefMutGuard<'_, Self> { RefMutGuard::new(Self { inner: RefMutContainer::new(normalized), }) diff --git a/bindings/python/src/utils/pretokenization.rs b/bindings/python/src/utils/pretokenization.rs index 88fdd19f5..70444aace 100644 --- a/bindings/python/src/utils/pretokenization.rs +++ b/bindings/python/src/utils/pretokenization.rs @@ -39,7 +39,7 @@ fn normalize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResu } else { ToPyResult(pretok.normalize(|normalized| { let norm = PyNormalizedStringRefMut::new(normalized); - func.call((norm.get(),), None)?; + func.call((norm.get().clone(),), None)?; Ok(()) })) .into() @@ -272,7 +272,7 @@ impl DestroyPtr for PyPreTokenizedStringRefMut { } impl PyPreTokenizedStringRefMut { - pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard { + pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard<'_, Self> { // SAFETY: This is safe because we return a RefMutGuard here. // The compiler will make sure the &mut stays valid as necessary. RefMutGuard::new(Self { diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index b955731d1..ea2fd331a 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -6,6 +6,7 @@ use rayon::iter::IterBridge; use rayon::prelude::*; use rayon_cond::CondIterator; use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU8; use std::sync::atomic::Ordering; // Re-export rayon current_num_threads @@ -14,10 +15,11 @@ pub use rayon::current_num_threads; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; static USED_PARALLELISM: AtomicBool = AtomicBool::new(false); +static PARALLELISM: AtomicU8 = AtomicU8::new(0); /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set pub fn is_parallelism_configured() -> bool { - std::env::var(ENV_VARIABLE).is_ok() + std::env::var(ENV_VARIABLE).is_ok() || get_override_parallelism().is_some() } /// Check if at some point we used a parallel iterator @@ -25,8 +27,18 @@ pub fn has_parallelism_been_used() -> bool { USED_PARALLELISM.load(Ordering::SeqCst) } +/// Get internally set parallelism +fn get_override_parallelism() -> Option { + match PARALLELISM.load(Ordering::SeqCst) { + 0 => None, + 1 => Some(false), + 2 => Some(true), + _ => unreachable!(), + } +} + /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable -pub fn get_parallelism() -> bool { +fn get_env_parallelism() -> bool { match std::env::var(ENV_VARIABLE) { Ok(mut v) => { v.make_ascii_lowercase(); @@ -36,9 +48,17 @@ pub fn get_parallelism() -> bool { } } +pub fn get_parallelism() -> bool { + if let Some(parallel) = get_override_parallelism() { + parallel + } else { + get_env_parallelism() + } +} + /// Set the value for `TOKENIZERS_PARALLELISM` for the current process pub fn set_parallelism(val: bool) { - std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" }) + PARALLELISM.store(if val { 2 } else { 1 }, Ordering::SeqCst); } /// Allows to convert into an iterator that can be executed either parallelly or serially.