Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to skip special tokens #1419

Closed
wants to merge 16 commits into from
12 changes: 12 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,18 @@ class Tokenizer:
Returns:
A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch
"""
pass
@property
def encode_special_tokens(self):
"""
Modifies the tokenizer in order to use or not the special tokens
during encoding.
Args:
value (:obj:`bool`):
Whether to use the special tokens or not
"""
pass
@staticmethod
Expand Down
19 changes: 19 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,25 @@ impl PyTokenizer {
self.tokenizer.id_to_token(id)
}

/// Modifies the tokenizer in order to use or not the special tokens
/// during encoding.
///
/// Args:
/// value (:obj:`bool`):
/// Whether to use the special tokens or not
///
#[setter]
fn set_encode_special_tokens(&mut self, value: bool) {
self.tokenizer.set_encode_special_tokens(value);
}
/// Get the value of the `encode_special_tokens` attribute
///
/// Returns:
/// :obj:`bool`: the tokenizer's encode_special_tokens attribute
#[getter]
fn get_encode_special_tokens(&self) -> bool {
self.tokenizer.get_encode_special_tokens()
}
/// Add the given tokens to the vocabulary
///
/// The given tokens are added only if they don't already exist in the vocabulary.
Expand Down
31 changes: 31 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,34 @@ def test_unigram_byte_fallback(self):
output = tokenizer.encode("A sentence 🤗")
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]

def test_encode_special_tokens(self):
tokenizer = Tokenizer.from_pretrained("t5-base")
tokenizer.add_tokens(["<eot>"])
tokenizer.add_special_tokens(["<end_of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<end_of_text>", "▁dear", "<eot>", "▁friend", "!"]

tokenizer.encode_special_tokens = True
assert tokenizer.encode_special_tokens == True

output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == [
"▁Hey",
"▁there",
"<",
"end",
"_",
"of",
"_",
"text",
">",
"▁dear",
"<eot>",
"▁friend",
"!",
]

tokenizer.add_tokens(["of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<", "end", "_", "of_text>", "▁dear", "<eot>", "▁friend", "!"]
81 changes: 81 additions & 0 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ pub(super) struct AddedVocabulary {
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie: MatchingSet,

/// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
encode_special_tokens: bool,
}

impl AddedVocabulary {
Expand All @@ -180,6 +183,7 @@ impl AddedVocabulary {
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
encode_special_tokens: false,
}
}
/// Size of the additional vocabulary
Expand Down Expand Up @@ -214,6 +218,15 @@ impl AddedVocabulary {
.or_else(|| model.id_to_token(id))
}

//
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.encode_special_tokens = value;
}

pub fn get_encode_special_tokens(&self) -> bool {
self.encode_special_tokens
}

/// Check if a token is a special token
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens_set.contains(token)
Expand Down Expand Up @@ -356,6 +369,12 @@ impl AddedVocabulary {
let aho_id = mat.pattern();
let id = split_re.1[aho_id];
let added_token = &self.added_tokens_map_r.get(&id).unwrap();

if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
{
continue;
}

if added_token.single_word {
let start_space = start == 0 || !ends_with_word(&sentence[..start]);
let stop_space = stop == sentence.len() || !starts_with_word(&sentence[stop..]);
Expand Down Expand Up @@ -900,4 +919,66 @@ mod tests {
]
);
}

#[test]
fn test_encode_special_tokens() {
let model = ModelMock::new(&[]);
let mut vocab = AddedVocabulary::new();
let normalizer = Lowercase;

vocab.add_tokens(
&[
AddedToken::from("<mask>", true)
.lstrip(true)
.rstrip(true)
.single_word(true),
AddedToken::from("ask>", false),
AddedToken::from("<pad>", true),
],
&model,
Some(&normalizer),
);
vocab.set_encode_special_tokens(true);

let result = vocab.extract_and_normalize(
Some(&normalizer),
"Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
);

assert_eq!(
simplify_output(&result),
vec![
("hi <m", None),
("ask>", Some(vec![1])),
(" there\t<m", None),
("ask>", Some(vec![1])),
("\t<m", None),
("ask>", Some(vec![1])),
("\u{2000} <pad> <m", None),
("ask>", Some(vec![1])),
("<pad><pad>", None)
]
);

vocab.set_encode_special_tokens(false);

let result = vocab.extract_and_normalize(
Some(&normalizer),
"Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
);
assert_eq!(
simplify_output(&result),
vec![
("hi", None),
(" <mask> ", Some(vec![0])),
("there", None),
("\t<mask>\t", Some(vec![0])),
("<mask>\u{2000} ", Some(vec![0])),
("<pad>", Some(vec![2])),
(" <mask>", Some(vec![0])),
("<pad>", Some(vec![2])),
("<pad>", Some(vec![2]))
]
);
}
}
10 changes: 10 additions & 0 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,16 @@ where
self.added_vocabulary.id_to_token(id, &self.model)
}

/// set the added bocab's splitting scheme
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.added_vocabulary.set_encode_special_tokens(value);
}

/// Get added token value
pub fn get_encode_special_tokens(&self) -> bool {
self.added_vocabulary.get_encode_special_tokens()
}

/// Encode a single sequence
fn encode_single_sequence(
&self,
Expand Down
Loading