Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Manishearth committed Oct 23, 2024
1 parent 4715a27 commit c67f5ae
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ impl BpeTrainer {
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
let (mut words, counts) =
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());

Expand Down Expand Up @@ -533,17 +533,27 @@ impl BpeTrainer {
// Safety: This is just a type assertion, the code below may no longer be safe
// if the type of `pos` changes
let ref pos: HashSet<usize> = top.pos;


let words_len = words.len();
struct WordPtr(*mut Word);
// Safety: We do not actually use this for concurrent access to the same memory,
// only to different chunks within the same allocation.
unsafe impl Sync for WordPtr {}
let word_start = WordPtr(words.as_mut_ptr());

let changes = pos
.maybe_par_iter()
.flat_map(|&i| {
// Safety: Accessing this Vec overall as an &T whilst parts of it
// are being mutated is *probably* safe, see https://github.com/rust-lang/unsafe-code-guidelines/issues/412
// and related issues. If not, we can always use raw pointers here.
let word = &words[i] as *const _ as *mut Word;
// Safety:
// We can acces each `word` here in parallel because each position
// can be there only once (pos is a HashSet). So this is safe.
// We are producing a valid pointer since we are indexing in bounds
//
// We can access each `word` here in parallel because each position
// can be there only once (pos is a HashSet).
unsafe {
assert!(i < words_len);
// This is words[i], but avoids needing to go through &T (which triggers UB)
let word = word_start.0.add(i);
// let word: &mut Word = &mut (*word);
(*word)
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
Expand Down

0 comments on commit c67f5ae

Please sign in to comment.