Skip to content

Commit

Permalink
Merge branch 'main' into fix-decode
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jul 12, 2024
2 parents 7761621 + fdd26ba commit 1bd8d96
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 14 deletions.
4 changes: 4 additions & 0 deletions bindings/python/tests/bindings/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def test_can_modify(self):
model.byte_fallback = True
assert model.byte_fallback == True

def test_dropout_zero(self):
model = BPE(dropout=0.0)
assert model.dropout == 0.0


class TestWordPiece:
def test_instantiate(self, bert_files):
Expand Down
2 changes: 2 additions & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ unstable_wasm = ["fancy-regex", "getrandom/js"]
criterion = "0.5"
tempfile = "3.10"
assert_approx_eq = "1.1"
tracing = "0.1"
tracing-subscriber = "0.3.18"

[profile.release]
lto = "fat"
2 changes: 1 addition & 1 deletion tokenizers/src/models/bpe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub enum Error {
#[error("Unk token `{0}` not found in the vocabulary")]
UnkTokenOutOfVocabulary(String),
/// Dropout not between 0 and 1.
#[error("Dropout should be between 0 and 1")]
#[error("Dropout should be between 0 and 1, inclusive")]
InvalidDropout,
}

Expand Down
18 changes: 15 additions & 3 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl BpeBuilder {
pub fn build(mut self) -> Result<BPE> {
// Validate dropout.
if let Some(p) = self.config.dropout {
if p <= 0.0 || p > 1.0 {
if !(0.0..=1.0).contains(&p) {
return Err(Error::InvalidDropout.into());
}
}
Expand Down Expand Up @@ -214,7 +214,7 @@ pub struct BPE {
pub(crate) merges: MergeMap,
/// Contains the cache for optimizing the encoding step.
cache: Option<Cache<String, Word>>,
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// Dropout probability for merges. 0.0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
pub dropout: Option<f32>,
/// The unknown token to be used when we encounter an unknown char
Expand Down Expand Up @@ -493,7 +493,7 @@ impl Model for BPE {
return Ok(vec![]);
}

if self.dropout.is_none() {
if self.dropout.is_none() || self.dropout == Some(0.0) {
self.tokenize_with_cache(sequence)
} else {
let word = self.merge_word(sequence)?;
Expand Down Expand Up @@ -685,6 +685,11 @@ mod tests {
let tokens = bpe.tokenize("unrelated").unwrap();
assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);

// With dropout = 0.0 (equivalent to dropout == none)
bpe.dropout = Some(0.0);
let tokens = bpe.tokenize("unrelated").unwrap();
assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);

// Now set dropout to 1.0. Result should be no merges performed.
bpe.dropout = Some(1.0);
let tokens = bpe.tokenize("unrelated").unwrap();
Expand Down Expand Up @@ -739,6 +744,13 @@ mod tests {
assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32);
}

#[test]
// Ensure BPEBuilder with dropout = 0.0 doesn't error
fn test_bpe_with_dropout_0() {
let bpe = BPE::builder().dropout(0.0).build().unwrap();
assert_eq!(bpe.dropout, Some(0.0));
}

#[test]
// Ensure `BPE::from_file` works as expected.
fn test_bpe_with_continuing_subword_prefix() {
Expand Down
30 changes: 20 additions & 10 deletions tokenizers/src/tokenizer/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,15 @@ where
for token in &tokens {
// Warn the user if the id is different than expected
let received_id = tokenizer.token_to_id(&token.token.content);
if received_id != Some(token.id) {
warn!(
"Warning: Token '{}' was expected to have ID '{}' but was given ID '{}'",
token.token.content,
token.id,
if let Some(rid) = received_id {
if let Some(rid) = received_id {
if rid != token.id {
warn!(
"Warning: Token '{}' was expected to have ID '{}' but was given ID '{}'",
token.token.content,
token.id,
rid.to_string()
} else {
"None".to_string()
}
);
);
}
}
}
let added_tokens: Vec<_> = tokens.into_iter().map(|token| token.token).collect();
Expand All @@ -179,6 +177,7 @@ where
mod tests {
use crate::tokenizer::Tokenizer;
use std::str::FromStr;
use tracing_subscriber::fmt;

#[test]
fn test_deserialization_serialization_invariant() {
Expand Down Expand Up @@ -233,4 +232,15 @@ mod tests {
// It should be exactly the same as above
assert_eq!(tok_str, tok_json);
}

#[cfg(feature = "http")]
#[test]
fn test_from_pretrained() {
fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();
let _ = Tokenizer::from_pretrained("Qwen/Qwen2-7B-Instruct", None);
warn!("This should be the first warning");
}
}
15 changes: 15 additions & 0 deletions tokenizers/tests/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,21 @@ fn tokenizer() {
assert_eq!(serde_json::to_string(&de).unwrap(), ser);
}

#[test]
fn bpe_with_dropout_serde() {
let mut bpe = BPE::default();
bpe.dropout = Some(0.1);
let ser = serde_json::to_string(&bpe).unwrap();
let de = serde_json::from_str(&ser).unwrap();
assert_eq!(bpe, de);

// set dropout to 0.0 (which is analogous to None) and reserialize
bpe.dropout = Some(0.0);
let ser = serde_json::to_string(&bpe).unwrap();
let de = serde_json::from_str(&ser).unwrap();
assert_eq!(bpe, de);
}

#[test]
fn test_deserialize_long_file() {
let _tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
Expand Down

0 comments on commit 1bd8d96

Please sign in to comment.