Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding an API for decode streaming. #1677

Merged
merged 9 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 184 additions & 2 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
use std::{
collections::HashMap,
fs::{read_to_string, File},
io::prelude::*,
io::BufReader,
io::{prelude::*, BufReader},
ops::{Deref, DerefMut},
path::{Path, PathBuf},
};
Expand Down Expand Up @@ -906,6 +905,189 @@ where
Ok(tokens.join(" "))
}
}

/// Decode the given ids, back to a String
/// See [`DecodeStream`]
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> {
DecodeStream::new(self, skip_special_tokens)
}
Comment on lines +911 to +913
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd rather we explicitly create this DecodeStream with DecodeStream::new(tokenizer, ...)
without adding this to the tokenizer funcs!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you wish, this follows the .iter() pattern in regular rust as it's more convient given the lifetime bound of the DecodeStream object.

https://doc.rust-lang.org/src/alloc/collections/vec_deque/mod.rs.html#1204

It's really just sugard, I can happily remove.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. No sounds good, was more thinking about the coming python api as well but in rust makes sense for sure

}

/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids.
///
/// This is necessary because decoding in general cannot achieve that since strings
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces
///
/// Example:
///
/// ```
/// # #[cfg(not(target_os = "windows"))]
/// # {
/// use tokenizers::Tokenizer;
/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();
///
/// let mut decode_stream = tokenizer.decode_stream(false);
/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
/// assert_eq!(
/// decode_stream.step(1246).unwrap(),
/// Some(" example".to_string())
/// );
/// # }
/// ```
///
/// Returning `None` means the given id is not enough to produce a chunk.
/// This typically happens with `byte_fallback` options where some tokens do
/// not represent valid utf-8, and only follow-up token_ids will help produce
/// a valid chunk.
/// ```
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC};
/// use std::collections::HashMap;
/// use std::iter::FromIterator;
///
/// let vocab = HashMap::from_iter([
/// ("<0x20>".to_string(), 0),
/// ("<0xC3>".to_string(), 1),
/// ("<0xA9>".to_string(), 2),
/// (" This".to_string(), 3),
/// ]);
/// let merges = vec![];
/// let bpe = BPE::builder()
/// .vocab_and_merges(vocab, merges)
/// .byte_fallback(true)
/// .build()
/// .unwrap();
/// let tokenizer = TokenizerBuilder::default()
/// .with_model(bpe)
/// .with_decoder(Some(ByteFallback::default()))
/// .with_normalizer(Some(NFC))
/// .with_pre_tokenizer(Some(ByteLevel::default()))
/// .with_post_processor(Some(ByteLevel::default()))
/// .build().unwrap();
///
/// let mut decode_stream = tokenizer.decode_stream(false);
/// // Single byte_fallback is valid utf-8
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
/// // Invalid utf-8
/// assert_eq!(decode_stream.step(1).unwrap(), None);
/// // Valid utf-8 again, this corresponds to both tokens: [1, 2]
/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
/// ```
///
/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would
/// fail.
///
/// ```
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC};
/// use std::collections::HashMap;
/// use std::iter::FromIterator;
///
/// let vocab = HashMap::from_iter([
/// ("▁This".to_string(), 0),
/// ]);
/// let merges = vec![];
/// let bpe = BPE::builder()
/// .vocab_and_merges(vocab, merges)
/// .byte_fallback(true)
/// .build()
/// .unwrap();
/// let tokenizer = TokenizerBuilder::new()
/// .with_model(bpe)
/// .with_decoder(Some(Metaspace::default()))
/// .with_normalizer(Some(NFC))
/// .with_pre_tokenizer(Some(ByteLevel::default()))
/// .with_post_processor(Some(ByteLevel::default()))
/// .build()
/// .unwrap();
///
/// // Strip decoder removes the extra initial space
/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This");
/// // Decoding one token at a time would produce "ThisThis"
/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This");
///
/// // Using a stream fixes it by keeping the necessary state.
/// let mut decode_stream = tokenizer.decode_stream(false);
/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string()));
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string()));
/// ```
pub struct DecodeStream<'tok, M, N, PT, PP, D> {
/// A reference to the tokenizer
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
/// Regular decode option that is kept throughout.
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}

#[derive(thiserror::Error, Debug)]
pub enum DecodeStreamError {
#[error("Invalid prefix encountered")]
InvalidPrefix,
}

impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
Self {
tokenizer,
ids: vec![],
skip_special_tokens,
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
/// See [`DecodeStream`]
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker @Narsil

I ran into streamed decoding issues as well and had the same solution in mind. However, I came to the conclusion that this solution has its own flaw: If you want to actually decode the character because it is part of the completion, this code would assume it's an incomplete UTF-8 marker and not yield anything.

The only clean solution is to offer a Decoder::decode_u8 method. I could help out here, if desired.

if !(string.starts_with(&self.prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.ids = self.ids.drain(self.read_index..).collect();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, the state is bound to be quite small with this!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we have in TGI, the overhead is indeed quite low. You're decoding twice as much (prefix + new text) and you have only a handful of extra tokens.

self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.read_index = self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returning '�' might be more expected (at least it's not None so people can print it still?) or ''

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's wrong to do so. Invalid utf-8 is perfectly normal and should not be returned before enough token are accumulated (see the accent example) to provide valid utf-8. If valid utf-8 follows invalid utf-8 then both will be returned at the same time

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Producing "" is also wrong, since the token really didn't produce anything, not the empty string.

}
}
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
Expand Down
63 changes: 63 additions & 0 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::collections::HashMap;
use std::iter::FromIterator;

use tokenizers::decoders::byte_fallback::ByteFallback;
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
use tokenizers::normalizers::{Sequence, Strip, NFC};
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
Expand Down Expand Up @@ -58,6 +62,65 @@ fn load_tokenizer() {
assert_eq!(decoded, example);
}

#[test]
fn streaming_tokenizer() {
let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();

let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
assert_eq!(
decode_stream.step(1246).unwrap(),
Some(" example".to_string())
);

let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
let encoded = tokenizer.encode("This is an example", false).unwrap();
assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]);
let mut decode_stream = tokenizer.decode_stream(false);
// No space anymore
assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string()));
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string()));
assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string()));
assert_eq!(
decode_stream.step(823).unwrap(),
Some(" example".to_string())
);

// None example
let vocab = HashMap::from_iter([
("<0x20>".to_string(), 0),
("<0xC3>".to_string(), 1),
("<0xA9>".to_string(), 2),
(" This".to_string(), 3),
]);
let merges = vec![];
let bpe = BPE::builder()
.vocab_and_merges(vocab, merges)
.byte_fallback(true)
.build()
.unwrap();
let tokenizer = TokenizerBuilder::new()
.with_model(bpe)
.with_normalizer(Some(Sequence::new(vec![
Strip::new(true, true).into(),
NFC.into(),
])))
.with_pre_tokenizer(Some(ByteLevel::default()))
.with_post_processor(Some(ByteLevel::default()))
.with_decoder(Some(ByteFallback::default()))
.build()
.unwrap();
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string()));
assert_eq!(decode_stream.step(1).unwrap(), None);
assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string()));
assert_eq!(decode_stream.step(2).unwrap(), None);
}

#[test]
#[ignore]
fn quicktour_slow_train() -> tokenizers::Result<()> {
Expand Down
Loading