diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index c12646102..6195d170b 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -9,7 +9,7 @@ pub mod split; pub mod unicode_scripts; pub mod whitespace; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::pre_tokenizers::bert::BertPreTokenizer; use crate::pre_tokenizers::byte_level::ByteLevel; @@ -23,7 +23,7 @@ use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +#[derive(Serialize, Clone, Debug, PartialEq)] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), @@ -57,6 +57,142 @@ impl PreTokenizer for PreTokenizerWrapper { } } +impl<'de> Deserialize<'de> for PreTokenizerWrapper { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + pub struct Tagged { + #[serde(rename = "type")] + variant: EnumType, + #[serde(flatten)] + rest: serde_json::Value, + } + #[derive(Deserialize, Serialize)] + pub enum EnumType { + BertPreTokenizer, + ByteLevel, + Delimiter, + Metaspace, + Whitespace, + Sequence, + Split, + Punctuation, + WhitespaceSplit, + Digits, + UnicodeScripts, + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum PreTokenizerHelper { + Tagged(Tagged), + Legacy(serde_json::Value), + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum PreTokenizerUntagged { + BertPreTokenizer(BertPreTokenizer), + ByteLevel(ByteLevel), + Delimiter(CharDelimiterSplit), + Metaspace(Metaspace), + Whitespace(Whitespace), + Sequence(Sequence), + Split(Split), + Punctuation(Punctuation), + WhitespaceSplit(WhitespaceSplit), + Digits(Digits), + UnicodeScripts(UnicodeScripts), + } + + let helper = PreTokenizerHelper::deserialize(deserializer)?; + + Ok(match helper { + PreTokenizerHelper::Tagged(pretok) => { + let mut values: serde_json::Map = + serde_json::from_value(pretok.rest).map_err(serde::de::Error::custom)?; + values.insert( + "type".to_string(), + serde_json::to_value(&pretok.variant).map_err(serde::de::Error::custom)?, + ); + let values = serde_json::Value::Object(values); + match pretok.variant { + EnumType::BertPreTokenizer => PreTokenizerWrapper::BertPreTokenizer( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::ByteLevel => PreTokenizerWrapper::ByteLevel( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Delimiter => PreTokenizerWrapper::Delimiter( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Metaspace => PreTokenizerWrapper::Metaspace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Whitespace => PreTokenizerWrapper::Whitespace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Sequence => PreTokenizerWrapper::Sequence( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Split => PreTokenizerWrapper::Split( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Punctuation => PreTokenizerWrapper::Punctuation( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::WhitespaceSplit => PreTokenizerWrapper::WhitespaceSplit( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Digits => PreTokenizerWrapper::Digits( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + } + } + + PreTokenizerHelper::Legacy(value) => { + let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + match untagged { + PreTokenizerUntagged::BertPreTokenizer(bert) => { + PreTokenizerWrapper::BertPreTokenizer(bert) + } + PreTokenizerUntagged::ByteLevel(byte_level) => { + PreTokenizerWrapper::ByteLevel(byte_level) + } + PreTokenizerUntagged::Delimiter(delimiter) => { + PreTokenizerWrapper::Delimiter(delimiter) + } + PreTokenizerUntagged::Metaspace(metaspace) => { + PreTokenizerWrapper::Metaspace(metaspace) + } + PreTokenizerUntagged::Whitespace(whitespace) => { + PreTokenizerWrapper::Whitespace(whitespace) + } + PreTokenizerUntagged::Sequence(sequence) => { + PreTokenizerWrapper::Sequence(sequence) + } + PreTokenizerUntagged::Split(split) => PreTokenizerWrapper::Split(split), + PreTokenizerUntagged::Punctuation(punctuation) => { + PreTokenizerWrapper::Punctuation(punctuation) + } + PreTokenizerUntagged::WhitespaceSplit(whitespace_split) => { + PreTokenizerWrapper::WhitespaceSplit(whitespace_split) + } + PreTokenizerUntagged::Digits(digits) => PreTokenizerWrapper::Digits(digits), + PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => { + PreTokenizerWrapper::UnicodeScripts(unicode_scripts) + } + } + } + }) + } +} + impl_enum_from!(BertPreTokenizer, PreTokenizerWrapper, BertPreTokenizer); impl_enum_from!(ByteLevel, PreTokenizerWrapper, ByteLevel); impl_enum_from!(CharDelimiterSplit, PreTokenizerWrapper, Delimiter); @@ -152,25 +288,22 @@ mod tests { match reconstructed { Err(err) => assert_eq!( err.to_string(), - "data did not match any variant of untagged enum PreTokenizerWrapper" + "data did not match any variant of untagged enum PreTokenizerUntagged" ), _ => panic!("Expected an error here"), } let json = r#"{"type":"Metaspace", "replacement":"▁" }"#; - let reconstructed = serde_json::from_str::(json); + let reconstructed = serde_json::from_str::(json).unwrap(); assert_eq!( - reconstructed.unwrap(), + reconstructed, PreTokenizerWrapper::Metaspace(Metaspace::default()) ); let json = r#"{"type":"Metaspace", "add_prefix_space":true }"#; let reconstructed = serde_json::from_str::(json); match reconstructed { - Err(err) => assert_eq!( - err.to_string(), - "data did not match any variant of untagged enum PreTokenizerWrapper" - ), + Err(err) => assert_eq!(err.to_string(), "missing field `replacement`"), _ => panic!("Expected an error here"), } let json = r#"{"behavior":"default_split"}"#; @@ -178,7 +311,7 @@ mod tests { match reconstructed { Err(err) => assert_eq!( err.to_string(), - "data did not match any variant of untagged enum PreTokenizerWrapper" + "data did not match any variant of untagged enum PreTokenizerUntagged" ), _ => panic!("Expected an error here"), }