Skip to content

Commit

Permalink
Using serde (serde_pyo3) to get __str__ and __repr__ easily. (#1588)
Browse files Browse the repository at this point in the history
* Using serde (serde_pyo3) to get __str__ and __repr__ easily.

* Putting it within tokenizers, it needs to be too specific.

* Clippy is our friend.

* Ruff.

* Update the tests.

* Pretty sure this is wrong (#1589)

* Adding support for ellipsis.

* Fmt.

* Ruff.

* Fixing tokenizer.

---------

Co-authored-by: Eric Buehler <[email protected]>
  • Loading branch information
Narsil and EricLBuehler authored Aug 7, 2024
1 parent 7a30bca commit ab9c7de
Show file tree
Hide file tree
Showing 11 changed files with 960 additions and 8 deletions.
12 changes: 11 additions & 1 deletion bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use super::error::ToPyResult;
/// a Decoder will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)]
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyDecoder {
#[serde(flatten)]
pub(crate) decoder: PyDecoderWrapper,
}

Expand Down Expand Up @@ -114,6 +114,16 @@ impl PyDecoder {
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
}

macro_rules! getter {
Expand Down
12 changes: 11 additions & 1 deletion bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use super::error::{deprecation_warning, ToPyResult};
/// This class cannot be constructed directly. Please use one of the concrete models.
#[pyclass(module = "tokenizers.models", name = "Model", subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyModel {
#[serde(flatten)]
pub model: Arc<RwLock<ModelWrapper>>,
}

Expand Down Expand Up @@ -220,6 +220,16 @@ impl PyModel {
fn get_trainer(&self, py: Python<'_>) -> PyResult<PyObject> {
PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py)
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
}

/// An implementation of the BPE (Byte-Pair Encoding) algorithm
Expand Down
12 changes: 11 additions & 1 deletion bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ impl PyNormalizedStringMut<'_> {
/// Normalizer will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyNormalizer {
#[serde(flatten)]
pub(crate) normalizer: PyNormalizerTypeWrapper,
}

Expand Down Expand Up @@ -169,6 +169,16 @@ impl PyNormalizer {
ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?;
Ok(normalized.get().to_owned())
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
}

macro_rules! getter {
Expand Down
12 changes: 11 additions & 1 deletion bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use super::utils::*;
subclass
)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyPreTokenizer {
#[serde(flatten)]
pub(crate) pretok: PyPreTokenizerTypeWrapper,
}

Expand Down Expand Up @@ -181,6 +181,16 @@ impl PyPreTokenizer {
.map(|(s, o, _)| (s.to_owned(), o))
.collect())
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
}

macro_rules! getter {
Expand Down
12 changes: 11 additions & 1 deletion bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use tokenizers as tk;
subclass
)]
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyPostProcessor {
#[serde(flatten)]
pub processor: Arc<PostProcessorWrapper>,
}

Expand Down Expand Up @@ -139,6 +139,16 @@ impl PyPostProcessor {
.into_py()?;
Ok(final_encoding.into())
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
}

/// This post-processor takes care of adding the special tokens needed by
Expand Down
26 changes: 25 additions & 1 deletion bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use serde::Serialize;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};

Expand Down Expand Up @@ -462,7 +463,8 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
/// The core algorithm that this :obj:`Tokenizer` should be using.
///
#[pyclass(dict, module = "tokenizers", name = "Tokenizer")]
#[derive(Clone)]
#[derive(Clone, Serialize)]
#[serde(transparent)]
pub struct PyTokenizer {
tokenizer: Tokenizer,
}
Expand Down Expand Up @@ -638,6 +640,16 @@ impl PyTokenizer {
ToPyResult(self.tokenizer.save(path, pretty)).into()
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

/// Return the number of special tokens that would be added for single/pair sentences.
/// :param is_pair: Boolean indicating if the input would be a single sentence or a pair
/// :return:
Expand Down Expand Up @@ -1434,4 +1446,16 @@ mod test {

Tokenizer::from_file(&tmp).unwrap();
}

#[test]
fn serde_pyo3() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));

let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");
}
}
12 changes: 11 additions & 1 deletion bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use tokenizers as tk;
/// Trainer will return an instance of this class when instantiated.
#[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)]
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyTrainer {
#[serde(flatten)]
pub trainer: Arc<RwLock<TrainerWrapper>>,
}

Expand Down Expand Up @@ -69,6 +69,16 @@ impl PyTrainer {
Err(e) => Err(e),
}
}

fn __repr__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::repr(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}

fn __str__(&self) -> PyResult<String> {
crate::utils::serde_pyo3::to_string(self)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
}

impl Trainer for PyTrainer {
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod iterators;
mod normalization;
mod pretokenization;
mod regex;
pub mod serde_pyo3;

pub use iterators::*;
pub use normalization::*;
Expand Down
Loading

0 comments on commit ab9c7de

Please sign in to comment.