Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert word counts to u64 #1433

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ impl PyBpeTrainer {
}

#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, BpeTrainer, min_frequency)
}

#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, BpeTrainer, min_frequency, freq);
}

Expand Down Expand Up @@ -397,12 +397,12 @@ impl PyWordPieceTrainer {
}

#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, WordPieceTrainer, min_frequency())
}

#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
}

Expand Down Expand Up @@ -589,12 +589,12 @@ impl PyWordLevelTrainer {
}

#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, WordLevelTrainer, min_frequency)
}

#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, WordLevelTrainer, min_frequency, freq);
}

Expand Down
40 changes: 20 additions & 20 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Eq)]
struct Merge {
pair: Pair,
count: u32,
count: u64,
pos: HashSet<usize>,
}
impl PartialEq for Merge {
Expand All @@ -36,7 +36,7 @@ impl Ord for Merge {
}

struct Config {
min_frequency: u32,
min_frequency: u64,
vocab_size: usize,
show_progress: bool,
special_tokens: Vec<AddedToken>,
Expand Down Expand Up @@ -79,7 +79,7 @@ impl BpeTrainerBuilder {

/// Set the expected minimum frequency
#[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self {
pub fn min_frequency(mut self, frequency: u64) -> Self {
self.config.min_frequency = frequency;
self
}
Expand Down Expand Up @@ -176,7 +176,7 @@ impl BpeTrainerBuilder {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
pub struct BpeTrainer {
/// The minimum frequency a pair must have to produce a merge operation
pub min_frequency: u32,
pub min_frequency: u64,
/// The target vocabulary size
pub vocab_size: usize,
/// Whether to show progress while training
Expand All @@ -195,7 +195,7 @@ pub struct BpeTrainer {
/// An optional parameter to limit the max length of any single token
pub max_token_length: Option<usize>,

words: HashMap<String, u32>,
words: HashMap<String, u64>,
}

impl Default for BpeTrainer {
Expand All @@ -205,7 +205,7 @@ impl Default for BpeTrainer {
}

impl BpeTrainer {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
Self {
min_frequency,
vocab_size,
Expand Down Expand Up @@ -263,7 +263,7 @@ impl BpeTrainer {
/// Compute the initial alphabet and limit it if relevant
fn compute_alphabet(
&self,
wc: &HashMap<String, u32>,
wc: &HashMap<String, u64>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
Expand Down Expand Up @@ -322,13 +322,13 @@ impl BpeTrainer {
/// Tokenize words and add subwords to the vocabulary when relevant
fn tokenize_words(
&self,
wc: &HashMap<String, u32>,
wc: &HashMap<String, u64>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
p: &Option<ProgressBar>,
) -> (Vec<Word>, Vec<u32>) {
) -> (Vec<Word>, Vec<u64>) {
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
let mut counts: Vec<u32> = Vec::with_capacity(wc.len());
let mut counts: Vec<u64> = Vec::with_capacity(wc.len());

for (word, count) in wc {
let mut current_word = Word::new();
Expand Down Expand Up @@ -373,7 +373,7 @@ impl BpeTrainer {
fn count_pairs(
&self,
words: &[Word],
counts: &[u32],
counts: &[u64],
p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
words
Expand Down Expand Up @@ -431,7 +431,7 @@ impl BpeTrainer {

pub fn do_train(
&self,
word_counts: &HashMap<String, u32>,
word_counts: &HashMap<String, u64>,
model: &mut BPE,
) -> Result<Vec<AddedToken>> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
Expand Down Expand Up @@ -470,7 +470,7 @@ impl BpeTrainer {
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
count: count as u64,
pos,
});
}
Expand All @@ -493,8 +493,8 @@ impl BpeTrainer {
}

let mut top = queue.pop().unwrap();
if top.count != pair_counts[&top.pair] as u32 {
top.count = pair_counts[&top.pair] as u32;
if top.count != pair_counts[&top.pair] as u64 {
top.count = pair_counts[&top.pair] as u64;
queue.push(top);
continue;
}
Expand Down Expand Up @@ -573,7 +573,7 @@ impl BpeTrainer {
if count > 0 {
queue.push(Merge {
pair,
count: count as u32,
count: count as u64,
pos,
});
}
Expand Down Expand Up @@ -632,7 +632,7 @@ impl Trainer for BpeTrainer {
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
let words: Result<HashMap<String, u64>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
Expand Down Expand Up @@ -665,7 +665,7 @@ mod tests {

#[test]
fn test_train() {
let word_counts: HashMap<String, u32> = [
let word_counts: HashMap<String, u64> = [
("roses".into(), 1),
("are".into(), 2),
("red".into(), 1),
Expand Down Expand Up @@ -744,7 +744,7 @@ mod tests {
*/

let max_token_length = 16;
let long_word_counts: HashMap<String, u32> = [
let long_word_counts: HashMap<String, u64> = [
("singlelongtokenwithoutcasechange", 2),
("singleLongTokenWithCamelCaseChange", 2),
("Longsingletokenwithpunctu@t!onwithin", 2),
Expand Down Expand Up @@ -784,7 +784,7 @@ mod tests {
// directly compares tokens with known expected values.
// maybe unstable depending on specific settings or changes.
*/
let long_word_counts: HashMap<String, u32> = [
let long_word_counts: HashMap<String, u64> = [
("sin", 2),
("Sin", 2),
("Lon", 2),
Expand Down
12 changes: 6 additions & 6 deletions tokenizers/src/models/wordlevel/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashMap;
pub struct WordLevelTrainer {
/// The minimum frequency a word must have to be part of the vocabulary
#[builder(default = "0")]
pub min_frequency: u32,
pub min_frequency: u64,
/// The target vocabulary size
#[builder(default = "30_000")]
pub vocab_size: usize,
Expand All @@ -22,7 +22,7 @@ pub struct WordLevelTrainer {
pub special_tokens: Vec<AddedToken>,

#[builder(default, private)]
words: HashMap<String, u32>,
words: HashMap<String, u64>,
}

impl Default for WordLevelTrainer {
Expand All @@ -38,14 +38,14 @@ impl WordLevelTrainer {

fn do_train(
&self,
word_counts: &HashMap<String, u32>,
word_counts: &HashMap<String, u64>,
model: &mut WordLevel,
) -> Result<Vec<AddedToken>> {
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();

//sort the word counts first by inverse counts and then by word, in order
//to keep the sorting deterministic in case of equal counts
let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering {
let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
let count_comp: Ordering = l.1.cmp(r.1);
if count_comp != Ordering::Equal {
return count_comp.reverse();
Expand Down Expand Up @@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer {
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
let words: Result<HashMap<String, u64>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
Expand Down Expand Up @@ -132,7 +132,7 @@ mod tests {

#[test]
fn test_train() {
let word_counts: HashMap<String, u32> = [
let word_counts: HashMap<String, u64> = [
("the".into(), 25),
("roses".into(), 22),
("are".into(), 24),
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/models/wordpiece/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder {

/// Set the expected minimum frequency
#[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self {
pub fn min_frequency(mut self, frequency: u64) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
self
}
Expand Down Expand Up @@ -94,11 +94,11 @@ pub struct WordPieceTrainer {
}

impl WordPieceTrainer {
pub fn min_frequency(&self) -> u32 {
pub fn min_frequency(&self) -> u64 {
self.bpe_trainer.min_frequency
}

pub fn set_min_frequency(&mut self, freq: u32) {
pub fn set_min_frequency(&mut self, freq: u64) {
self.bpe_trainer.min_frequency = freq;
}

Expand Down
Loading