Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast regex #1605

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ lazy_static! {
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
)
.unwrap();
static ref RE_VEC: Vec<SysRegex> = {
let pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
let mut vec = Vec::with_capacity(MAX_NUM_THREADS);
for _ in 0..MAX_NUM_THREADS {
vec.push(SysRegex::new(pattern).unwrap());
}
vec
};
static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
static ref CHAR_BYTES: HashMap<char, u8> =
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
Expand Down Expand Up @@ -111,12 +119,31 @@ impl ByteLevel {
}
}

use std::num::NonZeroU64;
use std::thread;

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x =
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
u64::from(x) as usize - 1
}

const MAX_NUM_THREADS: usize = 128;

/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
/// their byte-level counterpart. It also splits the input according to the configured regex.
// TODO: Give the ability to modify this regex
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let re_ref: &SysRegex = &RE;
let re_ref: &SysRegex = &RE_VEC[hash_current_thread() % MAX_NUM_THREADS]; // TODO use the thread thing here as well!
pretokenized.split(|_, mut normalized| {
if self.add_prefix_space && !normalized.get().starts_with(' ') {
normalized.prepend(" ");
Expand Down
66 changes: 61 additions & 5 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@
}
}

use std::num::NonZeroU64;
use std::thread;

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x =
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
u64::from(x) as usize
}

const MAX_NUM_THREADS: usize = 128;

type MatchingSet = (AhoCorasick, Vec<u32>);

lazy_static! {
Expand Down Expand Up @@ -156,11 +175,16 @@
/// us remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>,

/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
//// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie: MatchingSet,

// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie_vec: Vec<MatchingSet>,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie_vec: Vec<MatchingSet>,

/// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
encode_special_tokens: bool,
}
Expand All @@ -181,8 +205,10 @@
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
split_trie: (trie.clone(), vec![]),
split_normalized_trie: (normalized_trie.clone(), vec![]),
split_trie_vec: vec![(trie, vec![]); MAX_NUM_THREADS],
split_normalized_trie_vec: vec![(normalized_trie, vec![]); MAX_NUM_THREADS],
encode_special_tokens: false,
}
}
Expand Down Expand Up @@ -345,6 +371,7 @@
.build(tokens.iter().map(|token| &token.content))
.expect("Failed to build tried when refreshing tokens");
self.split_trie = (trie, ids);
self.split_trie_vec = vec![self.split_trie.clone(); MAX_NUM_THREADS];

let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
let patterns: Vec<_> = ntokens
Expand All @@ -362,6 +389,7 @@
.build(patterns.iter().map(|content| content.get()))
.expect("Failed to build tried when refreshing tokens (normalized)");
self.split_normalized_trie = (normalized_trie, nids);
self.split_normalized_trie_vec = vec![self.split_normalized_trie.clone(); MAX_NUM_THREADS];
}

/// Find any AddedToken in the given sentence, using the provided MatchingSet.
Expand Down Expand Up @@ -425,10 +453,30 @@
splits
}


fn fast_split_with_indices(
&self,
sentence: NormalizedString,
split_re: &MatchingSet,
) -> Vec<(NormalizedString, Option<Vec<Token>>)> {
self.find_matches(sentence.get(), split_re)
.into_iter()
.map(|(id, byte_offsets)| {
let slice = sentence
.slice(Range::Normalized(byte_offsets.0..byte_offsets.1))
.expect("AddedVocabulary bad split");
if let Some(id) = id {
(slice, Some(vec![Token::new(id, String::new(), (0, 0))]))
} else {
(slice, None)
}
})
.collect()
}
/// Split the input sentence to extract anything we found from the `MatchingSet`, as well as
/// the list of corresponding IDs
/// The list of IDs have the exact same number of elements than the Iterator.
fn split_with_indices(

Check warning on line 479 in tokenizers/src/tokenizer/added_vocabulary.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.8)

method `split_with_indices` is never used

