diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 603044fab..5462fd23e 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -370,7 +370,8 @@ impl AddedVocabulary { let id = split_re.1[aho_id]; let added_token = &self.added_tokens_map_r.get(&id).unwrap(); - if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content) { + if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content) + { continue; } @@ -918,4 +919,66 @@ mod tests { ] ); } + + #[test] + fn test_encode_special_tokens() { + let model = ModelMock::new(&[]); + let mut vocab = AddedVocabulary::new(); + let normalizer = Lowercase; + + vocab.add_tokens( + &[ + AddedToken::from("", true) + .lstrip(true) + .rstrip(true) + .single_word(true), + AddedToken::from("ask>", false), + AddedToken::from("", true), + ], + &model, + Some(&normalizer), + ); + vocab.set_encode_special_tokens(true); + + let result = vocab.extract_and_normalize( + Some(&normalizer), + "Hi there\t\t\u{2000} ", + ); + + assert_eq!( + simplify_output(&result), + vec![ + ("hi ", Some(vec![1])), + (" there\t", Some(vec![1])), + ("\t", Some(vec![1])), + ("\u{2000} ", Some(vec![1])), + ("", None) + ] + ); + + vocab.set_encode_special_tokens(false); + + let result = vocab.extract_and_normalize( + Some(&normalizer), + "Hi there\t\t\u{2000} ", + ); + assert_eq!( + simplify_output(&result), + vec![ + ("hi", None), + (" ", Some(vec![0])), + ("there", None), + ("\t\t", Some(vec![0])), + ("\u{2000} ", Some(vec![0])), + ("", Some(vec![2])), + (" ", Some(vec![0])), + ("", Some(vec![2])), + ("", Some(vec![2])) + ] + ); + } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ce9c4852a..b02a33e79 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -684,13 +684,13 @@ where pub fn id_to_token(&self, id: u32) -> Option { self.added_vocabulary.id_to_token(id, &self.model) } - + /// set the added bocab's splitting scheme - pub fn set_encode_special_tokens(&mut self, value:bool){ + pub fn set_encode_special_tokens(&mut self, value: bool) { self.added_vocabulary.set_encode_special_tokens(value); } - /// Get added token value + /// Get added token value pub fn get_encode_special_tokens(&mut self) -> &bool { self.added_vocabulary.get_encode_special_tokens() }