diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.py b/bindings/python/py_src/tokenizers/decoders/__init__.py index a717379c5..12ada5dbd 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.py +++ b/bindings/python/py_src/tokenizers/decoders/__init__.py @@ -12,3 +12,4 @@ BPEDecoder = decoders.BPEDecoder CTC = decoders.CTC Sequence = decoders.Sequence +DecodeStream = decoders.DecodeStream diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index b967fbd14..adad6f53b 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -1,4 +1,12 @@ # Generated content DO NOT EDIT +class DecodeStream: + """ + Class needed for streaming decode + + """ + def __init__(self, skip_special_tokens): + pass + class Decoder: """ Base class for all decoders diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index daa3f8c57..88e0a5398 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -1,6 +1,7 @@ use std::sync::{Arc, RwLock}; use crate::pre_tokenizers::from_string; +use crate::tokenizer::PyTokenizer; use crate::utils::PyPattern; use pyo3::exceptions; use pyo3::prelude::*; @@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } +/// Class needed for streaming decode +/// +#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")] +#[derive(Clone)] +pub struct PyDecodeStream { + /// Regular decode option that is kept throughout. + skip_special_tokens: bool, + /// A temporary buffer of the necessary token_ids needed + /// to produce valid string chunks. + /// This typically contains 3 parts: + /// - read + /// - prefix + /// - rest + /// + /// Read is the bit necessary to surround the prefix + /// so decoding the whole ids produces a valid prefix. + /// Prefix is the previously produced string, kept around to trim off of + /// the next valid chunk + ids: Vec, + /// The previously returned chunk that needs to be discarded from the + /// decoding of the current ids to produce the next chunk + prefix: String, + /// The index within the ids corresponding to the prefix so we can drain + /// correctly + prefix_index: usize, + /// We need to keep 2 prefixes. + /// Prefix is the second one that was already emitted to discard the part + /// of the text of all the ids + /// read is the prefix kept only for starting side effects of the prefix + read_index: usize, +} + +#[pymethods] +impl PyDecodeStream { + #[new] + #[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")] + fn new(skip_special_tokens: bool) -> Self { + PyDecodeStream { + skip_special_tokens, + ids: vec![], + prefix: "".to_string(), + prefix_index: 0, + read_index: 0, + } + } + + #[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")] + fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult> { + ToPyResult(tk::tokenizer::step_decode_stream( + &tokenizer.tokenizer, + id, + self.skip_special_tokens, + &mut self.ids, + &mut self.prefix, + &mut self.prefix_index, + &mut self.read_index, + )) + .into() + } +} + #[cfg(test)] mod test { use std::sync::{Arc, RwLock}; diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 401a146ab..09fb891e1 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl", 0.0), + ("<0x20>", -0.1), + ("<0xC3>", -0.2), + ("<0xA9>", -0.3), + ] + tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True)) + tokenizer.decoder = ByteFallback() + stream = DecodeStream(skip_special_tokens=False) + assert stream.step(tokenizer, 1) == " " + assert stream.step(tokenizer, 2) == None + assert stream.step(tokenizer, 3) == "é" + + vocab = [ + ("", 0.0), + ("▁This", -0.1), + ] + tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False)) + tokenizer.decoder = DecoderMetaspace() + stream = DecodeStream(skip_special_tokens=False) + assert stream.step(tokenizer, 1) == "This" + assert stream.step(tokenizer, 1) == " This" + def test_get_vocab(self): tokenizer = Tokenizer(BPE()) tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 1d4e62339..5f542f2a5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1069,24 +1069,50 @@ where /// See [`DecodeStream`] pub fn step(&mut self, id: u32) -> Result> { - self.ids.push(id); - let string = self - .tokenizer - .decode(self.ids.as_slice(), self.skip_special_tokens)?; - if string.len() > self.prefix.len() && !string.ends_with('�') { - if !(string.starts_with(&self.prefix)) { - return Err(Box::new(DecodeStreamError::InvalidPrefix)); - } - let new_text = &string[self.prefix.len()..].to_string(); - let new_prefix_index = self.ids.len() - self.prefix_index; - self.ids = self.ids.drain(self.read_index..).collect(); - self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; - self.read_index = self.prefix_index; - self.prefix_index = new_prefix_index; - Ok(Some(new_text.to_string())) - } else { - Ok(None) + step_decode_stream( + &self.tokenizer, + id, + self.skip_special_tokens, + &mut self.ids, + &mut self.prefix, + &mut self.prefix_index, + &mut self.read_index, + ) + } +} + +/// Internal function exposed only to bypass python limitations +pub fn step_decode_stream( + tokenizer: &TokenizerImpl, + id: u32, + skip_special_tokens: bool, + ids: &mut Vec, + prefix: &mut String, + prefix_index: &mut usize, + read_index: &mut usize, +) -> Result> +where + M: Model, + N: Normalizer, + PT: PreTokenizer, + PP: PostProcessor, + D: Decoder, +{ + ids.push(id); + let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?; + if string.len() > prefix.len() && !string.ends_with('�') { + if !(string.starts_with(&*prefix)) { + return Err(Box::new(DecodeStreamError::InvalidPrefix)); } + let new_text = &string[prefix.len()..].to_string(); + let new_prefix_index = ids.len() - *prefix_index; + *ids = ids.drain(*read_index..).collect(); + *prefix = tokenizer.decode(&ids, skip_special_tokens)?; + *read_index = *prefix_index; + *prefix_index = new_prefix_index; + Ok(Some(new_text.to_string())) + } else { + Ok(None) } }