From c97389bca6ce0171fc59d207cc907640e2e4b1fd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Nov 2024 20:52:05 +0800 Subject: [PATCH] Adding an API for decode streaming. --- tokenizers/src/tokenizer/mod.rs | 56 +++++++++++++++++++++++++++++-- tokenizers/tests/documentation.rs | 16 +++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 49bc539a2..466b552d7 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -12,8 +12,7 @@ use std::{ collections::HashMap, fs::{read_to_string, File}, - io::prelude::*, - io::BufReader, + io::{prelude::*, BufReader}, ops::{Deref, DerefMut}, path::{Path, PathBuf}, }; @@ -906,6 +905,59 @@ where Ok(tokens.join(" ")) } } + + /// Decode the given ids, back to a String + pub fn decode_stream<'tok>( + &'tok self, + skip_special_tokens: bool, + ) -> DecodeStream<'tok, M, N, PT, PP, D> { + DecodeStream::new(self, skip_special_tokens) + } +} + +pub struct DecodeStream<'tok, M, N, PT, PP, D> { + tokenizer: &'tok TokenizerImpl, + ids: Vec, + prefix_index: usize, + prefix: String, + skip_special_tokens: bool, +} + +impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D> +where + M: Model, + N: Normalizer, + PT: PreTokenizer, + PP: PostProcessor, + D: Decoder, +{ + fn new(tokenizer: &'tok TokenizerImpl, skip_special_tokens: bool) -> Self { + Self { + tokenizer, + ids: vec![], + skip_special_tokens, + prefix: "".to_string(), + prefix_index: 0, + } + } + + pub fn step(&mut self, id: u32) -> Result> { + self.ids.push(id); + let string = self + .tokenizer + .decode(self.ids.as_slice(), self.skip_special_tokens)?; + println!("Decode got {string} {} Ids:{:?}", self.prefix, self.ids); + if string.len() > self.prefix.len() && !string.ends_with('�') { + let new_text = &string[self.prefix.len()..]; + self.prefix = new_text.to_string(); + let new_prefix_index = self.ids.len() - self.prefix_index; + self.ids = self.ids.drain(self.prefix_index..).collect(); + self.prefix_index = new_prefix_index; + Ok(Some(new_text.to_string())) + } else { + Ok(None) + } + } } impl TokenizerImpl diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index ad29590b9..ae334c06c 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -58,6 +58,22 @@ fn load_tokenizer() { assert_eq!(decoded, example); } +#[test] +fn streaming_tokenizer() { + let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); + + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); + assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); + assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); + assert_eq!( + decode_stream.step(1246).unwrap(), + Some(" example".to_string()) + ); + + // TODO add an example with byte fallback for `None` example +} + #[test] #[ignore] fn quicktour_slow_train() -> tokenizers::Result<()> {