Skip to content

Commit

Permalink
Support None to reset pre_tokenizers and normalizers, and index seq…
Browse files Browse the repository at this point in the history
…uences (#1590)

* initial commit

* support None

* fix clippy

* cleanup

* clean?

* propagate to pre_tokenizer

* fix test

* fix rust tests

* fix node

* propagate to decoder and post processor

* fix calls

* lint

* fmt

* node be happy I am fixing you

* initial commit

* support None

* fix clippy

* cleanup

* clean?

* propagate to pre_tokenizer

* fix test

* fix rust tests

* fix node

* propagate to decoder and post processor

* fix calls

* lint

* fmt

* node be happy I am fixing you

* add a small test

* styling

* style merge

* fix merge test

* fmt

* nits

* update tset
  • Loading branch information
ArthurZucker authored Aug 7, 2024
1 parent eea8e1a commit bded212
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 67 deletions.
8 changes: 4 additions & 4 deletions bindings/node/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_pre_tokenizer((*pre_tokenizer).clone());
.with_pre_tokenizer(Some((*pre_tokenizer).clone()));
}

#[napi]
Expand All @@ -217,7 +217,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_decoder((*decoder).clone());
.with_decoder(Some((*decoder).clone()));
}

#[napi]
Expand All @@ -231,7 +231,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_post_processor((*post_processor).clone());
.with_post_processor(Some((*post_processor).clone()));
}

#[napi]
Expand All @@ -240,7 +240,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_normalizer((*normalizer).clone());
.with_normalizer(Some((*normalizer).clone()));
}

#[napi]
Expand Down
23 changes: 19 additions & 4 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::sync::{Arc, RwLock};

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{exceptions, prelude::*};
use std::sync::{Arc, RwLock};

use crate::error::ToPyResult;
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
Expand Down Expand Up @@ -354,6 +352,7 @@ impl PyNFKC {
/// A list of Normalizer to be run as a sequence
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
pub struct PySequence {}

#[pymethods]
impl PySequence {
#[new]
Expand All @@ -380,6 +379,22 @@ impl PySequence {
fn __len__(&self) -> usize {
0
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().normalizer {
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyNormalizerTypeWrapper::Single(inner) => {
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
}
}
}
}

/// Lowercase Normalizer
Expand Down
18 changes: 18 additions & 0 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,24 @@ impl PySequence {
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().pretok {
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py)
}
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyPreTokenizerTypeWrapper::Single(inner) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
}
}
}
}

pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
Expand Down
39 changes: 23 additions & 16 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1371,8 +1371,9 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.normalizers.Normalizer`
#[setter]
fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) {
self.tokenizer.with_normalizer(normalizer.clone());
fn set_normalizer(&mut self, normalizer: Option<PyRef<PyNormalizer>>) {
let normalizer_option = normalizer.map(|norm| norm.clone());
self.tokenizer.with_normalizer(normalizer_option);
}

/// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer
Expand All @@ -1387,8 +1388,9 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.normalizers.Normalizer`
#[setter]
fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) {
self.tokenizer.with_pre_tokenizer(pretok.clone());
fn set_pre_tokenizer(&mut self, pretok: Option<PyRef<PyPreTokenizer>>) {
self.tokenizer
.with_pre_tokenizer(pretok.map(|pre| pre.clone()));
}

/// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer
Expand All @@ -1403,8 +1405,9 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.processors.PostProcessor`
#[setter]
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
self.tokenizer.with_post_processor(processor.clone());
fn set_post_processor(&mut self, processor: Option<PyRef<PyPostProcessor>>) {
self.tokenizer
.with_post_processor(processor.map(|p| p.clone()));
}

/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
Expand All @@ -1419,8 +1422,8 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.decoders.Decoder`
#[setter]
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
self.tokenizer.with_decoder(decoder.clone());
fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
}
}

Expand All @@ -1436,10 +1439,12 @@ mod test {
#[test]
fn serialize() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
],
))));

let tmp = NamedTempFile::new().unwrap().into_temp_path();
tokenizer.save(&tmp, false).unwrap();
Expand All @@ -1450,10 +1455,12 @@ mod test {
#[test]
fn serde_pyo3() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
],
))));

let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");
Expand Down
8 changes: 8 additions & 0 deletions bindings/python/tests/bindings/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def test_can_make_sequences(self):
output = normalizer.normalize_str(" HELLO ")
assert output == "hello"

def test_items(self):
normalizers = Sequence([BertNormalizer(True, True), Prepend()])
assert normalizers[1].__class__ == Prepend
normalizers[0].lowercase = False
assert not normalizers[0].lowercase
with pytest.raises(IndexError):
print(normalizers[2])


class TestLowercase:
def test_instantiate(self):
Expand Down
7 changes: 7 additions & 0 deletions bindings/python/tests/bindings/test_pre_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def test_bert_like(self):
("?", (29, 30)),
]

def test_items(self):
pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()])
assert pre_tokenizers[1].__class__ == Punctuation
assert pre_tokenizers[0].__class__ == Metaspace
pre_tokenizers[0].split = False
assert not pre_tokenizers[0].split


