Skip to content

Commit

Permalink
Convert word counts to u64
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller committed Jan 17, 2024
1 parent 888dd4b commit 3ddcb2d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
40 changes: 20 additions & 20 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Eq)]
struct Merge {
pair: Pair,
count: u32,
count: u64,
pos: HashSet<usize>,
}
impl PartialEq for Merge {
Expand All @@ -36,7 +36,7 @@ impl Ord for Merge {
}

struct Config {
min_frequency: u32,
min_frequency: u64,
vocab_size: usize,
show_progress: bool,
special_tokens: Vec<AddedToken>,
Expand Down Expand Up @@ -79,7 +79,7 @@ impl BpeTrainerBuilder {

/// Set the expected minimum frequency
#[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self {
pub fn min_frequency(mut self, frequency: u64) -> Self {
self.config.min_frequency = frequency;
self
}
Expand Down Expand Up @@ -176,7 +176,7 @@ impl BpeTrainerBuilder {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
pub struct BpeTrainer {
/// The minimum frequency a pair must have to produce a merge operation
pub min_frequency: u32,
pub min_frequency: u64,
/// The target vocabulary size
pub vocab_size: usize,
/// Whether to show progress while training
Expand All @@ -195,7 +195,7 @@ pub struct BpeTrainer {
/// An optional parameter to limit the max length of any single token
pub max_token_length: Option<usize>,

words: HashMap<String, u32>,
words: HashMap<String, u64>,
}

impl Default for BpeTrainer {
Expand All @@ -205,7 +205,7 @@ impl Default for BpeTrainer {
}

impl BpeTrainer {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
Self {
min_frequency,
vocab_size,
Expand Down Expand Up @@ -263,7 +263,7 @@ impl BpeTrainer {
/// Compute the initial alphabet and limit it if relevant
fn compute_alphabet(
&self,
wc: &HashMap<String, u32>,
wc: &HashMap<String, u64>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
Expand Down Expand Up @@ -322,13 +322,13 @@ impl BpeTrainer {
/// Tokenize words and add subwords to the vocabulary when relevant
fn tokenize_words(
&self,
wc: &HashMap<String, u32>,
wc: &HashMap<String, u64>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
p: &Option<ProgressBar>,
) -> (Vec<Word>, Vec<u32>) {
) -> (Vec<Word>, Vec<u64>) {
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
let mut counts: Vec<u32> = Vec::with_capacity(wc.len());
let mut counts: Vec<u64> = Vec::with_capacity(wc.len());

for (word, count) in wc {
let mut current_word = Word::new();
Expand Down Expand Up @@ -373,7 +373,7 @@ impl BpeTrainer {
fn count_pairs(
&self,
words: &[Word],
counts: &[u32],
counts: &[u64],
p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
words
Expand Down Expand Up @@ -431,7 +431,7 @@ impl BpeTrainer {

pub fn do_train(
&self,
word_counts: &HashMap<String, u32>,
word_counts: &HashMap<String, u64>,
model: &mut BPE,
) -> Result<Vec<AddedToken>> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
Expand Down Expand Up @@ -470,7 +470,7 @@ impl BpeTrainer {
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
count: count as u64,
pos,
});
}
Expand All @@ -493,8 +493,8 @@ impl BpeTrainer {
}

let mut top = queue.pop().unwrap();
if top.count != pair_counts[&top.pair] as u32 {
top.count = pair_counts[&top.pair] as u32;
if top.count != pair_counts[&top.pair] as u64 {
top.count = pair_counts[&top.pair] as u64;
queue.push(top);
continue;
}
Expand Down Expand Up @@ -573,7 +573,7 @@ impl BpeTrainer {
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
count: count as u64,
pos,
});
}
Expand Down Expand Up @@ -632,7 +632,7 @@ impl Trainer for BpeTrainer {
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
let words: Result<HashMap<String, u64>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
Expand Down Expand Up @@ -665,7 +665,7 @@ mod tests {

#[test]
fn test_train() {
let word_counts: HashMap<String, u32> = [
let word_counts: HashMap<String, u64> = [
("roses".into(), 1),
("are".into(), 2),
("red".into(), 1),
Expand Down Expand Up @@ -744,7 +744,7 @@ mod tests {
*/

let max_token_length = 16;
let long_word_counts: HashMap<String, u32> = [
let long_word_counts: HashMap<String, u64> = [
("singlelongtokenwithoutcasechange", 2),
("singleLongTokenWithCamelCaseChange", 2),
("Longsingletokenwithpunctu@t!onwithin", 2),
Expand Down Expand Up @@ -784,7 +784,7 @@ mod tests {
// directly compares tokens with known expected values.
// maybe unstable depending on specific settings or changes.
*/
let long_word_counts: HashMap<String, u32> = [
let long_word_counts: HashMap<String, u64> = [
("sin", 2),
("Sin", 2),
("Lon", 2),
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/models/wordpiece/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder {

/// Set the expected minimum frequency
#[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self {
pub fn min_frequency(mut self, frequency: u64) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
self
}
Expand Down Expand Up @@ -94,11 +94,11 @@ pub struct WordPieceTrainer {
}

impl WordPieceTrainer {
pub fn min_frequency(&self) -> u32 {
pub fn min_frequency(&self) -> u64 {
self.bpe_trainer.min_frequency
}

pub fn set_min_frequency(&mut self, freq: u32) {
pub fn set_min_frequency(&mut self, freq: u64) {
self.bpe_trainer.min_frequency = freq;
}

Expand Down

0 comments on commit 3ddcb2d

Please sign in to comment.