From c97389bca6ce0171fc59d207cc907640e2e4b1fd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Nov 2024 20:52:05 +0800 Subject: [PATCH 1/9] Adding an API for decode streaming. --- tokenizers/src/tokenizer/mod.rs | 56 +++++++++++++++++++++++++++++-- tokenizers/tests/documentation.rs | 16 +++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 49bc539a2..466b552d7 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -12,8 +12,7 @@ use std::{ collections::HashMap, fs::{read_to_string, File}, - io::prelude::*, - io::BufReader, + io::{prelude::*, BufReader}, ops::{Deref, DerefMut}, path::{Path, PathBuf}, }; @@ -906,6 +905,59 @@ where Ok(tokens.join(" ")) } } + + /// Decode the given ids, back to a String + pub fn decode_stream<'tok>( + &'tok self, + skip_special_tokens: bool, + ) -> DecodeStream<'tok, M, N, PT, PP, D> { + DecodeStream::new(self, skip_special_tokens) + } +} + +pub struct DecodeStream<'tok, M, N, PT, PP, D> { + tokenizer: &'tok TokenizerImpl, + ids: Vec, + prefix_index: usize, + prefix: String, + skip_special_tokens: bool, +} + +impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D> +where + M: Model, + N: Normalizer, + PT: PreTokenizer, + PP: PostProcessor, + D: Decoder, +{ + fn new(tokenizer: &'tok TokenizerImpl, skip_special_tokens: bool) -> Self { + Self { + tokenizer, + ids: vec![], + skip_special_tokens, + prefix: "".to_string(), + prefix_index: 0, + } + } + + pub fn step(&mut self, id: u32) -> Result> { + self.ids.push(id); + let string = self + .tokenizer + .decode(self.ids.as_slice(), self.skip_special_tokens)?; + println!("Decode got {string} {} Ids:{:?}", self.prefix, self.ids); + if string.len() > self.prefix.len() && !string.ends_with('�') { + let new_text = &string[self.prefix.len()..]; + self.prefix = new_text.to_string(); + let new_prefix_index = self.ids.len() - self.prefix_index; + self.ids = self.ids.drain(self.prefix_index..).collect(); + self.prefix_index = new_prefix_index; + Ok(Some(new_text.to_string())) + } else { + Ok(None) + } + } } impl TokenizerImpl diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index ad29590b9..ae334c06c 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -58,6 +58,22 @@ fn load_tokenizer() { assert_eq!(decoded, example); } +#[test] +fn streaming_tokenizer() { + let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); + + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); + assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); + assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); + assert_eq!( + decode_stream.step(1246).unwrap(), + Some(" example".to_string()) + ); + + // TODO add an example with byte fallback for `None` example +} + #[test] #[ignore] fn quicktour_slow_train() -> tokenizers::Result<()> { From c3578d45d6ec87272d680db09eae45bf81b4887d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Nov 2024 20:55:33 +0800 Subject: [PATCH 2/9] Add another missing test case (proving the effect of state.) --- tokenizers/tests/documentation.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index ae334c06c..023cdc442 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -71,6 +71,7 @@ fn streaming_tokenizer() { Some(" example".to_string()) ); + // TODO change the tokenizer to prove side effects of the streaming state. // TODO add an example with byte fallback for `None` example } From bdcb2b9c1702b141bd1c731f4b42b0f6600d75e6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Nov 2024 20:59:58 +0800 Subject: [PATCH 3/9] Ellide lifetime. --- tokenizers/src/tokenizer/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 466b552d7..72e31a3a2 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -908,9 +908,9 @@ where /// Decode the given ids, back to a String pub fn decode_stream<'tok>( - &'tok self, + &self, skip_special_tokens: bool, - ) -> DecodeStream<'tok, M, N, PT, PP, D> { + ) -> DecodeStream<'_, M, N, PT, PP, D> { DecodeStream::new(self, skip_special_tokens) } } From af7d82eb67d8117c685eee8152680c7ec61021b2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Nov 2024 21:54:13 +0800 Subject: [PATCH 4/9] Ellide bis. --- tokenizers/src/tokenizer/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 72e31a3a2..bbc62bf8d 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -907,10 +907,7 @@ where } /// Decode the given ids, back to a String - pub fn decode_stream<'tok>( - &self, - skip_special_tokens: bool, - ) -> DecodeStream<'_, M, N, PT, PP, D> { + pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> { DecodeStream::new(self, skip_special_tokens) } } From 5a5406e67ca65cd8212dff23f028b59131a003bb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 7 Nov 2024 16:34:01 +0800 Subject: [PATCH 5/9] Fixing the streaming implementation. --- tokenizers/src/tokenizer/mod.rs | 25 ++++++++++++---- tokenizers/tests/documentation.rs | 49 +++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index bbc62bf8d..a227688ad 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -915,11 +915,22 @@ where pub struct DecodeStream<'tok, M, N, PT, PP, D> { tokenizer: &'tok TokenizerImpl, ids: Vec, - prefix_index: usize, prefix: String, + prefix_index: usize, + /// We need to keep 2 prefixes. + /// Prefix is the second one that was already emitted to discard the part + /// of the text of all the ids + /// read is the prefix kept only for starting side effects of the prefix + read_index: usize, skip_special_tokens: bool, } +#[derive(thiserror::Error, Debug)] +pub enum DecodeStreamError { + #[error("Invalid prefix encountered")] + InvalidPrefix, +} + impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D> where M: Model, @@ -935,6 +946,7 @@ where skip_special_tokens, prefix: "".to_string(), prefix_index: 0, + read_index: 0, } } @@ -943,12 +955,15 @@ where let string = self .tokenizer .decode(self.ids.as_slice(), self.skip_special_tokens)?; - println!("Decode got {string} {} Ids:{:?}", self.prefix, self.ids); if string.len() > self.prefix.len() && !string.ends_with('�') { - let new_text = &string[self.prefix.len()..]; - self.prefix = new_text.to_string(); + if !(string.starts_with(&self.prefix)) { + return Err(Box::new(DecodeStreamError::InvalidPrefix)); + } + let new_text = &string[self.prefix.len()..].to_string(); let new_prefix_index = self.ids.len() - self.prefix_index; - self.ids = self.ids.drain(self.prefix_index..).collect(); + self.ids = self.ids.drain(self.read_index..).collect(); + self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; + self.read_index = self.prefix_index; self.prefix_index = new_prefix_index; Ok(Some(new_text.to_string())) } else { diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 023cdc442..a851f40ab 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -1,3 +1,7 @@ +use std::collections::HashMap; +use std::iter::FromIterator; + +use tokenizers::decoders::byte_fallback::ByteFallback; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::normalizers::{Sequence, Strip, NFC}; use tokenizers::pre_tokenizers::byte_level::ByteLevel; @@ -71,8 +75,49 @@ fn streaming_tokenizer() { Some(" example".to_string()) ); - // TODO change the tokenizer to prove side effects of the streaming state. - // TODO add an example with byte fallback for `None` example + let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); + let encoded = tokenizer.encode("This is an example", false).unwrap(); + assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]); + let mut decode_stream = tokenizer.decode_stream(false); + // No space anymore + assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string())); + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string())); + assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string())); + assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string())); + assert_eq!( + decode_stream.step(823).unwrap(), + Some(" example".to_string()) + ); + + // None example + let vocab = HashMap::from_iter([ + ("<0x20>".to_string(), 0), + ("<0xC3>".to_string(), 1), + ("<0xA9>".to_string(), 2), + ]); + let merges = vec![]; + let bpe = BPE::builder() + .vocab_and_merges(vocab, merges) + .byte_fallback(true) + .build() + .unwrap(); + let tokenizer = TokenizerBuilder::new() + .with_model(bpe) + .with_normalizer(Some(Sequence::new(vec![ + Strip::new(true, true).into(), + NFC.into(), + ]))) + .with_pre_tokenizer(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevel::default())) + .with_decoder(Some(ByteFallback::default())) + .build() + .unwrap(); + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); + assert_eq!(decode_stream.step(1).unwrap(), None); + assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); + assert_eq!(decode_stream.step(2).unwrap(), None); } #[test] From 18b999c64534ff165f7069c5e8c18897f2835cd2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 7 Nov 2024 21:34:31 +0800 Subject: [PATCH 6/9] Adding more docs. --- tokenizers/src/tokenizer/mod.rs | 116 +++++++++++++++++++++++++++++- tokenizers/tests/documentation.rs | 1 + 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index a227688ad..208a0fc7a 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -907,22 +907,135 @@ where } /// Decode the given ids, back to a String + /// See [`DecodeStream`] pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> { DecodeStream::new(self, skip_special_tokens) } } +/// DecodeStream will keep the state necessary to produce individual chunks of +/// strings given an input stream of token_ids. +/// +/// This is necessary because decoding in general cannot achieve that since strings +/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces +/// +/// Example: +/// +/// ``` +/// use tokenizers::Tokenizer; +/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); +/// +/// let mut decode_stream = tokenizer.decode_stream(false); +/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); +/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); +/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); +/// assert_eq!( +/// decode_stream.step(1246).unwrap(), +/// Some(" example".to_string()) +/// ); +/// ``` +/// +/// Returning `None` means the given id is not enough to produce a chunk. +/// This typically happens with `byte_fallback` options where some tokens do +/// not represent valid utf-8, and only follow-up token_ids will help produce +/// a valid chunk. +/// ``` +/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC}; +/// use std::collections::HashMap; +/// use std::iter::FromIterator; +/// +/// let vocab = HashMap::from_iter([ +/// ("<0x20>".to_string(), 0), +/// ("<0xC3>".to_string(), 1), +/// ("<0xA9>".to_string(), 2), +/// (" This".to_string(), 3), +/// ]); +/// let merges = vec![]; +/// let bpe = BPE::builder() +/// .vocab_and_merges(vocab, merges) +/// .byte_fallback(true) +/// .build() +/// .unwrap(); +/// let tokenizer = TokenizerBuilder::default() +/// .with_model(bpe) +/// .with_decoder(Some(ByteFallback::default())) +/// .with_normalizer(Some(NFC)) +/// .with_pre_tokenizer(Some(ByteLevel::default())) +/// .with_post_processor(Some(ByteLevel::default())) +/// .build().unwrap(); +/// +/// let mut decode_stream = tokenizer.decode_stream(false); +/// // Single byte_fallback is valid utf-8 +/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); +/// // Invalid utf-8 +/// assert_eq!(decode_stream.step(1).unwrap(), None); +/// // Valid utf-8 again, this corresponds to both tokens: [1, 2] +/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); +/// ``` +/// +/// To see how [`DecodeStream`] is necessary, let's show how using raw [`Tokenizer::decode`] would +/// fail. +/// +/// ``` +/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC}; +/// use std::collections::HashMap; +/// use std::iter::FromIterator; +/// +/// let vocab = HashMap::from_iter([ +/// ("▁This".to_string(), 0), +/// ]); +/// let merges = vec![]; +/// let bpe = BPE::builder() +/// .vocab_and_merges(vocab, merges) +/// .byte_fallback(true) +/// .build() +/// .unwrap(); +/// let tokenizer = TokenizerBuilder::new() +/// .with_model(bpe) +/// .with_decoder(Some(Metaspace::default())) +/// .with_normalizer(Some(NFC)) +/// .with_pre_tokenizer(Some(ByteLevel::default())) +/// .with_post_processor(Some(ByteLevel::default())) +/// .build() +/// .unwrap(); +/// +/// // Strip decoder removes the extra initial space +/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This"); +/// // Decoding one token at a time would produce "ThisThis" +/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This"); +/// +/// // Using a stream fixes it by keeping the necessary state. +/// let mut decode_stream = tokenizer.decode_stream(false); +/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string())); +/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string())); +/// ``` pub struct DecodeStream<'tok, M, N, PT, PP, D> { + /// A reference to the tokenizer tokenizer: &'tok TokenizerImpl, + /// Regular decode option that is kept throughout. + skip_special_tokens: bool, + /// A temporary buffer of the necessary token_ids needed + /// to produce valid string chunks. + /// This typically contains 3 parts: + /// - read + /// - prefix + /// - rest + /// Read is the bit necessary to surround the prefix + /// so decoding the whole ids produces a valid prefix. + /// Prefix is the previously produced string, kept around to trim off of + /// the next valid chunk ids: Vec, + /// The previously returned chunk that needs to be discarded from the + /// decoding of the current ids to produce the next chunk prefix: String, + /// The index within the ids corresponding to the prefix so we can drain + /// correctly prefix_index: usize, /// We need to keep 2 prefixes. /// Prefix is the second one that was already emitted to discard the part /// of the text of all the ids /// read is the prefix kept only for starting side effects of the prefix read_index: usize, - skip_special_tokens: bool, } #[derive(thiserror::Error, Debug)] @@ -950,6 +1063,7 @@ where } } + /// See [`DecodeStream`] pub fn step(&mut self, id: u32) -> Result> { self.ids.push(id); let string = self diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index a851f40ab..304211e77 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -95,6 +95,7 @@ fn streaming_tokenizer() { ("<0x20>".to_string(), 0), ("<0xC3>".to_string(), 1), ("<0xA9>".to_string(), 2), + (" This".to_string(), 3), ]); let merges = vec![]; let bpe = BPE::builder() From c32a2c24bf8e961e667ce8bb2c06885747c7b4c6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 7 Nov 2024 21:58:43 +0800 Subject: [PATCH 7/9] End of list. --- tokenizers/src/tokenizer/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 208a0fc7a..ac79e6f76 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1020,6 +1020,7 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> { /// - read /// - prefix /// - rest + /// /// Read is the bit necessary to surround the prefix /// so decoding the whole ids produces a valid prefix. /// Prefix is the previously produced string, kept around to trim off of From a326447c6959d7cc46efb5eaf69be837531d81f9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 7 Nov 2024 22:00:18 +0800 Subject: [PATCH 8/9] Fix internal link. --- tokenizers/src/tokenizer/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ac79e6f76..8290a6524 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -973,7 +973,7 @@ where /// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); /// ``` /// -/// To see how [`DecodeStream`] is necessary, let's show how using raw [`Tokenizer::decode`] would +/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would /// fail. /// /// ``` From 218fd3b784a23540d65317f184922b4cff0838e9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 7 Nov 2024 22:17:00 +0800 Subject: [PATCH 9/9] Skip doctest on Windows (no tokenizer file because no make) --- tokenizers/src/tokenizer/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 8290a6524..1d4e62339 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -922,6 +922,8 @@ where /// Example: /// /// ``` +/// # #[cfg(not(target_os = "windows"))] +/// # { /// use tokenizers::Tokenizer; /// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); /// @@ -933,6 +935,7 @@ where /// decode_stream.step(1246).unwrap(), /// Some(" example".to_string()) /// ); +/// # } /// ``` /// /// Returning `None` means the given id is not enough to produce a chunk.