-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding an API for decode streaming. #1677
Changes from all commits
c97389b
c3578d4
bdcb2b9
af7d82e
5a5406e
18b999c
c32a2c2
a326447
218fd3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,7 @@ | |
use std::{ | ||
collections::HashMap, | ||
fs::{read_to_string, File}, | ||
io::prelude::*, | ||
io::BufReader, | ||
io::{prelude::*, BufReader}, | ||
ops::{Deref, DerefMut}, | ||
path::{Path, PathBuf}, | ||
}; | ||
|
@@ -906,6 +905,189 @@ where | |
Ok(tokens.join(" ")) | ||
} | ||
} | ||
|
||
/// Decode the given ids, back to a String | ||
/// See [`DecodeStream`] | ||
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> { | ||
DecodeStream::new(self, skip_special_tokens) | ||
} | ||
} | ||
|
||
/// DecodeStream will keep the state necessary to produce individual chunks of | ||
/// strings given an input stream of token_ids. | ||
/// | ||
/// This is necessary because decoding in general cannot achieve that since strings | ||
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces | ||
/// | ||
/// Example: | ||
/// | ||
/// ``` | ||
/// # #[cfg(not(target_os = "windows"))] | ||
/// # { | ||
/// use tokenizers::Tokenizer; | ||
/// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); | ||
/// | ||
/// let mut decode_stream = tokenizer.decode_stream(false); | ||
/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); | ||
/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); | ||
/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); | ||
/// assert_eq!( | ||
/// decode_stream.step(1246).unwrap(), | ||
/// Some(" example".to_string()) | ||
/// ); | ||
/// # } | ||
/// ``` | ||
/// | ||
/// Returning `None` means the given id is not enough to produce a chunk. | ||
/// This typically happens with `byte_fallback` options where some tokens do | ||
/// not represent valid utf-8, and only follow-up token_ids will help produce | ||
/// a valid chunk. | ||
/// ``` | ||
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC}; | ||
/// use std::collections::HashMap; | ||
/// use std::iter::FromIterator; | ||
/// | ||
/// let vocab = HashMap::from_iter([ | ||
/// ("<0x20>".to_string(), 0), | ||
/// ("<0xC3>".to_string(), 1), | ||
/// ("<0xA9>".to_string(), 2), | ||
/// (" This".to_string(), 3), | ||
/// ]); | ||
/// let merges = vec![]; | ||
/// let bpe = BPE::builder() | ||
/// .vocab_and_merges(vocab, merges) | ||
/// .byte_fallback(true) | ||
/// .build() | ||
/// .unwrap(); | ||
/// let tokenizer = TokenizerBuilder::default() | ||
/// .with_model(bpe) | ||
/// .with_decoder(Some(ByteFallback::default())) | ||
/// .with_normalizer(Some(NFC)) | ||
/// .with_pre_tokenizer(Some(ByteLevel::default())) | ||
/// .with_post_processor(Some(ByteLevel::default())) | ||
/// .build().unwrap(); | ||
/// | ||
/// let mut decode_stream = tokenizer.decode_stream(false); | ||
/// // Single byte_fallback is valid utf-8 | ||
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); | ||
/// // Invalid utf-8 | ||
/// assert_eq!(decode_stream.step(1).unwrap(), None); | ||
/// // Valid utf-8 again, this corresponds to both tokens: [1, 2] | ||
/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); | ||
/// ``` | ||
/// | ||
/// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would | ||
/// fail. | ||
/// | ||
/// ``` | ||
/// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC}; | ||
/// use std::collections::HashMap; | ||
/// use std::iter::FromIterator; | ||
/// | ||
/// let vocab = HashMap::from_iter([ | ||
/// ("▁This".to_string(), 0), | ||
/// ]); | ||
/// let merges = vec![]; | ||
/// let bpe = BPE::builder() | ||
/// .vocab_and_merges(vocab, merges) | ||
/// .byte_fallback(true) | ||
/// .build() | ||
/// .unwrap(); | ||
/// let tokenizer = TokenizerBuilder::new() | ||
/// .with_model(bpe) | ||
/// .with_decoder(Some(Metaspace::default())) | ||
/// .with_normalizer(Some(NFC)) | ||
/// .with_pre_tokenizer(Some(ByteLevel::default())) | ||
/// .with_post_processor(Some(ByteLevel::default())) | ||
/// .build() | ||
/// .unwrap(); | ||
/// | ||
/// // Strip decoder removes the extra initial space | ||
/// assert_eq!(tokenizer.decode(&[0, 0], false).unwrap(), "This This"); | ||
/// // Decoding one token at a time would produce "ThisThis" | ||
/// assert_eq!(tokenizer.decode(&[0], false).unwrap(), "This"); | ||
/// | ||
/// // Using a stream fixes it by keeping the necessary state. | ||
/// let mut decode_stream = tokenizer.decode_stream(false); | ||
/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string())); | ||
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string())); | ||
/// ``` | ||
pub struct DecodeStream<'tok, M, N, PT, PP, D> { | ||
/// A reference to the tokenizer | ||
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, | ||
/// Regular decode option that is kept throughout. | ||
skip_special_tokens: bool, | ||
/// A temporary buffer of the necessary token_ids needed | ||
/// to produce valid string chunks. | ||
/// This typically contains 3 parts: | ||
/// - read | ||
/// - prefix | ||
/// - rest | ||
/// | ||
/// Read is the bit necessary to surround the prefix | ||
/// so decoding the whole ids produces a valid prefix. | ||
/// Prefix is the previously produced string, kept around to trim off of | ||
/// the next valid chunk | ||
ids: Vec<u32>, | ||
/// The previously returned chunk that needs to be discarded from the | ||
/// decoding of the current ids to produce the next chunk | ||
prefix: String, | ||
/// The index within the ids corresponding to the prefix so we can drain | ||
/// correctly | ||
prefix_index: usize, | ||
/// We need to keep 2 prefixes. | ||
/// Prefix is the second one that was already emitted to discard the part | ||
/// of the text of all the ids | ||
/// read is the prefix kept only for starting side effects of the prefix | ||
read_index: usize, | ||
} | ||
|
||
#[derive(thiserror::Error, Debug)] | ||
pub enum DecodeStreamError { | ||
#[error("Invalid prefix encountered")] | ||
InvalidPrefix, | ||
} | ||
|
||
impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D> | ||
where | ||
M: Model, | ||
N: Normalizer, | ||
PT: PreTokenizer, | ||
PP: PostProcessor, | ||
D: Decoder, | ||
{ | ||
fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self { | ||
Self { | ||
tokenizer, | ||
ids: vec![], | ||
skip_special_tokens, | ||
prefix: "".to_string(), | ||
prefix_index: 0, | ||
read_index: 0, | ||
} | ||
} | ||
|
||
ArthurZucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// See [`DecodeStream`] | ||
pub fn step(&mut self, id: u32) -> Result<Option<String>> { | ||
self.ids.push(id); | ||
let string = self | ||
.tokenizer | ||
.decode(self.ids.as_slice(), self.skip_special_tokens)?; | ||
if string.len() > self.prefix.len() && !string.ends_with('�') { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran into streamed decoding issues as well and had the same solution in mind. However, I came to the conclusion that this solution has its own flaw: If you want to actually decode the The only clean solution is to offer a |
||
if !(string.starts_with(&self.prefix)) { | ||
return Err(Box::new(DecodeStreamError::InvalidPrefix)); | ||
} | ||
let new_text = &string[self.prefix.len()..].to_string(); | ||
let new_prefix_index = self.ids.len() - self.prefix_index; | ||
self.ids = self.ids.drain(self.read_index..).collect(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, the state is bound to be quite small with this! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what we have in TGI, the overhead is indeed quite low. You're decoding twice as much (prefix + new text) and you have only a handful of extra tokens. |
||
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; | ||
self.read_index = self.prefix_index; | ||
self.prefix_index = new_prefix_index; | ||
Ok(Some(new_text.to_string())) | ||
} else { | ||
Ok(None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. returning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it's wrong to do so. Invalid utf-8 is perfectly normal and should not be returned before enough token are accumulated (see the accent example) to provide valid utf-8. If valid utf-8 follows invalid utf-8 then both will be returned at the same time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Producing "" is also wrong, since the token really didn't produce anything, not the empty string. |
||
} | ||
} | ||
} | ||
|
||
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd rather we explicitly create this
DecodeStream
withDecodeStream::new(tokenizer, ...)
without adding this to the tokenizer funcs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you wish, this follows the
.iter()
pattern in regular rust as it's more convient given the lifetime bound of theDecodeStream
object.https://doc.rust-lang.org/src/alloc/collections/vec_deque/mod.rs.html#1204
It's really just sugard, I can happily remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. No sounds good, was more thinking about the coming python api as well but in rust makes sense for sure