Skip to content

Commit

Permalink
Fixing the streaming implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Nov 7, 2024
1 parent af7d82e commit 5a5406e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
25 changes: 20 additions & 5 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,11 +915,22 @@ where
pub struct DecodeStream<'tok, M, N, PT, PP, D> {
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
ids: Vec<u32>,
prefix_index: usize,
prefix: String,
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
skip_special_tokens: bool,
}

#[derive(thiserror::Error, Debug)]
pub enum DecodeStreamError {
#[error("Invalid prefix encountered")]
InvalidPrefix,
}

impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
where
M: Model,
Expand All @@ -935,6 +946,7 @@ where
skip_special_tokens,
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

Expand All @@ -943,12 +955,15 @@ where
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
println!("Decode got {string} {} Ids:{:?}", self.prefix, self.ids);
if string.len() > self.prefix.len() && !string.ends_with('�') {
let new_text = &string[self.prefix.len()..];
self.prefix = new_text.to_string();
if !(string.starts_with(&self.prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.ids = self.ids.drain(self.prefix_index..).collect();
self.ids = self.ids.drain(self.read_index..).collect();
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.read_index = self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Expand Down
49 changes: 47 additions & 2 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::collections::HashMap;
use std::iter::FromIterator;

use tokenizers::decoders::byte_fallback::ByteFallback;
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
use tokenizers::normalizers::{Sequence, Strip, NFC};
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
Expand Down Expand Up @@ -71,8 +75,49 @@ fn streaming_tokenizer() {
Some(" example".to_string())
);

// TODO change the tokenizer to prove side effects of the streaming state.
// TODO add an example with byte fallback for `None` example
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
let encoded = tokenizer.encode("This is an example", false).unwrap();
assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]);
let mut decode_stream = tokenizer.decode_stream(false);
// No space anymore
assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string()));
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string()));
assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string()));
assert_eq!(
decode_stream.step(823).unwrap(),
Some(" example".to_string())
);

// None example
let vocab = HashMap::from_iter([
("<0x20>".to_string(), 0),
("<0xC3>".to_string(), 1),
("<0xA9>".to_string(), 2),
]);
let merges = vec![];
let bpe = BPE::builder()
.vocab_and_merges(vocab, merges)
.byte_fallback(true)
.build()
.unwrap();
let tokenizer = TokenizerBuilder::new()
.with_model(bpe)
.with_normalizer(Some(Sequence::new(vec![
Strip::new(true, true).into(),
NFC.into(),
])))
.with_pre_tokenizer(Some(ByteLevel::default()))
.with_post_processor(Some(ByteLevel::default()))
.with_decoder(Some(ByteFallback::default()))
.build()
.unwrap();
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
assert_eq!(decode_stream.step(1).unwrap(), None);
assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
assert_eq!(decode_stream.step(2).unwrap(), None);
}

#[test]
Expand Down

0 comments on commit 5a5406e

Please sign in to comment.