diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 2876f1ef5..81468ad2f 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -8,6 +8,63 @@ use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap, HashSet}; +mod unsafevec { + use std::cell::UnsafeCell; + use std::marker::PhantomData; + + use super::Word; + + pub(super) struct UnsafeVec<'v> { + value: UnsafeCell<*mut Word>, + len: usize, + _data: PhantomData<&'v mut Vec>, + } + + // SAFETY: the only use of a &UnsafeVec is to pass it to + // fn get_unchecked_mut(), that function and its safety contract + // mandate that only disjoint indexes may be used + // if that unsafe contract is upheld there's no way to abuse a &UnsafeVec + // to obtain unsynchronized data access + unsafe impl<'v> Sync for UnsafeVec<'v> {} + + impl<'v> UnsafeVec<'v> { + pub(super) fn new(words: &'v mut Vec) -> UnsafeVec<'v> { + let len = words.len(); + UnsafeVec { + value: UnsafeCell::new(words.as_mut_ptr()), + len, + _data: PhantomData, + } + } + // SAFETY: per UnsafeVec, every call to get_unchecked_mut must be made with a unique index i + #[allow(clippy::mut_from_ref)] + pub(super) unsafe fn get_unchecked_mut(&self, i: isize) -> &mut Word { + assert!((i as usize) < self.len); + &mut *(*self.value.get()).offset(i) + } + } + // run with miri: cargo +nightly miri test test_unsafe_vec + #[test] + fn test_unsafe_vec() { + let mut v: Vec<_> = std::iter::repeat(Word::new()).take(100).collect(); + let vec = UnsafeVec::new(&mut v); + // if these ranges overlap this is immediately detected by miri + let [first, second] = [0..50, 50..100]; + std::thread::scope(|s| { + s.spawn(|| unsafe { + for i in first { + *vec.get_unchecked_mut(i) = Word::new(); + } + }); + s.spawn(|| unsafe { + for i in second { + *vec.get_unchecked_mut(i) = Word::new(); + } + }); + }); + } +} + #[derive(Debug, Eq)] struct Merge { pair: Pair, @@ -454,7 +511,7 @@ impl BpeTrainer { // 3. Tokenize words // self.update_progress(&progress, word_counts.len(), "Tokenize words"); - let (words, counts) = + let (mut words, counts) = self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress); self.finalize_progress(&progress, words.len()); @@ -529,24 +586,27 @@ impl BpeTrainer { } merges.push((top.pair, new_token_id)); - // Merge the new pair in every words - let changes = top - .pos - .maybe_par_iter() - .flat_map(|&i| { - let word = &words[i] as *const _ as *mut Word; - // We can merge each of these words in parallel here because each position - // can be there only once (HashSet). So this is safe. - unsafe { - // let word: &mut Word = &mut (*word); - (*word) - .merge(top.pair.0, top.pair.1, new_token_id, max_token_length) - .into_iter() - .map(|c| (c, i)) - .collect::>() - } - }) - .collect::>(); + let changes = { + let words = unsafevec::UnsafeVec::new(&mut words); + + // Merge the new pair in every words + top.pos + .maybe_par_iter() + .flat_map(|&i| { + // We can merge each of these words in parallel here because each position + // can be there only once (HashSet). So this is safe. + // SAFETY: this really is a HashSet, making i's unique + let _: &HashSet<_> = &top.pos; + unsafe { + let word: &mut Word = words.get_unchecked_mut(i as isize); + word.merge(top.pair.0, top.pair.1, new_token_id, max_token_length) + .into_iter() + .map(|c| (c, i)) + .collect::>() + } + }) + .collect::>() + }; // Introduce new formed pairs for ((pair, change), iw) in changes { diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index b955731d1..903f3b533 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -32,7 +32,11 @@ pub fn get_parallelism() -> bool { v.make_ascii_lowercase(); !matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0") } + #[cfg(not(miri))] Err(_) => true, // If we couldn't get the variable, we use the default + // FIXME: for now turn parallelism off under miri, otherwise complains about crossbeam-epoch + #[cfg(miri)] + Err(_) => false, } }