From 500db282a816a2c4f3cb1710c5169db2fda2cff7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Nov 2024 13:02:38 +0800 Subject: [PATCH] Adding an API for decode streaming. (#1677) * Adding an API for decode streaming. * Add another missing test case (proving the effect of state.) * Ellide lifetime. * Ellide bis. * Fixing the streaming implementation. * Adding more docs. * End of list. * Fix internal link. * Skip doctest on Windows (no tokenizer file because no make) --- tokenizers/src/tokenizer/mod.rs | 186 +++++++++++++++++++++++++++++- tokenizers/tests/documentation.rs | 63 ++++++++++ 2 files changed, 247 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 49bc539a2..1d4e62339 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -12,8 +12,7 @@ use std::{ collections::HashMap, fs::{read_to_string, File}, - io::prelude::*, - io::BufReader, + io::{prelude::*, BufReader}, ops::{Deref, DerefMut}, path::{Path, PathBuf}, }; @@ -906,6 +905,189 @@ where Ok(tokens.join(" ")) } } + + /// Decode the given ids, back to a String + /// See [`DecodeStream`] + pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> { + DecodeStream::new(self, skip_special_tokens) + } +} + +/// DecodeStream will keep the state necessary to produce individual chunks of +/// strings given an input stream of token_ids. +/// +/// This is necessary because decoding in general cannot achieve that since strings +/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces +/// +/// Example: +/// +/// ``` +/// # #[cfg(not(target_os = "windows"))] +/// # { +/// use tokenizers::Tokenizer; +/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); +/// +/// let mut decode_stream = tokenizer.decode_stream(false); +/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); +/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); +/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); +/// assert_eq!( +/// decode_stream.step(1246).unwrap(), +/// Some(" example".to_string()) +/// ); +/// # } +/// ``` +/// +/// Returning `None` means the given id is not enough to produce a chunk. +/// This typically happens with `byte_fallback` options where some tokens do +/// not represent valid utf-8, and only follow-up token_ids will help produce +/// a valid chunk. +/// ``` +/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC}; +/// use std::collections::HashMap; +/// use std::iter::FromIterator; +/// +/// let vocab = HashMap::from_iter([ +/// ("<0x20>".to_string(), 0), +/// ("<0xC3>".to_string(), 1), +/// ("<0xA9>".to_string(), 2), +/// (" This".to_string(), 3), +/// ]); +/// let merges = vec![]; +/// let bpe = BPE::builder() +/// .vocab_and_merges(vocab, merges) +/// .byte_fallback(true) +/// .build() +/// .unwrap(); +/// let tokenizer = TokenizerBuilder::default() +/// .with_model(bpe) +/// .with_decoder(Some(ByteFallback::default())) +/// .with_normalizer(Some(NFC)) +/// .with_pre_tokenizer(Some(ByteLevel::default())) +/// .with_post_processor(Some(ByteLevel::default())) +/// .build().unwrap(); +/// +/// let mut decode_stream = tokenizer.decode_stream(false); +/// // Single byte_fallback is valid utf-8 +/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); +/// // Invalid utf-8 +/// assert_eq!(decode_stream.step(1).unwrap(), None); +/// // Valid utf-8 again, this corresponds to both tokens: [1, 2] +/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); +/// ``` +/// +/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would +/// fail. +/// +/// ``` +/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC}; +/// use std::collections::HashMap; +/// use std::iter::FromIterator; +/// +/// let vocab = HashMap::from_iter([ +/// ("▁This".to_string(), 0), +/// ]); +/// let merges = vec![]; +/// let bpe = BPE::builder() +/// .vocab_and_merges(vocab, merges) +/// .byte_fallback(true) +/// .build() +/// .unwrap(); +/// let tokenizer = TokenizerBuilder::new() +/// .with_model(bpe) +/// .with_decoder(Some(Metaspace::default())) +/// .with_normalizer(Some(NFC)) +/// .with_pre_tokenizer(Some(ByteLevel::default())) +/// .with_post_processor(Some(ByteLevel::default())) +/// .build() +/// .unwrap(); +/// +/// // Strip decoder removes the extra initial space +/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This"); +/// // Decoding one token at a time would produce "ThisThis" +/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This"); +/// +/// // Using a stream fixes it by keeping the necessary state. +/// let mut decode_stream = tokenizer.decode_stream(false); +/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string())); +/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string())); +/// ``` +pub struct DecodeStream<'tok, M, N, PT, PP, D> { + /// A reference to the tokenizer + tokenizer: &'tok TokenizerImpl, + /// Regular decode option that is kept throughout. + skip_special_tokens: bool, + /// A temporary buffer of the necessary token_ids needed + /// to produce valid string chunks. + /// This typically contains 3 parts: + /// - read + /// - prefix + /// - rest + /// + /// Read is the bit necessary to surround the prefix + /// so decoding the whole ids produces a valid prefix. + /// Prefix is the previously produced string, kept around to trim off of + /// the next valid chunk + ids: Vec, + /// The previously returned chunk that needs to be discarded from the + /// decoding of the current ids to produce the next chunk + prefix: String, + /// The index within the ids corresponding to the prefix so we can drain + /// correctly + prefix_index: usize, + /// We need to keep 2 prefixes. + /// Prefix is the second one that was already emitted to discard the part + /// of the text of all the ids + /// read is the prefix kept only for starting side effects of the prefix + read_index: usize, +} + +#[derive(thiserror::Error, Debug)] +pub enum DecodeStreamError { + #[error("Invalid prefix encountered")] + InvalidPrefix, +} + +impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D> +where + M: Model, + N: Normalizer, + PT: PreTokenizer, + PP: PostProcessor, + D: Decoder, +{ + fn new(tokenizer: &'tok TokenizerImpl, skip_special_tokens: bool) -> Self { + Self { + tokenizer, + ids: vec![], + skip_special_tokens, + prefix: "".to_string(), + prefix_index: 0, + read_index: 0, + } + } + + /// See [`DecodeStream`] + pub fn step(&mut self, id: u32) -> Result> { + self.ids.push(id); + let string = self + .tokenizer + .decode(self.ids.as_slice(), self.skip_special_tokens)?; + if string.len() > self.prefix.len() && !string.ends_with('�') { + if !(string.starts_with(&self.prefix)) { + return Err(Box::new(DecodeStreamError::InvalidPrefix)); + } + let new_text = &string[self.prefix.len()..].to_string(); + let new_prefix_index = self.ids.len() - self.prefix_index; + self.ids = self.ids.drain(self.read_index..).collect(); + self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; + self.read_index = self.prefix_index; + self.prefix_index = new_prefix_index; + Ok(Some(new_text.to_string())) + } else { + Ok(None) + } + } } impl TokenizerImpl diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index ad29590b9..304211e77 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -1,3 +1,7 @@ +use std::collections::HashMap; +use std::iter::FromIterator; + +use tokenizers::decoders::byte_fallback::ByteFallback; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::normalizers::{Sequence, Strip, NFC}; use tokenizers::pre_tokenizers::byte_level::ByteLevel; @@ -58,6 +62,65 @@ fn load_tokenizer() { assert_eq!(decoded, example); } +#[test] +fn streaming_tokenizer() { + let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); + + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); + assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); + assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); + assert_eq!( + decode_stream.step(1246).unwrap(), + Some(" example".to_string()) + ); + + let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); + let encoded = tokenizer.encode("This is an example", false).unwrap(); + assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]); + let mut decode_stream = tokenizer.decode_stream(false); + // No space anymore + assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string())); + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string())); + assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string())); + assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string())); + assert_eq!( + decode_stream.step(823).unwrap(), + Some(" example".to_string()) + ); + + // None example + let vocab = HashMap::from_iter([ + ("<0x20>".to_string(), 0), + ("<0xC3>".to_string(), 1), + ("<0xA9>".to_string(), 2), + (" This".to_string(), 3), + ]); + let merges = vec![]; + let bpe = BPE::builder() + .vocab_and_merges(vocab, merges) + .byte_fallback(true) + .build() + .unwrap(); + let tokenizer = TokenizerBuilder::new() + .with_model(bpe) + .with_normalizer(Some(Sequence::new(vec![ + Strip::new(true, true).into(), + NFC.into(), + ]))) + .with_pre_tokenizer(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevel::default())) + .with_decoder(Some(ByteFallback::default())) + .build() + .unwrap(); + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); + assert_eq!(decode_stream.step(1).unwrap(), None); + assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); + assert_eq!(decode_stream.step(2).unwrap(), None); +} + #[test] #[ignore] fn quicktour_slow_train() -> tokenizers::Result<()> {