Skip to content

Commit

Permalink
Encode special tokens (#1437)
Browse files Browse the repository at this point in the history
* add doc in the code

* add option to skip special tokens

* nits

* add api dummy for now

* Fmt.

* Fix fmt.

* Fix the stub.

* add a test

* add a test in python

* style it

* nits

* add getter and setters

* stub

* update python test

* fmt

* last nit

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
ArthurZucker and Narsil authored Jan 19, 2024
1 parent 888dd4b commit 6a77d48
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 0 deletions.
12 changes: 12 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,18 @@ class Tokenizer:
Returns:
A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch
"""
pass
@property
def encode_special_tokens(self):
"""
Modifies the tokenizer in order to use or not the special tokens
during encoding.
Args:
value (:obj:`bool`):
Whether to use the special tokens or not
"""
pass
@staticmethod
Expand Down
19 changes: 19 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,25 @@ impl PyTokenizer {
self.tokenizer.id_to_token(id)
}

/// Modifies the tokenizer in order to use or not the special tokens
/// during encoding.
///
/// Args:
/// value (:obj:`bool`):
/// Whether to use the special tokens or not
///
#[setter]
fn set_encode_special_tokens(&mut self, value: bool) {
self.tokenizer.set_encode_special_tokens(value);
}
/// Get the value of the `encode_special_tokens` attribute
///
/// Returns:
/// :obj:`bool`: the tokenizer's encode_special_tokens attribute
#[getter]
fn get_encode_special_tokens(&self) -> bool {
self.tokenizer.get_encode_special_tokens()
}
/// Add the given tokens to the vocabulary
///
/// The given tokens are added only if they don't already exist in the vocabulary.
Expand Down
31 changes: 31 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,34 @@ def test_unigram_byte_fallback(self):
output = tokenizer.encode("A sentence 🤗")
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]

def test_encode_special_tokens(self):

This comment has been minimized.

Copy link
@ArthurZucker

ArthurZucker Sep 27, 2024

Author Collaborator

this test does not work anymore

tokenizer = Tokenizer.from_pretrained("t5-base")
tokenizer.add_tokens(["<eot>"])
tokenizer.add_special_tokens(["<end_of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<end_of_text>", "▁dear", "<eot>", "▁friend", "!"]

tokenizer.encode_special_tokens = True
assert tokenizer.encode_special_tokens == True

output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == [
"▁Hey",
"▁there",
"<",
"end",
"_",
"of",
"_",
"text",
">",
"▁dear",
"<eot>",
"▁friend",
"!",
]

tokenizer.add_tokens(["of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<", "end", "_", "of_text>", "▁dear", "<eot>", "▁friend", "!"]
101 changes: 101 additions & 0 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ pub(super) struct AddedVocabulary {
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie: MatchingSet,

/// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
encode_special_tokens: bool,
}

impl AddedVocabulary {
Expand All @@ -180,6 +183,7 @@ impl AddedVocabulary {
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
encode_special_tokens: false,
}
}
/// Size of the additional vocabulary
Expand Down Expand Up @@ -214,6 +218,15 @@ impl AddedVocabulary {
.or_else(|| model.id_to_token(id))
}

//
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.encode_special_tokens = value;
}

pub fn get_encode_special_tokens(&self) -> bool {
self.encode_special_tokens
}

/// Check if a token is a special token
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens_set.contains(token)
Expand Down Expand Up @@ -356,6 +369,12 @@ impl AddedVocabulary {
let aho_id = mat.pattern();
let id = split_re.1[aho_id];
let added_token = &self.added_tokens_map_r.get(&id).unwrap();

if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
{
continue;
}

if added_token.single_word {
let start_space = start == 0 || !ends_with_word(&sentence[..start]);
let stop_space = stop == sentence.len() || !starts_with_word(&sentence[stop..]);
Expand Down Expand Up @@ -436,6 +455,18 @@ impl AddedVocabulary {
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
.expect("AddedVocabulary bad split");

// <s> normalized = False
// "I read a book <s>Hey" -> "I read a book", " <s>", "Hey"

// </s> normalized = True -> "▁</s>"
// "I read a book</s>Hey" -> "I read a book</s>Hey"

// Day normalized = True -> "Day"
// "I read a book monday" -> "I read a book monday"

// [DAY] normalized = False -> "Day"
// "I read a [DAY] monday" -> "I read a " "[DAY]", "book monday"
// 320055
// 2. Then extract the normalized tokens from the normalized pieces of the string
pretokenized
.split(|_, mut sequence| {
Expand All @@ -444,6 +475,14 @@ impl AddedVocabulary {
})
.expect("AddedVocabulary bad split");

// ["I read a book", " <s>", "Hey"] -> ["▁I read a book", "▁ <s>", "▁Hey"]
// ["▁I read a book", "▁ <s>", "▁Hey"] -> [.., "▁ ", "<s>", "▁Hey"]

// </s> normalized = True -> "▁</s>"
// "I read a book</s>Hey" -> ["▁I read a book", "<","/","s",">", "Hey"]

// "I read a " "[DAY]", "book monday" -> "i read a " "[day]", "book monday"

pretokenized
}
}
Expand Down Expand Up @@ -880,4 +919,66 @@ mod tests {
]
);
}

#[test]
fn test_encode_special_tokens() {
let model = ModelMock::new(&[]);
let mut vocab = AddedVocabulary::new();
let normalizer = Lowercase;

vocab.add_tokens(
&[
AddedToken::from("<mask>", true)
.lstrip(true)
.rstrip(true)
.single_word(true),
AddedToken::from("ask>", false),
AddedToken::from("<pad>", true),
],
&model,
Some(&normalizer),
);
vocab.set_encode_special_tokens(true);

let result = vocab.extract_and_normalize(
Some(&normalizer),
"Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
);

assert_eq!(
simplify_output(&result),
vec![
("hi <m", None),
("ask>", Some(vec![1])),
(" there\t<m", None),
("ask>", Some(vec![1])),
("\t<m", None),
("ask>", Some(vec![1])),
("\u{2000} <pad> <m", None),
("ask>", Some(vec![1])),
("<pad><pad>", None)
]
);

vocab.set_encode_special_tokens(false);

let result = vocab.extract_and_normalize(
Some(&normalizer),
"Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
);
assert_eq!(
simplify_output(&result),
vec![
("hi", None),
(" <mask> ", Some(vec![0])),
("there", None),
("\t<mask>\t", Some(vec![0])),
("<mask>\u{2000} ", Some(vec![0])),
("<pad>", Some(vec![2])),
(" <mask>", Some(vec![0])),
("<pad>", Some(vec![2])),
("<pad>", Some(vec![2]))
]
);
}
}
10 changes: 10 additions & 0 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,16 @@ where
self.added_vocabulary.id_to_token(id, &self.model)
}

/// set the added bocab's splitting scheme
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.added_vocabulary.set_encode_special_tokens(value);
}

/// Get added token value
pub fn get_encode_special_tokens(&self) -> bool {
self.added_vocabulary.get_encode_special_tokens()
}

/// Encode a single sequence
fn encode_single_sequence(
&self,
Expand Down

0 comments on commit 6a77d48

Please sign in to comment.