diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 5f0968fcb..6e79e7029 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -10,7 +10,7 @@ pub mod wordpiece; pub use super::pre_tokenizers::byte_level; pub use super::pre_tokenizers::metaspace; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::decoders::bpe::BPEDecoder; use crate::decoders::byte_fallback::ByteFallback; @@ -24,7 +24,7 @@ use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Clone, Debug)] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), @@ -39,6 +39,116 @@ pub enum DecoderWrapper { ByteFallback(ByteFallback), } +impl<'de> Deserialize<'de> for DecoderWrapper { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + pub struct Tagged { + #[serde(rename = "type")] + variant: EnumType, + #[serde(flatten)] + rest: serde_json::Value, + } + #[derive(Serialize, Deserialize)] + pub enum EnumType { + BPEDecoder, + ByteLevel, + WordPiece, + Metaspace, + CTC, + Sequence, + Replace, + Fuse, + Strip, + ByteFallback, + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum DecoderHelper { + Tagged(Tagged), + Legacy(serde_json::Value), + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum DecoderUntagged { + BPE(BPEDecoder), + ByteLevel(ByteLevel), + WordPiece(WordPiece), + Metaspace(Metaspace), + CTC(CTC), + Sequence(Sequence), + Replace(Replace), + Fuse(Fuse), + Strip(Strip), + ByteFallback(ByteFallback), + } + + let helper = DecoderHelper::deserialize(deserializer).expect("Helper"); + Ok(match helper { + DecoderHelper::Tagged(model) => { + let mut values: serde_json::Map = + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?; + values.insert( + "type".to_string(), + serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?, + ); + let values = serde_json::Value::Object(values); + match model.variant { + EnumType::BPEDecoder => DecoderWrapper::BPE( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::ByteLevel => DecoderWrapper::ByteLevel( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::WordPiece => DecoderWrapper::WordPiece( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Metaspace => DecoderWrapper::Metaspace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::CTC => DecoderWrapper::CTC( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Sequence => DecoderWrapper::Sequence( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Replace => DecoderWrapper::Replace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Fuse => DecoderWrapper::Fuse( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Strip => DecoderWrapper::Strip( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::ByteFallback => DecoderWrapper::ByteFallback( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + } + } + DecoderHelper::Legacy(value) => { + let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + match untagged { + DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec), + DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec), + DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec), + DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec), + DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec), + DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec), + DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec), + DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec), + DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec), + DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec), + } + } + }) + } +} + impl Decoder for DecoderWrapper { fn decode_chain(&self, tokens: Vec) -> Result> { match self { @@ -98,7 +208,7 @@ mod tests { match parse { Err(err) => assert_eq!( format!("{err}"), - "data did not match any variant of untagged enum DecoderWrapper" + "data did not match any variant of untagged enum DecoderUntagged" ), _ => panic!("Expected error"), } @@ -108,7 +218,7 @@ mod tests { match parse { Err(err) => assert_eq!( format!("{err}"), - "data did not match any variant of untagged enum DecoderWrapper" + "data did not match any variant of untagged enum DecoderUntagged" ), _ => panic!("Expected error"), } @@ -116,10 +226,7 @@ mod tests { let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#; let parse = serde_json::from_str::(json); match parse { - Err(err) => assert_eq!( - format!("{err}"), - "data did not match any variant of untagged enum DecoderWrapper" - ), + Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"), _ => panic!("Expected error"), } }