Skip to content

Commit

Permalink
Add more support for tiktoken based tokenizers (#1493)
Browse files Browse the repository at this point in the history
* first commit

* update

* clippy

* lint

* clippy and lint

* fmt

* revert print

* 😈

* style

* add a test

* more fmt

* Use ignore_merges

* stub

* fix

* update

* Update tokenizers/src/models/bpe/model.rs

Co-authored-by: Nicolas Patry <[email protected]>

* update

* rust lint

* dob; t repeat yourself

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
ArthurZucker and Narsil authored Apr 15, 2024
1 parent 6e58f83 commit 914576f
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 8 deletions.
4 changes: 4 additions & 0 deletions bindings/python/py_src/tokenizers/models/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class BPE(Model):
byte_fallback (:obj:`bool`, `optional`):
Whether to use spm byte-fallback trick (defaults to False)
ignore_merges (:obj:`bool`, `optional`):
Whether or not to match tokens with the vocab before using merges.
"""
def __init__(
self,
Expand All @@ -124,6 +127,7 @@ class BPE(Model):
end_of_word_suffix=None,
fuse_unk=None,
byte_fallback=False,
ignore_merges=False,
):
pass

Expand Down
14 changes: 13 additions & 1 deletion bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ impl PyModel {
///
/// byte_fallback (:obj:`bool`, `optional`):
/// Whether to use spm byte-fallback trick (defaults to False)
///
/// ignore_merges (:obj:`bool`, `optional`):
/// Whether or not to match tokens with the vocab before using merges.
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
pub struct PyBPE {}

Expand All @@ -279,6 +282,7 @@ impl PyBPE {
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
"ignore_merges" => builder = builder.ignore_merges(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
};
}
Expand Down Expand Up @@ -396,11 +400,19 @@ impl PyBPE {
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
setter!(self_, BPE, byte_fallback, byte_fallback);
}
#[getter]
fn get_ignore_merges(self_: PyRef<Self>) -> bool {
getter!(self_, BPE, ignore_merges)
}

#[setter]
fn set_ignore_merges(self_: PyRef<Self>, ignore_merges: bool) {
setter!(self_, BPE, ignore_merges, ignore_merges);
}
#[new]
#[pyo3(
signature = (vocab=None, merges=None, **kwargs),
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)")]
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False, ignore_merges=False)")]
fn new(
py: Python<'_>,
vocab: Option<PyVocab>,
Expand Down
124 changes: 117 additions & 7 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct Config {
end_of_word_suffix: Option<String>,
fuse_unk: bool,
byte_fallback: bool,
ignore_merges: bool,
}

/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
Expand All @@ -49,6 +50,7 @@ impl Default for BpeBuilder {
end_of_word_suffix: None,
fuse_unk: false,
byte_fallback: false,
ignore_merges: false,
},
}
}
Expand Down Expand Up @@ -123,6 +125,12 @@ impl BpeBuilder {
self.config.byte_fallback = byte_fallback;
self
}
/// Set the `ignore_merges` option.
#[must_use]
pub fn ignore_merges(mut self, ignore_merges: bool) -> Self {
self.config.ignore_merges = ignore_merges;
self
}

/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(mut self) -> Result<BPE> {
Expand Down Expand Up @@ -190,6 +198,7 @@ impl BpeBuilder {
end_of_word_suffix: self.config.end_of_word_suffix,
fuse_unk: self.config.fuse_unk,
byte_fallback: self.config.byte_fallback,
ignore_merges: self.config.ignore_merges,
})
}
}
Expand Down Expand Up @@ -219,6 +228,8 @@ pub struct BPE {
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
/// for each byte in the unk token
pub byte_fallback: bool,
/// Whether or not to direct output words if they are part of the vocab.
pub ignore_merges: bool,
}

impl std::fmt::Debug for BPE {
Expand All @@ -232,6 +243,7 @@ impl std::fmt::Debug for BPE {
.field("byte_fallback", &self.byte_fallback)
.field("vocab", &self.vocab.len())
.field("merges", &self.merges.len())
.field("ignore_merges", &self.ignore_merges)
.finish()
}
}
Expand All @@ -258,6 +270,7 @@ impl Clone for BPE {
end_of_word_suffix: self.end_of_word_suffix.clone(),
fuse_unk: self.fuse_unk,
byte_fallback: self.byte_fallback,
ignore_merges: self.ignore_merges,
}
}
}
Expand Down Expand Up @@ -448,15 +461,19 @@ impl BPE {

fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
Ok(self.word_to_tokens(hit).collect())
} else {
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
cache.set(sequence.to_owned(), word);
return Ok(self.word_to_tokens(hit).collect());
}
if self.ignore_merges {
if let Some(id) = self.vocab.get(sequence) {
return Ok(vec![Token::new(*id, sequence.to_string().clone(), (0, 0))]);
}
Ok(ret)
}
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
cache.set(sequence.to_owned(), word);
}
Ok(ret)
}
}

Expand Down Expand Up @@ -862,4 +879,97 @@ mod tests {
let tokens = bpe.tokenize("\n").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
}

#[test]
fn test_ignore_merges() {
// 0x0A == '\n' in bytes
let vocab: Vocab = [
(".:.:".into(), 0),
("Ġbelirtilen".into(), 1),
(".".into(), 2),
(":".into(), 3),
("bel".into(), 4),
("irtilen".into(), 5),
("Ġ".into(), 6),
(".:".into(), 7),
("belirtilen".into(), 8),
(".:.".into(), 9),
("be".into(), 10),
("l".into(), 11),
("ir".into(), 12),
("ti".into(), 13),
("en".into(), 14),
("irtil".into(), 15),
("irti".into(), 16),
("i".into(), 17),
("r".into(), 18),
("t".into(), 19),
("b".into(), 20),
("e".into(), 21),
("n".into(), 22),
]
.iter()
.cloned()
.collect();
let mut bpe = BpeBuilder::default()
.vocab_and_merges(
vocab,
vec![
(".".into(), ":".into()),
("b".into(), "e".into()),
("be".into(), "l".into()),
("i".into(), "r".into()),
("t".into(), "i".into()),
("ir".into(), "ti".into()),
("e".into(), "n".into()),
("irti".into(), "l".into()),
],
)
.ignore_merges(true)
.build()
.unwrap();
let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 0))]);

let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 0))]);

bpe.ignore_merges = false;

let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(
tokens,
vec![
Token::new(7u32, ".:".into(), (0, 2)),
Token::new(7u32, ".:".into(), (2, 4))
]
);

let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(
tokens,
vec![
Token {
id: 6,
value: "Ġ".into(),
offsets: (0, 2)
},
Token {
id: 4,
value: "bel".into(),
offsets: (2, 5)
},
Token {
id: 15,
value: "irtil".into(),
offsets: (5, 10)
},
Token {
id: 14,
value: "en".into(),
offsets: (10, 12)
}
]
)
}
}

0 comments on commit 914576f

Please sign in to comment.