class TestDigits:
def test_instantiate(self):
Expand Down
13 changes: 12 additions & 1 deletion bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from tokenizers import AddedToken, Encoding, Tokenizer
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.models import BPE, Model, Unigram
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
from tokenizers.processors import RobertaProcessing, TemplateProcessing
from tokenizers.normalizers import Strip, Lowercase, Sequence


from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files


Expand Down Expand Up @@ -551,6 +552,16 @@ def test_decode_special(self):
assert output == "name is john"
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)

def test_setting_to_none(self):
tokenizer = Tokenizer(BPE())
tokenizer.normalizer = Strip()
tokenizer.normalizer = None
assert tokenizer.normalizer == None

tokenizer.pre_tokenizer = Metaspace()
tokenizer.pre_tokenizer = None
assert tokenizer.pre_tokenizer == None


class TestTokenizerRepr:
def test_repr(self):
Expand Down
14 changes: 7 additions & 7 deletions tokenizers/benches/bert_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer {
let sep_id = *wp.get_vocab().get("[SEP]").unwrap();
let cls_id = *wp.get_vocab().get("[CLS]").unwrap();
let mut tokenizer = TokenizerImpl::new(wp);
tokenizer.with_pre_tokenizer(BertPreTokenizer);
tokenizer.with_normalizer(BertNormalizer::default());
tokenizer.with_decoder(decoders::wordpiece::WordPiece::default());
tokenizer.with_post_processor(BertProcessing::new(
tokenizer.with_pre_tokenizer(Some(BertPreTokenizer));
tokenizer.with_normalizer(Some(BertNormalizer::default()));
tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default()));
tokenizer.with_post_processor(Some(BertProcessing::new(
("[SEP]".to_string(), sep_id),
("[CLS]".to_string(), cls_id),
));
)));
tokenizer
}

Expand Down Expand Up @@ -81,7 +81,7 @@ fn bench_train(c: &mut Criterion) {
DecoderWrapper,
>;
let mut tokenizer = Tok::new(WordPiece::default());
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("WordPiece Train vocabulary (small)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand All @@ -94,7 +94,7 @@ fn bench_train(c: &mut Criterion) {
});

let mut tokenizer = Tok::new(WordPiece::default());
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("WordPiece Train vocabulary (big)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand Down
8 changes: 4 additions & 4 deletions tokenizers/benches/bpe_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ static BATCH_SIZE: usize = 1_000;

fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(ByteLevel::default());
tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
tokenizer.with_decoder(Some(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
tokenizer
Expand Down Expand Up @@ -74,7 +74,7 @@ fn bench_train(c: &mut Criterion) {
.build()
.into();
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand All @@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) {
});

let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("BPE Train vocabulary (big)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand Down
17 changes: 8 additions & 9 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,19 +550,18 @@ where
}

/// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: impl Into<N>) -> &mut Self {
self.normalizer = Some(normalizer.into());
pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self {
self.normalizer = normalizer.map(|norm| norm.into());
self
}

/// Get the normalizer
pub fn get_normalizer(&self) -> Option<&N> {
self.normalizer.as_ref()
}

/// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into<PT>) -> &mut Self {
self.pre_tokenizer = Some(pre_tokenizer.into());
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self {
self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into());
self
}

Expand All @@ -572,8 +571,8 @@ where
}

/// Set the post processor
pub fn with_post_processor(&mut self, post_processor: impl Into<PP>) -> &mut Self {
self.post_processor = Some(post_processor.into());
pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self {
self.post_processor = post_processor.map(|post_proc| post_proc.into());
self
}

Expand All @@ -583,8 +582,8 @@ where
}

/// Set the decoder
pub fn with_decoder(&mut self, decoder: impl Into<D>) -> &mut Self {
self.decoder = Some(decoder.into());
pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self {
self.decoder = decoder.map(|dec| dec.into());
self
}

Expand Down
18 changes: 10 additions & 8 deletions tokenizers/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ pub fn get_byte_level_bpe() -> BPE {
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer
.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space))
.with_decoder(ByteLevel::default())
.with_post_processor(ByteLevel::default().trim_offsets(trim_offsets));
.with_pre_tokenizer(Some(
ByteLevel::default().add_prefix_space(add_prefix_space),
))
.with_decoder(Some(ByteLevel::default()))
.with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets)));

tokenizer
}
Expand All @@ -43,13 +45,13 @@ pub fn get_bert() -> Tokenizer {
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
tokenizer
.with_normalizer(BertNormalizer::default())
.with_pre_tokenizer(BertPreTokenizer)
.with_decoder(WordPieceDecoder::default())
.with_post_processor(BertProcessing::new(
.with_normalizer(Some(BertNormalizer::default()))
.with_pre_tokenizer(Some(BertPreTokenizer))
.with_decoder(Some(WordPieceDecoder::default()))
.with_post_processor(Some(BertProcessing::new(
(String::from("[SEP]"), sep),
(String::from("[CLS]"), cls),
));
)));

tokenizer
}
Loading

0 comments on commit bded212

Please sign in to comment.