diff --git a/bindings/python/py_src/tokenizers/models/__init__.pyi b/bindings/python/py_src/tokenizers/models/__init__.pyi index b46f32f25..955b9a163 100644 --- a/bindings/python/py_src/tokenizers/models/__init__.pyi +++ b/bindings/python/py_src/tokenizers/models/__init__.pyi @@ -112,6 +112,9 @@ class BPE(Model): byte_fallback (:obj:`bool`, `optional`): Whether to use spm byte-fallback trick (defaults to False) + + ignore_merges (:obj:`bool`, `optional`): + Whether or not to match tokens with the vocab before using merges. """ def __init__( self, @@ -124,6 +127,7 @@ class BPE(Model): end_of_word_suffix=None, fuse_unk=None, byte_fallback=False, + ignore_merges=False, ): pass diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 0c7bafe42..8fce02c94 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -253,6 +253,9 @@ impl PyModel { /// /// byte_fallback (:obj:`bool`, `optional`): /// Whether to use spm byte-fallback trick (defaults to False) +/// +/// ignore_merges (:obj:`bool`, `optional`): +/// Whether or not to match tokens with the vocab before using merges. #[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")] pub struct PyBPE {} @@ -279,6 +282,7 @@ impl PyBPE { "end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?), "fuse_unk" => builder = builder.fuse_unk(value.extract()?), "byte_fallback" => builder = builder.byte_fallback(value.extract()?), + "ignore_merges" => builder = builder.ignore_merges(value.extract()?), _ => println!("Ignored unknown kwarg option {}", key), }; } @@ -396,11 +400,19 @@ impl PyBPE { fn set_byte_fallback(self_: PyRef, byte_fallback: bool) { setter!(self_, BPE, byte_fallback, byte_fallback); } + #[getter] + fn get_ignore_merges(self_: PyRef) -> bool { + getter!(self_, BPE, ignore_merges) + } + #[setter] + fn set_ignore_merges(self_: PyRef, ignore_merges: bool) { + setter!(self_, BPE, ignore_merges, ignore_merges); + } #[new] #[pyo3( signature = (vocab=None, merges=None, **kwargs), - text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)")] + text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False, ignore_merges=False)")] fn new( py: Python<'_>, vocab: Option, diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 2fc9915ca..618f42b47 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -28,6 +28,7 @@ struct Config { end_of_word_suffix: Option, fuse_unk: bool, byte_fallback: bool, + ignore_merges: bool, } /// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration. @@ -49,6 +50,7 @@ impl Default for BpeBuilder { end_of_word_suffix: None, fuse_unk: false, byte_fallback: false, + ignore_merges: false, }, } } @@ -123,6 +125,12 @@ impl BpeBuilder { self.config.byte_fallback = byte_fallback; self } + /// Set the `ignore_merges` option. + #[must_use] + pub fn ignore_merges(mut self, ignore_merges: bool) -> Self { + self.config.ignore_merges = ignore_merges; + self + } /// Returns a `BPE` model that uses the `BpeBuilder`'s configuration. pub fn build(mut self) -> Result { @@ -190,6 +198,7 @@ impl BpeBuilder { end_of_word_suffix: self.config.end_of_word_suffix, fuse_unk: self.config.fuse_unk, byte_fallback: self.config.byte_fallback, + ignore_merges: self.config.ignore_merges, }) } } @@ -219,6 +228,8 @@ pub struct BPE { /// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"` /// for each byte in the unk token pub byte_fallback: bool, + /// Whether or not to direct output words if they are part of the vocab. + pub ignore_merges: bool, } impl std::fmt::Debug for BPE { @@ -232,6 +243,7 @@ impl std::fmt::Debug for BPE { .field("byte_fallback", &self.byte_fallback) .field("vocab", &self.vocab.len()) .field("merges", &self.merges.len()) + .field("ignore_merges", &self.ignore_merges) .finish() } } @@ -258,6 +270,7 @@ impl Clone for BPE { end_of_word_suffix: self.end_of_word_suffix.clone(), fuse_unk: self.fuse_unk, byte_fallback: self.byte_fallback, + ignore_merges: self.ignore_merges, } } } @@ -448,15 +461,19 @@ impl BPE { fn tokenize_with_cache(&self, sequence: &str) -> Result> { if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) { - Ok(self.word_to_tokens(hit).collect()) - } else { - let word = self.merge_word(sequence)?; - let ret = self.word_to_tokens(&word).collect(); - if let Some(ref cache) = self.cache { - cache.set(sequence.to_owned(), word); + return Ok(self.word_to_tokens(hit).collect()); + } + if self.ignore_merges { + if let Some(id) = self.vocab.get(sequence) { + return Ok(vec![Token::new(*id, sequence.to_string().clone(), (0, 0))]); } - Ok(ret) } + let word = self.merge_word(sequence)?; + let ret = self.word_to_tokens(&word).collect(); + if let Some(ref cache) = self.cache { + cache.set(sequence.to_owned(), word); + } + Ok(ret) } } @@ -862,4 +879,97 @@ mod tests { let tokens = bpe.tokenize("\n").unwrap(); assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]); } + + #[test] + fn test_ignore_merges() { + // 0x0A == '\n' in bytes + let vocab: Vocab = [ + (".:.:".into(), 0), + ("Ġbelirtilen".into(), 1), + (".".into(), 2), + (":".into(), 3), + ("bel".into(), 4), + ("irtilen".into(), 5), + ("Ġ".into(), 6), + (".:".into(), 7), + ("belirtilen".into(), 8), + (".:.".into(), 9), + ("be".into(), 10), + ("l".into(), 11), + ("ir".into(), 12), + ("ti".into(), 13), + ("en".into(), 14), + ("irtil".into(), 15), + ("irti".into(), 16), + ("i".into(), 17), + ("r".into(), 18), + ("t".into(), 19), + ("b".into(), 20), + ("e".into(), 21), + ("n".into(), 22), + ] + .iter() + .cloned() + .collect(); + let mut bpe = BpeBuilder::default() + .vocab_and_merges( + vocab, + vec![ + (".".into(), ":".into()), + ("b".into(), "e".into()), + ("be".into(), "l".into()), + ("i".into(), "r".into()), + ("t".into(), "i".into()), + ("ir".into(), "ti".into()), + ("e".into(), "n".into()), + ("irti".into(), "l".into()), + ], + ) + .ignore_merges(true) + .build() + .unwrap(); + let tokens = bpe.tokenize(".:.:").unwrap(); + assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 0))]); + + let tokens = bpe.tokenize("Ġbelirtilen").unwrap(); + assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 0))]); + + bpe.ignore_merges = false; + + let tokens = bpe.tokenize(".:.:").unwrap(); + assert_eq!( + tokens, + vec![ + Token::new(7u32, ".:".into(), (0, 2)), + Token::new(7u32, ".:".into(), (2, 4)) + ] + ); + + let tokens = bpe.tokenize("Ġbelirtilen").unwrap(); + assert_eq!( + tokens, + vec![ + Token { + id: 6, + value: "Ġ".into(), + offsets: (0, 2) + }, + Token { + id: 4, + value: "bel".into(), + offsets: (2, 5) + }, + Token { + id: 15, + value: "irtil".into(), + offsets: (5, 10) + }, + Token { + id: 14, + value: "en".into(), + offsets: (10, 12) + } + ] + ) + } }