diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index defc7d93d..4ad2f9a79 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,10 +6,13 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; -use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; type TokenMap = HashMap; type Vocab = Vec<(String, f64)>; @@ -28,6 +31,7 @@ pub struct Unigram { fuse_unk: bool, is_optimized: bool, byte_fallback: bool, + pub special_tokens: Option>>, } impl PartialEq for Unigram { fn eq(&self, other: &Self) -> bool { @@ -52,6 +56,7 @@ impl Clone for Unigram { fuse_unk: self.fuse_unk, is_optimized: self.is_optimized, byte_fallback: self.byte_fallback, + special_tokens: self.special_tokens.as_ref().map(Arc::clone), } } } @@ -114,6 +119,9 @@ impl Unigram { let mut min_score = f64::INFINITY; for (id, (token, score)) in vocab.iter().enumerate() { + if &token.to_string() == "" { + continue; + } token_to_ids.insert(token.to_string(), id as u32); let bytes: Vec = token.bytes().collect(); builder.push(&bytes); @@ -137,6 +145,7 @@ impl Unigram { cache: Cache::default(), is_optimized, byte_fallback, + special_tokens: None, }) } diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index a0c2f4542..c45883930 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -385,6 +385,10 @@ impl AddedVocabulary { if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content) { + println!( + "finding matched, but encode special tokens is true. skipping: {:?}", + added_token + ); continue; }