Skip to content

Commit

Permalink
[pre_tokenizers] Fix sentencepiece based Metaspace (#1357)
Browse files Browse the repository at this point in the history
* nits

* allow for legacy beahaviour without making any breaking changes

* add a todo

* set to legacy by default

* skip legacy serialization

* push correct update

* lint

* add deserialization test

* add a python test as well

* updates

* fix serialization tests

* nits

* python stylijng of the tests

* better tests

* fix offsets

* fix imports

* fmt

* update metaspace

* remove TODO

* use enm

* fix some tses

* nits

* use enum

* update tests

* syling

* remove impl from for PrependScheme

* use simple getters and setters

* lint

* update tests

* add test new == new_with_prepend_scheme

* revert a change

* use setters and getterts

* Update bindings/python/src/pre_tokenizers.rs

Co-authored-by: Nicolas Patry <[email protected]>

* nits

* use copy rather than ref

* nits format

* more nits

* allow option string

* enforce First Never Always camel cased

* nits

* refactor

* update test as well

* fmt

* nits

* properly error out

* Update bindings/python/src/pre_tokenizers.rs

Co-authored-by: Nicolas Patry <[email protected]>

* suggestion changes

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
ArthurZucker and Narsil authored Nov 14, 2023
1 parent ee2af9e commit f55822b
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 17 deletions.
56 changes: 49 additions & 7 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tk::pre_tokenizers::bert::BertPreTokenizer;
use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::digits::Digits;
use tk::pre_tokenizers::metaspace::Metaspace;
use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme};
use tk::pre_tokenizers::punctuation::Punctuation;
use tk::pre_tokenizers::split::Split;
use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
Expand Down Expand Up @@ -452,6 +452,21 @@ impl PySequence {
}
}

fn from_string(string: String) -> Result<PrependScheme, PyErr> {
let scheme = match string.as_str() {
"first" => PrependScheme::First,
"never" => PrependScheme::Never,
"always" => PrependScheme::Always,
_ => {
return Err(exceptions::PyValueError::new_err(format!(
"{} is an unknown variant, should be one of ['first', 'never', 'always']",
string
)));
}
};
Ok(scheme)
}

/// Metaspace pre-tokenizer
///
/// This pre-tokenizer replaces any whitespace by the provided replacement character.
Expand Down Expand Up @@ -489,17 +504,44 @@ impl PyMetaspace {
setter!(self_, Metaspace, add_prefix_space, add_prefix_space);
}

#[getter]
fn get_prepend_scheme(self_: PyRef<Self>) -> String {
// Assuming Metaspace has a method to get the prepend_scheme as a string
let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme());
match scheme {
PrependScheme::First => "first",
PrependScheme::Never => "never",
PrependScheme::Always => "always",
}
.to_string()
}

#[setter]
fn set_prepend_scheme(self_: PyRef<Self>, prepend_scheme: String) -> PyResult<()> {
let scheme = from_string(prepend_scheme)?;
setter!(self_, Metaspace, @set_prepend_scheme, scheme);
Ok(())
}

#[new]
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, prepend_scheme=None, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
fn new(
replacement: PyChar,
add_prefix_space: bool,
prepend_scheme: Option<String>,
_kwargs: Option<&PyDict>,
) -> (Self, PyPreTokenizer) {
(
PyMetaspace {},
Metaspace::new(replacement.0, add_prefix_space).into(),
)
) -> PyResult<(Self, PyPreTokenizer)> {
// Create a new Metaspace instance
let mut new_instance: Metaspace = Metaspace::new(replacement.0, add_prefix_space);

// If a prepend scheme is provided, set it
if let Some(prepend_scheme) = prepend_scheme {
match from_string(prepend_scheme) {
Ok(prepend_scheme_enum) => new_instance.set_prepend_scheme(prepend_scheme_enum),
Err(err) => return Err(err),
}
}
Ok((PyMetaspace {}, new_instance.into()))
}
}

Expand Down
2 changes: 2 additions & 0 deletions bindings/python/tests/bindings/test_pre_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def test_can_modify(self):
assert pretok.replacement == "%"
pretok.add_prefix_space = True
assert pretok.add_prefix_space == True
pretok.prepend_scheme = "never"
assert pretok.prepend_scheme == "never"


class TestCharDelimiterSplit:
Expand Down
7 changes: 3 additions & 4 deletions tokenizers/src/decoders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,22 @@ mod tests {

#[test]
fn decoder_serialization() {
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
}

#[test]
fn decoder_serialization_other_no_arg() {
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
}

#[test]
fn decoder_serialization_no_decode() {
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
}
}
181 changes: 175 additions & 6 deletions tokenizers/src/pre_tokenizers/metaspace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use serde::{Deserialize, Deserializer, Serialize};

use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
/// Enum representing options for the metaspace prepending scheme.
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
#[serde(rename_all = "snake_case")]
pub enum PrependScheme {
/// Specifies that the scheme should be prepended only once, on the first split.
First,
/// Specifies that the space should not be prepended.
Never,
/// Specifies that the scheme should always be prepended.
Always,
}

