diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index ed21f3469..1a03a7721 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -29,8 +29,8 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyDecoder { - #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, } @@ -114,6 +114,16 @@ impl PyDecoder { fn decode(&self, tokens: Vec) -> PyResult { ToPyResult(self.decoder.decode(tokens)).into() } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } macro_rules! getter { diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index bffa1bc21..424be9f57 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -26,8 +26,8 @@ use super::error::{deprecation_warning, ToPyResult}; /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyModel { - #[serde(flatten)] pub model: Arc>, } @@ -220,6 +220,16 @@ impl PyModel { fn get_trainer(&self, py: Python<'_>) -> PyResult { PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } /// An implementation of the BPE (Byte-Pair Encoding) algorithm diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 6f866515d..ba143c3f8 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -42,8 +42,8 @@ impl PyNormalizedStringMut<'_> { /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyNormalizer { - #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, } @@ -167,6 +167,16 @@ impl PyNormalizer { ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?; Ok(normalized.get().to_owned()) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } macro_rules! getter { diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a30b4ca82..02556e59c 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,8 +35,8 @@ use super::utils::*; subclass )] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyPreTokenizer { - #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, } @@ -181,6 +181,16 @@ impl PyPreTokenizer { .map(|(s, o, _)| (s.to_owned(), o)) .collect()) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } macro_rules! getter { diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index c46d8ea49..1d8e8dfac 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -28,8 +28,8 @@ use tokenizers as tk; subclass )] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyPostProcessor { - #[serde(flatten)] pub processor: Arc, } @@ -139,6 +139,16 @@ impl PyPostProcessor { .into_py()?; Ok(final_encoding.into()) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } /// This post-processor takes care of adding the special tokens needed by diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 8b3e30617..8b582633d 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; @@ -462,7 +463,8 @@ type Tokenizer = TokenizerImpl PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + /// Return the number of special tokens that would be added for single/pair sentences. /// :param is_pair: Boolean indicating if the input would be a single sentence or a pair /// :return: @@ -1439,4 +1451,16 @@ mod test { Tokenizer::from_file(&tmp).unwrap(); } + + #[test] + fn serde_pyo3() { + let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); + tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), + ]))); + + let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap(); + assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))"); + } } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 716e4cfeb..c71442298 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -16,8 +16,8 @@ use tokenizers as tk; /// Trainer will return an instance of this class when instantiated. #[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyTrainer { - #[serde(flatten)] pub trainer: Arc>, } @@ -69,6 +69,16 @@ impl PyTrainer { Err(e) => Err(e), } } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } impl Trainer for PyTrainer { diff --git a/bindings/python/src/utils/mod.rs b/bindings/python/src/utils/mod.rs index 1e409a504..43352a7fa 100644 --- a/bindings/python/src/utils/mod.rs +++ b/bindings/python/src/utils/mod.rs @@ -5,6 +5,7 @@ mod iterators; mod normalization; mod pretokenization; mod regex; +pub mod serde_pyo3; pub use iterators::*; pub use normalization::*; diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs new file mode 100644 index 000000000..471993614 --- /dev/null +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -0,0 +1,773 @@ +use serde::de::value::Error; +use serde::{ser, Serialize}; +type Result = ::std::result::Result; + +pub struct Serializer { + // This string starts empty and JSON is appended as values are serialized. + output: String, + /// Each levels remembers its own number of elements + num_elements: Vec, + max_elements: usize, + level: usize, + max_depth: usize, + /// Maximum string representation + /// Useful to ellipsis precompiled_charmap + max_string: usize, +} + +// By convention, the public API of a Serde serializer is one or more `to_abc` +// functions such as `to_string`, `to_bytes`, or `to_writer` depending on what +// Rust types the serializer is able to produce as output. +// +// This basic serializer supports only `to_string`. +pub fn to_string(value: &T) -> Result +where + T: Serialize, +{ + let max_depth = 20; + let max_elements = 6; + let max_string = 100; + let mut serializer = Serializer { + output: String::new(), + level: 0, + max_depth, + max_elements, + num_elements: vec![0; max_depth], + max_string, + }; + value.serialize(&mut serializer)?; + Ok(serializer.output) +} + +pub fn repr(value: &T) -> Result +where + T: Serialize, +{ + let max_depth = 200; + let max_string = usize::MAX; + let mut serializer = Serializer { + output: String::new(), + level: 0, + max_depth, + max_elements: 100, + num_elements: vec![0; max_depth], + max_string, + }; + value.serialize(&mut serializer)?; + Ok(serializer.output) +} + +impl<'a> ser::Serializer for &'a mut Serializer { + // The output type produced by this `Serializer` during successful + // serialization. Most serializers that produce text or binary output should + // set `Ok = ()` and serialize into an `io::Write` or buffer contained + // within the `Serializer` instance, as happens here. Serializers that build + // in-memory data structures may be simplified by using `Ok` to propagate + // the data structure around. + type Ok = (); + + // The error type when some error occurs during serialization. + type Error = Error; + + // Associated types for keeping track of additional state while serializing + // compound data structures like sequences and maps. In this case no + // additional state is required beyond what is already stored in the + // Serializer struct. + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + // Here we go with the simple methods. The following 12 methods receive one + // of the primitive types of the data model and map it to JSON by appending + // into the output string. + fn serialize_bool(self, v: bool) -> Result<()> { + self.output += if v { "True" } else { "False" }; + Ok(()) + } + + // JSON does not distinguish between different sizes of integers, so all + // signed integers will be serialized the same and all unsigned integers + // will be serialized the same. Other formats, especially compact binary + // formats, may need independent logic for the different sizes. + fn serialize_i8(self, v: i8) -> Result<()> { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i16(self, v: i16) -> Result<()> { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i32(self, v: i32) -> Result<()> { + self.serialize_i64(i64::from(v)) + } + + // Not particularly efficient but this is example code anyway. A more + // performant approach would be to use the `itoa` crate. + fn serialize_i64(self, v: i64) -> Result<()> { + self.output += &v.to_string(); + Ok(()) + } + + fn serialize_u8(self, v: u8) -> Result<()> { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u16(self, v: u16) -> Result<()> { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u32(self, v: u32) -> Result<()> { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u64(self, v: u64) -> Result<()> { + self.output += &v.to_string(); + Ok(()) + } + + fn serialize_f32(self, v: f32) -> Result<()> { + self.serialize_f64(f64::from(v)) + } + + fn serialize_f64(self, v: f64) -> Result<()> { + self.output += &v.to_string(); + Ok(()) + } + + // Serialize a char as a single-character string. Other formats may + // represent this differently. + fn serialize_char(self, v: char) -> Result<()> { + self.serialize_str(&v.to_string()) + } + + // This only works for strings that don't require escape sequences but you + // get the idea. For example it would emit invalid JSON if the input string + // contains a '"' character. + fn serialize_str(self, v: &str) -> Result<()> { + self.output += "\""; + if v.len() > self.max_string { + self.output += &v[..self.max_string]; + self.output += "..."; + } else { + self.output += v; + } + self.output += "\""; + Ok(()) + } + + // Serialize a byte array as an array of bytes. Could also use a base64 + // string here. Binary formats will typically represent byte arrays more + // compactly. + fn serialize_bytes(self, v: &[u8]) -> Result<()> { + use serde::ser::SerializeSeq; + let mut seq = self.serialize_seq(Some(v.len()))?; + for byte in v { + seq.serialize_element(byte)?; + } + seq.end() + } + + // An absent optional is represented as the JSON `null`. + fn serialize_none(self) -> Result<()> { + self.serialize_unit() + } + + // A present optional is represented as just the contained value. Note that + // this is a lossy representation. For example the values `Some(())` and + // `None` both serialize as just `null`. Unfortunately this is typically + // what people expect when working with JSON. Other formats are encouraged + // to behave more intelligently if possible. + fn serialize_some(self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + // In Serde, unit means an anonymous value containing no data. Map this to + // JSON as `null`. + fn serialize_unit(self) -> Result<()> { + self.output += "None"; + Ok(()) + } + + // Unit struct means a named value containing no data. Again, since there is + // no data, map this to JSON as `null`. There is no need to serialize the + // name in most formats. + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + self.serialize_unit() + } + + // When serializing a unit variant (or any other kind of variant), formats + // can choose whether to keep track of it by index or by name. Binary + // formats typically use the index of the variant and human-readable formats + // typically use the name. + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result<()> { + // self.serialize_str(variant) + self.output += variant; + Ok(()) + } + + // As is done here, serializers are encouraged to treat newtype structs as + // insignificant wrappers around the data they contain. + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + // Note that newtype variant (and all of the other variant serialization + // methods) refer exclusively to the "externally tagged" enum + // representation. + // + // Serialize this to JSON in externally tagged form as `{ NAME: VALUE }`. + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<()> + where + T: ?Sized + Serialize, + { + // variant.serialize(&mut *self)?; + self.output += variant; + self.output += "("; + value.serialize(&mut *self)?; + self.output += ")"; + Ok(()) + } + + // Now we get to the serialization of compound types. + // + // The start of the sequence, each value, and the end are three separate + // method calls. This one is responsible only for serializing the start, + // which in JSON is `[`. + // + // The length of the sequence may or may not be known ahead of time. This + // doesn't make a difference in JSON because the length is not represented + // explicitly in the serialized form. Some serializers may only be able to + // support sequences for which the length is known up front. + fn serialize_seq(self, _len: Option) -> Result { + self.output += "["; + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); + self.num_elements[self.level] = 0; + Ok(self) + } + + // Tuples look just like sequences in JSON. Some formats may be able to + // represent tuples more efficiently by omitting the length, since tuple + // means that the corresponding `Deserialize implementation will know the + // length without needing to look at the serialized data. + fn serialize_tuple(self, _len: usize) -> Result { + self.output += "("; + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); + self.num_elements[self.level] = 0; + Ok(self) + } + + // Tuple structs look just like sequences in JSON. + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_tuple(len) + } + + // Tuple variants are represented in JSON as `{ NAME: [DATA...] }`. Again + // this method is only responsible for the externally tagged representation. + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + // variant.serialize(&mut *self)?; + self.output += variant; + self.output += "("; + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); + self.num_elements[self.level] = 0; + Ok(self) + } + + // Maps are represented in JSON as `{ K: V, K: V, ... }`. + fn serialize_map(self, _len: Option) -> Result { + self.output += "{"; + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); + self.num_elements[self.level] = 0; + Ok(self) + } + + // Structs look just like maps in JSON. In particular, JSON requires that we + // serialize the field names of the struct. Other formats may be able to + // omit the field names when serializing structs because the corresponding + // Deserialize implementation is required to know what the keys are without + // looking at the serialized data. + fn serialize_struct(self, name: &'static str, _len: usize) -> Result { + // self.serialize_map(Some(len)) + // name.serialize(&mut *self)?; + if let Some(stripped) = name.strip_suffix("Helper") { + self.output += stripped; + } else { + self.output += name + } + self.output += "("; + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); + self.num_elements[self.level] = 0; + Ok(self) + } + + // Struct variants are represented in JSON as `{ NAME: { K: V, ... } }`. + // This is the externally tagged representation. + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + // variant.serialize(&mut *self)?; + self.output += variant; + self.output += "("; + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); + self.num_elements[self.level] = 0; + Ok(self) + } +} + +// The following 7 impls deal with the serialization of compound types like +// sequences and maps. Serialization of such types is begun by a Serializer +// method and followed by zero or more calls to serialize individual elements of +// the compound type and one call to end the compound type. +// +// This impl is SerializeSeq so these methods are called after `serialize_seq` +// is called on the Serializer. +impl<'a> ser::SerializeSeq for &'a mut Serializer { + // Must match the `Ok` type of the serializer. + type Ok = (); + // Must match the `Error` type of the serializer. + type Error = Error; + + // Serialize a single element of the sequence. + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements { + if !self.output.ends_with('[') { + self.output += ", "; + } + value.serialize(&mut **self) + } else { + if num_elements == self.max_elements { + self.output += ", ..."; + } + Ok(()) + } + } + + // Close the sequence. + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += "]"; + Ok(()) + } +} + +// Same thing but for tuples. +impl<'a> ser::SerializeTuple for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements { + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + } else { + if num_elements == self.max_elements { + self.output += ", ..."; + } + Ok(()) + } + } + + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += ")"; + Ok(()) + } +} + +// Same thing but for tuple structs. +impl<'a> ser::SerializeTupleStruct for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements { + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + } else { + if num_elements == self.max_elements { + self.output += ", ..."; + } + Ok(()) + } + } + + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += ")"; + Ok(()) + } +} + +// Tuple variants are a little different. Refer back to the +// `serialize_tuple_variant` method above: +// +// self.output += "{"; +// variant.serialize(&mut *self)?; +// self.output += ":["; +// +// So the `end` method in this impl is responsible for closing both the `]` and +// the `}`. +impl<'a> ser::SerializeTupleVariant for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements { + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + } else { + if num_elements == self.max_elements { + self.output += ", ..."; + } + Ok(()) + } + } + + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += ")"; + Ok(()) + } +} + +// Some `Serialize` types are not able to hold a key and value in memory at the +// same time so `SerializeMap` implementations are required to support +// `serialize_key` and `serialize_value` individually. +// +// There is a third optional method on the `SerializeMap` trait. The +// `serialize_entry` method allows serializers to optimize for the case where +// key and value are both available simultaneously. In JSON it doesn't make a +// difference so the default behavior for `serialize_entry` is fine. +impl<'a> ser::SerializeMap for &'a mut Serializer { + type Ok = (); + type Error = Error; + + // The Serde data model allows map keys to be any serializable type. JSON + // only allows string keys so the implementation below will produce invalid + // JSON if the key serializes as something other than a string. + // + // A real JSON serializer would need to validate that map keys are strings. + // This can be done by using a different Serializer to serialize the key + // (instead of `&mut **self`) and having that other serializer only + // implement `serialize_str` and return an error on any other data type. + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements { + if !self.output.ends_with('{') { + self.output += ", "; + } + key.serialize(&mut **self) + } else { + if num_elements == self.max_elements { + self.output += ", ..."; + } + Ok(()) + } + } + + // It doesn't make a difference whether the colon is printed at the end of + // `serialize_key` or at the beginning of `serialize_value`. In this case + // the code is a bit simpler having it here. + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements { + self.output += ":"; + value.serialize(&mut **self) + } else { + Ok(()) + } + } + + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += "}"; + Ok(()) + } +} + +// Structs are like maps in which the keys are constrained to be compile-time +// constant strings. +impl<'a> ser::SerializeStruct for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + // key.serialize(&mut **self)?; + if key != "type" { + self.output += key; + self.output += "="; + value.serialize(&mut **self) + } else { + Ok(()) + } + } + + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += ")"; + Ok(()) + } +} + +// Similar to `SerializeTupleVariant`, here the `end` method is responsible for +// closing both of the curly braces opened by `serialize_struct_variant`. +impl<'a> ser::SerializeStructVariant for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + // key.serialize(&mut **self)?; + self.output += key; + self.output += "="; + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); + self.output += ")"; + Ok(()) + } +} + +//////////////////////////////////////////////////////////////////////////////// + +#[test] +fn test_basic() { + assert_eq!(to_string(&true).unwrap(), "True"); + assert_eq!(to_string(&Some(1)).unwrap(), "1"); + assert_eq!(to_string(&None::).unwrap(), "None"); +} + +#[test] +fn test_struct() { + #[derive(Serialize)] + struct Test { + int: u32, + seq: Vec<&'static str>, + } + + let test = Test { + int: 1, + seq: vec!["a", "b"], + }; + let expected = r#"Test(int=1, seq=["a", "b"])"#; + assert_eq!(to_string(&test).unwrap(), expected); +} + +#[test] +fn test_enum() { + #[derive(Serialize)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let u = E::Unit; + let expected = r#"Unit"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let n = E::Newtype(1); + let expected = r#"Newtype(1)"#; + assert_eq!(to_string(&n).unwrap(), expected); + + let t = E::Tuple(1, 2); + let expected = r#"Tuple(1, 2)"#; + assert_eq!(to_string(&t).unwrap(), expected); + + let s = E::Struct { a: 1 }; + let expected = r#"Struct(a=1)"#; + assert_eq!(to_string(&s).unwrap(), expected); +} + +#[test] +fn test_enum_untagged() { + #[derive(Serialize)] + #[serde(untagged)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let u = E::Unit; + let expected = r#"None"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let n = E::Newtype(1); + let expected = r#"1"#; + assert_eq!(to_string(&n).unwrap(), expected); + + let t = E::Tuple(1, 2); + let expected = r#"(1, 2)"#; + assert_eq!(to_string(&t).unwrap(), expected); + + let s = E::Struct { a: 1 }; + let expected = r#"E(a=1)"#; + assert_eq!(to_string(&s).unwrap(), expected); +} + +#[test] +fn test_struct_tagged() { + #[derive(Serialize)] + #[serde(untagged)] + enum E { + A(A), + } + + #[derive(Serialize)] + #[serde(tag = "type")] + struct A { + a: bool, + b: usize, + } + + let u = A { a: true, b: 1 }; + // let expected = r#"A(type="A", a=True, b=1)"#; + // No we skip all `type` manually inserted variants. + let expected = r#"A(a=True, b=1)"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let u = E::A(A { a: true, b: 1 }); + let expected = r#"A(a=True, b=1)"#; + assert_eq!(to_string(&u).unwrap(), expected); +} + +#[test] +fn test_flatten() { + #[derive(Serialize)] + struct A { + a: bool, + b: usize, + } + + #[derive(Serialize)] + struct B { + c: A, + d: usize, + } + + #[derive(Serialize)] + struct C { + #[serde(flatten)] + c: A, + d: usize, + } + + #[derive(Serialize)] + #[serde(transparent)] + struct D { + e: A, + } + + let u = B { + c: A { a: true, b: 1 }, + d: 2, + }; + let expected = r#"B(c=A(a=True, b=1), d=2)"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let u = C { + c: A { a: true, b: 1 }, + d: 2, + }; + // XXX This is unfortunate but true, flatten forces the serialization + // to use the serialize_map without any means for the Serializer to know about this + // flattening attempt + let expected = r#"{"a":True, "b":1, "d":2}"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let u = D { + e: A { a: true, b: 1 }, + }; + let expected = r#"A(a=True, b=1)"#; + assert_eq!(to_string(&u).unwrap(), expected); +} diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index fb8958576..b89d71335 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -7,8 +7,9 @@ from tokenizers.implementations import BertWordPieceTokenizer from tokenizers.models import BPE, Model, Unigram from tokenizers.pre_tokenizers import ByteLevel, Metaspace -from tokenizers.normalizers import Strip -from tokenizers.processors import RobertaProcessing +from tokenizers.processors import RobertaProcessing, TemplateProcessing +from tokenizers.normalizers import Strip, Lowercase, Sequence + from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files @@ -560,3 +561,28 @@ def test_setting_to_none(self): tokenizer.pre_tokenizer = Metaspace() tokenizer.pre_tokenizer = None assert tokenizer.pre_tokenizer == None + +class TestTokenizerRepr: + def test_repr(self): + tokenizer = Tokenizer(BPE()) + out = repr(tokenizer) + assert ( + out + == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + ) + + def test_repr_complete(self): + tokenizer = Tokenizer(BPE()) + tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True) + tokenizer.post_processor = TemplateProcessing( + single=["[CLS]", "$0", "[SEP]"], + pair=["[CLS]:0", "$A", "[SEP]:0", "$B:1", "[SEP]:1"], + special_tokens=[("[CLS]", 1), ("[SEP]", 0)], + ) + tokenizer.normalizer = Sequence([Lowercase(), Strip()]) + out = repr(tokenizer) + assert ( + out + == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + ) + diff --git a/bindings/python/tests/test_serialization.py b/bindings/python/tests/test_serialization.py index 4434e6304..9da2c3e27 100644 --- a/bindings/python/tests/test_serialization.py +++ b/bindings/python/tests/test_serialization.py @@ -5,6 +5,7 @@ import tqdm from huggingface_hub import hf_hub_download from tokenizers import Tokenizer +from tokenizers.models import BPE, Unigram from .utils import albert_base, data_dir @@ -16,6 +17,73 @@ def test_full_serialization_albert(self, albert_base): # file exceeds the buffer capacity Tokenizer.from_file(albert_base) + def test_str_big(self, albert_base): + tokenizer = Tokenizer.from_file(albert_base) + assert ( + str(tokenizer) + == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":1, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":2, "content":"[CLS]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":3, "content":"[SEP]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":4, "content":"[MASK]", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=Sequence(normalizers=[Replace(pattern=String("``"), content="\""), Replace(pattern=String("''"), content="\""), NFKD(), StripAccents(), Lowercase(), ...]), pre_tokenizer=Sequence(pretokenizers=[WhitespaceSplit(), Metaspace(replacement="▁", prepend_scheme=always, split=True)]), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[2], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[3], tokens=["[SEP]"])}), decoder=Metaspace(replacement="▁", prepend_scheme=always, split=True), model=Unigram(unk_id=1, vocab=[("", 0), ("", 0), ("[CLS]", 0), ("[SEP]", 0), ("[MASK]", 0), ...], byte_fallback=False))""" + ) + + def test_repr_str(self): + tokenizer = Tokenizer(BPE()) + tokenizer.add_tokens(["my"]) + assert ( + repr(tokenizer) + == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, "normalized":True, "special":False}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + ) + assert ( + str(tokenizer) + == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + ) + + def test_repr_str_ellipsis(self): + model = BPE() + assert ( + repr(model) + == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + ) + assert ( + str(model) + == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + ) + + vocab = [ + ("A", 0.0), + ("B", -0.01), + ("C", -0.02), + ("D", -0.03), + ("E", -0.04), + ] + # No ellispsis yet + model = Unigram(vocab, 0, byte_fallback=False) + assert ( + repr(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + ) + assert ( + str(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + ) + + # Ellispis for longer than 5 elements only on `str`. + vocab = [ + ("A", 0.0), + ("B", -0.01), + ("C", -0.02), + ("D", -0.03), + ("E", -0.04), + ("F", -0.04), + ] + model = Unigram(vocab, 0, byte_fallback=False) + assert ( + repr(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ("F", -0.04)], byte_fallback=False)""" + ) + assert ( + str(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ...], byte_fallback=False)""" + ) + def check(tokenizer_file) -> bool: with open(tokenizer_file, "r") as f: diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 5f0968fcb..6e79e7029 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -10,7 +10,7 @@ pub mod wordpiece; pub use super::pre_tokenizers::byte_level; pub use super::pre_tokenizers::metaspace; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::decoders::bpe::BPEDecoder; use crate::decoders::byte_fallback::ByteFallback; @@ -24,7 +24,7 @@ use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Clone, Debug)] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), @@ -39,6 +39,116 @@ pub enum DecoderWrapper { ByteFallback(ByteFallback), } +impl<'de> Deserialize<'de> for DecoderWrapper { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + pub struct Tagged { + #[serde(rename = "type")] + variant: EnumType, + #[serde(flatten)] + rest: serde_json::Value, + } + #[derive(Serialize, Deserialize)] + pub enum EnumType { + BPEDecoder, + ByteLevel, + WordPiece, + Metaspace, + CTC, + Sequence, + Replace, + Fuse, + Strip, + ByteFallback, + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum DecoderHelper { + Tagged(Tagged), + Legacy(serde_json::Value), + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum DecoderUntagged { + BPE(BPEDecoder), + ByteLevel(ByteLevel), + WordPiece(WordPiece), + Metaspace(Metaspace), + CTC(CTC), + Sequence(Sequence), + Replace(Replace), + Fuse(Fuse), + Strip(Strip), + ByteFallback(ByteFallback), + } + + let helper = DecoderHelper::deserialize(deserializer).expect("Helper"); + Ok(match helper { + DecoderHelper::Tagged(model) => { + let mut values: serde_json::Map = + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?; + values.insert( + "type".to_string(), + serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?, + ); + let values = serde_json::Value::Object(values); + match model.variant { + EnumType::BPEDecoder => DecoderWrapper::BPE( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::ByteLevel => DecoderWrapper::ByteLevel( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::WordPiece => DecoderWrapper::WordPiece( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Metaspace => DecoderWrapper::Metaspace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::CTC => DecoderWrapper::CTC( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Sequence => DecoderWrapper::Sequence( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Replace => DecoderWrapper::Replace( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Fuse => DecoderWrapper::Fuse( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::Strip => DecoderWrapper::Strip( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::ByteFallback => DecoderWrapper::ByteFallback( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + } + } + DecoderHelper::Legacy(value) => { + let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + match untagged { + DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec), + DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec), + DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec), + DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec), + DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec), + DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec), + DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec), + DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec), + DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec), + DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec), + } + } + }) + } +} + impl Decoder for DecoderWrapper { fn decode_chain(&self, tokens: Vec) -> Result> { match self { @@ -98,7 +208,7 @@ mod tests { match parse { Err(err) => assert_eq!( format!("{err}"), - "data did not match any variant of untagged enum DecoderWrapper" + "data did not match any variant of untagged enum DecoderUntagged" ), _ => panic!("Expected error"), } @@ -108,7 +218,7 @@ mod tests { match parse { Err(err) => assert_eq!( format!("{err}"), - "data did not match any variant of untagged enum DecoderWrapper" + "data did not match any variant of untagged enum DecoderUntagged" ), _ => panic!("Expected error"), } @@ -116,10 +226,7 @@ mod tests { let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#; let parse = serde_json::from_str::(json); match parse { - Err(err) => assert_eq!( - format!("{err}"), - "data did not match any variant of untagged enum DecoderWrapper" - ), + Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"), _ => panic!("Expected error"), } } diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index 42c7fa510..130e2ce1e 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -1,11 +1,11 @@ use crate::processors::byte_level::bytes_char; use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use serde::{Deserialize, Serialize}; +use crate::utils::macro_rules_attribute; use std::collections::{HashMap, HashSet}; -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(tag = "type")] -pub struct ByteLevel {} +#[derive(Clone, Debug)] +#[macro_rules_attribute(impl_serde_type!)] +pub struct ByteLevel; lazy_static! { static ref BYTES_CHAR: HashMap = bytes_char(); diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index c5144be14..c56a26c1e 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -73,3 +73,34 @@ impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled); impl_enum_from!(Replace, NormalizerWrapper, Replace); impl_enum_from!(Prepend, NormalizerWrapper, Prepend); impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel); + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn post_processor_deserialization_no_type() { + let json = r#"{"strip_left":false, "strip_right":true}"#; + let reconstructed = serde_json::from_str::(json); + assert!(matches!( + reconstructed.unwrap(), + NormalizerWrapper::StripNormalizer(_) + )); + + let json = r#"{"trim_offsets":true, "add_prefix_space":true}"#; + let reconstructed = serde_json::from_str::(json); + match reconstructed { + Err(err) => assert_eq!( + err.to_string(), + "data did not match any variant of untagged enum NormalizerWrapper" + ), + _ => panic!("Expected an error here"), + } + + let json = r#"{"prepend":"a"}"#; + let reconstructed = serde_json::from_str::(json); + assert!(matches!( + reconstructed.unwrap(), + NormalizerWrapper::Prepend(_) + )); + } +}