diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 2d3845b55..13d3fbef4 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -41,6 +41,14 @@ lazy_static! { r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" ) .unwrap(); + static ref RE_VEC: Vec = { + let pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; + let mut vec = Vec::with_capacity(MAX_NUM_THREADS); + for _ in 0..MAX_NUM_THREADS { + vec.push(SysRegex::new(pattern).unwrap()); + } + vec + }; static ref BYTES_CHAR: HashMap = bytes_char(); static ref CHAR_BYTES: HashMap = bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); @@ -111,12 +119,31 @@ impl ByteLevel { } } +use std::num::NonZeroU64; +use std::thread; + +pub struct FakeThreadId(NonZeroU64); + +fn hash_current_thread() -> usize { + // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter + // that works great for our use case of avoiding collisions in our array. Unfortunately, + // it's private. However, there are only so many ways you can layout a u64, so just transmute + // https://github.com/rust-lang/rust/issues/67939 + const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; + let x = + unsafe { std::mem::transmute::(thread::current().id()).0 }; + u64::from(x) as usize - 1 +} + +const MAX_NUM_THREADS: usize = 128; + /// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into /// their byte-level counterpart. It also splits the input according to the configured regex. // TODO: Give the ability to modify this regex impl PreTokenizer for ByteLevel { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { - let re_ref: &SysRegex = &RE; + let re_ref: &SysRegex = &RE_VEC[hash_current_thread() % MAX_NUM_THREADS]; // TODO use the thread thing here as well! pretokenized.split(|_, mut normalized| { if self.add_prefix_space && !normalized.get().starts_with(' ') { normalized.prepend(" "); diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index a0c2f4542..7db205189 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -92,6 +92,25 @@ impl std::hash::Hash for AddedToken { } } +use std::num::NonZeroU64; +use std::thread; + +pub struct FakeThreadId(NonZeroU64); + +fn hash_current_thread() -> usize { + // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter + // that works great for our use case of avoiding collisions in our array. Unfortunately, + // it's private. However, there are only so many ways you can layout a u64, so just transmute + // https://github.com/rust-lang/rust/issues/67939 + const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; + let x = + unsafe { std::mem::transmute::(thread::current().id()).0 }; + u64::from(x) as usize +} + +const MAX_NUM_THREADS: usize = 128; + type MatchingSet = (AhoCorasick, Vec); lazy_static! { @@ -156,11 +175,16 @@ pub struct AddedVocabulary { /// us remove them easily with an O(1) complexity. special_tokens_set: HashSet, - /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens + //// A RegexSet containing all the non-normalized patterns used to split on AddedTokens split_trie: MatchingSet, /// A RegexSet containing all the normalized patterns used to split on AddedTokens split_normalized_trie: MatchingSet, + // A RegexSet containing all the non-normalized patterns used to split on AddedTokens + split_trie_vec: Vec, + /// A RegexSet containing all the normalized patterns used to split on AddedTokens + split_normalized_trie_vec: Vec, + /// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them encode_special_tokens: bool, } @@ -181,8 +205,10 @@ impl AddedVocabulary { added_tokens: vec![], special_tokens: vec![], special_tokens_set: HashSet::new(), - split_trie: (trie, vec![]), - split_normalized_trie: (normalized_trie, vec![]), + split_trie: (trie.clone(), vec![]), + split_normalized_trie: (normalized_trie.clone(), vec![]), + split_trie_vec: vec![(trie, vec![]); MAX_NUM_THREADS], + split_normalized_trie_vec: vec![(normalized_trie, vec![]); MAX_NUM_THREADS], encode_special_tokens: false, } } @@ -345,6 +371,7 @@ impl AddedVocabulary { .build(tokens.iter().map(|token| &token.content)) .expect("Failed to build tried when refreshing tokens"); self.split_trie = (trie, ids); + self.split_trie_vec = vec![self.split_trie.clone(); MAX_NUM_THREADS]; let (ntokens, nids): (Vec<&AddedToken>, Vec) = normalized.into_iter().unzip(); let patterns: Vec<_> = ntokens @@ -362,6 +389,7 @@ impl AddedVocabulary { .build(patterns.iter().map(|content| content.get())) .expect("Failed to build tried when refreshing tokens (normalized)"); self.split_normalized_trie = (normalized_trie, nids); + self.split_normalized_trie_vec = vec![self.split_normalized_trie.clone(); MAX_NUM_THREADS]; } /// Find any AddedToken in the given sentence, using the provided MatchingSet. @@ -425,6 +453,26 @@ impl AddedVocabulary { splits } + + fn fast_split_with_indices( + &self, + sentence: NormalizedString, + split_re: &MatchingSet, + ) -> Vec<(NormalizedString, Option>)> { + self.find_matches(sentence.get(), split_re) + .into_iter() + .map(|(id, byte_offsets)| { + let slice = sentence + .slice(Range::Normalized(byte_offsets.0..byte_offsets.1)) + .expect("AddedVocabulary bad split"); + if let Some(id) = id { + (slice, Some(vec![Token::new(id, String::new(), (0, 0))])) + } else { + (slice, None) + } + }) + .collect() + } /// Split the input sentence to extract anything we found from the `MatchingSet`, as well as /// the list of corresponding IDs /// The list of IDs have the exact same number of elements than the Iterator. @@ -465,7 +513,12 @@ impl AddedVocabulary { // 1. We extract all the non-normalized tokens from the non-normalized string pretokenized - .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie))) + .split(|_, sequence| { + Ok(self.fast_split_with_indices( + sequence, + &self.split_trie_vec[hash_current_thread() % MAX_NUM_THREADS], + )) + }) .expect("AddedVocabulary bad split"); // normalized = False @@ -484,7 +537,10 @@ impl AddedVocabulary { pretokenized .split(|_, mut sequence| { normalizer.map(|n| n.normalize(&mut sequence)); - Ok(self.split_with_indices(sequence, &self.split_normalized_trie)) + Ok(self.fast_split_with_indices( + sequence, + &self.split_normalized_trie_vec[hash_current_thread() % MAX_NUM_THREADS], + )) }) .expect("AddedVocabulary bad split"); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 1c2ad6e0b..4bde4d39d 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -881,7 +881,7 @@ where ) -> Result { let mut pretokenized: PreTokenizedString = pretokenized.into(); pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?; - pretokenized.into_encoding(word_idx, type_id, offsets_type) + pretokenized.fast_into_encoding() } } diff --git a/tokenizers/src/tokenizer/pre_tokenizer.rs b/tokenizers/src/tokenizer/pre_tokenizer.rs index 9667c240a..be919fd62 100644 --- a/tokenizers/src/tokenizer/pre_tokenizer.rs +++ b/tokenizers/src/tokenizer/pre_tokenizer.rs @@ -186,6 +186,25 @@ impl PreTokenizedString { } } + pub fn fast_into_encoding(self) -> Result { + if self.splits.is_empty() { + Ok(Encoding::default()) + } else if !self.splits.iter().all(|split| split.tokens.is_some()) { + Err("Split has not been tokenized.".into()) + } else { + let tokens = self + .splits + .into_iter() + .flat_map(|split| { + split.tokens.unwrap().into_iter().map(|token| { + // Replace this with the actual fields you need for the Encoding type + (token.id, String::new(), (0, 0), None, 0) + }) + }) + .collect(); + Ok(tokens) + } + } /// Returns a list of splits, each of them being a slice of the normalized /// string, the associated offsets either in original or normalized /// referential, as well as the potention tokens