From 3ddcb2d6940ed859f1ee03d8ff8f29d7fa97a5df Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Wed, 17 Jan 2024 06:52:06 +0000 Subject: [PATCH] Convert word counts to u64 --- tokenizers/src/models/bpe/trainer.rs | 40 +++++++++++----------- tokenizers/src/models/wordpiece/trainer.rs | 6 ++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 3821cdab4..303fdbc81 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet}; #[derive(Debug, Eq)] struct Merge { pair: Pair, - count: u32, + count: u64, pos: HashSet, } impl PartialEq for Merge { @@ -36,7 +36,7 @@ impl Ord for Merge { } struct Config { - min_frequency: u32, + min_frequency: u64, vocab_size: usize, show_progress: bool, special_tokens: Vec, @@ -79,7 +79,7 @@ impl BpeTrainerBuilder { /// Set the expected minimum frequency #[must_use] - pub fn min_frequency(mut self, frequency: u32) -> Self { + pub fn min_frequency(mut self, frequency: u64) -> Self { self.config.min_frequency = frequency; self } @@ -176,7 +176,7 @@ impl BpeTrainerBuilder { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub struct BpeTrainer { /// The minimum frequency a pair must have to produce a merge operation - pub min_frequency: u32, + pub min_frequency: u64, /// The target vocabulary size pub vocab_size: usize, /// Whether to show progress while training @@ -195,7 +195,7 @@ pub struct BpeTrainer { /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: HashMap, } impl Default for BpeTrainer { @@ -205,7 +205,7 @@ impl Default for BpeTrainer { } impl BpeTrainer { - pub fn new(min_frequency: u32, vocab_size: usize) -> Self { + pub fn new(min_frequency: u64, vocab_size: usize) -> Self { Self { min_frequency, vocab_size, @@ -263,7 +263,7 @@ impl BpeTrainer { /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, - wc: &HashMap, + wc: &HashMap, w2id: &mut HashMap, id2w: &mut Vec, ) { @@ -322,13 +322,13 @@ impl BpeTrainer { /// Tokenize words and add subwords to the vocabulary when relevant fn tokenize_words( &self, - wc: &HashMap, + wc: &HashMap, w2id: &mut HashMap, id2w: &mut Vec, p: &Option, - ) -> (Vec, Vec) { + ) -> (Vec, Vec) { let mut words: Vec = Vec::with_capacity(wc.len()); - let mut counts: Vec = Vec::with_capacity(wc.len()); + let mut counts: Vec = Vec::with_capacity(wc.len()); for (word, count) in wc { let mut current_word = Word::new(); @@ -373,7 +373,7 @@ impl BpeTrainer { fn count_pairs( &self, words: &[Word], - counts: &[u32], + counts: &[u64], p: &Option, ) -> (HashMap, HashMap>) { words @@ -431,7 +431,7 @@ impl BpeTrainer { pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &HashMap, model: &mut BPE, ) -> Result> { let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); @@ -470,7 +470,7 @@ impl BpeTrainer { if count > 0 { queue.push(Merge { pair, - count: count as u32, + count: count as u64, pos, }); } @@ -493,8 +493,8 @@ impl BpeTrainer { } let mut top = queue.pop().unwrap(); - if top.count != pair_counts[&top.pair] as u32 { - top.count = pair_counts[&top.pair] as u32; + if top.count != pair_counts[&top.pair] as u64 { + top.count = pair_counts[&top.pair] as u64; queue.push(top); continue; } @@ -573,7 +573,7 @@ impl BpeTrainer { if count > 0 { queue.push(Merge { pair, - count: count as u32, + count: count as u64, pos, }); } @@ -632,7 +632,7 @@ impl Trainer for BpeTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; @@ -665,7 +665,7 @@ mod tests { #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: HashMap = [ ("roses".into(), 1), ("are".into(), 2), ("red".into(), 1), @@ -744,7 +744,7 @@ mod tests { */ let max_token_length = 16; - let long_word_counts: HashMap = [ + let long_word_counts: HashMap = [ ("singlelongtokenwithoutcasechange", 2), ("singleLongTokenWithCamelCaseChange", 2), ("Longsingletokenwithpunctu@t!onwithin", 2), @@ -784,7 +784,7 @@ mod tests { // directly compares tokens with known expected values. // maybe unstable depending on specific settings or changes. */ - let long_word_counts: HashMap = [ + let long_word_counts: HashMap = [ ("sin", 2), ("Sin", 2), ("Lon", 2), diff --git a/tokenizers/src/models/wordpiece/trainer.rs b/tokenizers/src/models/wordpiece/trainer.rs index 1adcc2be4..58a5abc8f 100644 --- a/tokenizers/src/models/wordpiece/trainer.rs +++ b/tokenizers/src/models/wordpiece/trainer.rs @@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder { /// Set the expected minimum frequency #[must_use] - pub fn min_frequency(mut self, frequency: u32) -> Self { + pub fn min_frequency(mut self, frequency: u64) -> Self { self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency); self } @@ -94,11 +94,11 @@ pub struct WordPieceTrainer { } impl WordPieceTrainer { - pub fn min_frequency(&self) -> u32 { + pub fn min_frequency(&self) -> u64 { self.bpe_trainer.min_frequency } - pub fn set_min_frequency(&mut self, freq: u32) { + pub fn set_min_frequency(&mut self, freq: u64) { self.bpe_trainer.min_frequency = freq; }