diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 699d0bbdc..0f64357a3 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -11,7 +11,7 @@ use tk::pre_tokenizers::bert::BertPreTokenizer; use tk::pre_tokenizers::byte_level::ByteLevel; use tk::pre_tokenizers::delimiter::CharDelimiterSplit; use tk::pre_tokenizers::digits::Digits; -use tk::pre_tokenizers::metaspace::Metaspace; +use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; use tk::pre_tokenizers::punctuation::Punctuation; use tk::pre_tokenizers::split::Split; use tk::pre_tokenizers::unicode_scripts::UnicodeScripts; @@ -452,6 +452,21 @@ impl PySequence { } } +fn from_string(string: String) -> Result { + let scheme = match string.as_str() { + "first" => PrependScheme::First, + "never" => PrependScheme::Never, + "always" => PrependScheme::Always, + _ => { + return Err(exceptions::PyValueError::new_err(format!( + "{} is an unknown variant, should be one of ['first', 'never', 'always']", + string + ))); + } + }; + Ok(scheme) +} + /// Metaspace pre-tokenizer /// /// This pre-tokenizer replaces any whitespace by the provided replacement character. @@ -489,17 +504,44 @@ impl PyMetaspace { setter!(self_, Metaspace, add_prefix_space, add_prefix_space); } + #[getter] + fn get_prepend_scheme(self_: PyRef) -> String { + // Assuming Metaspace has a method to get the prepend_scheme as a string + let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme()); + match scheme { + PrependScheme::First => "first", + PrependScheme::Never => "never", + PrependScheme::Always => "always", + } + .to_string() + } + + #[setter] + fn set_prepend_scheme(self_: PyRef, prepend_scheme: String) -> PyResult<()> { + let scheme = from_string(prepend_scheme)?; + setter!(self_, Metaspace, @set_prepend_scheme, scheme); + Ok(()) + } + #[new] - #[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")] + #[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, prepend_scheme=None, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")] fn new( replacement: PyChar, add_prefix_space: bool, + prepend_scheme: Option, _kwargs: Option<&PyDict>, - ) -> (Self, PyPreTokenizer) { - ( - PyMetaspace {}, - Metaspace::new(replacement.0, add_prefix_space).into(), - ) + ) -> PyResult<(Self, PyPreTokenizer)> { + // Create a new Metaspace instance + let mut new_instance: Metaspace = Metaspace::new(replacement.0, add_prefix_space); + + // If a prepend scheme is provided, set it + if let Some(prepend_scheme) = prepend_scheme { + match from_string(prepend_scheme) { + Ok(prepend_scheme_enum) => new_instance.set_prepend_scheme(prepend_scheme_enum), + Err(err) => return Err(err), + } + } + Ok((PyMetaspace {}, new_instance.into())) } } diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index 90b468c12..daf827b26 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -110,6 +110,8 @@ def test_can_modify(self): assert pretok.replacement == "%" pretok.add_prefix_space = True assert pretok.add_prefix_space == True + pretok.prepend_scheme = "never" + assert pretok.prepend_scheme == "never" class TestCharDelimiterSplit: diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 94204b8f1..87beeca5b 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -73,15 +73,14 @@ mod tests { #[test] fn decoder_serialization() { - let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; + let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#; let decoder: DecoderWrapper = serde_json::from_str(json).unwrap(); let serialized = serde_json::to_string(&decoder).unwrap(); assert_eq!(serialized, json); } - #[test] fn decoder_serialization_other_no_arg() { - let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; + let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#; let decoder: DecoderWrapper = serde_json::from_str(json).unwrap(); let serialized = serde_json::to_string(&decoder).unwrap(); assert_eq!(serialized, json); @@ -89,7 +88,7 @@ mod tests { #[test] fn decoder_serialization_no_decode() { - let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; + let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#; assert!(serde_json::from_str::(json).is_err()); } } diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index ad4df5afb..d7c8c3861 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,6 +1,17 @@ +use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use serde::{Deserialize, Deserializer, Serialize}; -use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; +/// Enum representing options for the metaspace prepending scheme. +#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)] +#[serde(rename_all = "snake_case")] +pub enum PrependScheme { + /// Specifies that the scheme should be prepended only once, on the first split. + First, + /// Specifies that the space should not be prepended. + Never, + /// Specifies that the scheme should always be prepended. + Always, +} #[derive(Debug, Clone, PartialEq, Serialize, Eq)] /// Replaces all the whitespaces by the provided meta character and then @@ -9,6 +20,7 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitD pub struct Metaspace { replacement: char, pub add_prefix_space: bool, + pub prepend_scheme: PrependScheme, #[serde(skip)] str_rep: String, } @@ -23,27 +35,51 @@ impl<'de> Deserialize<'de> for Metaspace { Metaspace, } + fn default_prepend_scheme_value() -> PrependScheme { + PrependScheme::Always + } + #[derive(Deserialize)] pub struct MetaspaceHelper { #[serde(rename = "type")] _type: Type, replacement: char, pub add_prefix_space: bool, + #[serde(default = "default_prepend_scheme_value")] + pub prepend_scheme: PrependScheme, #[serde(skip, rename = "str_rep")] _str_rep: String, } let helper = MetaspaceHelper::deserialize(deserializer)?; - Ok(Self::new(helper.replacement, helper.add_prefix_space)) + let instance = Self::new_with_prepend_scheme( + helper.replacement, + helper.add_prefix_space, + helper.prepend_scheme, + ); + Ok(instance) } } impl Metaspace { pub fn new(replacement: char, add_prefix_space: bool) -> Self { + Self::new_with_prepend_scheme( + replacement, + add_prefix_space, + PrependScheme::Always, // always prepend for legacy purpose + ) + } + + pub fn new_with_prepend_scheme( + replacement: char, + add_prefix_space: bool, + prepend_scheme: PrependScheme, + ) -> Self { Self { replacement, str_rep: replacement.to_string(), add_prefix_space, + prepend_scheme, } } @@ -55,6 +91,14 @@ impl Metaspace { self.replacement = replacement; self.str_rep = replacement.to_string(); } + + pub fn get_prepend_scheme(&self) -> PrependScheme { + self.prepend_scheme + } + + pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) { + self.prepend_scheme = scheme; + } } impl Default for Metaspace { @@ -65,10 +109,19 @@ impl Default for Metaspace { impl PreTokenizer for Metaspace { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { + let mut first_split = true; + pretokenized.split(|_, mut normalized| { normalized.replace(' ', &self.str_rep)?; if self.add_prefix_space && !normalized.get().starts_with(self.replacement) { - normalized.prepend(&self.str_rep); + if self.prepend_scheme == PrependScheme::Always { + normalized.prepend(&self.str_rep); + } else if self.prepend_scheme == PrependScheme::First && first_split { + normalized.prepend(&self.str_rep); + first_split = false; + } + } else { + first_split = false; } normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext) @@ -103,13 +156,15 @@ impl Decoder for Metaspace { #[cfg(test)] mod tests { + use regex::Regex; + use super::*; use crate::{OffsetReferential, OffsetType}; #[test] fn serialization() { let metaspace = Metaspace::new('_', true); - let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#; + let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#; assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s); assert_eq!( serde_json::from_str::(metaspace_s).unwrap(), @@ -118,8 +173,7 @@ mod tests { // Also check it can deserialize previous versions let metaspace = Metaspace::new('_', true); - let metaspace_s = - r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true}"#; + let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#; assert_eq!( serde_json::from_str::(metaspace_s).unwrap(), metaspace @@ -188,6 +242,121 @@ mod tests { ); } + #[test] + fn non_legacy_meta_space() { + assert_eq!( + Metaspace::new('▁', true), + Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always) + ); + + let mut pretok = Metaspace::new('▁', true); + pretok.set_prepend_scheme(PrependScheme::Always); + assert_eq!( + pretok, + Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always) + ); + + pretok.set_prepend_scheme(PrependScheme::Never); + assert_eq!( + pretok, + Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never) + ); + + pretok.set_prepend_scheme(PrependScheme::First); + assert_eq!( + pretok, + Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First) + ); + + let mut pretokenized = PreTokenizedString::from("Hey my friend how▁are you"); + let re_ref = Regex::new(r"()").unwrap(); + pretokenized + .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated)) + .expect("Bad split"); + + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Normalized, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("▁Hey", (0, 6)), + ("▁my", (6, 11)), + ("▁friend", (11, 20)), + ("▁", (20, 23)), + ("", (23, 26)), + ("how", (26, 29)), + ("▁are", (29, 35)), + ("▁you", (35, 41)) + ] + ); + pretok.set_prepend_scheme(PrependScheme::Always); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Normalized, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("▁Hey", (0, 6)), + ("▁my", (6, 11)), + ("▁friend", (11, 20)), + ("▁", (20, 23)), + ("▁", (23, 29)), + ("▁how", (29, 35)), + ("▁are", (35, 41)), + ("▁you", (41, 47)) + ] + ); + + pretok.set_prepend_scheme(PrependScheme::First); + let mut pretokenized = PreTokenizedString::from(" Hey how"); // test with prefix + pretokenized + .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated)) + .expect("Bad split"); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Normalized, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("▁Hey", (0, 6)), + ("▁", (6, 9)), + ("", (9, 12)), + ("how", (12, 15)) + ] + ); + + let mut pretokenized = PreTokenizedString::from(" Hey how are you"); // test with many splits + pretokenized + .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated)) + .expect("Bad split"); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Normalized, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("▁Hey", (0, 6)), + ("▁", (6, 9)), + ("", (9, 12)), + ("how", (12, 15)), + ("▁", (15, 18)), + ("", (18, 21)), + ("are", (21, 24)), + ("▁", (24, 27)), + ("", (27, 30)), + ("▁you", (30, 36)) + ] + ); + } #[test] fn decode() { let decoder = Metaspace::new('▁', true); diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 0ec7f9f59..42bbd15aa 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -104,6 +104,34 @@ mod tests { PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true)) ])) ); + + let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str( + r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"first"}"#, + ) + .unwrap(); + + assert_eq!( + pre_tokenizer, + PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme( + '▁', + true, + metaspace::PrependScheme::First + )) + ); + + let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str( + r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}"#, + ) + .unwrap(); + + assert_eq!( + pre_tokenizer, + PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme( + '▁', + true, + metaspace::PrependScheme::Always + )) + ); } #[test]