Check warning on line 479 in tokenizers/src/tokenizer/added_vocabulary.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

method `split_with_indices` is never used

Check warning on line 479 in tokenizers/src/tokenizer/added_vocabulary.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.7)

method `split_with_indices` is never used

Check warning on line 479 in tokenizers/src/tokenizer/added_vocabulary.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.9)

method `split_with_indices` is never used

Check warning on line 479 in tokenizers/src/tokenizer/added_vocabulary.rs

View workflow job for this annotation

GitHub Actions / Check everything builds

method `split_with_indices` is never used

Check warning on line 479 in tokenizers/src/tokenizer/added_vocabulary.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest)

method `split_with_indices` is never used
&self,
sentence: NormalizedString,
split_re: &MatchingSet,
Expand Down Expand Up @@ -465,7 +513,12 @@

// 1. We extract all the non-normalized tokens from the non-normalized string
pretokenized
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
.split(|_, sequence| {
Ok(self.fast_split_with_indices(
sequence,
&self.split_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
))
})
.expect("AddedVocabulary bad split");

// <s> normalized = False
Expand All @@ -484,7 +537,10 @@
pretokenized
.split(|_, mut sequence| {
normalizer.map(|n| n.normalize(&mut sequence));
Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
Ok(self.fast_split_with_indices(
sequence,
&self.split_normalized_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
))
})
.expect("AddedVocabulary bad split");

Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,13 +875,13 @@
fn do_tokenize<P: Into<PreTokenizedString>>(
&self,
pretokenized: P,
type_id: u32,

Check warning on line 878 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.8)

unused variable: `type_id`

Check warning on line 878 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

unused variable: `type_id`

Check warning on line 878 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.7)

unused variable: `type_id`

Check warning on line 878 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.9)

unused variable: `type_id`

Check warning on line 878 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds

unused variable: `type_id`

Check warning on line 878 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest)

unused variable: `type_id`
word_idx: Option<u32>,

Check warning on line 879 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.8)

unused variable: `word_idx`

Check warning on line 879 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

unused variable: `word_idx`

Check warning on line 879 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.7)

unused variable: `word_idx`

Check warning on line 879 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.9)

unused variable: `word_idx`

Check warning on line 879 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds

unused variable: `word_idx`

Check warning on line 879 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest)

unused variable: `word_idx`
offsets_type: OffsetType,

Check warning on line 880 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.8)

unused variable: `offsets_type`

Check warning on line 880 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

unused variable: `offsets_type`

Check warning on line 880 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.7)

unused variable: `offsets_type`

Check warning on line 880 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.9)

unused variable: `offsets_type`

Check warning on line 880 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds

unused variable: `offsets_type`

Check warning on line 880 in tokenizers/src/tokenizer/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest)

unused variable: `offsets_type`
) -> Result<Encoding> {
let mut pretokenized: PreTokenizedString = pretokenized.into();
pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?;
pretokenized.into_encoding(word_idx, type_id, offsets_type)
pretokenized.fast_into_encoding()
}
}

Expand Down
19 changes: 19 additions & 0 deletions tokenizers/src/tokenizer/pre_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,25 @@ impl PreTokenizedString {
}
}

pub fn fast_into_encoding(self) -> Result<Encoding> {
if self.splits.is_empty() {
Ok(Encoding::default())
} else if !self.splits.iter().all(|split| split.tokens.is_some()) {
Err("Split has not been tokenized.".into())
} else {
let tokens = self
.splits
.into_iter()
.flat_map(|split| {
split.tokens.unwrap().into_iter().map(|token| {
// Replace this with the actual fields you need for the Encoding type
(token.id, String::new(), (0, 0), None, 0)
})
})
.collect();
Ok(tokens)
}
}
/// Returns a list of splits, each of them being a slice of the normalized
/// string, the associated offsets either in original or normalized
/// referential, as well as the potention tokens
Expand Down
Loading