Skip to content

Commit

Permalink
small fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 5, 2024
1 parent 0475c05 commit 167ecde
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
8 changes: 5 additions & 3 deletions tokenizers/src/models/unigram/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ impl<'de> Visitor<'de> for UnigramVisitor {

#[cfg(test)]
mod test {
use crate::AddedVocabulary;

use super::*;

#[test]
fn test_serialization() {
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
let model = Unigram::from(vocab, Some(0), false).unwrap();
let model = Unigram::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap();

let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
Expand All @@ -94,7 +96,7 @@ mod test {
#[test]
fn test_serialization_unk_id_not_zero() {
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
let model = Unigram::from(vocab, Some(1), false).unwrap();
let model = Unigram::from(vocab, Some(1), false, &AddedVocabulary::default()).unwrap();

let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
Expand All @@ -105,7 +107,7 @@ mod test {
#[test]
fn test_serialization_no_unk_id() {
let vocab = vec![("a".to_string(), -0.5)];
let model = Unigram::from(vocab, None, false).unwrap();
let model = Unigram::from(vocab, None, false, &AddedVocabulary::default()).unwrap();

let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
Expand Down
10 changes: 7 additions & 3 deletions tokenizers/src/models/unigram/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram};
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
use crate::AddedVocabulary;
use log::debug;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
Expand Down Expand Up @@ -182,6 +183,7 @@ impl UnigramTrainer {
special_tokens.into_iter().chain(pieces).collect(),
unk_id,
model.byte_fallback(),
&AddedVocabulary::default(),
)
}

Expand Down Expand Up @@ -567,7 +569,8 @@ impl UnigramTrainer {
if required_chars.len() as u32 > self.vocab_size {
return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
}
let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?;
let mut new_model =
Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
loop {
// Sub-EM iteration.
for _iter in 0..self.n_sub_iterations {
Expand All @@ -576,7 +579,8 @@ impl UnigramTrainer {

// Executes M step.
pieces = self.run_m_step(&pieces, &expected);
new_model = Unigram::from(pieces.clone(), Some(0), false)?;
new_model =
Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;

// Useful comment for checking compatibility with spm
debug!(
Expand All @@ -600,7 +604,7 @@ impl UnigramTrainer {

// Prunes pieces.
pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
new_model = Unigram::from(pieces.clone(), Some(0), false)?;
new_model = Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
}
self.finalize_progress(&progress, expected_updates);

Expand Down

0 comments on commit 167ecde

Please sign in to comment.