Skip to content

Commit

Permalink
Add safety comments (#1651)
Browse files Browse the repository at this point in the history
* Unsafe comment for from_u32_unchecked

* Add safety comments and type assertion for HashSet parallel iteration

* Add safety comment for String splice

* fixes

* fmt

* pos
  • Loading branch information
Manishearth authored Oct 29, 2024
1 parent 6ea7588 commit 5512a42
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
27 changes: 21 additions & 6 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ impl BpeTrainer {
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
let (mut words, counts) =
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());

Expand Down Expand Up @@ -530,14 +530,29 @@ impl BpeTrainer {
merges.push((top.pair, new_token_id));

// Merge the new pair in every words
let changes = top
.pos
// Safety: This is just a type assertion, the code below may no longer be safe
// if the type of `pos` changes
let pos: &HashSet<usize> = &top.pos;

let words_len = words.len();
struct WordPtr(*mut Word);
// Safety: We do not actually use this for concurrent access to the same memory,
// only to different chunks within the same allocation.
unsafe impl Sync for WordPtr {}
let word_start = WordPtr(words.as_mut_ptr());

let changes = pos
.maybe_par_iter()
.flat_map(|&i| {
let word = &words[i] as *const _ as *mut Word;
// We can merge each of these words in parallel here because each position
// can be there only once (HashSet). So this is safe.
// Safety:
// We are producing a valid pointer since we are indexing in bounds
//
// We can access each `word` here in parallel because each position
// can be there only once (pos is a HashSet).
unsafe {
assert!(i < words_len);
// This is words[i], but avoids needing to go through &T (which triggers UB)
let word = word_start.0.add(i);
// let word: &mut Word = &mut (*word);
(*word)
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
Expand Down
3 changes: 3 additions & 0 deletions tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub(crate) fn bytes_char() -> HashMap<u8, char> {
}
}

// Safety: cs contains all values from bs (between 0 and 255),
// and some values of value 2⁸ + n, where n is between 0 and 255. This is between 255 and 512.
// Both ranges are valid UTF-32 values (which is fully saturated until 0xD000)
bs.into_iter()
.zip(cs)
.map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
Expand Down
8 changes: 8 additions & 0 deletions tokenizers/src/tokenizer/normalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,16 @@ impl NormalizedString {
.collect::<String>();

self.alignments.splice(n_range.clone(), alignments);

// This bounds check already happens above (`self.normalized[n_range.clone()]`), but future
// code could change to mutate `self` or `self.normalized` in the interim.
// Perform it again and hope the optimizer collapses it.
assert!(self.normalized.get(n_range.clone()).is_some());
unsafe {
self.normalized
// Safety: This is safe as long as we do not splice across a
// UTF-8 character, and we only add UTF-8 text. `normalized` is a String
// so the latter is trivially true, and we assert for the former above.
.as_mut_vec()
.splice(n_range, normalized.bytes());
}
Expand Down

0 comments on commit 5512a42

Please sign in to comment.