Skip to content

Commit

Permalink
Python binding for decode stream
Browse files Browse the repository at this point in the history
Different API because Python cannot handle lifetimes properly.
  • Loading branch information
Narsil committed Nov 15, 2024
1 parent 500db28 commit d35b966
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 18 deletions.
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC
Sequence = decoders.Sequence
DecodeStream = decoders.DecodeStream
8 changes: 8 additions & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Generated content DO NOT EDIT
class DecodeStream:
"""
Class needed for streaming decode
"""
def __init__(self, skip_special_tokens):
pass

class Decoder:
"""
Base class for all decoders
Expand Down
63 changes: 63 additions & 0 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::{Arc, RwLock};

use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern;
use pyo3::exceptions;
use pyo3::prelude::*;
Expand Down Expand Up @@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;
m.add_class::<PySequenceDecoder>()?;
m.add_class::<PyDecodeStream>()?;
Ok(())
}

/// Class needed for streaming decode
///
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
#[derive(Clone)]
pub struct PyDecodeStream {
/// Regular decode option that is kept throughout.
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}

#[pymethods]
impl PyDecodeStream {
#[new]
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
fn new(skip_special_tokens: bool) -> Self {
PyDecodeStream {
skip_special_tokens,
ids: vec![],
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
))
.into()
}
}

#[cfg(test)]
mod test {
use std::sync::{Arc, RwLock};
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
#[derive(Clone, Serialize)]
#[serde(transparent)]
pub struct PyTokenizer {
tokenizer: Tokenizer,
pub(crate) tokenizer: Tokenizer,
}

impl PyTokenizer {
Expand Down
32 changes: 32 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
from tokenizers.processors import RobertaProcessing, TemplateProcessing
from tokenizers.normalizers import Strip, Lowercase, Sequence
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace


from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
Expand Down Expand Up @@ -365,6 +366,37 @@ def test_decode(self):
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
assert output == ["my name is john", "pair"]

# Can decode stream
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 0) == "my"
assert stream.step(tokenizer, 1) == " name"
assert stream.step(tokenizer, 2) == " is"
assert stream.step(tokenizer, 3) == " john"

def test_decode_stream(self):
vocab = [
("<unk>", 0.0),
("<0x20>", -0.1),
("<0xC3>", -0.2),
("<0xA9>", -0.3),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
tokenizer.decoder = ByteFallback()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == " "
assert stream.step(tokenizer, 2) == None
assert stream.step(tokenizer, 3) == "é"

vocab = [
("<unk>", 0.0),
("▁This", -0.1),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
tokenizer.decoder = DecoderMetaspace()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == "This"
assert stream.step(tokenizer, 1) == " This"

def test_get_vocab(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
Expand Down
60 changes: 43 additions & 17 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,24 +1069,50 @@ where

/// See [`DecodeStream`]
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&self.prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.ids = self.ids.drain(self.read_index..).collect();
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.read_index = self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
step_decode_stream(
&self.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
)
}
}

/// Internal function exposed only to bypass python limitations
pub fn step_decode_stream<M, N, PT, PP, D>(
tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
id: u32,
skip_special_tokens: bool,
ids: &mut Vec<u32>,
prefix: &mut String,
prefix_index: &mut usize,
read_index: &mut usize,
) -> Result<Option<String>>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
ids.push(id);
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
if string.len() > prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&*prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[prefix.len()..].to_string();
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*read_index..).collect();
*prefix = tokenizer.decode(&ids, skip_special_tokens)?;
*read_index = *prefix_index;
*prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
}
}

Expand Down

0 comments on commit d35b966

Please sign in to comment.