diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 6343bbd07..4c3cd5f9c 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -37,7 +37,7 @@ fn bytes_char() -> HashMap { lazy_static! { /// Regex that matches exactly one token. /// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L98 - static ref RE: SysRegex = SysRegex::new( + static ref GPT2_RE: SysRegex = SysRegex::new( r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" ) .unwrap(); @@ -46,7 +46,7 @@ lazy_static! { bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. @@ -63,18 +63,26 @@ pub struct ByteLevel { /// Set it to False if you want to use your own splitting. #[serde(default = "default_true")] pub use_regex: bool, + + #[serde(default = "default_none")] + pub custom_regex: Option, } fn default_true() -> bool { true } +fn default_none() -> Option { + None +} + impl Default for ByteLevel { fn default() -> Self { Self { add_prefix_space: true, trim_offsets: true, use_regex: true, + custom_regex: None, } } } @@ -85,6 +93,7 @@ impl ByteLevel { add_prefix_space, trim_offsets, use_regex, + custom_regex: None, } } @@ -109,6 +118,12 @@ impl ByteLevel { self.use_regex = v; self } + + #[must_use] + pub fn custom_regex(mut self, v: Option) -> Self { + self.custom_regex = v; + self + } } /// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into @@ -116,7 +131,14 @@ impl ByteLevel { // TODO: Give the ability to modify this regex impl PreTokenizer for ByteLevel { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { - let re_ref: &SysRegex = &RE; + let custom_regex = match &self.custom_regex { + Some(s) => Some(SysRegex::new(&s)?), + None => None, + }; + let re_ref: &SysRegex = match &custom_regex { + Some(custom_regex) => &custom_regex, + None => &GPT2_RE, + }; pretokenized.split(|_, mut normalized| { if self.add_prefix_space && !normalized.get().starts_with(' ') { normalized.prepend(" "); @@ -269,6 +291,34 @@ mod tests { ); } + #[test] + fn pre_tokenization_custom_regex() { + let bytelevel = ByteLevel::default() + .custom_regex(Some(r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+".into())) + .add_prefix_space(false); + let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into(); + bytelevel.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("Hello", (0, 5)), + ("Ġmy", (5, 8)), + ("Ġfriend", (8, 15)), + (",", (15, 16)), + ("Ġhow", (16, 20)), + ("Ġis", (20, 23)), + ("Ġyour", (23, 28)), + ("Ġday", (28, 32)), + ("Ġgoing", (32, 38)), + ("?", (38, 39)) + ] + ); + } + #[test] fn pre_tokenization_no_regex() { let bytelevel = ByteLevel::default().use_regex(false); diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 66c670ad8..cc285c5f4 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -64,7 +64,7 @@ mod tests { ); let bytelevel = ByteLevel::default().trim_offsets(true); - let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]); + let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel.clone())]); let expected = Encoding::new( vec![0; 5], vec![0; 5], diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 54fa9053d..08d711153 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -173,7 +173,7 @@ fn decoders() { let byte_level_ser = serde_json::to_string(&byte_level).unwrap(); assert_eq!( byte_level_ser, - r#"{"type":"ByteLevel","add_prefix_space":true,"trim_offsets":true,"use_regex":true}"# + r#"{"type":"ByteLevel","add_prefix_space":true,"trim_offsets":true,"use_regex":true,"custom_regex":null}"# ); serde_json::from_str::(&byte_level_ser).unwrap(); let byte_level_wrapper: DecoderWrapper = serde_json::from_str(&byte_level_ser).unwrap();