From c67f5aefa229b7305e1ac4f9ef29c6ad92e49d29 Mon Sep 17 00:00:00 2001 From: Manish Goregaokar Date: Wed, 23 Oct 2024 07:04:11 -0700 Subject: [PATCH] fixes --- tokenizers/src/models/bpe/trainer.rs | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index cb907bb8d..621a9b25b 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -454,7 +454,7 @@ impl BpeTrainer { // 3. Tokenize words // self.update_progress(&progress, word_counts.len(), "Tokenize words"); - let (words, counts) = + let (mut words, counts) = self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress); self.finalize_progress(&progress, words.len()); @@ -533,17 +533,27 @@ impl BpeTrainer { // Safety: This is just a type assertion, the code below may no longer be safe // if the type of `pos` changes let ref pos: HashSet = top.pos; + + + let words_len = words.len(); + struct WordPtr(*mut Word); + // Safety: We do not actually use this for concurrent access to the same memory, + // only to different chunks within the same allocation. + unsafe impl Sync for WordPtr {} + let word_start = WordPtr(words.as_mut_ptr()); + let changes = pos .maybe_par_iter() .flat_map(|&i| { - // Safety: Accessing this Vec overall as an &T whilst parts of it - // are being mutated is *probably* safe, see https://github.com/rust-lang/unsafe-code-guidelines/issues/412 - // and related issues. If not, we can always use raw pointers here. - let word = &words[i] as *const _ as *mut Word; // Safety: - // We can acces each `word` here in parallel because each position - // can be there only once (pos is a HashSet). So this is safe. + // We are producing a valid pointer since we are indexing in bounds + // + // We can access each `word` here in parallel because each position + // can be there only once (pos is a HashSet). unsafe { + assert!(i < words_len); + // This is words[i], but avoids needing to go through &T (which triggers UB) + let word = word_start.0.add(i); // let word: &mut Word = &mut (*word); (*word) .merge(top.pair.0, top.pair.1, new_token_id, max_token_length)