Skip to content

Commit

Permalink
Adding an API for decode streaming. (#1677)
Browse files Browse the repository at this point in the history
* Adding an API for decode streaming.

* Add another missing test case (proving the effect of state.)

* Ellide lifetime.

* Ellide bis.

* Fixing the streaming implementation.

* Adding more docs.

* End of list.

* Fix internal link.

* Skip doctest on Windows (no tokenizer file because no make)
  • Loading branch information
Narsil authored Nov 15, 2024
1 parent f4c9fd7 commit 500db28
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 2 deletions.
186 changes: 184 additions & 2 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
use std::{
collections::HashMap,
fs::{read_to_string, File},
io::prelude::*,
io::BufReader,
io::{prelude::*, BufReader},
ops::{Deref, DerefMut},
path::{Path, PathBuf},
};
Expand Down Expand Up @@ -906,6 +905,189 @@ where
Ok(tokens.join(" "))
}
}

/// Decode the given ids, back to a String
/// See [`DecodeStream`]
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> {
DecodeStream::new(self, skip_special_tokens)
}
}

/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids.
///
/// This is necessary because decoding in general cannot achieve that since strings
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces
///
/// Example:
///
/// ```
/// # #[cfg(not(target_os = "windows"))]
/// # {
/// use tokenizers::Tokenizer;
/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
///
/// let mut decode_stream = tokenizer.decode_stream(false);
/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
/// assert_eq!(
/// decode_stream.step(1246).unwrap(),
/// Some(" example".to_string())
/// );
/// # }
/// ```
///
/// Returning `None` means the given id is not enough to produce a chunk.
/// This typically happens with `byte_fallback` options where some tokens do
/// not represent valid utf-8, and only follow-up token_ids will help produce
/// a valid chunk.
/// ```
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC};
/// use std::collections::HashMap;
/// use std::iter::FromIterator;
///
/// let vocab = HashMap::from_iter([
/// ("<0x20>".to_string(), 0),
/// ("<0xC3>".to_string(), 1),
/// ("<0xA9>".to_string(), 2),
/// (" This".to_string(), 3),
/// ]);
/// let merges = vec![];
/// let bpe = BPE::builder()
/// .vocab_and_merges(vocab, merges)
/// .byte_fallback(true)
/// .build()
/// .unwrap();
/// let tokenizer = TokenizerBuilder::default()
/// .with_model(bpe)
/// .with_decoder(Some(ByteFallback::default()))
/// .with_normalizer(Some(NFC))
/// .with_pre_tokenizer(Some(ByteLevel::default()))
/// .with_post_processor(Some(ByteLevel::default()))
/// .build().unwrap();
///
/// let mut decode_stream = tokenizer.decode_stream(false);
/// // Single byte_fallback is valid utf-8
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
/// // Invalid utf-8
/// assert_eq!(decode_stream.step(1).unwrap(), None);
/// // Valid utf-8 again, this corresponds to both tokens: [1, 2]
/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
/// ```
///
/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would
/// fail.
///
/// ```
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC};
/// use std::collections::HashMap;
/// use std::iter::FromIterator;
///
/// let vocab = HashMap::from_iter([
/// ("▁This".to_string(), 0),
/// ]);
/// let merges = vec![];
/// let bpe = BPE::builder()
/// .vocab_and_merges(vocab, merges)
/// .byte_fallback(true)
/// .build()
/// .unwrap();
/// let tokenizer = TokenizerBuilder::new()
/// .with_model(bpe)
/// .with_decoder(Some(Metaspace::default()))
/// .with_normalizer(Some(NFC))
/// .with_pre_tokenizer(Some(ByteLevel::default()))
/// .with_post_processor(Some(ByteLevel::default()))
/// .build()
/// .unwrap();
///
/// // Strip decoder removes the extra initial space
/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This");
/// // Decoding one token at a time would produce "ThisThis"
/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This");
///
/// // Using a stream fixes it by keeping the necessary state.
/// let mut decode_stream = tokenizer.decode_stream(false);
/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string()));
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string()));
/// ```
pub struct DecodeStream<'tok, M, N, PT, PP, D> {
/// A reference to the tokenizer
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
/// Regular decode option that is kept throughout.
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}

#[derive(thiserror::Error, Debug)]
pub enum DecodeStreamError {
#[error("Invalid prefix encountered")]
InvalidPrefix,
}

impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
Self {
tokenizer,
ids: vec![],
skip_special_tokens,
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

/// See [`DecodeStream`]
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&self.prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.ids = self.ids.drain(self.read_index..).collect();
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.read_index = self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
}
}
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
Expand Down
63 changes: 63 additions & 0 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::collections::HashMap;
use std::iter::FromIterator;

use tokenizers::decoders::byte_fallback::ByteFallback;
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
use tokenizers::normalizers::{Sequence, Strip, NFC};
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
Expand Down Expand Up @@ -58,6 +62,65 @@ fn load_tokenizer() {
assert_eq!(decoded, example);
}

#[test]
fn streaming_tokenizer() {
let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();

let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
assert_eq!(
decode_stream.step(1246).unwrap(),
Some(" example".to_string())
);

let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
let encoded = tokenizer.encode("This is an example", false).unwrap();
assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]);
let mut decode_stream = tokenizer.decode_stream(false);
// No space anymore
assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string()));
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string()));
assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string()));
assert_eq!(
decode_stream.step(823).unwrap(),
Some(" example".to_string())
);

// None example
let vocab = HashMap::from_iter([
("<0x20>".to_string(), 0),
("<0xC3>".to_string(), 1),
("<0xA9>".to_string(), 2),
(" This".to_string(), 3),
]);
let merges = vec![];
let bpe = BPE::builder()
.vocab_and_merges(vocab, merges)
.byte_fallback(true)
.build()
.unwrap();
let tokenizer = TokenizerBuilder::new()
.with_model(bpe)
.with_normalizer(Some(Sequence::new(vec![
Strip::new(true, true).into(),
NFC.into(),
])))
.with_pre_tokenizer(Some(ByteLevel::default()))
.with_post_processor(Some(ByteLevel::default()))
.with_decoder(Some(ByteFallback::default()))
.build()
.unwrap();
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
assert_eq!(decode_stream.step(1).unwrap(), None);
assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
assert_eq!(decode_stream.step(2).unwrap(), None);
}

#[test]
#[ignore]
fn quicktour_slow_train() -> tokenizers::Result<()> {
Expand Down

0 comments on commit 500db28

Please sign in to comment.