Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safety comments #1651

Merged
merged 6 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ impl BpeTrainer {
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
let (mut words, counts) =
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());

Expand Down Expand Up @@ -530,14 +530,30 @@ impl BpeTrainer {
merges.push((top.pair, new_token_id));

// Merge the new pair in every words
let changes = top
.pos
// Safety: This is just a type assertion, the code below may no longer be safe
// if the type of `pos` changes
let ref pos: HashSet<usize> = top.pos;


let words_len = words.len();
struct WordPtr(*mut Word);
// Safety: We do not actually use this for concurrent access to the same memory,
// only to different chunks within the same allocation.
unsafe impl Sync for WordPtr {}
let word_start = WordPtr(words.as_mut_ptr());

let changes = pos
.maybe_par_iter()
.flat_map(|&i| {
let word = &words[i] as *const _ as *mut Word;
// We can merge each of these words in parallel here because each position
// can be there only once (HashSet). So this is safe.
// Safety:
// We are producing a valid pointer since we are indexing in bounds
//
// We can access each `word` here in parallel because each position
// can be there only once (pos is a HashSet).
unsafe {
assert!(i < words_len);
// This is words[i], but avoids needing to go through &T (which triggers UB)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FMI, why is &T UB?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's UB to mutate data that currently has an active &T. This persists through raw pointers in some ways (and not in others).

let word = word_start.0.add(i);
// let word: &mut Word = &mut (*word);
(*word)
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
Expand Down
3 changes: 3 additions & 0 deletions tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub(crate) fn bytes_char() -> HashMap<u8, char> {
}
}

// Safety: cs contains all values from bs (between 0 and 255),
// and some values of value 2⁸ + n, where n is between 0 and 255. This is between 255 and 512.
// Both ranges are valid UTF-32 values (which is fully saturated until 0xD000)
bs.into_iter()
.zip(cs)
.map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
Expand Down
8 changes: 8 additions & 0 deletions tokenizers/src/tokenizer/normalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,16 @@ impl NormalizedString {
.collect::<String>();

self.alignments.splice(n_range.clone(), alignments);

// This bounds check already happens above (`self.normalized[n_range.clone()]`), but future
// code could change to mutate `self` or `self.normalized` in the interim.
// Perform it again and hope the optimizer collapses it.
assert!(self.normalized.get(n_range.clone()).is_some());
unsafe {
self.normalized
// Safety: This is safe as long as we do not splice across a
// UTF-8 character, and we only add UTF-8 text. `normalized` is a String
// so the latter is trivially true, and we assert for the former above.
.as_mut_vec()
.splice(n_range, normalized.bytes());
}
Expand Down
Loading