From fd24c27949b81c38b88defb57c65b8ebebc1826e Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Wed, 17 Jan 2024 07:03:32 +0000 Subject: [PATCH] More spots needed to compile --- bindings/python/src/trainers.rs | 12 ++++++------ tokenizers/src/models/wordlevel/trainer.rs | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 1c1c9310a..707dc7230 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -183,12 +183,12 @@ impl PyBpeTrainer { } #[getter] - fn get_min_frequency(self_: PyRef) -> u32 { + fn get_min_frequency(self_: PyRef) -> u64 { getter!(self_, BpeTrainer, min_frequency) } #[setter] - fn set_min_frequency(self_: PyRef, freq: u32) { + fn set_min_frequency(self_: PyRef, freq: u64) { setter!(self_, BpeTrainer, min_frequency, freq); } @@ -397,12 +397,12 @@ impl PyWordPieceTrainer { } #[getter] - fn get_min_frequency(self_: PyRef) -> u32 { + fn get_min_frequency(self_: PyRef) -> u64 { getter!(self_, WordPieceTrainer, min_frequency()) } #[setter] - fn set_min_frequency(self_: PyRef, freq: u32) { + fn set_min_frequency(self_: PyRef, freq: u64) { setter!(self_, WordPieceTrainer, @set_min_frequency, freq); } @@ -589,12 +589,12 @@ impl PyWordLevelTrainer { } #[getter] - fn get_min_frequency(self_: PyRef) -> u32 { + fn get_min_frequency(self_: PyRef) -> u64 { getter!(self_, WordLevelTrainer, min_frequency) } #[setter] - fn set_min_frequency(self_: PyRef, freq: u32) { + fn set_min_frequency(self_: PyRef, freq: u64) { setter!(self_, WordLevelTrainer, min_frequency, freq); } diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index d4048b15d..c52ad08d7 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -10,7 +10,7 @@ use std::collections::HashMap; pub struct WordLevelTrainer { /// The minimum frequency a word must have to be part of the vocabulary #[builder(default = "0")] - pub min_frequency: u32, + pub min_frequency: u64, /// The target vocabulary size #[builder(default = "30_000")] pub vocab_size: usize, @@ -22,7 +22,7 @@ pub struct WordLevelTrainer { pub special_tokens: Vec, #[builder(default, private)] - words: HashMap, + words: HashMap, } impl Default for WordLevelTrainer { @@ -38,14 +38,14 @@ impl WordLevelTrainer { fn do_train( &self, - word_counts: &HashMap, + word_counts: &HashMap, model: &mut WordLevel, ) -> Result> { let mut ordered_counts = word_counts.iter().collect::>(); //sort the word counts first by inverse counts and then by word, in order //to keep the sorting deterministic in case of equal counts - let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering { + let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering { let count_comp: Ordering = l.1.cmp(r.1); if count_comp != Ordering::Equal { return count_comp.reverse(); @@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; @@ -132,7 +132,7 @@ mod tests { #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: HashMap = [ ("the".into(), 25), ("roses".into(), 22), ("are".into(), 24),