#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
/// Replaces all the whitespaces by the provided meta character and then
Expand All @@ -9,6 +20,7 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitD
pub struct Metaspace {
replacement: char,
pub add_prefix_space: bool,
pub prepend_scheme: PrependScheme,
#[serde(skip)]
str_rep: String,
}
Expand All @@ -23,27 +35,51 @@ impl<'de> Deserialize<'de> for Metaspace {
Metaspace,
}

fn default_prepend_scheme_value() -> PrependScheme {
PrependScheme::Always
}

#[derive(Deserialize)]
pub struct MetaspaceHelper {
#[serde(rename = "type")]
_type: Type,
replacement: char,
pub add_prefix_space: bool,
#[serde(default = "default_prepend_scheme_value")]
pub prepend_scheme: PrependScheme,
#[serde(skip, rename = "str_rep")]
_str_rep: String,
}

let helper = MetaspaceHelper::deserialize(deserializer)?;
Ok(Self::new(helper.replacement, helper.add_prefix_space))
let instance = Self::new_with_prepend_scheme(
helper.replacement,
helper.add_prefix_space,
helper.prepend_scheme,
);
Ok(instance)
}
}

impl Metaspace {
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
Self::new_with_prepend_scheme(
replacement,
add_prefix_space,
PrependScheme::Always, // always prepend for legacy purpose
)
}

pub fn new_with_prepend_scheme(
replacement: char,
add_prefix_space: bool,
prepend_scheme: PrependScheme,
) -> Self {
Self {
replacement,
str_rep: replacement.to_string(),
add_prefix_space,
prepend_scheme,
}
}

Expand All @@ -55,6 +91,14 @@ impl Metaspace {
self.replacement = replacement;
self.str_rep = replacement.to_string();
}

pub fn get_prepend_scheme(&self) -> PrependScheme {
self.prepend_scheme
}

pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) {
self.prepend_scheme = scheme;
}
}

impl Default for Metaspace {
Expand All @@ -65,10 +109,19 @@ impl Default for Metaspace {

impl PreTokenizer for Metaspace {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let mut first_split = true;

pretokenized.split(|_, mut normalized| {
normalized.replace(' ', &self.str_rep)?;
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
normalized.prepend(&self.str_rep);
if self.prepend_scheme == PrependScheme::Always {
normalized.prepend(&self.str_rep);
} else if self.prepend_scheme == PrependScheme::First && first_split {
normalized.prepend(&self.str_rep);
first_split = false;
}
} else {
first_split = false;
}

normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
Expand Down Expand Up @@ -103,13 +156,15 @@ impl Decoder for Metaspace {

#[cfg(test)]
mod tests {
use regex::Regex;

use super::*;
use crate::{OffsetReferential, OffsetType};

#[test]
fn serialization() {
let metaspace = Metaspace::new('_', true);
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#;
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
Expand All @@ -118,8 +173,7 @@ mod tests {

// Also check it can deserialize previous versions
let metaspace = Metaspace::new('_', true);
let metaspace_s =
r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true}"#;
let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
metaspace
Expand Down Expand Up @@ -188,6 +242,121 @@ mod tests {
);
}

#[test]
fn non_legacy_meta_space() {
assert_eq!(
Metaspace::new('▁', true),
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);

let mut pretok = Metaspace::new('▁', true);
pretok.set_prepend_scheme(PrependScheme::Always);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);

pretok.set_prepend_scheme(PrependScheme::Never);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
);

pretok.set_prepend_scheme(PrependScheme::First);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
);

let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
let re_ref = Regex::new(r"(<s>)").unwrap();
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
.expect("Bad split");

pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁my", (6, 11)),
("▁friend", (11, 20)),
("▁", (20, 23)),
("<s>", (23, 26)),
("how", (26, 29)),
("▁are", (29, 35)),
("▁you", (35, 41))
]
);
pretok.set_prepend_scheme(PrependScheme::Always);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁my", (6, 11)),
("▁friend", (11, 20)),
("▁", (20, 23)),
("▁<s>", (23, 29)),
("▁how", (29, 35)),
("▁are", (35, 41)),
("▁you", (41, 47))
]
);

pretok.set_prepend_scheme(PrependScheme::First);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
.expect("Bad split");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁", (6, 9)),
("<s>", (9, 12)),
("how", (12, 15))
]
);

let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
.expect("Bad split");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁", (6, 9)),
("<s>", (9, 12)),
("how", (12, 15)),
("▁", (15, 18)),
("<s>", (18, 21)),
("are", (21, 24)),
("▁", (24, 27)),
("<s>", (27, 30)),
("▁you", (30, 36))
]
);
}
#[test]
fn decode() {
let decoder = Metaspace::new('▁', true);
Expand Down
Loading

0 comments on commit f55822b

Please sign in to comment.