Skip to content

Commit

Permalink
Unsound call of set_var (#1664)
Browse files Browse the repository at this point in the history
* refactor: lift cloning to caller

* refactor: do not elide lifetimes as in Rust 2018

* fix: unsound use of env::set_var, was breaking stdlib change to make unsafe

It is generally not safe to set env variables. The correct way to set a config
value that needs to be overridden is to hold a copy internal to the library and
only read from the environment.
  • Loading branch information
sftse authored Oct 25, 2024
1 parent a8738a9 commit 6ea7588
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 13 deletions.
2 changes: 1 addition & 1 deletion bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ impl tk::tokenizer::Normalizer for CustomNormalizer {
Python::with_gil(|py| {
let normalized = PyNormalizedStringRefMut::new(normalized);
let py_normalized = self.inner.bind(py);
py_normalized.call_method("normalize", (normalized.get(),), None)?;
py_normalized.call_method("normalize", (normalized.get().clone(),), None)?;
Ok(())
})
}
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
Python::with_gil(|py| {
let pretok = PyPreTokenizedStringRefMut::new(sentence);
let py_pretok = self.inner.bind(py);
py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?;
py_pretok.call_method("pre_tokenize", (pretok.get().clone(),), None)?;
Ok(())
})
}
Expand Down
10 changes: 5 additions & 5 deletions bindings/python/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,24 @@ pub trait DestroyPtr {
fn destroy(&mut self);
}

pub struct RefMutGuard<'r, T: DestroyPtr + Clone> {
pub struct RefMutGuard<'r, T: DestroyPtr> {
content: T,
r: PhantomData<&'r mut T>,
}
impl<T: DestroyPtr + Clone> RefMutGuard<'_, T> {
impl<T: DestroyPtr> RefMutGuard<'_, T> {
pub fn new(content: T) -> Self {
Self {
content,
r: PhantomData,
}
}

pub fn get(&self) -> T {
self.content.clone()
pub fn get(&self) -> &T {
&self.content
}
}

impl<T: DestroyPtr + Clone> Drop for RefMutGuard<'_, T> {
impl<T: DestroyPtr> Drop for RefMutGuard<'_, T> {
fn drop(&mut self) {
self.content.destroy()
}
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/src/utils/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ impl DestroyPtr for PyNormalizedStringRefMut {
}

impl PyNormalizedStringRefMut {
pub fn new(normalized: &mut NormalizedString) -> RefMutGuard<Self> {
pub fn new(normalized: &mut NormalizedString) -> RefMutGuard<'_, Self> {
RefMutGuard::new(Self {
inner: RefMutContainer::new(normalized),
})
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/utils/pretokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn normalize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResu
} else {
ToPyResult(pretok.normalize(|normalized| {
let norm = PyNormalizedStringRefMut::new(normalized);
func.call((norm.get(),), None)?;
func.call((norm.get().clone(),), None)?;
Ok(())
}))
.into()
Expand Down Expand Up @@ -272,7 +272,7 @@ impl DestroyPtr for PyPreTokenizedStringRefMut {
}

impl PyPreTokenizedStringRefMut {
pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard<Self> {
pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard<'_, Self> {
// SAFETY: This is safe because we return a RefMutGuard here.
// The compiler will make sure the &mut stays valid as necessary.
RefMutGuard::new(Self {
Expand Down
26 changes: 23 additions & 3 deletions tokenizers/src/utils/parallelism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use rayon::iter::IterBridge;
use rayon::prelude::*;
use rayon_cond::CondIterator;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU8;
use std::sync::atomic::Ordering;

// Re-export rayon current_num_threads
Expand All @@ -14,19 +15,30 @@ pub use rayon::current_num_threads;
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";

static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);
static PARALLELISM: AtomicU8 = AtomicU8::new(0);

/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
pub fn is_parallelism_configured() -> bool {
std::env::var(ENV_VARIABLE).is_ok()
std::env::var(ENV_VARIABLE).is_ok() || get_override_parallelism().is_some()
}

/// Check if at some point we used a parallel iterator
pub fn has_parallelism_been_used() -> bool {
USED_PARALLELISM.load(Ordering::SeqCst)
}

/// Get internally set parallelism
fn get_override_parallelism() -> Option<bool> {
match PARALLELISM.load(Ordering::SeqCst) {
0 => None,
1 => Some(false),
2 => Some(true),
_ => unreachable!(),
}
}

/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
pub fn get_parallelism() -> bool {
fn get_env_parallelism() -> bool {
match std::env::var(ENV_VARIABLE) {
Ok(mut v) => {
v.make_ascii_lowercase();
Expand All @@ -36,9 +48,17 @@ pub fn get_parallelism() -> bool {
}
}

pub fn get_parallelism() -> bool {
if let Some(parallel) = get_override_parallelism() {
parallel
} else {
get_env_parallelism()
}
}

/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
pub fn set_parallelism(val: bool) {
std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" })
PARALLELISM.store(if val { 2 } else { 1 }, Ordering::SeqCst);
}

/// Allows to convert into an iterator that can be executed either parallelly or serially.
Expand Down

0 comments on commit 6ea7588

Please sign in to comment.