From 57f8241dc07bafa938f63e224482e20154e0cf5b Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 7 Jun 2024 22:30:26 +0900 Subject: [PATCH 1/5] enable dropout = 0.0 --- bindings/python/tests/bindings/test_models.py | 3 +++ tokenizers/src/models/bpe/mod.rs | 2 +- tokenizers/src/models/bpe/model.rs | 18 +++++++++++++++--- tokenizers/tests/serialization.rs | 15 +++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/bindings/python/tests/bindings/test_models.py b/bindings/python/tests/bindings/test_models.py index c6a50ce86..7b6746186 100644 --- a/bindings/python/tests/bindings/test_models.py +++ b/bindings/python/tests/bindings/test_models.py @@ -69,6 +69,9 @@ def test_can_modify(self): model.byte_fallback = True assert model.byte_fallback == True + def test_dropout_zero(self): + model = BPE(dropout=0.0) + assert model.dropout == 0.0 class TestWordPiece: def test_instantiate(self, bert_files): diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 51e7b6ca0..f0d40b2df 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -31,7 +31,7 @@ pub enum Error { #[error("Unk token `{0}` not found in the vocabulary")] UnkTokenOutOfVocabulary(String), /// Dropout not between 0 and 1. - #[error("Dropout should be between 0 and 1")] + #[error("Dropout should be between 0 and 1, inclusive")] InvalidDropout, } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 618f42b47..78ce25684 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -136,7 +136,7 @@ impl BpeBuilder { pub fn build(mut self) -> Result { // Validate dropout. if let Some(p) = self.config.dropout { - if p <= 0.0 || p > 1.0 { + if p < 0.0 || p > 1.0 { return Err(Error::InvalidDropout.into()); } } @@ -214,7 +214,7 @@ pub struct BPE { pub(crate) merges: MergeMap, /// Contains the cache for optimizing the encoding step. cache: Option>, - /// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will + /// Dropout probability for merges. 0.0 = no dropout is the default. At 1.0, tokenization will /// perform no merges, so the result will just be characters. pub dropout: Option, /// The unknown token to be used when we encounter an unknown char @@ -493,7 +493,7 @@ impl Model for BPE { return Ok(vec![]); } - if self.dropout.is_none() { + if self.dropout.is_none() || self.dropout == Some(0.0) { self.tokenize_with_cache(sequence) } else { let word = self.merge_word(sequence)?; @@ -685,6 +685,11 @@ mod tests { let tokens = bpe.tokenize("unrelated").unwrap(); assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]); + // With dropout = 0.0 (equivalent to dropout == none) + bpe.dropout = Some(0.0); + let tokens = bpe.tokenize("unrelated").unwrap(); + assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]); + // Now set dropout to 1.0. Result should be no merges performed. bpe.dropout = Some(1.0); let tokens = bpe.tokenize("unrelated").unwrap(); @@ -739,6 +744,13 @@ mod tests { assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32); } + #[test] + // Ensure BPEBuilder with dropout = 0.0 doesn't error + fn test_bpe_with_dropout_0() { + let bpe = BPE::builder().dropout(0.0).build().unwrap(); + assert_eq!(bpe.dropout, Some(0.0)); + } + #[test] // Ensure `BPE::from_file` works as expected. fn test_bpe_with_continuing_subword_prefix() { diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 54fa9053d..3d4086acb 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -229,6 +229,21 @@ fn tokenizer() { assert_eq!(serde_json::to_string(&de).unwrap(), ser); } +#[test] +fn bpe_with_dropout_serde() { + let mut bpe = BPE::default(); + bpe.dropout = Some(0.1); + let ser = serde_json::to_string(&bpe).unwrap(); + let de = serde_json::from_str(&ser).unwrap(); + assert_eq!(bpe, de); + + // set dropout to 0.0 (which is analagous to None) and reserialize + bpe.dropout = Some(0.0); + let ser = serde_json::to_string(&bpe).unwrap(); + let de = serde_json::from_str(&ser).unwrap(); + assert_eq!(bpe, de); +} + #[test] fn test_deserialize_long_file() { let _tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); From 8f9dfb81041cdb31169f27c84133d78ae6d2ed33 Mon Sep 17 00:00:00 2001 From: Marco Date: Sat, 8 Jun 2024 16:19:12 +0900 Subject: [PATCH 2/5] typo --- tokenizers/tests/serialization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 3d4086acb..4d51d4281 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -237,7 +237,7 @@ fn bpe_with_dropout_serde() { let de = serde_json::from_str(&ser).unwrap(); assert_eq!(bpe, de); - // set dropout to 0.0 (which is analagous to None) and reserialize + // set dropout to 0.0 (which is analogous to None) and reserialize bpe.dropout = Some(0.0); let ser = serde_json::to_string(&bpe).unwrap(); let de = serde_json::from_str(&ser).unwrap(); From 526ef6071b69be3650ca54345de546e3050b8a6b Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 7 Jun 2024 22:30:26 +0900 Subject: [PATCH 3/5] enable dropout = 0.0 --- bindings/python/tests/bindings/test_models.py | 3 +++ tokenizers/src/models/bpe/mod.rs | 2 +- tokenizers/src/models/bpe/model.rs | 18 +++++++++++++++--- tokenizers/tests/serialization.rs | 15 +++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/bindings/python/tests/bindings/test_models.py b/bindings/python/tests/bindings/test_models.py index c6a50ce86..7b6746186 100644 --- a/bindings/python/tests/bindings/test_models.py +++ b/bindings/python/tests/bindings/test_models.py @@ -69,6 +69,9 @@ def test_can_modify(self): model.byte_fallback = True assert model.byte_fallback == True + def test_dropout_zero(self): + model = BPE(dropout=0.0) + assert model.dropout == 0.0 class TestWordPiece: def test_instantiate(self, bert_files): diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 51e7b6ca0..f0d40b2df 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -31,7 +31,7 @@ pub enum Error { #[error("Unk token `{0}` not found in the vocabulary")] UnkTokenOutOfVocabulary(String), /// Dropout not between 0 and 1. - #[error("Dropout should be between 0 and 1")] + #[error("Dropout should be between 0 and 1, inclusive")] InvalidDropout, } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 618f42b47..78ce25684 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -136,7 +136,7 @@ impl BpeBuilder { pub fn build(mut self) -> Result { // Validate dropout. if let Some(p) = self.config.dropout { - if p <= 0.0 || p > 1.0 { + if p < 0.0 || p > 1.0 { return Err(Error::InvalidDropout.into()); } } @@ -214,7 +214,7 @@ pub struct BPE { pub(crate) merges: MergeMap, /// Contains the cache for optimizing the encoding step. cache: Option>, - /// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will + /// Dropout probability for merges. 0.0 = no dropout is the default. At 1.0, tokenization will /// perform no merges, so the result will just be characters. pub dropout: Option, /// The unknown token to be used when we encounter an unknown char @@ -493,7 +493,7 @@ impl Model for BPE { return Ok(vec![]); } - if self.dropout.is_none() { + if self.dropout.is_none() || self.dropout == Some(0.0) { self.tokenize_with_cache(sequence) } else { let word = self.merge_word(sequence)?; @@ -685,6 +685,11 @@ mod tests { let tokens = bpe.tokenize("unrelated").unwrap(); assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]); + // With dropout = 0.0 (equivalent to dropout == none) + bpe.dropout = Some(0.0); + let tokens = bpe.tokenize("unrelated").unwrap(); + assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]); + // Now set dropout to 1.0. Result should be no merges performed. bpe.dropout = Some(1.0); let tokens = bpe.tokenize("unrelated").unwrap(); @@ -739,6 +744,13 @@ mod tests { assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32); } + #[test] + // Ensure BPEBuilder with dropout = 0.0 doesn't error + fn test_bpe_with_dropout_0() { + let bpe = BPE::builder().dropout(0.0).build().unwrap(); + assert_eq!(bpe.dropout, Some(0.0)); + } + #[test] // Ensure `BPE::from_file` works as expected. fn test_bpe_with_continuing_subword_prefix() { diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 54fa9053d..3d4086acb 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -229,6 +229,21 @@ fn tokenizer() { assert_eq!(serde_json::to_string(&de).unwrap(), ser); } +#[test] +fn bpe_with_dropout_serde() { + let mut bpe = BPE::default(); + bpe.dropout = Some(0.1); + let ser = serde_json::to_string(&bpe).unwrap(); + let de = serde_json::from_str(&ser).unwrap(); + assert_eq!(bpe, de); + + // set dropout to 0.0 (which is analagous to None) and reserialize + bpe.dropout = Some(0.0); + let ser = serde_json::to_string(&bpe).unwrap(); + let de = serde_json::from_str(&ser).unwrap(); + assert_eq!(bpe, de); +} + #[test] fn test_deserialize_long_file() { let _tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); From 789d2dfeffdd09f521bdfd5906774f8a580ebd6b Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Tue, 11 Jun 2024 00:32:33 +0900 Subject: [PATCH 4/5] lint --- tokenizers/src/models/bpe/model.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 78ce25684..8d22ab52d 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -136,7 +136,7 @@ impl BpeBuilder { pub fn build(mut self) -> Result { // Validate dropout. if let Some(p) = self.config.dropout { - if p < 0.0 || p > 1.0 { + if !(0.0..=1.0).contains(&p) { return Err(Error::InvalidDropout.into()); } } From ddec297194306d0870080372c757f2fd8234adb3 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Tue, 11 Jun 2024 22:59:02 +0900 Subject: [PATCH 5/5] formatter --- bindings/python/tests/bindings/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bindings/python/tests/bindings/test_models.py b/bindings/python/tests/bindings/test_models.py index 7b6746186..063698384 100644 --- a/bindings/python/tests/bindings/test_models.py +++ b/bindings/python/tests/bindings/test_models.py @@ -73,6 +73,7 @@ def test_dropout_zero(self): model = BPE(dropout=0.0) assert model.dropout == 0.0 + class TestWordPiece: def test_instantiate(self, bert_files): assert isinstance(WordPiece(), Model)