Skip to content

Commit

Permalink
fix: encapsulate the pattern of obtaining &mut to disjoint elements o…
Browse files Browse the repository at this point in the history
…f Vec as UnsafeVec

It is never ok, for any code under any circumstances and any reason, to change a
&T to a &mut T unless the access happens to be mediated by UnsafeCell.
  • Loading branch information
sftse committed Oct 22, 2024
1 parent 9b77c05 commit ac96cba
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 19 deletions.
98 changes: 79 additions & 19 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,63 @@ use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};

mod unsafevec {
use std::cell::UnsafeCell;
use std::marker::PhantomData;

use super::Word;

pub(super) struct UnsafeVec<'v> {
value: UnsafeCell<*mut Word>,
len: usize,
_data: PhantomData<&'v mut Vec<Word>>,
}

// SAFETY: the only use of a &UnsafeVec is to pass it to
// fn get_unchecked_mut(), that function and its safety contract
// mandate that only disjoint indexes may be used
// if that unsafe contract is upheld there's no way to abuse a &UnsafeVec
// to obtain unsynchronized data access
unsafe impl<'v> Sync for UnsafeVec<'v> {}

impl<'v> UnsafeVec<'v> {
pub(super) fn new(words: &'v mut Vec<Word>) -> UnsafeVec<'v> {
let len = words.len();
UnsafeVec {
value: UnsafeCell::new(words.as_mut_ptr()),
len,
_data: PhantomData,
}
}
// SAFETY: per UnsafeVec, every call to get_unchecked_mut must be made with a unique index i
#[allow(clippy::mut_from_ref)]
pub(super) unsafe fn get_unchecked_mut(&self, i: isize) -> &mut Word {
assert!((i as usize) < self.len);
&mut *(*self.value.get()).offset(i)
}
}
// run with miri: cargo +nightly miri test test_unsafe_vec
#[test]
fn test_unsafe_vec() {
let mut v: Vec<_> = std::iter::repeat(Word::new()).take(100).collect();
let vec = UnsafeVec::new(&mut v);
// if these ranges overlap this is immediately detected by miri
let [first, second] = [0..50, 50..100];
std::thread::scope(|s| {
s.spawn(|| unsafe {
for i in first {
*vec.get_unchecked_mut(i) = Word::new();
}
});
s.spawn(|| unsafe {
for i in second {
*vec.get_unchecked_mut(i) = Word::new();
}
});
});
}
}

#[derive(Debug, Eq)]
struct Merge {
pair: Pair,
Expand Down Expand Up @@ -454,7 +511,7 @@ impl BpeTrainer {
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
let (mut words, counts) =
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());

Expand Down Expand Up @@ -529,24 +586,27 @@ impl BpeTrainer {
}
merges.push((top.pair, new_token_id));

// Merge the new pair in every words
let changes = top
.pos
.maybe_par_iter()
.flat_map(|&i| {
let word = &words[i] as *const _ as *mut Word;
// We can merge each of these words in parallel here because each position
// can be there only once (HashSet). So this is safe.
unsafe {
// let word: &mut Word = &mut (*word);
(*word)
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
.into_iter()
.map(|c| (c, i))
.collect::<Vec<_>>()
}
})
.collect::<Vec<_>>();
let changes = {
let words = unsafevec::UnsafeVec::new(&mut words);

// Merge the new pair in every words
top.pos
.maybe_par_iter()
.flat_map(|&i| {
// We can merge each of these words in parallel here because each position
// can be there only once (HashSet). So this is safe.
// SAFETY: this really is a HashSet, making i's unique
let _: &HashSet<_> = &top.pos;
unsafe {
let word: &mut Word = words.get_unchecked_mut(i as isize);
word.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
.into_iter()
.map(|c| (c, i))
.collect::<Vec<_>>()
}
})
.collect::<Vec<_>>()
};

// Introduce new formed pairs
for ((pair, change), iw) in changes {
Expand Down
4 changes: 4 additions & 0 deletions tokenizers/src/utils/parallelism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ pub fn get_parallelism() -> bool {
v.make_ascii_lowercase();
!matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0")
}
#[cfg(not(miri))]
Err(_) => true, // If we couldn't get the variable, we use the default
// FIXME: for now turn parallelism off under miri, otherwise complains about crossbeam-epoch
#[cfg(miri)]
Err(_) => false,
}
}

Expand Down

0 comments on commit ac96cba

Please sign in to comment.