Skip to content

Commit

Permalink
Adding an API for decode streaming.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Nov 6, 2024
1 parent 1740bff commit c97389b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
56 changes: 54 additions & 2 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
use std::{
collections::HashMap,
fs::{read_to_string, File},
io::prelude::*,
io::BufReader,
io::{prelude::*, BufReader},
ops::{Deref, DerefMut},
path::{Path, PathBuf},
};
Expand Down Expand Up @@ -906,6 +905,59 @@ where
Ok(tokens.join(" "))
}
}

/// Decode the given ids, back to a String
pub fn decode_stream<'tok>(
&'tok self,
skip_special_tokens: bool,
) -> DecodeStream<'tok, M, N, PT, PP, D> {
DecodeStream::new(self, skip_special_tokens)
}
}

pub struct DecodeStream<'tok, M, N, PT, PP, D> {
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
ids: Vec<u32>,
prefix_index: usize,
prefix: String,
skip_special_tokens: bool,
}

impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
Self {
tokenizer,
ids: vec![],
skip_special_tokens,
prefix: "".to_string(),
prefix_index: 0,
}
}

pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
println!("Decode got {string} {} Ids:{:?}", self.prefix, self.ids);
if string.len() > self.prefix.len() && !string.ends_with('�') {
let new_text = &string[self.prefix.len()..];
self.prefix = new_text.to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.ids = self.ids.drain(self.prefix_index..).collect();
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
}
}
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
Expand Down
16 changes: 16 additions & 0 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ fn load_tokenizer() {
assert_eq!(decoded, example);
}

#[test]
fn streaming_tokenizer() {
let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();

let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
assert_eq!(
decode_stream.step(1246).unwrap(),
Some(" example".to_string())
);

// TODO add an example with byte fallback for `None` example
}

#[test]
#[ignore]
fn quicktour_slow_train() -> tokenizers::Result<()> {
Expand Down

0 comments on commit c97389b

Please sign in to comment.