diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index c56a26c1e..2a786d702 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -14,12 +14,12 @@ pub use crate::normalizers::replace::Replace; pub use crate::normalizers::strip::{Strip, StripAccents}; pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::{NormalizedString, Normalizer}; /// Wrapper for known Normalizers. -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Serialize)] #[serde(untagged)] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), @@ -38,6 +38,149 @@ pub enum NormalizerWrapper { ByteLevel(ByteLevel), } +impl<'de> Deserialize<'de> for NormalizerWrapper { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + pub struct Tagged { + #[serde(rename = "type")] + variant: EnumType, + #[serde(flatten)] + rest: serde_json::Value, + } + #[derive(Serialize, Deserialize)] + pub enum EnumType { + Bert, + Strip, + StripAccents, + NFC, + NFD, + NFKC, + NFKD, + Sequence, + Lowercase, + Nmt, + Precompiled, + Replace, + Prepend, + ByteLevel, + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum NormalizerHelper { + Tagged(Tagged), + Legacy(serde_json::Value), + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum NormalizerUntagged { + BertNormalizer(BertNormalizer), + StripNormalizer(Strip), + StripAccents(StripAccents), + NFC(NFC), + NFD(NFD), + NFKC(NFKC), + NFKD(NFKD), + Sequence(Sequence), + Lowercase(Lowercase), + Nmt(Nmt), + Precompiled(Precompiled), + Replace(Replace), + Prepend(Prepend), + ByteLevel(ByteLevel), + } + + let helper = NormalizerHelper::deserialize(deserializer)?; + Ok(match helper { + NormalizerHelper::Tagged(model) => { + let mut values: serde_json::Map = + serde_json::from_value(model.rest).expect("Parsed values"); + values.insert( + "type".to_string(), + serde_json::to_value(&model.variant).expect("Reinsert"), + ); + let values = serde_json::Value::Object(values); + match model.variant { + EnumType::Bert => NormalizerWrapper::BertNormalizer( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Strip => NormalizerWrapper::StripNormalizer( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::StripAccents => NormalizerWrapper::StripAccents( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::NFC => NormalizerWrapper::NFC( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::NFD => NormalizerWrapper::NFD( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::NFKC => NormalizerWrapper::NFKC( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::NFKD => NormalizerWrapper::NFKD( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Sequence => NormalizerWrapper::Sequence( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Lowercase => NormalizerWrapper::Lowercase( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Nmt => NormalizerWrapper::Nmt( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Precompiled => NormalizerWrapper::Precompiled( + serde_json::from_str( + &serde_json::to_string(&values).expect("Can reserialize precompiled"), + ) + // .map_err(serde::de::Error::custom) + .expect("Precompiled"), + ), + EnumType::Replace => NormalizerWrapper::Replace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Prepend => NormalizerWrapper::Prepend( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::ByteLevel => NormalizerWrapper::ByteLevel( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + } + } + + NormalizerHelper::Legacy(value) => { + let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + match untagged { + NormalizerUntagged::BertNormalizer(bpe) => { + NormalizerWrapper::BertNormalizer(bpe) + } + NormalizerUntagged::StripNormalizer(bpe) => { + NormalizerWrapper::StripNormalizer(bpe) + } + NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe), + NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe), + NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe), + NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe), + NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe), + NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe), + NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe), + NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe), + NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe), + NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe), + NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe), + NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe), + } + } + }) + } +} + impl Normalizer for NormalizerWrapper { fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> { match self { @@ -91,7 +234,7 @@ mod tests { match reconstructed { Err(err) => assert_eq!( err.to_string(), - "data did not match any variant of untagged enum NormalizerWrapper" + "data did not match any variant of untagged enum NormalizerUntagged" ), _ => panic!("Expected an error here"), } @@ -103,4 +246,36 @@ mod tests { NormalizerWrapper::Prepend(_) )); } + + #[test] + fn normalizer_serialization() { + let json = r#"{"type":"Sequence","normalizers":[]}"#; + assert!(serde_json::from_str::(json).is_ok()); + let json = r#"{"type":"Sequence","normalizers":[{}]}"#; + let parse = serde_json::from_str::(json); + match parse { + Err(err) => assert_eq!( + format!("{err}"), + "data did not match any variant of untagged enum NormalizerUntagged" + ), + _ => panic!("Expected error"), + } + + let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#; + let parse = serde_json::from_str::(json); + match parse { + Err(err) => assert_eq!( + format!("{err}"), + "data did not match any variant of untagged enum NormalizerUntagged" + ), + _ => panic!("Expected error"), + } + + let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#; + let parse = serde_json::from_str::(json); + match parse { + Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"), + _ => panic!("Expected error"), + } + } }