From 61804d9b73dc3f28621c141f89d7cc2be6a1c9ce Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 15:21:55 +0200 Subject: [PATCH 01/94] initial commit --- bindings/python/Cargo.toml | 1 + bindings/python/src/models.rs | 1 - tokenizers/Cargo.toml | 1 + tokenizers/src/models/bpe/model.rs | 5 +++-- tokenizers/src/models/mod.rs | 3 ++- tokenizers/src/models/unigram/model.rs | 4 +++- tokenizers/src/models/wordlevel/mod.rs | 5 +++-- tokenizers/src/models/wordpiece/mod.rs | 5 +++-- 8 files changed, 16 insertions(+), 9 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 3b1b1bbf1..14050874d 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -19,6 +19,7 @@ numpy = "0.21" ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" +derive_more = "0.99.17" [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 846bb61c0..b22e2eb19 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -15,7 +15,6 @@ use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; use tk::models::ModelWrapper; use tk::{Model, Token}; use tokenizers as tk; - use super::error::{deprecation_warning, ToPyResult}; /// Base class for all models diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 07cc85d1b..605725088 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -63,6 +63,7 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +derive_more = "0.99.17" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 618f42b47..4e7ec439d 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -11,6 +11,7 @@ use std::{ io::{BufRead, BufReader}, path::{Path, PathBuf}, }; +use derive_more::Display; pub type Vocab = HashMap; type VocabR = HashMap; @@ -202,9 +203,9 @@ impl BpeBuilder { }) } } - /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. -#[derive(PartialEq)] +#[derive(PartialEq, Display)] +#[display(fmt = "{:?} {:?} {:?} {:p}", vocab, merges, byte_fallback, ignore_merges)] pub struct BPE { /// The vocabulary assigns a number to each token. pub(crate) vocab: Vocab, diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index bb7cebc4c..59b643626 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize, Serializer}; +use derive_more::Display; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; @@ -57,7 +58,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { } } -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] +#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, Display)] #[serde(untagged)] pub enum ModelWrapper { BPE(BPE), diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index defc7d93d..8247c4c02 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -10,11 +10,13 @@ use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; - +use derive_more::Display; type TokenMap = HashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. +#[derive(Display)] +#[display(fmt = "{:?} {:?} {:?} {} {}", token_to_ids, vocab, unk_id, bos_id, eos_id)] pub struct Unigram { token_to_ids: TokenMap, pub(crate) vocab: Vocab, diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 3482ffee0..71d567b0e 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::{Path, PathBuf}; - +use derive_more::Display; mod serialization; mod trainer; @@ -94,7 +94,8 @@ impl WordLevelBuilder { } } -#[derive(PartialEq, Clone, Eq)] +#[derive(PartialEq, Clone, Eq, Display)] +#[display(fmt = "vocab={:?}, unk_token={}", vocab, unk_token)] pub struct WordLevel { vocab: HashMap, vocab_r: HashMap, diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 9baf24589..b50fbeaae 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -11,7 +11,7 @@ use std::{ io::{BufRead, BufReader}, path::{Path, PathBuf}, }; - +use derive_more::Display; mod serialization; mod trainer; pub use trainer::*; @@ -119,7 +119,8 @@ impl WordPieceBuilder { /// A /// [WordPiece](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf) /// model. -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Display)] +#[display(fmt = "vocab={:?}, unk_token={}, continuing_subword_prefix={:?}", vocab, unk_token, continuing_subword_prefix)] pub struct WordPiece { vocab: Vocab, vocab_r: VocabR, From a56da5f7375fa9c295a571fd1b967a8c236d3dff Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 15:25:22 +0200 Subject: [PATCH 02/94] will this work? --- bindings/python/src/models.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index b22e2eb19..b2321f559 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -15,6 +15,7 @@ use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; use tk::models::ModelWrapper; use tk::{Model, Token}; use tokenizers as tk; +use derive_more::Display; use super::error::{deprecation_warning, ToPyResult}; /// Base class for all models @@ -24,7 +25,8 @@ use super::error::{deprecation_warning, ToPyResult}; /// /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Display)] +#[display(fmt="{:?}",*self.model.as_ref().read().unwrap())] pub struct PyModel { #[serde(flatten)] pub model: Arc>, From f1a6a97c821567726f120bed409bd00d461f3c4f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 16:08:26 +0200 Subject: [PATCH 03/94] make it work for the model for now --- bindings/python/src/models.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index b2321f559..ea03c8527 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -26,7 +26,7 @@ use super::error::{deprecation_warning, ToPyResult}; /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] #[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt="{:?}",*self.model.as_ref().read().unwrap())] +#[display(fmt="{}","model.as_ref().read().unwrap()")] pub struct PyModel { #[serde(flatten)] pub model: Arc>, @@ -221,6 +221,9 @@ impl PyModel { fn get_trainer(&self, py: Python<'_>) -> PyResult { PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py) } + fn __str__(&self) -> PyResult { + Ok(format!("{}", self.model.read().unwrap())) + } } /// An implementation of the BPE (Byte-Pair Encoding) algorithm From 4a49530a0bc13058480e2f1fe847f946a5b9ef6e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 17:24:41 +0200 Subject: [PATCH 04/94] updates --- bindings/python/Cargo.toml | 1 + bindings/python/src/models.rs | 8 ++++---- tokenizers/Cargo.toml | 1 + tokenizers/src/models/bpe/model.rs | 17 +++++++++++++++-- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/models/unigram/model.rs | 11 +++++++++-- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 9 +++++++-- 8 files changed, 39 insertions(+), 12 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 14050874d..25e666b20 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -20,6 +20,7 @@ ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" +ellipse = "0.2.0" [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index ea03c8527..3eade26bb 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -2,8 +2,10 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; +use super::error::{deprecation_warning, ToPyResult}; use crate::token::PyToken; use crate::trainers::PyTrainer; +use derive_more::Display; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -15,8 +17,6 @@ use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; use tk::models::ModelWrapper; use tk::{Model, Token}; use tokenizers as tk; -use derive_more::Display; -use super::error::{deprecation_warning, ToPyResult}; /// Base class for all models /// @@ -26,7 +26,7 @@ use super::error::{deprecation_warning, ToPyResult}; /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] #[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt="{}","model.as_ref().read().unwrap()")] +#[display(fmt = "{}", "model.as_ref().read().unwrap()")] pub struct PyModel { #[serde(flatten)] pub model: Arc>, @@ -221,7 +221,7 @@ impl PyModel { fn get_trainer(&self, py: Python<'_>) -> PyResult { PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py) } - fn __str__(&self) -> PyResult { + fn __str__(&self) -> PyResult { Ok(format!("{}", self.model.read().unwrap())) } } diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 605725088..a48064ac3 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -64,6 +64,7 @@ getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" derive_more = "0.99.17" +ellipse = "0.2.0" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 4e7ec439d..5dc9087ce 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,6 +2,8 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::iter::ResultShunt; +use derive_more::Display; +use ellipse::Ellipse; use serde_json::Value; use std::borrow::Cow; use std::{ @@ -11,7 +13,6 @@ use std::{ io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -use derive_more::Display; pub type Vocab = HashMap; type VocabR = HashMap; @@ -203,9 +204,21 @@ impl BpeBuilder { }) } } + /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. #[derive(PartialEq, Display)] -#[display(fmt = "{:?} {:?} {:?} {:p}", vocab, merges, byte_fallback, ignore_merges)] +#[display( + fmt = "BPE(vocab={:?}, ...], merges={:?}, byte_fallback={:?}, ignore_merges={:?}", + "{ + let mut vocab_vec: Vec<_> = vocab.into_iter().collect(); + vocab_vec.sort_by_key(|&(_, v)| v); + vocab_vec.truncate(5); + vocab_vec + }", + byte_fallback, + byte_fallback, + ignore_merges +)] pub struct BPE { /// The vocabulary assigns a number to each token. pub(crate) vocab: Vocab, diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 59b643626..49a31211a 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,8 +8,8 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use serde::{Deserialize, Serialize, Serializer}; use derive_more::Display; +use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 8247c4c02..f3a32a8a0 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,17 +6,24 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; +use derive_more::Display; use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; -use derive_more::Display; type TokenMap = HashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. #[derive(Display)] -#[display(fmt = "{:?} {:?} {:?} {} {}", token_to_ids, vocab, unk_id, bos_id, eos_id)] +#[display( + fmt = "{:?} {:?} {:?} {} {}", + token_to_ids, + vocab, + unk_id, + bos_id, + eos_id +)] pub struct Unigram { token_to_ids: TokenMap, pub(crate) vocab: Vocab, diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 71d567b0e..4c5bdf90d 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,11 +1,11 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; +use derive_more::Display; use serde_json::Value; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::{Path, PathBuf}; -use derive_more::Display; mod serialization; mod trainer; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index b50fbeaae..c7c3bee98 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,6 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; +use derive_more::Display; use std::{ borrow::Cow, collections::HashMap, @@ -11,7 +12,6 @@ use std::{ io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -use derive_more::Display; mod serialization; mod trainer; pub use trainer::*; @@ -120,7 +120,12 @@ impl WordPieceBuilder { /// [WordPiece](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf) /// model. #[derive(Clone, PartialEq, Eq, Display)] -#[display(fmt = "vocab={:?}, unk_token={}, continuing_subword_prefix={:?}", vocab, unk_token, continuing_subword_prefix)] +#[display( + fmt = "vocab={:?}, unk_token={}, continuing_subword_prefix={:?}", + vocab, + unk_token, + continuing_subword_prefix +)] pub struct WordPiece { vocab: Vocab, vocab_r: VocabR, From f4af6162a8ba84c3d229924ec97c95d39402fc54 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 18:14:54 +0200 Subject: [PATCH 05/94] update --- tokenizers/src/models/bpe/model.rs | 60 +++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 5dc9087ce..88536105d 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -3,7 +3,6 @@ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::iter::ResultShunt; use derive_more::Display; -use ellipse::Ellipse; use serde_json::Value; use std::borrow::Cow; use std::{ @@ -206,19 +205,7 @@ impl BpeBuilder { } /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. -#[derive(PartialEq, Display)] -#[display( - fmt = "BPE(vocab={:?}, ...], merges={:?}, byte_fallback={:?}, ignore_merges={:?}", - "{ - let mut vocab_vec: Vec<_> = vocab.into_iter().collect(); - vocab_vec.sort_by_key(|&(_, v)| v); - vocab_vec.truncate(5); - vocab_vec - }", - byte_fallback, - byte_fallback, - ignore_merges -)] +#[derive(PartialEq)] pub struct BPE { /// The vocabulary assigns a number to each token. pub(crate) vocab: Vocab, @@ -262,6 +249,51 @@ impl std::fmt::Debug for BPE { } } +impl std::fmt::Display for BPE { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut vocab_vec: Vec<_> = self.vocab.iter().collect(); + vocab_vec.sort_by_key(|&(_, v)| v); + vocab_vec.truncate(5); + + let vocab_str: String = vocab_vec + .iter() + .map(|(k, v)| format!("'{}':{}", k, v)) + .collect::>() + .join(", "); + + let mut merges_vec: Vec<_> = self.merges.iter().collect(); + merges_vec.truncate(5); + merges_vec.sort_by_key(|&(_, v)| v); + + let merges_str: String = merges_vec + .iter() + .map(|((id1, id2), _)| { + ( + self.vocab_r.get(id1).cloned().unwrap_or_else(|| id1.to_string()), + self.vocab_r.get(id2).cloned().unwrap_or_else(|| id2.to_string()), + ) + }) + .map(|(id1, id2)| format!("('{}', '{}')", id1, id2)) + .collect::>() + .join(", "); + + + write!( + f, + "BPE(vocab={{{}, ...}}, merges=[{:?}, ...], dropout={:?}, unk_token={:?}, continuing_subword_prefix={:?}, end_of_word_suffix={:?}, fuse_unk={}, byte_fallback={}, ignore_merges={})", + vocab_str, + merges_str, + self.dropout, + self.unk_token, + self.continuing_subword_prefix, + self.end_of_word_suffix, + self.fuse_unk, + self.byte_fallback, + self.ignore_merges + ) + } +} + impl Default for BPE { fn default() -> Self { Self::builder().build().unwrap() From 88630dcb57bb2246cac3919921d7db859a8062ac Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 18:18:52 +0200 Subject: [PATCH 06/94] add metaspace --- tokenizers/src/pre_tokenizers/metaspace.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 52b415c9b..9fd61c54b 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use serde::{de, Deserialize, Deserializer, Serialize}; - +use derive_more::Display; /// Enum representing options for the metaspace prepending scheme. #[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)] #[serde(rename_all = "snake_case")] @@ -13,10 +13,11 @@ pub enum PrependScheme { Always, } -#[derive(Debug, Clone, PartialEq, Serialize, Eq)] +#[derive(Debug, Clone, PartialEq, Serialize, Eq, Display)] /// Replaces all the whitespaces by the provided meta character and then /// splits on this character #[serde(tag = "type")] +#[display(fmt="Metaspace(replacement={}, prepend_scheme={:?}, split={})", replacement, prepend_scheme, split)] pub struct Metaspace { replacement: char, pub prepend_scheme: PrependScheme, From b9d44da76db40c5611af974770c462db33d1ca35 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 18:32:46 +0200 Subject: [PATCH 07/94] update --- bindings/python/src/pre_tokenizers.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a2bd9b39c..a3de7d4d1 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -23,7 +23,7 @@ use tokenizers as tk; use super::error::ToPyResult; use super::utils::*; - +use derive_more::Display; /// Base class for all pre-tokenizers /// /// This class is not supposed to be instantiated directly. Instead, any implementation of a @@ -34,7 +34,8 @@ use super::utils::*; name = "PreTokenizer", subclass )] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Display)] +#[display(fmt="PreTokenizer(pretok={}","pretok")] pub struct PyPreTokenizer { #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, @@ -484,6 +485,7 @@ pub(crate) fn from_string(string: String) -> Result { /// token (relevant when special tokens are used or other pre_tokenizer are used). /// #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Metaspace")] +#[derive(Display)] pub struct PyMetaspace {} #[pymethods] impl PyMetaspace { @@ -650,10 +652,13 @@ impl Serialize for PyPreTokenizerWrapper { } } -#[derive(Clone, Deserialize)] +#[derive(Clone, Deserialize, Display)] #[serde(untagged)] +#[display(fmt="PreTokenizer.{}")] pub(crate) enum PyPreTokenizerTypeWrapper { + #[display(fmt="A")] Sequence(Vec>>), + #[display(fmt="B")] Single(Arc>), } From a90ec224c42720fcfa69a06a79600a0c8d6b4731 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 3 Jun 2024 18:41:42 +0200 Subject: [PATCH 08/94] does not work --- bindings/python/src/pre_tokenizers.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a3de7d4d1..a40baf1e7 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -485,7 +485,6 @@ pub(crate) fn from_string(string: String) -> Result { /// token (relevant when special tokens are used or other pre_tokenizer are used). /// #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Metaspace")] -#[derive(Display)] pub struct PyMetaspace {} #[pymethods] impl PyMetaspace { From 22242755dc2958ff2fbf56a39975bea5ba368b89 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 08:02:29 +0200 Subject: [PATCH 09/94] current modifications --- bindings/python/src/decoders.rs | 9 +++++--- bindings/python/src/normalizers.rs | 9 +++++--- bindings/python/src/pre_tokenizers.rs | 8 +++---- bindings/python/src/processors.rs | 4 ++-- bindings/python/src/tokenizer.rs | 24 ++++++++++----------- tokenizers/src/decoders/bpe.rs | 4 ++-- tokenizers/src/decoders/byte_fallback.rs | 5 +++-- tokenizers/src/decoders/ctc.rs | 10 +++++++-- tokenizers/src/decoders/fuse.rs | 5 +++-- tokenizers/src/decoders/mod.rs | 6 +++--- tokenizers/src/decoders/sequence.rs | 4 +++- tokenizers/src/decoders/strip.rs | 5 +++-- tokenizers/src/decoders/wordpiece.rs | 5 +++-- tokenizers/src/models/bpe/model.rs | 22 +++++++++++-------- tokenizers/src/normalizers/replace.rs | 10 +++++++-- tokenizers/src/pre_tokenizers/byte_level.rs | 14 ++++++++---- tokenizers/src/pre_tokenizers/metaspace.rs | 9 ++++++-- tokenizers/src/processors/bert.rs | 4 +++- tokenizers/src/processors/mod.rs | 7 +++--- tokenizers/src/processors/roberta.rs | 10 ++++++++- tokenizers/src/processors/sequence.rs | 5 +++-- tokenizers/src/processors/template.rs | 5 +++-- tokenizers/src/tokenizer/mod.rs | 19 ++++++++++++---- 23 files changed, 133 insertions(+), 70 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index ed21f3469..f824cf3ad 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -2,6 +2,7 @@ use std::sync::{Arc, RwLock}; use crate::pre_tokenizers::from_string; use crate::utils::PyPattern; +use derive_more::Display; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -28,7 +29,7 @@ use super::error::ToPyResult; /// This class is not supposed to be instantiated directly. Instead, any implementation of /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, Display)] pub struct PyDecoder { #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, @@ -478,7 +479,7 @@ impl PySequenceDecoder { } } -#[derive(Clone)] +#[derive(Clone, Display)] pub(crate) struct CustomDecoder { inner: PyObject, } @@ -531,10 +532,12 @@ impl<'de> Deserialize<'de> for CustomDecoder { } } -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, Display)] #[serde(untagged)] pub(crate) enum PyDecoderWrapper { + #[display(fmt = "{}", self)] Custom(Arc>), + #[display(fmt = "{}", self)] Wrapped(Arc>), } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 645852fa8..35ac69e77 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; - +use derive_more::Display; use crate::error::ToPyResult; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use serde::ser::SerializeStruct; @@ -43,7 +43,8 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Display)] +#[display(fmt = "{}", "normalizer")] pub struct PyNormalizer { #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, @@ -560,10 +561,12 @@ impl Serialize for PyNormalizerWrapper { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Display)] #[serde(untagged)] pub(crate) enum PyNormalizerTypeWrapper { + #[display(fmt="{}", "_0.iter().map(|arc| arc.as_ref().read().unwrap().to_string()).collect::>()")] Sequence(Vec>>), + #[display(fmt = "{}", self)] Single(Arc>), } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a40baf1e7..3e6289a7d 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,7 +35,7 @@ use derive_more::Display; subclass )] #[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt="PreTokenizer(pretok={}","pretok")] +#[display(fmt = "PreTokenizer(pretok={}", "pretok")] pub struct PyPreTokenizer { #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, @@ -653,11 +653,11 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize, Display)] #[serde(untagged)] -#[display(fmt="PreTokenizer.{}")] +#[display(fmt = "PreTokenizer.{}")] pub(crate) enum PyPreTokenizerTypeWrapper { - #[display(fmt="A")] + #[display(fmt = "A")] Sequence(Vec>>), - #[display(fmt="B")] + #[display(fmt = "B")] Single(Arc>), } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index c46d8ea49..7fdc68ef1 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; - +use derive_more::Display; use crate::encoding::PyEncoding; use crate::error::ToPyResult; use serde::{Deserialize, Serialize}; @@ -27,7 +27,7 @@ use tokenizers as tk; name = "PostProcessor", subclass )] -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, Display)] pub struct PyPostProcessor { #[serde(flatten)] pub processor: Arc, diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 1c6bc9cc1..b44c344ae 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,12 +1,23 @@ use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; +use super::decoders::PyDecoder; +use super::encoding::PyEncoding; +use super::error::{PyError, ToPyResult}; +use super::models::PyModel; +use super::normalizers::PyNormalizer; +use super::pre_tokenizers::PyPreTokenizer; +use super::trainers::PyTrainer; +use crate::processors::PyPostProcessor; +use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; +use derive_more::Display; use numpy::{npyffi, PyArray1}; use pyo3::class::basic::CompareOp; use pyo3::exceptions; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; +use std::collections::BTreeMap; use tk::models::bpe::BPE; use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, @@ -15,17 +26,6 @@ use tk::tokenizer::{ use tk::utils::iter::ResultShunt; use tokenizers as tk; -use super::decoders::PyDecoder; -use super::encoding::PyEncoding; -use super::error::{PyError, ToPyResult}; -use super::models::PyModel; -use super::normalizers::PyNormalizer; -use super::pre_tokenizers::PyPreTokenizer; -use super::trainers::PyTrainer; -use crate::processors::PyPostProcessor; -use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; -use std::collections::BTreeMap; - /// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`. /// It can have special options that defines the way it should behave. /// @@ -462,7 +462,7 @@ type Tokenizer = TokenizerImpl` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token #[non_exhaustive] +#[display(fmt = "ByteFallback")] pub struct ByteFallback { #[serde(rename = "type")] type_: MustBe!("ByteFallback"), diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index 2798638d4..7ef687404 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,15 +1,21 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; - +use derive_more::Display; use itertools::Itertools; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Display)] /// The CTC (Connectionist Temporal Classification) decoder takes care /// of sanitizing a list of inputs token. /// Due to some alignement problem the output of some models can come /// with duplicated token. #[serde(tag = "type")] +#[display( + fmt = "CTC(pad_token={}, word_delimiter_token={}, cleanup={}", + pad_token, + word_delimiter_token, + cleanup +)] #[non_exhaustive] pub struct CTC { /// The pad token used by CTC to delimit a new token. diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index 5e4a1c119..9208afa97 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,13 +1,14 @@ use crate::tokenizer::{Decoder, Result}; +use derive_more::Display; use monostate::MustBe; use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[derive(Clone, Debug, Serialize, Deserialize, Default, Display)] /// Fuse simply fuses all tokens into one big string. /// It's usually the last decoding step anyway, but this /// decoder exists incase some decoders need to happen after that /// step #[non_exhaustive] +#[display(fmt = "Fuse")] pub struct Fuse { #[serde(rename = "type")] type_: MustBe!("Fuse"), diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 682e63b50..f1da8df34 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -10,8 +10,6 @@ pub mod wordpiece; pub use super::pre_tokenizers::byte_level; pub use super::pre_tokenizers::metaspace; -use serde::{Deserialize, Serialize}; - use crate::decoders::bpe::BPEDecoder; use crate::decoders::byte_fallback::ByteFallback; use crate::decoders::ctc::CTC; @@ -23,8 +21,10 @@ use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; +use derive_more::Display; +use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Display)] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index 73169b695..a69ca6151 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,10 +1,12 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; +use derive_more::Display; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug)] #[macro_rules_attribute(impl_serde_type!)] +#[derive(Clone, Debug, Display)] +#[display(fmt = "{:?}", "decoders")] pub struct Sequence { decoders: Vec, } diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index b095fc37e..344d61489 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,13 +1,14 @@ use crate::tokenizer::{Decoder, Result}; +use derive_more::Display; use serde::{Deserialize, Serialize}; - -#[derive(Deserialize, Clone, Debug, Serialize, Default)] +#[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] /// Strip is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token #[serde(tag = "type")] #[non_exhaustive] +#[display(fmt = "Strip(content={}, start={}, stop={})", content, start, stop)] pub struct Strip { pub content: char, pub start: usize, diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 8ecd3987c..28494ec0f 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,11 +1,12 @@ use crate::tokenizer::{Decoder, Result}; +use derive_more::Display; use serde::{Deserialize, Serialize}; - -#[derive(Deserialize, Clone, Debug, Serialize)] +#[derive(Deserialize, Clone, Debug, Serialize, Display)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens /// back into a readable string. #[serde(tag = "type")] +#[display(fmt = "WordPiece(prefix={}, cleanup={:?}", prefix, cleanup)] #[non_exhaustive] pub struct WordPiece { /// The prefix to be used for continuing subwords diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 88536105d..fa0e2772f 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,7 +2,6 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::iter::ResultShunt; -use derive_more::Display; use serde_json::Value; use std::borrow::Cow; use std::{ @@ -260,25 +259,30 @@ impl std::fmt::Display for BPE { .map(|(k, v)| format!("'{}':{}", k, v)) .collect::>() .join(", "); - + let mut merges_vec: Vec<_> = self.merges.iter().collect(); - merges_vec.truncate(5); - merges_vec.sort_by_key(|&(_, v)| v); - + merges_vec.truncate(5); + merges_vec.sort_by_key(|&(_, v)| v); + let merges_str: String = merges_vec .iter() .map(|((id1, id2), _)| { ( - self.vocab_r.get(id1).cloned().unwrap_or_else(|| id1.to_string()), - self.vocab_r.get(id2).cloned().unwrap_or_else(|| id2.to_string()), + self.vocab_r + .get(id1) + .cloned() + .unwrap_or_else(|| id1.to_string()), + self.vocab_r + .get(id2) + .cloned() + .unwrap_or_else(|| id2.to_string()), ) }) .map(|(id1, id2)| format!("('{}', '{}')", id1, id2)) .collect::>() .join(", "); - - write!( + write!( f, "BPE(vocab={{{}, ...}}, merges=[{:?}, ...], dropout={:?}, unk_token={:?}, continuing_subword_prefix={:?}, end_of_word_suffix={:?}, fuse_unk={}, byte_fallback={}, ignore_merges={})", vocab_str, diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index cdd4a420a..2cd39c06f 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,8 +2,8 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; +use derive_more::Display; use serde::{Deserialize, Serialize}; - /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum ReplacePattern { @@ -42,8 +42,14 @@ impl std::convert::TryFrom for Replace { /// This normalizer will take a `pattern` (for now only a String) /// and replace every occurrence with `content`. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Display)] #[serde(tag = "type", try_from = "ReplaceDeserializer")] +#[display( + fmt = "Replace(pattern={:?}, content={}, regex={:?}", + pattern, + content, + regex +)] pub struct Replace { pattern: ReplacePattern, content: String, diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 6343bbd07..4d2b131d6 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,13 +1,13 @@ use std::collections::{HashMap, HashSet}; -use crate::utils::SysRegex; -use serde::{Deserialize, Serialize}; - use crate::tokenizer::{ Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::macro_rules_attribute; +use crate::utils::SysRegex; +use derive_more::Display; +use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. /// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 @@ -46,11 +46,17 @@ lazy_static! { bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. #[macro_rules_attribute(impl_serde_type!)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] +#[display( + fmt = "ByteLevel(add_prefix_space={},trim_offset={:?}, use_regex={}", + add_prefix_space, + trim_offsets, + use_regex +)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 9fd61c54b..40c59f54c 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -use serde::{de, Deserialize, Deserializer, Serialize}; use derive_more::Display; +use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. #[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)] #[serde(rename_all = "snake_case")] @@ -17,7 +17,12 @@ pub enum PrependScheme { /// Replaces all the whitespaces by the provided meta character and then /// splits on this character #[serde(tag = "type")] -#[display(fmt="Metaspace(replacement={}, prepend_scheme={:?}, split={})", replacement, prepend_scheme, split)] +#[display( + fmt = "Metaspace(replacement={}, prepend_scheme={:?}, split={})", + replacement, + prepend_scheme, + split +)] pub struct Metaspace { replacement: char, pub prepend_scheme: PrependScheme, diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index 627f9d180..bf0dbad07 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,10 +1,12 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; +use derive_more::Display; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Display)] #[serde(tag = "type")] +#[display(fmt = "BertProcessing(sep={:?}, cls={:?})", sep, cls)] pub struct BertProcessing { sep: (String, u32), cls: (String, u32), diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 130a537ba..ed8a86aaa 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -6,17 +6,18 @@ pub mod template; // Re-export these as processors pub use super::pre_tokenizers::byte_level; -use serde::{Deserialize, Serialize}; - use crate::pre_tokenizers::byte_level::ByteLevel; use crate::processors::bert::BertProcessing; use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; use crate::processors::template::TemplateProcessing; use crate::{Encoding, PostProcessor, Result}; +use derive_more::Display; +use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq)] +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, Display)] #[serde(untagged)] +#[display(fmt = "{}")] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) Roberta(RobertaProcessing), diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 3af9a8d60..08857adaf 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,10 +1,18 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; +use derive_more::Display; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Display)] +#[display( + fmt = "RobertaProcessing(sep={:?}, cls={:?}, trim_offsets={}, add_prefix_space={}", + sep, + cls, + trim_offsets, + add_prefix_space +)] #[serde(tag = "type")] pub struct RobertaProcessing { sep: (String, u32), diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 66c670ad8..413b5f1cc 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -1,10 +1,11 @@ use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; +use derive_more::Display; use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, PartialEq, Eq)] #[macro_rules_attribute(impl_serde_type!)] +#[derive(Clone, Debug, PartialEq, Eq, Display)] +#[display(fmt = "{:?}", self)] pub struct Sequence { processors: Vec, } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index c5aaa55db..2420d787e 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -56,12 +56,12 @@ //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! use crate::{Encoding, PostProcessor, Result}; +use derive_more::Display; use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; use std::result::Result as StdResult; - /// Represents any sequences received as input of the PostProcessor #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum Sequence { @@ -332,8 +332,9 @@ impl From> for Tokens { /// .unwrap(); /// ``` /// -#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, Display)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] +#[display(fmt = "TemplateProcessing({:?})", self)] #[builder(build_fn(validate = "Self::validate"))] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ebc68dfb1..0f1786cf5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -18,12 +18,12 @@ use std::{ path::{Path, PathBuf}, }; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; - use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use derive_more::Display; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; mod added_vocabulary; mod encoding; @@ -508,7 +508,18 @@ impl DerefMut for Tokenizer { pub struct TruncationParamError(String); /// A `Tokenizer` is capable of encoding/decoding any text. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Display)] +#[display( + fmt = "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_vocabulary={:?}, truncation={:?}, padding={:?}", + normalizer, + pre_tokenizer, + model, + post_processor, + decoder, + added_vocabulary, + truncation, + padding +)] pub struct TokenizerImpl { // Tokenizer parts normalizer: Option, From 4d9204e5e6ae2768ef27300dcc2fd987c3733b53 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 09:03:29 +0200 Subject: [PATCH 10/94] current status --- bindings/python/src/normalizers.rs | 29 ++++++++-- tokenizers/Cargo.toml | 1 + tokenizers/display_derive/Cargo.toml | 13 +++++ tokenizers/display_derive/src/lib.rs | 57 ++++++++++++++++++++ tokenizers/src/normalizers/bert.rs | 3 +- tokenizers/src/normalizers/mod.rs | 6 ++- tokenizers/src/normalizers/prepend.rs | 3 +- tokenizers/src/normalizers/strip.rs | 6 +-- tokenizers/src/normalizers/unicode.rs | 11 ++-- tokenizers/src/normalizers/utils.rs | 6 +-- tokenizers/src/tokenizer/added_vocabulary.rs | 3 +- tokenizers/src/tokenizer/mod.rs | 14 +---- 12 files changed, 119 insertions(+), 33 deletions(-) create mode 100644 tokenizers/display_derive/Cargo.toml create mode 100644 tokenizers/display_derive/src/lib.rs diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 35ac69e77..482e0cdd6 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -499,7 +499,7 @@ impl PyReplace { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Display)] pub(crate) struct CustomNormalizer { inner: PyObject, } @@ -542,10 +542,12 @@ impl<'de> Deserialize<'de> for CustomNormalizer { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Display)] #[serde(untagged)] pub(crate) enum PyNormalizerWrapper { + #[display(fmt="{}", "_0.inner")] Custom(CustomNormalizer), + #[display(fmt="{}", "_0")] Wrapped(NormalizerWrapper), } @@ -561,15 +563,32 @@ impl Serialize for PyNormalizerWrapper { } } -#[derive(Debug, Clone, Deserialize, Display)] +#[derive(Debug, Clone, Deserialize)] #[serde(untagged)] pub(crate) enum PyNormalizerTypeWrapper { - #[display(fmt="{}", "_0.iter().map(|arc| arc.as_ref().read().unwrap().to_string()).collect::>()")] Sequence(Vec>>), - #[display(fmt = "{}", self)] Single(Arc>), } +// Implement the Display trait for PyNormalizerTypeWrapper +impl std::fmt::Display for PyNormalizerTypeWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + PyNormalizerTypeWrapper::Sequence(decoders) => { + for decoder in decoders { + let decoder = decoder.as_ref().read().unwrap(); + writeln!(f, "{}", decoder)?; + } + Ok(()) + } + PyNormalizerTypeWrapper::Single(decoder) => { + let decoder = decoder.as_ref().read().unwrap(); + write!(f, "{}", decoder) + } + } + } +} + impl Serialize for PyNormalizerTypeWrapper { fn serialize(&self, serializer: S) -> Result where diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index a48064ac3..6b6a7f375 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -65,6 +65,7 @@ esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" derive_more = "0.99.17" ellipse = "0.2.0" +display_derive = { path = "display_derive" } [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml new file mode 100644 index 000000000..ec28e3c26 --- /dev/null +++ b/tokenizers/display_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "display_derive" +version = "0.1.0" +edition = "2021" + +[dependencies] +syn = "1.0" +quote = "1.0" +proc-macro2 = "1.0" +ellipse = "0.2.0" +utils = "0.0.3" +[lib] +proc-macro = true diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs new file mode 100644 index 000000000..9b05196b9 --- /dev/null +++ b/tokenizers/display_derive/src/lib.rs @@ -0,0 +1,57 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use utils::truncate_with_ellipsis; +#[proc_macro_derive(StructDisplay)] +pub fn display_derive(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let input = parse_macro_input!(input as DeriveInput); + + // Get the name of the struct + let name = input.ident; + + // Generate code to match the struct's fields + let expanded = match input.data { + Data::Struct(data) => { + match data.fields { + Fields::Named(fields) => { + // If the struct has named fields + let field_names = fields.named.iter().map(|f| &f.ident); + let field_names2 = field_names.clone(); + quote! { + impl std::fmt::Display for #name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}(", stringify!(#name))?; + #( + let value_str = self.#field_names2.to_string(); + let truncated_value_str = truncate_with_ellipsis(value_str, 10); + write!(f, "{}={}", stringify!(#field_names), truncated_value_str)?; + if stringify!(#field_names) != stringify!(#field_names2.clone().last().unwrap()) { + write!(f, ", ")?; + } + )* + write!(f, ")") + } + } + } + }, + Fields::Unit => { + // If the struct has no fields + quote! { + impl std::fmt::Display for #name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", stringify!(#name)) + } + } + } + }, + _ => unimplemented!(), + } + }, + _ => unimplemented!(), + }; + + // Convert into a token stream and return it + TokenStream::from(expanded) +} \ No newline at end of file diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 90d982c68..9a1e68ed7 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -2,6 +2,7 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; +use display_derive::StructDisplay; /// Checks whether a character is whitespace fn is_whitespace(c: char) -> bool { @@ -47,7 +48,7 @@ fn is_chinese_char(c: char) -> bool { ) } -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize, StructDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct BertNormalizer { diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 8ac4c58ec..273eb154d 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -15,12 +15,13 @@ pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; use serde::{Deserialize, Serialize}; - +use derive_more::Display; use crate::{NormalizedString, Normalizer}; /// Wrapper for known Normalizers. -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, Display)] #[serde(untagged)] +#[display(fmt="{}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), @@ -32,6 +33,7 @@ pub enum NormalizerWrapper { Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), + #[display(fmt="Precompiled")] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index 4e318c259..7cdd7245d 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,7 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, StructDisplay)] #[serde(tag = "type")] pub struct Prepend { pub prepend: String, diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 19f5ff314..265cc0cae 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -2,8 +2,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; - -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +use display_derive::StructDisplay; +#[derive(Copy, Clone, Debug, Deserialize, Serialize,StructDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct Strip { @@ -43,7 +43,7 @@ impl Normalizer for Strip { // This normalizer removes combining marks from a normalized string // It's different from unidecode as it does not attempt to modify // non ascii languages. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct StripAccents; diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 502b4239b..9a1e657cd 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,7 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; +use display_derive::StructDisplay; -#[derive(Default, Copy, Clone, Debug)] +#[derive(Default, Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFD; impl Normalizer for NFD { @@ -11,7 +12,7 @@ impl Normalizer for NFD { } } -#[derive(Default, Copy, Clone, Debug)] +#[derive(Default, Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKD; impl Normalizer for NFKD { @@ -21,7 +22,7 @@ impl Normalizer for NFKD { } } -#[derive(Default, Copy, Clone, Debug)] +#[derive(Default, Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFC; impl Normalizer for NFC { @@ -31,7 +32,7 @@ impl Normalizer for NFC { } } -#[derive(Default, Copy, Clone, Debug)] +#[derive(Default, Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKC; impl Normalizer for NFKC { @@ -72,7 +73,7 @@ fn do_nmt(normalized: &mut NormalizedString) { }); } -#[derive(Default, Copy, Clone, Debug)] +#[derive(Default, Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Nmt; impl Normalizer for Nmt { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index a7730a3f8..c917e0dfa 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; - -#[derive(Clone, Deserialize, Debug, Serialize)] +use display_derive::StructDisplay; +#[derive(Clone, Deserialize, Debug, Serialize, StructDisplay)] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. @@ -36,7 +36,7 @@ impl Normalizer for Sequence { } /// Lowercases the input -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Lowercase; impl Normalizer for Lowercase { diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 301d9bc81..f2d3c78ff 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,6 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; +use display_derive::StructDisplay; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -138,7 +139,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// were to add new tokens after this training process, we couldn't make sure the merges pairs /// exist as required. /// -#[derive(Clone, Debug)] +#[derive(Clone, Debug, StructDisplay)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 0f1786cf5..6bed8dc3f 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -22,6 +22,7 @@ use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; use derive_more::Display; +use display_derive::StructDisplay; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -508,18 +509,7 @@ impl DerefMut for Tokenizer { pub struct TruncationParamError(String); /// A `Tokenizer` is capable of encoding/decoding any text. -#[derive(Clone, Debug, Display)] -#[display( - fmt = "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_vocabulary={:?}, truncation={:?}, padding={:?}", - normalizer, - pre_tokenizer, - model, - post_processor, - decoder, - added_vocabulary, - truncation, - padding -)] +#[derive(Clone, Debug, StructDisplay)] pub struct TokenizerImpl { // Tokenizer parts normalizer: Option, From 4c2aca1e96050d787f1611fe743d2e7f7ca445cf Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 10:12:52 +0200 Subject: [PATCH 11/94] working shit --- bindings/python/src/normalizers.rs | 4 ++-- tokenizers/display_derive/Cargo.toml | 2 +- tokenizers/display_derive/src/lib.rs | 21 ++++++++++++++++----- tokenizers/src/tokenizer/mod.rs | 5 +++-- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 482e0cdd6..1e51ee563 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -43,8 +43,8 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt = "{}", "normalizer")] +#[derive(Clone, Serialize, Deserialize, Display, Debug)] +#[display(fmt = "{}", normalizer)] pub struct PyNormalizer { #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml index ec28e3c26..299b6ae5f 100644 --- a/tokenizers/display_derive/Cargo.toml +++ b/tokenizers/display_derive/Cargo.toml @@ -8,6 +8,6 @@ syn = "1.0" quote = "1.0" proc-macro2 = "1.0" ellipse = "0.2.0" -utils = "0.0.3" + [lib] proc-macro = true diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 9b05196b9..2c58d54e0 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -2,7 +2,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields}; -use utils::truncate_with_ellipsis; + #[proc_macro_derive(StructDisplay)] pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree @@ -23,13 +23,24 @@ pub fn display_derive(input: TokenStream) -> TokenStream { impl std::fmt::Display for #name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}(", stringify!(#name))?; + let mut first = true; #( - let value_str = self.#field_names2.to_string(); - let truncated_value_str = truncate_with_ellipsis(value_str, 10); - write!(f, "{}={}", stringify!(#field_names), truncated_value_str)?; - if stringify!(#field_names) != stringify!(#field_names2.clone().last().unwrap()) { + if!first { write!(f, ", ")?; } + first = false; + + let field_value = &self.#field_names2; + write!(f, "{}=", stringify!(#field_names))?; + { + let s = format!("{:?}", field_value); + let mut chars = s.chars(); + let mut prefix = (&mut chars).take(10 - 1).collect::(); + if chars.next().is_some() { + prefix.push('…'); + } + write!(f, "{}", prefix)?; + } )* write!(f, ")") } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 6bed8dc3f..0489dd95f 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -22,7 +22,6 @@ use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; use derive_more::Display; -use display_derive::StructDisplay; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -509,7 +508,9 @@ impl DerefMut for Tokenizer { pub struct TruncationParamError(String); /// A `Tokenizer` is capable of encoding/decoding any text. -#[derive(Clone, Debug, StructDisplay)] +#[derive(Clone, Debug, Display)] +#[display(bound = "M: Model, N: Normalizer + std::fmt::Display, PT: PreTokenizer + std::fmt::Display")] +#[display(fmt = "{} {} {}", "normalizer.as_ref().unwrap()", "pre_tokenizer.as_ref().unwrap()", "model")] pub struct TokenizerImpl { // Tokenizer parts normalizer: Option, From 904ce70a3c85ac46431d0ceb330b7b9573b67a1d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 13:22:17 +0200 Subject: [PATCH 12/94] this kinda works --- bindings/python/src/decoders.rs | 15 +++++++++++++++ bindings/python/src/normalizers.rs | 17 ++++++++++++----- tokenizers/display_derive/src/lib.rs | 6 +++--- tokenizers/src/normalizers/bert.rs | 6 +++--- tokenizers/src/normalizers/utils.rs | 5 ++++- tokenizers/src/tokenizer/added_vocabulary.rs | 5 +++-- 6 files changed, 40 insertions(+), 14 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index f824cf3ad..782055bc8 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -23,6 +23,16 @@ use tk::Decoder; use tokenizers as tk; use super::error::ToPyResult; +pub trait PythonStr{ + fn __str__(&self) -> PyResult; +} + +impl PythonStr for S{ + fn __str__(&self) -> PyResult{ + let s = format!("WOWOWO{}", self); + Ok(s.into()) + } +} /// Base class for all decoders /// @@ -30,6 +40,7 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize, Display)] +#[display(fmt = "decoder.{}", decoder)] pub struct PyDecoder { #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, @@ -62,6 +73,10 @@ impl PyDecoder { }, }) } + + fn __str__(&self) -> PyResult { + Ok(format!("{}", self.decoder)) + } } impl Decoder for PyDecoder { diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 1e51ee563..bda25f2b4 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -90,6 +90,7 @@ impl PyNormalizer { }, }) } + } impl Normalizer for PyNormalizer { @@ -167,6 +168,10 @@ impl PyNormalizer { ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?; Ok(normalized.get().to_owned()) } + + fn __str__(&self) -> PyResult{ + Ok(format!("{}", self.normalizer)) + } } macro_rules! getter { @@ -574,15 +579,17 @@ pub(crate) enum PyNormalizerTypeWrapper { impl std::fmt::Display for PyNormalizerTypeWrapper { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - PyNormalizerTypeWrapper::Sequence(decoders) => { + PyNormalizerTypeWrapper::Sequence(ref decoders) => { for decoder in decoders { - let decoder = decoder.as_ref().read().unwrap(); + let decoder = decoder.read().unwrap(); writeln!(f, "{}", decoder)?; + } - Ok(()) + writeln!(f, "]")?; + Ok(()) } - PyNormalizerTypeWrapper::Single(decoder) => { - let decoder = decoder.as_ref().read().unwrap(); + PyNormalizerTypeWrapper::Single(ref decoder) => { + let decoder = decoder.read().unwrap(); write!(f, "{}", decoder) } } diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 2c58d54e0..04f62d34d 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -25,7 +25,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { write!(f, "{}(", stringify!(#name))?; let mut first = true; #( - if!first { + if !first { write!(f, ", ")?; } first = false; @@ -33,9 +33,9 @@ pub fn display_derive(input: TokenStream) -> TokenStream { let field_value = &self.#field_names2; write!(f, "{}=", stringify!(#field_names))?; { - let s = format!("{:?}", field_value); + let s = format!("{}", field_value); let mut chars = s.chars(); - let mut prefix = (&mut chars).take(10 - 1).collect::(); + let mut prefix = (&mut chars).take(100000 - 1).collect::(); if chars.next().is_some() { prefix.push('…'); } diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 9a1e68ed7..41f60dcf3 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -2,8 +2,7 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; -use display_derive::StructDisplay; - +use derive_more::Display; /// Checks whether a character is whitespace fn is_whitespace(c: char) -> bool { // These are technically control characters but we count them as whitespace @@ -48,7 +47,8 @@ fn is_chinese_char(c: char) -> bool { ) } -#[derive(Copy, Clone, Debug, Deserialize, Serialize, StructDisplay)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize, Display)] +#[display(fmt="BertNormalizer(clean_text={}, handle_chinese_chars={}, strip_accents={:?}, lower_case={})",clean_text, handle_chinese_chars, strip_accents, lowercase)] #[serde(tag = "type")] #[non_exhaustive] pub struct BertNormalizer { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index c917e0dfa..f9a505fac 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,8 +3,10 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; +use derive_more::Display; use display_derive::StructDisplay; -#[derive(Clone, Deserialize, Debug, Serialize, StructDisplay)] +#[derive(Clone, Deserialize, Debug, Serialize, Display)] +#[display(fmt="{}", "normalizers.iter().map(|n| n.to_string()).collect::()")] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. @@ -39,6 +41,7 @@ impl Normalizer for Sequence { #[derive(Copy, Clone, Debug, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Lowercase; + impl Normalizer for Lowercase { fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { normalized.lowercase(); diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index f2d3c78ff..b49bef31f 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,7 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use display_derive::StructDisplay; +use derive_more::Display; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -139,7 +139,8 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// were to add new tokens after this training process, we couldn't make sure the merges pairs /// exist as required. /// -#[derive(Clone, Debug, StructDisplay)] +#[derive(Clone, Debug, Display)] +#[display(fmt="{:?}, encode_special_tokens={}", added_tokens_map, encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. From 6413810b6e961f4740ea0471c7ea4b049859e424 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 14:10:45 +0200 Subject: [PATCH 13/94] finallllly! --- bindings/python/src/decoders.rs | 26 +++++--------- bindings/python/src/tokenizer.rs | 4 +++ tokenizers/src/tokenizer/mod.rs | 57 ++++++++++++++++++++++++++++-- tokenizers/src/utils/padding.rs | 9 +++-- tokenizers/src/utils/truncation.rs | 4 ++- 5 files changed, 76 insertions(+), 24 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 782055bc8..bac3b5cb1 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -23,16 +23,6 @@ use tk::Decoder; use tokenizers as tk; use super::error::ToPyResult; -pub trait PythonStr{ - fn __str__(&self) -> PyResult; -} - -impl PythonStr for S{ - fn __str__(&self) -> PyResult{ - let s = format!("WOWOWO{}", self); - Ok(s.into()) - } -} /// Base class for all decoders /// @@ -40,7 +30,7 @@ impl PythonStr for S{ /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize, Display)] -#[display(fmt = "decoder.{}", decoder)] +#[display(fmt = "{}", decoder)] pub struct PyDecoder { #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, @@ -73,10 +63,6 @@ impl PyDecoder { }, }) } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.decoder)) - } } impl Decoder for PyDecoder { @@ -130,6 +116,10 @@ impl PyDecoder { fn decode(&self, tokens: Vec) -> PyResult { ToPyResult(self.decoder.decode(tokens)).into() } + + fn __str__(&self) -> PyResult { + Ok(format!("{}", self.decoder)) + } } macro_rules! getter { @@ -496,7 +486,7 @@ impl PySequenceDecoder { #[derive(Clone, Display)] pub(crate) struct CustomDecoder { - inner: PyObject, + pub inner: PyObject, } impl CustomDecoder { @@ -550,9 +540,9 @@ impl<'de> Deserialize<'de> for CustomDecoder { #[derive(Clone, Deserialize, Serialize, Display)] #[serde(untagged)] pub(crate) enum PyDecoderWrapper { - #[display(fmt = "{}", self)] + #[display(fmt = "{}", "_0.as_ref().read().unwrap().inner")] Custom(Arc>), - #[display(fmt = "{}", self)] + #[display(fmt = "{}", "_0.as_ref().read().unwrap()")] Wrapped(Arc>), } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index b44c344ae..d7469c1dc 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1410,6 +1410,10 @@ impl PyTokenizer { fn set_decoder(&mut self, decoder: PyRef) { self.tokenizer.with_decoder(decoder.clone()); } + + fn __str__(&self) -> PyResult{ + Ok(format!("{}", self.tokenizer)) + } } #[cfg(test)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 0489dd95f..672fbde2a 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -508,9 +508,7 @@ impl DerefMut for Tokenizer { pub struct TruncationParamError(String); /// A `Tokenizer` is capable of encoding/decoding any text. -#[derive(Clone, Debug, Display)] -#[display(bound = "M: Model, N: Normalizer + std::fmt::Display, PT: PreTokenizer + std::fmt::Display")] -#[display(fmt = "{} {} {}", "normalizer.as_ref().unwrap()", "pre_tokenizer.as_ref().unwrap()", "model")] +#[derive(Clone, Debug)] pub struct TokenizerImpl { // Tokenizer parts normalizer: Option, @@ -527,6 +525,59 @@ pub struct TokenizerImpl { padding: Option, } +impl std::fmt::Display for TokenizerImpl +where + M: std::fmt::Display, + N: std::fmt::Display, + PT: std::fmt::Display, + PP: std::fmt::Display, + D: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let normalizer_str = match &self.normalizer { + Some(n) => format!("{}", n), + None => "None".to_string(), + }; + println!("Normalizer"); + let pre_tokenizer_str = match &self.pre_tokenizer { + Some(pt) => format!("{}", pt), + None => "".to_string(), + }; + println!("PreTok"); + let post_processor_str = match &self.post_processor { + Some(pp) => format!("{}", pp), + None => "None".to_string(), + }; + println!("decoder"); + let decoder_str = match &self.decoder { + Some(d) => format!("{}", d), + None => "None".to_string(), + }; + println!("truncation"); + let truncation_str = match &self.truncation { + Some(t) => format!("{}", t), + None => "None".to_string(), + }; + println!("padding"); + let padding_str = match &self.padding { + Some(p) => format!("{}", p), + None => "None".to_string(), + }; + + write!( + f, + "{} {} {} {} {} {} {} {}", + normalizer_str, + pre_tokenizer_str, + self.model.to_string(), + post_processor_str, + decoder_str, + self.added_vocabulary.to_string(), + truncation_str, + padding_str + ) + } +} impl TokenizerImpl where M: Model, diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 39585a304..eded5ad0a 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,6 +1,7 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; use serde::{Deserialize, Serialize}; +use derive_more::Display; /// The various possible padding directions. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] @@ -18,7 +19,10 @@ impl std::convert::AsRef for PaddingDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize)] + +#[derive(Debug, Clone, Serialize, Deserialize, Display)] +// #[display(fmt="Strategy: {:?}, Direction: {:?}, Pad to multiple of: {:?}, Pad ID: {}, Pad Type ID: {}, Pad Token: {}", strategy, direction, pad_to_multiple_of, pad_id, pad_type_id, pad_token)] +#[display(fmt="Strategy:")] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, @@ -41,7 +45,8 @@ impl Default for PaddingParams { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Display)] +#[display(fmt={})] pub enum PaddingStrategy { BatchLongest, Fixed(usize), diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index a8ad2a614..870fbc503 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,4 +1,5 @@ use crate::tokenizer::{Encoding, Result}; +use derive_more::Display; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; @@ -19,7 +20,8 @@ impl std::convert::AsRef for TruncationDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Display)] +#[display(fmt="direction={:?}, max_length={}, strategy={:?}, stride={}", direction, max_length, strategy, stride)] pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, From fda66f59f6e935d64f8b21d454bef379d0907c0c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 14:29:44 +0200 Subject: [PATCH 14/94] nits --- bindings/python/src/tokenizer.rs | 4 ++++ tokenizers/src/tokenizer/mod.rs | 10 ++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index d7469c1dc..637d34f14 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1414,6 +1414,10 @@ impl PyTokenizer { fn __str__(&self) -> PyResult{ Ok(format!("{}", self.tokenizer)) } + + fn __repr__(&self) -> PyResult{ + Ok(format!("{}", self.tokenizer)) + } } #[cfg(test)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 672fbde2a..da402cafb 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -21,7 +21,6 @@ use std::{ use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; -use derive_more::Display; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -538,27 +537,22 @@ where Some(n) => format!("{}", n), None => "None".to_string(), }; - println!("Normalizer"); let pre_tokenizer_str = match &self.pre_tokenizer { Some(pt) => format!("{}", pt), - None => "".to_string(), + None => "None".to_string(), }; - println!("PreTok"); let post_processor_str = match &self.post_processor { Some(pp) => format!("{}", pp), None => "None".to_string(), }; - println!("decoder"); let decoder_str = match &self.decoder { Some(d) => format!("{}", d), None => "None".to_string(), }; - println!("truncation"); let truncation_str = match &self.truncation { Some(t) => format!("{}", t), None => "None".to_string(), }; - println!("padding"); let padding_str = match &self.padding { Some(p) => format!("{}", p), None => "None".to_string(), @@ -566,7 +560,7 @@ where write!( f, - "{} {} {} {} {} {} {} {}", + "Tokenizer(normalizer={},\npre_tokenizer={},\nmodel={},\npost_processor={},\ndecoder={},\nadded_vocab={},\ntruncation={},\npadding={}\n)", normalizer_str, pre_tokenizer_str, self.model.to_string(), From 20c9fc4fe8d90911d3c9437d306689bcf673d4f2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 14:36:59 +0200 Subject: [PATCH 15/94] updates --- tokenizers/src/decoders/sequence.rs | 2 +- tokenizers/src/processors/sequence.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index a69ca6151..efa330f51 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, Display)] -#[display(fmt = "{:?}", "decoders")] +#[display(fmt = "[{}]", "decoders.iter().map(|d| d.to_string()).collect::>().join(\", \")")] pub struct Sequence { decoders: Vec, } diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 413b5f1cc..76c8d0d1e 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -5,7 +5,7 @@ use derive_more::Display; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, Display)] -#[display(fmt = "{:?}", self)] +#[display(fmt = "[{}]", "processors.iter().map(|d| d.to_string()).collect::>().join(\", \")")] pub struct Sequence { processors: Vec, } From 86c77b6eaa476d87cfd12c2269f806ee3961b065 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 14:50:55 +0200 Subject: [PATCH 16/94] almost there --- bindings/python/src/decoders.rs | 4 ++++ bindings/python/src/models.rs | 3 +++ bindings/python/src/normalizers.rs | 5 ++++- bindings/python/src/pre_tokenizers.rs | 16 ++++++++++++---- tokenizers/src/normalizers/utils.rs | 2 +- 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index bac3b5cb1..3d88c15d1 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -120,6 +120,10 @@ impl PyDecoder { fn __str__(&self) -> PyResult { Ok(format!("{}", self.decoder)) } + + fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.decoder)) + } } macro_rules! getter { diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 3eade26bb..4c3a9cdd4 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -224,6 +224,9 @@ impl PyModel { fn __str__(&self) -> PyResult { Ok(format!("{}", self.model.read().unwrap())) } + fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.model.read().unwrap())) + } } /// An implementation of the BPE (Byte-Pair Encoding) algorithm diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index bda25f2b4..762e60f3b 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -172,6 +172,9 @@ impl PyNormalizer { fn __str__(&self) -> PyResult{ Ok(format!("{}", self.normalizer)) } + fn __repr__(&self) -> PyResult{ + Ok(format!("{}", self.normalizer)) + } } macro_rules! getter { @@ -585,7 +588,7 @@ impl std::fmt::Display for PyNormalizerTypeWrapper { writeln!(f, "{}", decoder)?; } - writeln!(f, "]")?; + writeln!(f, "?????")?; Ok(()) } PyNormalizerTypeWrapper::Single(ref decoder) => { diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 3e6289a7d..e9e749589 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,7 +35,7 @@ use derive_more::Display; subclass )] #[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt = "PreTokenizer(pretok={}", "pretok")] +#[display(fmt = "PreTokenizer(pretok={})", pretok)] pub struct PyPreTokenizer { #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, @@ -182,6 +182,14 @@ impl PyPreTokenizer { .map(|(s, o, _)| (s.to_owned(), o)) .collect()) } + + fn __str__(&self) -> PyResult{ + Ok(format!("{}", self.pretok)) + } + + fn __repr__(&self) -> PyResult{ + Ok(format!("{}", self.pretok)) + } } macro_rules! getter { @@ -653,11 +661,11 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize, Display)] #[serde(untagged)] -#[display(fmt = "PreTokenizer.{}")] +// #[display(fmt = "")] pub(crate) enum PyPreTokenizerTypeWrapper { - #[display(fmt = "A")] + #[display(fmt = "[{}]", "_0.iter().map(|d| d.as_ref().read().unwrap().to_string()).collect::>().join(\", \")")] Sequence(Vec>>), - #[display(fmt = "B")] + #[display(fmt = "_0.as_ref().read().unwrap()")] Single(Arc>), } diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index f9a505fac..e9b126569 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -6,7 +6,7 @@ use crate::utils::macro_rules_attribute; use derive_more::Display; use display_derive::StructDisplay; #[derive(Clone, Deserialize, Debug, Serialize, Display)] -#[display(fmt="{}", "normalizers.iter().map(|n| n.to_string()).collect::()")] +#[display(fmt = "[{}]", "normalizers.iter().map(|d| d.to_string()).collect::>().join(\", \")")] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. From a4296421c96797e06d3dd765e3f1119ab9670462 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 16:07:43 +0200 Subject: [PATCH 17/94] update updates --- bindings/python/src/pre_tokenizers.rs | 7 ++++--- tokenizers/display_derive/src/lib.rs | 2 +- tokenizers/src/normalizers/mod.rs | 4 ++-- tokenizers/src/pre_tokenizers/bert.rs | 3 ++- tokenizers/src/pre_tokenizers/delimiter.rs | 3 ++- tokenizers/src/pre_tokenizers/digits.rs | 3 ++- tokenizers/src/pre_tokenizers/metaspace.rs | 6 +++--- tokenizers/src/pre_tokenizers/mod.rs | 4 +++- tokenizers/src/pre_tokenizers/punctuation.rs | 3 ++- tokenizers/src/pre_tokenizers/sequence.rs | 4 +++- tokenizers/src/pre_tokenizers/split.rs | 8 +++++--- .../src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs | 4 +++- tokenizers/src/pre_tokenizers/whitespace.rs | 5 +++-- tokenizers/src/tokenizer/normalizer.rs | 5 +++-- 14 files changed, 38 insertions(+), 23 deletions(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index e9e749589..cb4f92ff1 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -596,7 +596,7 @@ impl PyUnicodeScripts { } } -#[derive(Clone)] +#[derive(Clone, Display)] pub(crate) struct CustomPreTokenizer { inner: PyObject, } @@ -640,7 +640,8 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { } } -#[derive(Clone, Deserialize)] +#[derive(Clone, Deserialize, Display)] +#[display(fmt="{}")] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Custom(CustomPreTokenizer), @@ -665,7 +666,7 @@ impl Serialize for PyPreTokenizerWrapper { pub(crate) enum PyPreTokenizerTypeWrapper { #[display(fmt = "[{}]", "_0.iter().map(|d| d.as_ref().read().unwrap().to_string()).collect::>().join(\", \")")] Sequence(Vec>>), - #[display(fmt = "_0.as_ref().read().unwrap()")] + #[display(fmt ="{}", "_0.as_ref().read().unwrap()")] Single(Arc>), } diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 04f62d34d..e3d867e27 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -52,7 +52,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { quote! { impl std::fmt::Display for #name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", stringify!(#name)) + write!(f, "{}()", stringify!(#name)) } } } diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 273eb154d..e4675020a 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -21,7 +21,7 @@ use crate::{NormalizedString, Normalizer}; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, Display)] #[serde(untagged)] -#[display(fmt="{}")] +#[display(fmt="normalizers.{}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), @@ -33,7 +33,7 @@ pub enum NormalizerWrapper { Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), - #[display(fmt="Precompiled")] + #[display(fmt="Precompiled()")] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index 93fdd05c1..eeaa0c315 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -1,12 +1,13 @@ use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; +use display_derive::StructDisplay; use unicode_categories::UnicodeCategories; fn is_bert_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct BertPreTokenizer; diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index 64ef63ccc..e58628f71 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,9 +1,10 @@ +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq,StructDisplay)] #[non_exhaustive] #[macro_rules_attribute(impl_serde_type!)] pub struct CharDelimiterSplit { diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index 942e2521b..5fb2a5b41 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,9 +1,10 @@ +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, StructDisplay)] /// Pre tokenizes the numbers into single tokens. If individual_digits is set /// to true, then all digits are splitted into individual tokens. #[non_exhaustive] diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 40c59f54c..6caf3e373 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -2,7 +2,7 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitD use derive_more::Display; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. -#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)] +#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, Display)] #[serde(rename_all = "snake_case")] pub enum PrependScheme { /// Specifies that the scheme should be prepended only once, on the first split. @@ -18,9 +18,9 @@ pub enum PrependScheme { /// splits on this character #[serde(tag = "type")] #[display( - fmt = "Metaspace(replacement={}, prepend_scheme={:?}, split={})", + fmt = "Metaspace(replacement='{}', prepend_scheme={:?}, split={})", replacement, - prepend_scheme, + "prepend_scheme.to_string().to_lowercase()", split )] pub struct Metaspace { diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index cf64fb876..6bb82024b 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -22,8 +22,10 @@ use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; +use derive_more::Display; -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] +#[display(fmt="{}")] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index 0ba7d6025..4421f246b 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,3 +1,4 @@ +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; @@ -8,7 +9,7 @@ fn is_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Punctuation { #[serde(default = "default_split")] diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 9dcafc673..0c2432dba 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -2,9 +2,11 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; use serde::{Deserialize, Serialize}; +use derive_more::Display; -#[derive(Clone, Debug, PartialEq)] #[macro_rules_attribute(impl_serde_type!)] +#[derive(Clone, Debug, PartialEq, Display)] +#[display(fmt="[{}]", "pretokenizers.iter().map(|d| d.to_string()).collect::>().join(\", \")")] pub struct Sequence { pretokenizers: Vec, } diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 0e2a9023b..15bedb2e0 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -1,12 +1,13 @@ use crate::utils::SysRegex; use serde::{Deserialize, Deserializer, Serialize}; - +use derive_more::Display; use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; /// Represents the different patterns that `Split` can use -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq,Display)] +#[display(fmt="{}")] pub enum SplitPattern { String(String), Regex(String), @@ -24,8 +25,9 @@ impl From<&str> for SplitPattern { } } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize,Display)] #[serde(tag = "type")] +#[display(fmt="Split(patter={}, regex={:?}, behavior={}, invert={})", "pattern", regex, behavior, invert)] pub struct Split { pattern: SplitPattern, #[serde(skip)] diff --git a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs index 2b6b54eb6..7df5e0367 100644 --- a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs +++ b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs @@ -1,8 +1,10 @@ +use display_derive::StructDisplay; + use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script}; use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct UnicodeScripts; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 8c24e8efb..0bce6d178 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,3 +1,4 @@ +use display_derive::StructDisplay; use regex::Regex; use crate::tokenizer::{ @@ -5,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Whitespace; @@ -28,7 +29,7 @@ impl PreTokenizer for Whitespace { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct WhitespaceSplit; diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 8d5b66455..6f9c89387 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -2,7 +2,7 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; - +use derive_more::Display; use serde::{Deserialize, Serialize}; /// The possible offsets referential @@ -78,7 +78,8 @@ where /// - MergedWithPrevious => `[ "the-", "final-", "-", "countdown" ]` /// - MergedWithNext => `[ "the", "-final", "-", "-countdown" ]` /// - Contiguous => `[ "the", "-", "final", "--", "countdown" ]` -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Display)] +#[display(fmt="{}")] pub enum SplitDelimiterBehavior { Removed, Isolated, From 3cec01009539eefd1aca11338e5a75a90c3acf36 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 16:22:10 +0200 Subject: [PATCH 18/94] more nits --- tokenizers/src/models/unigram/model.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index f3a32a8a0..3d3a658a9 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -17,9 +17,8 @@ type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. #[derive(Display)] #[display( - fmt = "{:?} {:?} {:?} {} {}", - token_to_ids, - vocab, + fmt = "Unigram(vocab={:?}, unk_id={:?}, bos_id={}, eos_id={})", + "vocab.iter().collect::>().truncate(5)", unk_id, bos_id, eos_id From 8d77286adc6c697142648441ea5ec047357af7a7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 16:31:07 +0200 Subject: [PATCH 19/94] nit --- tokenizers/src/tokenizer/added_vocabulary.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index b49bef31f..4ce0adab4 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -140,7 +140,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug, Display)] -#[display(fmt="{:?}, encode_special_tokens={}", added_tokens_map, encode_special_tokens)] +#[display(fmt="AddedVocabulary(added_tokens_map_r={:#?}, encode_special_tokens={})", "added_tokens_map_r", encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. From e48cd3aadedeb94bc54ded2c48487f876e781fe1 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:33:31 +0200 Subject: [PATCH 20/94] Update bindings/python/src/pre_tokenizers.rs --- bindings/python/src/pre_tokenizers.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index cb4f92ff1..113a90d9c 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -664,7 +664,15 @@ impl Serialize for PyPreTokenizerWrapper { #[serde(untagged)] // #[display(fmt = "")] pub(crate) enum PyPreTokenizerTypeWrapper { - #[display(fmt = "[{}]", "_0.iter().map(|d| d.as_ref().read().unwrap().to_string()).collect::>().join(\", \")")] + #[display(fmt = "[{}]", "_0_0.iter() + .map(|d| d.as_ref().read().unwrap().to_string()) + .fold(String::new(), |mut acc, s| { + if !acc.is_empty() { + acc.push_str(", "); + } + acc.push_str(&s); + acc + })")] Sequence(Vec>>), #[display(fmt ="{}", "_0.as_ref().read().unwrap()")] Single(Arc>), From 27576e5836e7e7fad8f0aa9d1fdf313937a1769d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 4 Jun 2024 16:45:41 +0200 Subject: [PATCH 21/94] ips --- tokenizers/src/models/unigram/model.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 3d3a658a9..34c2832a1 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -17,7 +17,7 @@ type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. #[derive(Display)] #[display( - fmt = "Unigram(vocab={:?}, unk_id={:?}, bos_id={}, eos_id={})", + fmt = "Unigram(vocab={:#?}, unk_id={:?}, bos_id={}, eos_id={})", "vocab.iter().collect::>().truncate(5)", unk_id, bos_id, From 1c6d272943e22208e7bced9472fe7e29b00e6a82 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 8 Jun 2024 17:01:36 +0200 Subject: [PATCH 22/94] update --- bindings/python/src/pre_tokenizers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 113a90d9c..eb5b484f2 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -664,7 +664,7 @@ impl Serialize for PyPreTokenizerWrapper { #[serde(untagged)] // #[display(fmt = "")] pub(crate) enum PyPreTokenizerTypeWrapper { - #[display(fmt = "[{}]", "_0_0.iter() + #[display(fmt = "[{}]", "_0.iter() .map(|d| d.as_ref().read().unwrap().to_string()) .fold(String::new(), |mut acc, s| { if !acc.is_empty() { From df51116efc8dde8c8ab39dd5e4a715e42fefa333 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 8 Jun 2024 18:06:33 +0200 Subject: [PATCH 23/94] update and fix --- bindings/python/src/pre_tokenizers.rs | 3 +-- bindings/python/src/processors.rs | 5 +++++ tokenizers/src/processors/template.rs | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index eb5b484f2..4334fc029 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -662,13 +662,12 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize, Display)] #[serde(untagged)] -// #[display(fmt = "")] pub(crate) enum PyPreTokenizerTypeWrapper { #[display(fmt = "[{}]", "_0.iter() .map(|d| d.as_ref().read().unwrap().to_string()) .fold(String::new(), |mut acc, s| { if !acc.is_empty() { - acc.push_str(", "); + acc.push_str(\", \"); } acc.push_str(&s); acc diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 7fdc68ef1..25ff10d08 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -139,6 +139,11 @@ impl PyPostProcessor { .into_py()?; Ok(final_encoding.into()) } + + fn __str__(&self) -> PyResult{ + Ok(format!("{}", &self)) + + } } /// This post-processor takes care of adding the special tokens needed by diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 2420d787e..b28d0b693 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -334,7 +334,7 @@ impl From> for Tokens { /// #[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, Display)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] -#[display(fmt = "TemplateProcessing({:?})", self)] +#[display(fmt = "TemplateProcessing(single={:?}, pair={:?})", single, pair)] #[builder(build_fn(validate = "Self::validate"))] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] From 59a89c980522d8326b1a5b30bebcaccfd29dc7b1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 8 Jun 2024 18:13:11 +0200 Subject: [PATCH 24/94] only commit one line --- tokenizers/display_derive/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index e3d867e27..1dba3bb2d 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -35,7 +35,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { { let s = format!("{}", field_value); let mut chars = s.chars(); - let mut prefix = (&mut chars).take(100000 - 1).collect::(); + let mut prefix = (&mut chars).take(100 - 1).collect::(); if chars.next().is_some() { prefix.push('…'); } From ac9b84935b6453415efd84fdad8b2c1628553d26 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 8 Jun 2024 19:47:12 +0200 Subject: [PATCH 25/94] update --- tokenizers/display_derive/src/lib.rs | 7 +++++-- tokenizers/src/normalizers/replace.rs | 2 +- tokenizers/src/tokenizer/added_vocabulary.rs | 12 ++++++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 1dba3bb2d..4a3d09628 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -19,6 +19,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { // If the struct has named fields let field_names = fields.named.iter().map(|f| &f.ident); let field_names2 = field_names.clone(); + let field_types = fields.named.iter().map(|f| &f.ty); quote! { impl std::fmt::Display for #name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -32,7 +33,9 @@ pub fn display_derive(input: TokenStream) -> TokenStream { let field_value = &self.#field_names2; write!(f, "{}=", stringify!(#field_names))?; - { + if std::any::TypeId::of::<#field_types>() == std::any::TypeId::of::(){ + write!(f, "\"{}\"", field_value)?; + } else { let s = format!("{}", field_value); let mut chars = s.chars(); let mut prefix = (&mut chars).take(100 - 1).collect::(); @@ -65,4 +68,4 @@ pub fn display_derive(input: TokenStream) -> TokenStream { // Convert into a token stream and return it TokenStream::from(expanded) -} \ No newline at end of file +} diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 2cd39c06f..d79c141b9 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -45,7 +45,7 @@ impl std::convert::TryFrom for Replace { #[derive(Debug, Serialize, Deserialize, Display)] #[serde(tag = "type", try_from = "ReplaceDeserializer")] #[display( - fmt = "Replace(pattern={:?}, content={}, regex={:?}", + fmt = "Replace(pattern={:?}, content=\"{}\", regex={:?}", pattern, content, regex diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 4ce0adab4..ab511490a 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -3,6 +3,7 @@ use super::{ }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; use derive_more::Display; +use display_derive::StructDisplay; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -12,7 +13,7 @@ use std::collections::{HashMap, HashSet}; /// like: /// - Whether they should only match single words /// - Whether to include any whitespace on its left or right -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, StructDisplay)] pub struct AddedToken { /// The content of the added token pub content: String, @@ -140,7 +141,14 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug, Display)] -#[display(fmt="AddedVocabulary(added_tokens_map_r={:#?}, encode_special_tokens={})", "added_tokens_map_r", encode_special_tokens)] +#[display(fmt="AddedVocabulary(added_tokens_map_r={{}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| { + if let Some(token) = added_tokens_map_r.get(&key) { + if !acc.is_empty() { + acc.push_str(\", \"); + } + acc.push_str(&format!("{}: {:?}", key, token)); + } + acc})", encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. From a3f743938f88f321a4502cd58c3ed9b365ffc72b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 8 Jun 2024 20:38:29 +0200 Subject: [PATCH 26/94] update the added vocab string --- tokenizers/src/tokenizer/added_vocabulary.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index ab511490a..49aae1621 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -141,14 +141,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug, Display)] -#[display(fmt="AddedVocabulary(added_tokens_map_r={{}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| { - if let Some(token) = added_tokens_map_r.get(&key) { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&format!("{}: {:?}", key, token)); - } - acc})", encode_special_tokens)] +#[display(fmt="AddedVocabulary(added_tokens_map_r={{{}}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| {if let Some(token) = added_tokens_map_r.get(&key){if !acc.is_empty(){acc.push_str(\", \");}acc.push_str(&format!(\"\n\t{}: {}\", key, &token.to_string()));}acc})", encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. From cf5b6f3b420b266b02fd19f478fd49e7bd7388b0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 08:08:14 +0200 Subject: [PATCH 27/94] nit --- tokenizers/src/normalizers/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index e4675020a..1322955d2 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -21,7 +21,7 @@ use crate::{NormalizedString, Normalizer}; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, Display)] #[serde(untagged)] -#[display(fmt="normalizers.{}")] +#[display(fmt="normalizers.{_0}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), From 5d332438d08cd9e9111c3ba335a78b85b1070974 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 08:14:23 +0200 Subject: [PATCH 28/94] fix sequence's display --- tokenizers/src/normalizers/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index e9b126569..a322eaa33 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -6,7 +6,7 @@ use crate::utils::macro_rules_attribute; use derive_more::Display; use display_derive::StructDisplay; #[derive(Clone, Deserialize, Debug, Serialize, Display)] -#[display(fmt = "[{}]", "normalizers.iter().map(|d| d.to_string()).collect::>().join(\", \")")] +#[display(fmt = "Sequence([{}])", "normalizers.iter().map(|d| d.to_string()).collect::>().join(\", \")")] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. From b73c43d472ef0d38488e97bd2a9b3ae7dd12ec03 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 09:59:01 +0200 Subject: [PATCH 29/94] update display for normalizer sequence --- tokenizers/src/normalizers/mod.rs | 2 +- tokenizers/src/normalizers/utils.rs | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 1322955d2..e4675020a 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -21,7 +21,7 @@ use crate::{NormalizedString, Normalizer}; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, Display)] #[serde(untagged)] -#[display(fmt="normalizers.{_0}")] +#[display(fmt="normalizers.{}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index a322eaa33..19152539f 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -6,7 +6,13 @@ use crate::utils::macro_rules_attribute; use derive_more::Display; use display_derive::StructDisplay; #[derive(Clone, Deserialize, Debug, Serialize, Display)] -#[display(fmt = "Sequence([{}])", "normalizers.iter().map(|d| d.to_string()).collect::>().join(\", \")")] +#[display(fmt = "Sequence([{}])", "normalizers.iter().fold(String::new(), |mut acc, d| { + if !acc.is_empty() { + acc.push_str(\", \"); + } + acc.push_str(&d.to_string()); + acc +})")] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. From b214d77e4255566b0bac7f5cf95263c491bcf628 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 10:03:49 +0200 Subject: [PATCH 30/94] style --- bindings/python/src/normalizers.rs | 20 +++++++++----------- bindings/python/src/pre_tokenizers.rs | 15 +++++++++------ bindings/python/src/processors.rs | 9 ++++----- bindings/python/src/tokenizer.rs | 4 ++-- tokenizers/src/decoders/sequence.rs | 5 ++++- tokenizers/src/normalizers/bert.rs | 10 ++++++++-- tokenizers/src/normalizers/mod.rs | 8 ++++---- tokenizers/src/normalizers/strip.rs | 4 ++-- tokenizers/src/normalizers/utils.rs | 7 +++++-- tokenizers/src/pre_tokenizers/delimiter.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 2 +- tokenizers/src/pre_tokenizers/sequence.rs | 7 +++++-- tokenizers/src/pre_tokenizers/split.rs | 20 +++++++++++++------- tokenizers/src/processors/sequence.rs | 5 ++++- tokenizers/src/tokenizer/mod.rs | 4 ++-- tokenizers/src/tokenizer/normalizer.rs | 6 +++--- tokenizers/src/utils/padding.rs | 5 ++--- tokenizers/src/utils/truncation.rs | 8 +++++++- 18 files changed, 85 insertions(+), 56 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 762e60f3b..6f8b4920b 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -1,11 +1,11 @@ use std::sync::{Arc, RwLock}; +use crate::error::ToPyResult; +use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; +use derive_more::Display; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use derive_more::Display; -use crate::error::ToPyResult; -use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ @@ -90,7 +90,6 @@ impl PyNormalizer { }, }) } - } impl Normalizer for PyNormalizer { @@ -169,10 +168,10 @@ impl PyNormalizer { Ok(normalized.get().to_owned()) } - fn __str__(&self) -> PyResult{ + fn __str__(&self) -> PyResult { Ok(format!("{}", self.normalizer)) } - fn __repr__(&self) -> PyResult{ + fn __repr__(&self) -> PyResult { Ok(format!("{}", self.normalizer)) } } @@ -553,9 +552,9 @@ impl<'de> Deserialize<'de> for CustomNormalizer { #[derive(Debug, Clone, Deserialize, Display)] #[serde(untagged)] pub(crate) enum PyNormalizerWrapper { - #[display(fmt="{}", "_0.inner")] + #[display(fmt = "{}", "_0.inner")] Custom(CustomNormalizer), - #[display(fmt="{}", "_0")] + #[display(fmt = "{}", "_0")] Wrapped(NormalizerWrapper), } @@ -586,10 +585,9 @@ impl std::fmt::Display for PyNormalizerTypeWrapper { for decoder in decoders { let decoder = decoder.read().unwrap(); writeln!(f, "{}", decoder)?; - } - writeln!(f, "?????")?; - Ok(()) + writeln!(f, "?????")?; + Ok(()) } PyNormalizerTypeWrapper::Single(ref decoder) => { let decoder = decoder.read().unwrap(); diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 4334fc029..229a4c60d 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -183,11 +183,11 @@ impl PyPreTokenizer { .collect()) } - fn __str__(&self) -> PyResult{ + fn __str__(&self) -> PyResult { Ok(format!("{}", self.pretok)) } - fn __repr__(&self) -> PyResult{ + fn __repr__(&self) -> PyResult { Ok(format!("{}", self.pretok)) } } @@ -641,7 +641,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { } #[derive(Clone, Deserialize, Display)] -#[display(fmt="{}")] +#[display(fmt = "{}")] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Custom(CustomPreTokenizer), @@ -663,7 +663,9 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize, Display)] #[serde(untagged)] pub(crate) enum PyPreTokenizerTypeWrapper { - #[display(fmt = "[{}]", "_0.iter() + #[display( + fmt = "[{}]", + "_0.iter() .map(|d| d.as_ref().read().unwrap().to_string()) .fold(String::new(), |mut acc, s| { if !acc.is_empty() { @@ -671,9 +673,10 @@ pub(crate) enum PyPreTokenizerTypeWrapper { } acc.push_str(&s); acc - })")] + })" + )] Sequence(Vec>>), - #[display(fmt ="{}", "_0.as_ref().read().unwrap()")] + #[display(fmt = "{}", "_0.as_ref().read().unwrap()")] Single(Arc>), } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 25ff10d08..7e695879b 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -1,12 +1,12 @@ use std::convert::TryInto; use std::sync::Arc; +use crate::encoding::PyEncoding; +use crate::error::ToPyResult; +use derive_more::Display; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use derive_more::Display; -use crate::encoding::PyEncoding; -use crate::error::ToPyResult; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; use tk::processors::byte_level::ByteLevel; @@ -140,9 +140,8 @@ impl PyPostProcessor { Ok(final_encoding.into()) } - fn __str__(&self) -> PyResult{ + fn __str__(&self) -> PyResult { Ok(format!("{}", &self)) - } } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 637d34f14..a6cbcf40b 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1411,11 +1411,11 @@ impl PyTokenizer { self.tokenizer.with_decoder(decoder.clone()); } - fn __str__(&self) -> PyResult{ + fn __str__(&self) -> PyResult { Ok(format!("{}", self.tokenizer)) } - fn __repr__(&self) -> PyResult{ + fn __repr__(&self) -> PyResult { Ok(format!("{}", self.tokenizer)) } } diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index efa330f51..bfe21a110 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -6,7 +6,10 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, Display)] -#[display(fmt = "[{}]", "decoders.iter().map(|d| d.to_string()).collect::>().join(\", \")")] +#[display( + fmt = "[{}]", + "decoders.iter().map(|d| d.to_string()).collect::>().join(\", \")" +)] pub struct Sequence { decoders: Vec, } diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 41f60dcf3..e58d028d6 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; +use derive_more::Display; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; -use derive_more::Display; /// Checks whether a character is whitespace fn is_whitespace(c: char) -> bool { // These are technically control characters but we count them as whitespace @@ -48,7 +48,13 @@ fn is_chinese_char(c: char) -> bool { } #[derive(Copy, Clone, Debug, Deserialize, Serialize, Display)] -#[display(fmt="BertNormalizer(clean_text={}, handle_chinese_chars={}, strip_accents={:?}, lower_case={})",clean_text, handle_chinese_chars, strip_accents, lowercase)] +#[display( + fmt = "BertNormalizer(clean_text={}, handle_chinese_chars={}, strip_accents={:?}, lower_case={})", + clean_text, + handle_chinese_chars, + strip_accents, + lowercase +)] #[serde(tag = "type")] #[non_exhaustive] pub struct BertNormalizer { diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index e4675020a..e0f974346 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -14,14 +14,14 @@ pub use crate::normalizers::strip::{Strip, StripAccents}; pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; -use serde::{Deserialize, Serialize}; -use derive_more::Display; use crate::{NormalizedString, Normalizer}; +use derive_more::Display; +use serde::{Deserialize, Serialize}; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, Display)] #[serde(untagged)] -#[display(fmt="normalizers.{}")] +#[display(fmt = "normalizers.{}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), @@ -33,7 +33,7 @@ pub enum NormalizerWrapper { Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), - #[display(fmt="Precompiled()")] + #[display(fmt = "Precompiled()")] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 265cc0cae..b188df19a 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,9 +1,9 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; -use display_derive::StructDisplay; -#[derive(Copy, Clone, Debug, Deserialize, Serialize,StructDisplay)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize, StructDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct Strip { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index 19152539f..ddd18b9b4 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -6,13 +6,16 @@ use crate::utils::macro_rules_attribute; use derive_more::Display; use display_derive::StructDisplay; #[derive(Clone, Deserialize, Debug, Serialize, Display)] -#[display(fmt = "Sequence([{}])", "normalizers.iter().fold(String::new(), |mut acc, d| { +#[display( + fmt = "Sequence([{}])", + "normalizers.iter().fold(String::new(), |mut acc, d| { if !acc.is_empty() { acc.push_str(\", \"); } acc.push_str(&d.to_string()); acc -})")] +})" +)] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index e58628f71..fb445203e 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Copy, Clone, Debug, PartialEq, Eq,StructDisplay)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] #[non_exhaustive] #[macro_rules_attribute(impl_serde_type!)] pub struct CharDelimiterSplit { diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 6bb82024b..40c113241 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -25,7 +25,7 @@ use crate::{PreTokenizedString, PreTokenizer}; use derive_more::Display; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] -#[display(fmt="{}")] +#[display(fmt = "{}")] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 0c2432dba..acd7ba01c 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,12 +1,15 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -use serde::{Deserialize, Serialize}; use derive_more::Display; +use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Display)] -#[display(fmt="[{}]", "pretokenizers.iter().map(|d| d.to_string()).collect::>().join(\", \")")] +#[display( + fmt = "[{}]", + "pretokenizers.iter().map(|d| d.to_string()).collect::>().join(\", \")" +)] pub struct Sequence { pretokenizers: Vec, } diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 15bedb2e0..a442a5884 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -1,13 +1,13 @@ -use crate::utils::SysRegex; -use serde::{Deserialize, Deserializer, Serialize}; -use derive_more::Display; use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; +use crate::utils::SysRegex; +use derive_more::Display; +use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq,Display)] -#[display(fmt="{}")] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, Display)] +#[display(fmt = "{}")] pub enum SplitPattern { String(String), Regex(String), @@ -25,9 +25,15 @@ impl From<&str> for SplitPattern { } } -#[derive(Debug, Serialize,Display)] +#[derive(Debug, Serialize, Display)] #[serde(tag = "type")] -#[display(fmt="Split(patter={}, regex={:?}, behavior={}, invert={})", "pattern", regex, behavior, invert)] +#[display( + fmt = "Split(patter={}, regex={:?}, behavior={}, invert={})", + "pattern", + regex, + behavior, + invert +)] pub struct Split { pattern: SplitPattern, #[serde(skip)] diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 76c8d0d1e..67e0f5c66 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -5,7 +5,10 @@ use derive_more::Display; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, Display)] -#[display(fmt = "[{}]", "processors.iter().map(|d| d.to_string()).collect::>().join(\", \")")] +#[display( + fmt = "[{}]", + "processors.iter().map(|d| d.to_string()).collect::>().join(\", \")" +)] pub struct Sequence { processors: Vec, } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index da402cafb..1c561b3bd 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -563,10 +563,10 @@ where "Tokenizer(normalizer={},\npre_tokenizer={},\nmodel={},\npost_processor={},\ndecoder={},\nadded_vocab={},\ntruncation={},\npadding={}\n)", normalizer_str, pre_tokenizer_str, - self.model.to_string(), + self.model, post_processor_str, decoder_str, - self.added_vocabulary.to_string(), + self.added_vocabulary, truncation_str, padding_str ) diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 6f9c89387..46fa7b2ee 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,9 +1,9 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; -use std::ops::{Bound, RangeBounds}; -use unicode_normalization_alignments::UnicodeNormalization; use derive_more::Display; use serde::{Deserialize, Serialize}; +use std::ops::{Bound, RangeBounds}; +use unicode_normalization_alignments::UnicodeNormalization; /// The possible offsets referential #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -79,7 +79,7 @@ where /// - MergedWithNext => `[ "the", "-final", "-", "-countdown" ]` /// - Contiguous => `[ "the", "-", "final", "--", "countdown" ]` #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Display)] -#[display(fmt="{}")] +#[display(fmt = "{}")] pub enum SplitDelimiterBehavior { Removed, Isolated, diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index eded5ad0a..d10ea7969 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,7 +1,7 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use serde::{Deserialize, Serialize}; use derive_more::Display; +use serde::{Deserialize, Serialize}; /// The various possible padding directions. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] @@ -19,10 +19,9 @@ impl std::convert::AsRef for PaddingDirection { } } - #[derive(Debug, Clone, Serialize, Deserialize, Display)] // #[display(fmt="Strategy: {:?}, Direction: {:?}, Pad to multiple of: {:?}, Pad ID: {}, Pad Type ID: {}, Pad Token: {}", strategy, direction, pad_to_multiple_of, pad_id, pad_type_id, pad_token)] -#[display(fmt="Strategy:")] +#[display(fmt = "Strategy:")] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 870fbc503..6e8328cd1 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -21,7 +21,13 @@ impl std::convert::AsRef for TruncationDirection { } #[derive(Debug, Clone, Serialize, Deserialize, Display)] -#[display(fmt="direction={:?}, max_length={}, strategy={:?}, stride={}", direction, max_length, strategy, stride)] +#[display( + fmt = "direction={:?}, max_length={}, strategy={:?}, stride={}", + direction, + max_length, + strategy, + stride +)] pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, From 06548314d0abc2e14ee66bc57b6a106c2747497f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 10:07:17 +0200 Subject: [PATCH 31/94] small nit --- bindings/python/src/normalizers.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 6f8b4920b..41d9e708d 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -586,7 +586,6 @@ impl std::fmt::Display for PyNormalizerTypeWrapper { let decoder = decoder.read().unwrap(); writeln!(f, "{}", decoder)?; } - writeln!(f, "?????")?; Ok(()) } PyNormalizerTypeWrapper::Single(ref decoder) => { From a15e3cce3d622200048471d7421ea2e71e821eb9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 10:24:50 +0200 Subject: [PATCH 32/94] updates to cleanup --- tokenizers/Cargo.toml | 1 - tokenizers/display_derive/Cargo.toml | 1 - tokenizers/src/decoders/ctc.rs | 10 ++-------- tokenizers/src/decoders/sequence.rs | 10 ++++++++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 6b6a7f375..6e6dd2f22 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -64,7 +64,6 @@ getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" derive_more = "0.99.17" -ellipse = "0.2.0" display_derive = { path = "display_derive" } [features] diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml index 299b6ae5f..7d7697910 100644 --- a/tokenizers/display_derive/Cargo.toml +++ b/tokenizers/display_derive/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" syn = "1.0" quote = "1.0" proc-macro2 = "1.0" -ellipse = "0.2.0" [lib] proc-macro = true diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index 7ef687404..d31b68e66 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,21 +1,15 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; -use derive_more::Display; +use display_derive::StructDisplay; use itertools::Itertools; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize, Display)] +#[derive(Debug, Clone, Serialize, Deserialize, StructDisplay)] /// The CTC (Connectionist Temporal Classification) decoder takes care /// of sanitizing a list of inputs token. /// Due to some alignement problem the output of some models can come /// with duplicated token. #[serde(tag = "type")] -#[display( - fmt = "CTC(pad_token={}, word_delimiter_token={}, cleanup={}", - pad_token, - word_delimiter_token, - cleanup -)] #[non_exhaustive] pub struct CTC { /// The pad token used by CTC to delimit a new token. diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index bfe21a110..e9e18698c 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -7,8 +7,14 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, Display)] #[display( - fmt = "[{}]", - "decoders.iter().map(|d| d.to_string()).collect::>().join(\", \")" + fmt = "decoders.Sequence([{}])", + "decoders.iter().map(|d| d.to_string()).fold( String::new(), |mut acc, s|{ + if !acc.is_empty(){ + acc.push_str(\", \"); + } + acc.push_str(&s); + acc + })" )] pub struct Sequence { decoders: Vec, From 6023192a74032c2023d7ae6e098f5a1ea82e7e7f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 10:37:28 +0200 Subject: [PATCH 33/94] update --- tokenizers/src/decoders/strip.rs | 5 ++--- tokenizers/src/decoders/wordpiece.rs | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index 344d61489..decdf6876 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,14 +1,13 @@ use crate::tokenizer::{Decoder, Result}; -use derive_more::Display; +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] +#[derive(Deserialize, Clone, Debug, Serialize, Default, StructDisplay)] /// Strip is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token #[serde(tag = "type")] #[non_exhaustive] -#[display(fmt = "Strip(content={}, start={}, stop={})", content, start, stop)] pub struct Strip { pub content: char, pub start: usize, diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 28494ec0f..a7b85df67 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,12 +1,11 @@ use crate::tokenizer::{Decoder, Result}; -use derive_more::Display; +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Display)] +#[derive(Deserialize, Clone, Debug, Serialize, StructDisplay)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens /// back into a readable string. #[serde(tag = "type")] -#[display(fmt = "WordPiece(prefix={}, cleanup={:?}", prefix, cleanup)] #[non_exhaustive] pub struct WordPiece { /// The prefix to be used for continuing subwords From ebf12582924ddb519490828d6a6d76566594ccf6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 10:44:58 +0200 Subject: [PATCH 34/94] update --- tokenizers/src/models/unigram/model.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 34c2832a1..929205684 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -17,8 +17,14 @@ type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. #[derive(Display)] #[display( - fmt = "Unigram(vocab={:#?}, unk_id={:?}, bos_id={}, eos_id={})", - "vocab.iter().collect::>().truncate(5)", + fmt = "Unigram(vocab={}, unk_id={:?}, bos_id={}, eos_id={})", + "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { + if !acc.is_empty() { + acc.push_str(\", \"); + } + acc.push_str(&format!(\"{}: {}\", value, key)); + acc +})", unk_id, bos_id, eos_id From 477a9b5c818c6e75331c4e17dd8b23e6454a0bdd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 10:53:38 +0200 Subject: [PATCH 35/94] nits --- tokenizers/src/models/unigram/model.rs | 6 +++--- tokenizers/src/pre_tokenizers/byte_level.rs | 10 ++-------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 929205684..a409e70c2 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -17,15 +17,15 @@ type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. #[derive(Display)] #[display( - fmt = "Unigram(vocab={}, unk_id={:?}, bos_id={}, eos_id={})", + fmt = "Unigram(vocab={{{}, ...}}, unk_id={}, bos_id={}, eos_id={})", "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { if !acc.is_empty() { acc.push_str(\", \"); } - acc.push_str(&format!(\"{}: {}\", value, key)); + acc.push_str(&format!(\"{}: \'{}\'\", value, key)); acc })", - unk_id, + "unk_id.unwrap()", bos_id, eos_id )] diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 4d2b131d6..852a7a3a1 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; use crate::utils::SysRegex; -use derive_more::Display; +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. @@ -50,13 +50,7 @@ lazy_static! { /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. #[macro_rules_attribute(impl_serde_type!)] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] -#[display( - fmt = "ByteLevel(add_prefix_space={},trim_offset={:?}, use_regex={}", - add_prefix_space, - trim_offsets, - use_regex -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word From 93a1e631b5af2276b081edc3b962fb413e3e8ad4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:00:33 +0200 Subject: [PATCH 36/94] fix some stuff --- tokenizers/src/models/unigram/model.rs | 2 +- tokenizers/src/models/wordlevel/mod.rs | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index a409e70c2..95edc2b35 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -22,7 +22,7 @@ type Vocab = Vec<(String, f64)>; if !acc.is_empty() { acc.push_str(\", \"); } - acc.push_str(&format!(\"{}: \'{}\'\", value, key)); + acc.push_str(&format!(\"\'{}\': {}\", key, value)); acc })", "unk_id.unwrap()", diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 4c5bdf90d..2ae32f336 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -95,7 +95,13 @@ impl WordLevelBuilder { } #[derive(PartialEq, Clone, Eq, Display)] -#[display(fmt = "vocab={:?}, unk_token={}", vocab, unk_token)] +#[display(fmt = "vocab={{{}, ...}}, unk_token={}", "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { + if !acc.is_empty() { + acc.push_str(\", \"); + } + acc.push_str(&format!(\"\'{}\': {}\", key, value)); + acc +})", unk_token)] pub struct WordLevel { vocab: HashMap, vocab_r: HashMap, From 7591f2bc91c55cf0818f64ad4f42d0f1e7cbe459 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:07:27 +0200 Subject: [PATCH 37/94] update sequence for pre_tokenizers using fold --- tokenizers/src/pre_tokenizers/sequence.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index acd7ba01c..5a83bebff 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -7,8 +7,14 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Display)] #[display( - fmt = "[{}]", - "pretokenizers.iter().map(|d| d.to_string()).collect::>().join(\", \")" + fmt = "pre_tokenizers.Seqence([{}])", + "pretokenizers.iter().fold(String::new(), |mut acc, p| { + if !acc.is_empty(){ + acc.push_str(\", \") + } + acc.push_str(&p.to_string()); + acc + })" )] pub struct Sequence { pretokenizers: Vec, From f50e4e03776d95b9bbc1266412563b6517ebd1b9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:13:46 +0200 Subject: [PATCH 38/94] update --- tokenizers/src/processors/mod.rs | 1 - tokenizers/src/processors/sequence.rs | 10 ++++++++-- tokenizers/src/tokenizer/added_vocabulary.rs | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index ed8a86aaa..320f5e7d2 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -17,7 +17,6 @@ use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, Display)] #[serde(untagged)] -#[display(fmt = "{}")] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) Roberta(RobertaProcessing), diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 67e0f5c66..b0a8938a4 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -6,8 +6,14 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, Display)] #[display( - fmt = "[{}]", - "processors.iter().map(|d| d.to_string()).collect::>().join(\", \")" + fmt = "processors.Sequence([{}])", + "processors.iter().fold(String::new(), |mut acc, p| { + if !acc.is_empty() { + acc.push_str(\", \"); + } + acc.push_str(&p.to_string()); + acc +})" )] pub struct Sequence { processors: Vec, diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 49aae1621..642f5670a 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -141,7 +141,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug, Display)] -#[display(fmt="AddedVocabulary(added_tokens_map_r={{{}}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| {if let Some(token) = added_tokens_map_r.get(&key){if !acc.is_empty(){acc.push_str(\", \");}acc.push_str(&format!(\"\n\t{}: {}\", key, &token.to_string()));}acc})", encode_special_tokens)] +#[display(fmt="AddedVocabulary(added_tokens_map_r={{{}, ...}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| {if let Some(token) = added_tokens_map_r.get(&key){if !acc.is_empty(){acc.push_str(\", \");}acc.push_str(&format!(\"\n\t{}: {}\", key, &token.to_string()));}acc})", encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. From 4f15052c04e77351653e2e7337c3a7c451d42f65 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:25:16 +0200 Subject: [PATCH 39/94] proper padding derive --- tokenizers/src/utils/padding.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index d10ea7969..03faf2943 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -4,7 +4,7 @@ use derive_more::Display; use serde::{Deserialize, Serialize}; /// The various possible padding directions. -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Display)] pub enum PaddingDirection { Left, Right, @@ -20,8 +20,7 @@ impl std::convert::AsRef for PaddingDirection { } #[derive(Debug, Clone, Serialize, Deserialize, Display)] -// #[display(fmt="Strategy: {:?}, Direction: {:?}, Pad to multiple of: {:?}, Pad ID: {}, Pad Type ID: {}, Pad Token: {}", strategy, direction, pad_to_multiple_of, pad_id, pad_type_id, pad_token)] -#[display(fmt = "Strategy:")] +#[display(fmt="strategy={}, direction={}, pad_to_multiple_of={}, pad_id={}, pad_type_id={}, pad_token={}", strategy, direction, "pad_to_multiple_of.unwrap()", pad_id, pad_type_id, pad_token)] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, @@ -45,7 +44,6 @@ impl Default for PaddingParams { } #[derive(Debug, Clone, Serialize, Deserialize, Display)] -#[display(fmt={})] pub enum PaddingStrategy { BatchLongest, Fixed(usize), From 85c7b6920b4302f194292b0d04d2ac5273742573 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:29:52 +0200 Subject: [PATCH 40/94] update trunctation for consistency --- tokenizers/src/utils/truncation.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 6e8328cd1..c844af597 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,10 +1,11 @@ use crate::tokenizer::{Encoding, Result}; +use display_derive::StructDisplay; use derive_more::Display; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default, Display)] pub enum TruncationDirection { Left, #[default] @@ -20,14 +21,7 @@ impl std::convert::AsRef for TruncationDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize, Display)] -#[display( - fmt = "direction={:?}, max_length={}, strategy={:?}, stride={}", - direction, - max_length, - strategy, - stride -)] +#[derive(Debug, Clone, Serialize, Deserialize, StructDisplay)] pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, @@ -57,7 +51,7 @@ pub enum TruncationError { SequenceTooShort, } -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Display)] pub enum TruncationStrategy { LongestFirst, OnlyFirst, From 0a16ca0bfce29989a934ec92bcb22c35d6dd5d11 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:32:10 +0200 Subject: [PATCH 41/94] clean --- tokenizers/src/tokenizer/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 1c561b3bd..2f19dd36d 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -560,7 +560,7 @@ where write!( f, - "Tokenizer(normalizer={},\npre_tokenizer={},\nmodel={},\npost_processor={},\ndecoder={},\nadded_vocab={},\ntruncation={},\npadding={}\n)", + "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_vocab={}, truncation={}, padding={})", normalizer_str, pre_tokenizer_str, self.model, From 35d442d748f0f0ce6b09d91f9baa7cf241b20312 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:34:17 +0200 Subject: [PATCH 42/94] styling --- tokenizers/src/models/wordlevel/mod.rs | 8 ++++++-- tokenizers/src/utils/padding.rs | 10 +++++++++- tokenizers/src/utils/truncation.rs | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 2ae32f336..c59d64419 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -95,13 +95,17 @@ impl WordLevelBuilder { } #[derive(PartialEq, Clone, Eq, Display)] -#[display(fmt = "vocab={{{}, ...}}, unk_token={}", "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { +#[display( + fmt = "vocab={{{}, ...}}, unk_token={}", + "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { if !acc.is_empty() { acc.push_str(\", \"); } acc.push_str(&format!(\"\'{}\': {}\", key, value)); acc -})", unk_token)] +})", + unk_token +)] pub struct WordLevel { vocab: HashMap, vocab_r: HashMap, diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 03faf2943..a42762a64 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -20,7 +20,15 @@ impl std::convert::AsRef for PaddingDirection { } #[derive(Debug, Clone, Serialize, Deserialize, Display)] -#[display(fmt="strategy={}, direction={}, pad_to_multiple_of={}, pad_id={}, pad_type_id={}, pad_token={}", strategy, direction, "pad_to_multiple_of.unwrap()", pad_id, pad_type_id, pad_token)] +#[display( + fmt = "strategy={}, direction={}, pad_to_multiple_of={}, pad_id={}, pad_type_id={}, pad_token={}", + strategy, + direction, + "pad_to_multiple_of.unwrap()", + pad_id, + pad_type_id, + pad_token +)] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index c844af597..dfcb06432 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Encoding, Result}; -use display_derive::StructDisplay; use derive_more::Display; +use display_derive::StructDisplay; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; From a3cc764e95960cf15c8d2d3a37b162730f3ba8f6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:58:45 +0200 Subject: [PATCH 43/94] update added tokens decoder as getter --- bindings/python/src/tokenizer.rs | 3 +-- tokenizers/src/tokenizer/mod.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index a6cbcf40b..f5c04d4d6 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -666,8 +666,7 @@ impl PyTokenizer { /// /// Returns: /// :obj:`Dict[int, AddedToken]`: The vocabulary - #[pyo3(signature = ())] - #[pyo3(text_signature = "(self)")] + #[getter] fn get_added_tokens_decoder(&self) -> BTreeMap { let mut sorted_map = BTreeMap::new(); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 2f19dd36d..35ac22971 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -560,7 +560,7 @@ where write!( f, - "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_vocab={}, truncation={}, padding={})", + "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={}, truncation={}, padding={})", normalizer_str, pre_tokenizer_str, self.model, From 5b20fa71789d4c1f264f5128b83cc979d71e8cd2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 11:59:14 +0200 Subject: [PATCH 44/94] update init property --- .../python/py_src/tokenizers/__init__.pyi | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 5dbc665dc..3ecef4089 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -725,6 +725,16 @@ class Tokenizer: """ pass + @property + def added_tokens_decoder(self): + """ + Get the underlying vocabulary + + Returns: + :obj:`Dict[int, AddedToken]`: The vocabulary + """ + pass + def decode(self, ids, skip_special_tokens=True): """ Decode the given list of ids back to a string @@ -970,15 +980,6 @@ class Tokenizer: """ pass - def get_added_tokens_decoder(self): - """ - Get the underlying vocabulary - - Returns: - :obj:`Dict[int, AddedToken]`: The vocabulary - """ - pass - def get_vocab(self, with_added_tokens=True): """ Get the underlying vocabulary From 15f877ecd476e9b5437c4da91894750e4646efd2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 13:07:14 +0200 Subject: [PATCH 45/94] nit --- bindings/python/src/pre_tokenizers.rs | 1 - tokenizers/src/pre_tokenizers/mod.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 229a4c60d..e06eceade 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -641,7 +641,6 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { } #[derive(Clone, Deserialize, Display)] -#[display(fmt = "{}")] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Custom(CustomPreTokenizer), diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 40c113241..19b13ff8d 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -25,7 +25,6 @@ use crate::{PreTokenizedString, PreTokenizer}; use derive_more::Display; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] -#[display(fmt = "{}")] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), From 9c45e8ff842323ac6e9386983611865b97bb68ad Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 14:17:25 +0200 Subject: [PATCH 46/94] update sequences and basic enums to show xxxx.Sequence --- bindings/python/src/processors.rs | 4 ++++ tokenizers/src/decoders/mod.rs | 1 + tokenizers/src/decoders/sequence.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 7e695879b..1aa55f76e 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -143,6 +143,10 @@ impl PyPostProcessor { fn __str__(&self) -> PyResult { Ok(format!("{}", &self)) } + + fn __repr__(&self) -> PyResult{ + Ok(format!("{}", &self)) + } } /// This post-processor takes care of adding the special tokens needed by diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index f1da8df34..d37072195 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -25,6 +25,7 @@ use derive_more::Display; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Clone, Debug, Display)] +#[display(fmt="decoders.{})] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index e9e18698c..a1863b224 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, Display)] #[display( - fmt = "decoders.Sequence([{}])", + fmt = "Sequence([{}])", "decoders.iter().map(|d| d.to_string()).fold( String::new(), |mut acc, s|{ if !acc.is_empty(){ acc.push_str(\", \"); diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 19b13ff8d..4dde24125 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -25,6 +25,7 @@ use crate::{PreTokenizedString, PreTokenizer}; use derive_more::Display; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] +#[display(fmt="pre_tokenizers.{}")] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), From 4a348702a1349c3fcf6d695c109a56664afc38fa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 14:22:01 +0200 Subject: [PATCH 47/94] update --- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index d37072195..35cffbae8 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -25,7 +25,7 @@ use derive_more::Display; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Clone, Debug, Display)] -#[display(fmt="decoders.{})] +#[display(fmt="decoders.{}")] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 6497da64f..3fa44402b 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -121,8 +121,14 @@ impl WordPieceBuilder { /// model. #[derive(Clone, PartialEq, Eq, Display)] #[display( - fmt = "vocab={:?}, unk_token={}, continuing_subword_prefix={:?}", - vocab, + fmt = "vocab={}, unk_token={}, continuing_subword_prefix={:?}", + "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { + if !acc.is_empty() { + acc.push_str(\", \"); + } + acc.push_str(&format!(\"\'{}\': {}\", key, value)); + acc + })", unk_token, continuing_subword_prefix )] From e0d35e0ee130916db0ad128cdab6d177422188c8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 14:40:57 +0200 Subject: [PATCH 48/94] update --- tokenizers/src/processors/mod.rs | 1 + tokenizers/src/processors/sequence.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 320f5e7d2..0e3e92fc0 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -16,6 +16,7 @@ use derive_more::Display; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, Display)] +#[display(fmt="processors.{}")] #[serde(untagged)] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index b0a8938a4..aa829a383 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, Display)] #[display( - fmt = "processors.Sequence([{}])", + fmt = "Sequence([{}])", "processors.iter().fold(String::new(), |mut acc, p| { if !acc.is_empty() { acc.push_str(\", \"); From fe95add61247efcf06738a37ffa553ebe6974ba5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 14:55:58 +0200 Subject: [PATCH 49/94] some finishing touch --- bindings/python/src/decoders.rs | 1 - bindings/python/src/normalizers.rs | 1 - bindings/python/src/pre_tokenizers.rs | 1 - tokenizers/src/pre_tokenizers/sequence.rs | 2 +- 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 3d88c15d1..3692fd9f0 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -30,7 +30,6 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize, Display)] -#[display(fmt = "{}", decoder)] pub struct PyDecoder { #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 41d9e708d..3bc1a6e21 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -44,7 +44,6 @@ impl PyNormalizedStringMut<'_> { /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] #[derive(Clone, Serialize, Deserialize, Display, Debug)] -#[display(fmt = "{}", normalizer)] pub struct PyNormalizer { #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index e06eceade..3c1366c75 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,7 +35,6 @@ use derive_more::Display; subclass )] #[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt = "PreTokenizer(pretok={})", pretok)] pub struct PyPreTokenizer { #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 5a83bebff..80190517e 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Display)] #[display( - fmt = "pre_tokenizers.Seqence([{}])", + fmt = "Seqence([{}])", "pretokenizers.iter().fold(String::new(), |mut acc, p| { if !acc.is_empty(){ acc.push_str(\", \") From 2770099240b3a2d761b12bb95b232e76662d216a Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:26:21 +0200 Subject: [PATCH 50/94] Update bindings/python/Cargo.toml --- bindings/python/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 25e666b20..14050874d 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -20,7 +20,6 @@ ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" -ellipse = "0.2.0" [dependencies.tokenizers] path = "../../tokenizers" From 3d0eb0aab9a466833166aa95443673eab40ed82c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 10 Jun 2024 18:39:42 +0200 Subject: [PATCH 51/94] nit --- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index c59d64419..10aeac106 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -96,7 +96,7 @@ impl WordLevelBuilder { #[derive(PartialEq, Clone, Eq, Display)] #[display( - fmt = "vocab={{{}, ...}}, unk_token={}", + fmt = "WordLevel(vocab={{{}, ...}}, unk_token={})", "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { if !acc.is_empty() { acc.push_str(\", \"); diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 3fa44402b..7261ce111 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -121,7 +121,7 @@ impl WordPieceBuilder { /// model. #[derive(Clone, PartialEq, Eq, Display)] #[display( - fmt = "vocab={}, unk_token={}, continuing_subword_prefix={:?}", + fmt = "WordPiece(vocab={}, unk_token={}, continuing_subword_prefix={:?})", "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { if !acc.is_empty() { acc.push_str(\", \"); From 11a3601ed71825b11525c4ef06a2171560cd1be5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Jun 2024 08:41:23 +0200 Subject: [PATCH 52/94] gracefully handle errors for the proc macro --- tokenizers/display_derive/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 4a3d09628..e3f845e22 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,7 +1,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, Error, Data, DeriveInput, Fields}; #[proc_macro_derive(StructDisplay)] pub fn display_derive(input: TokenStream) -> TokenStream { @@ -60,10 +60,10 @@ pub fn display_derive(input: TokenStream) -> TokenStream { } } }, - _ => unimplemented!(), + _ => return Error::new_spanned(&name, "Failed to automatically derive the `Display` trait for this structure.").to_compile_error().into(), } }, - _ => unimplemented!(), + _ => return Error::new_spanned(&name, "Failed to automatically derive the `Display` trait for this structure.").to_compile_error().into(), }; // Convert into a token stream and return it From 2a54482e4ee73b7bd97e92e1c2cf3faa7a32f748 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Jun 2024 10:28:46 +0200 Subject: [PATCH 53/94] remove derive_more I wanted to remove the derive more crate and implement stuff --- tokenizers/Cargo.toml | 1 - tokenizers/display_derive/src/lib.rs | 68 ++++++++++++++++--- tokenizers/src/decoders/bpe.rs | 3 +- tokenizers/src/decoders/byte_fallback.rs | 5 +- tokenizers/src/decoders/ctc.rs | 4 +- tokenizers/src/decoders/fuse.rs | 2 +- tokenizers/src/decoders/strip.rs | 4 +- tokenizers/src/decoders/wordpiece.rs | 4 +- tokenizers/src/normalizers/prepend.rs | 4 +- tokenizers/src/normalizers/strip.rs | 6 +- tokenizers/src/normalizers/unicode.rs | 12 ++-- tokenizers/src/normalizers/utils.rs | 5 +- tokenizers/src/pre_tokenizers/bert.rs | 4 +- tokenizers/src/pre_tokenizers/byte_level.rs | 4 +- tokenizers/src/pre_tokenizers/delimiter.rs | 4 +- tokenizers/src/pre_tokenizers/digits.rs | 4 +- tokenizers/src/pre_tokenizers/punctuation.rs | 4 +- tokenizers/src/pre_tokenizers/split.rs | 2 +- .../unicode_scripts/pre_tokenizer.rs | 4 +- tokenizers/src/pre_tokenizers/whitespace.rs | 6 +- tokenizers/src/tokenizer/added_vocabulary.rs | 5 +- tokenizers/src/utils/truncation.rs | 5 +- 22 files changed, 101 insertions(+), 59 deletions(-) diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 6e6dd2f22..4d4265bec 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -63,7 +63,6 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" -derive_more = "0.99.17" display_derive = { path = "display_derive" } [features] diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index e3f845e22..2daa2548b 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,15 +1,15 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Error, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, Data, DeriveInput, Fields}; -#[proc_macro_derive(StructDisplay)] +#[proc_macro_derive(Display)] pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); // Get the name of the struct - let name = input.ident; + let name = &input.ident; // Generate code to match the struct's fields let expanded = match input.data { @@ -18,7 +18,6 @@ pub fn display_derive(input: TokenStream) -> TokenStream { Fields::Named(fields) => { // If the struct has named fields let field_names = fields.named.iter().map(|f| &f.ident); - let field_names2 = field_names.clone(); let field_types = fields.named.iter().map(|f| &f.ty); quote! { impl std::fmt::Display for #name { @@ -31,7 +30,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { } first = false; - let field_value = &self.#field_names2; + let field_value = &self.#field_names; write!(f, "{}=", stringify!(#field_names))?; if std::any::TypeId::of::<#field_types>() == std::any::TypeId::of::(){ write!(f, "\"{}\"", field_value)?; @@ -60,12 +59,61 @@ pub fn display_derive(input: TokenStream) -> TokenStream { } } }, - _ => return Error::new_spanned(&name, "Failed to automatically derive the `Display` trait for this structure.").to_compile_error().into(), + Fields::Unnamed(_) => { + quote! { + compile_error!("Unnamed fields for struct are not supported."); + } + }, } }, - _ => return Error::new_spanned(&name, "Failed to automatically derive the `Display` trait for this structure.").to_compile_error().into(), - }; - - // Convert into a token stream and return it + Data::Enum(ref data_enum) => { + let variants = &data_enum.variants; + let display_impls = variants.iter().map(|variant| { + let ident = &variant.ident; + if let Some((_, meta)) = variant.attrs.iter().find(|(path, _)| path.is_ident("display")) { + if let Ok(Meta::List(MetaList { nested, .. })) = meta.parse_meta() { + let format_args = nested.iter().map(|nested_meta| { + if let NestedMeta::Lit(Lit::Str(s)) = nested_meta { + quote! { #s } + } else { + quote! { compile_error!("Invalid format argument"); } + } + }); + quote! { + impl std::fmt::Display for #name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::#ident(#format_args) => write!(f, "{}", stringify!(#ident)), + _ => unreachable!(), + } + } + } + } + } else { + quote! { + compile_error!("Invalid display attribute format"); + } + } + } else { + quote! { + impl std::fmt::Display for #name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", stringify!(#ident)) + } + } + } + } + }); + quote! { + #(#display_impls)* + } + }, + Data::Union(_) => { + quote! { + compile_error!("Unions are not supported for Display derive"); + + } + }; + } TokenStream::from(expanded) } diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 654e61a0e..5636c6524 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -1,6 +1,5 @@ use crate::tokenizer::{Decoder, Result}; - -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Display)] /// Allows decoding Original BPE by joining all the tokens and then replacing diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index ea9390db1..69817c1b4 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,9 +1,8 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; - -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] +#[derive(Deserialize, Clone Debug, Serialize, Default, Display)] /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index d31b68e66..f96e71f3e 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,10 +1,10 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; -use display_derive::StructDisplay; +use display_derive::Display; use itertools::Itertools; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize, StructDisplay)] +#[derive(Debug, Clone, Serialize, Deserialize, Display)] /// The CTC (Connectionist Temporal Classification) decoder takes care /// of sanitizing a list of inputs token. /// Due to some alignement problem the output of some models can come diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index 9208afa97..b91485eec 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use derive_more::Display; use monostate::MustBe; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize, Default, Display)] /// Fuse simply fuses all tokens into one big string. diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index decdf6876..93d085a45 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Default, StructDisplay)] +#[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] /// Strip is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index a7b85df67..c8bd57c06 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, StructDisplay)] +#[derive(Deserialize, Clone, Debug, Serialize, Display)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens /// back into a readable string. #[serde(tag = "type")] diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index 7cdd7245d..a9e6ded60 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Deserialize, Serialize, StructDisplay)] +#[derive(Clone, Debug, Deserialize, Serialize, Display)] #[serde(tag = "type")] pub struct Prepend { pub prepend: String, diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index b188df19a..ef298cc03 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,9 +1,9 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; -#[derive(Copy, Clone, Debug, Deserialize, Serialize, StructDisplay)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize, Display)] #[serde(tag = "type")] #[non_exhaustive] pub struct Strip { @@ -43,7 +43,7 @@ impl Normalizer for Strip { // This normalizer removes combining marks from a normalized string // It's different from unidecode as it does not attempt to modify // non ascii languages. -#[derive(Copy, Clone, Debug, StructDisplay)] +#[derive(Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct StripAccents; diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 9a1e657cd..8cdfcf1dd 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use display_derive::StructDisplay; +use display_derive::Display; -#[derive(Default, Copy, Clone, Debug, StructDisplay)] +#[derive(Default, Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFD; impl Normalizer for NFD { @@ -12,7 +12,7 @@ impl Normalizer for NFD { } } -#[derive(Default, Copy, Clone, Debug, StructDisplay)] +#[derive(Default, Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKD; impl Normalizer for NFKD { @@ -22,7 +22,7 @@ impl Normalizer for NFKD { } } -#[derive(Default, Copy, Clone, Debug, StructDisplay)] +#[derive(Default, Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFC; impl Normalizer for NFC { @@ -32,7 +32,7 @@ impl Normalizer for NFC { } } -#[derive(Default, Copy, Clone, Debug, StructDisplay)] +#[derive(Default, Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKC; impl Normalizer for NFKC { @@ -73,7 +73,7 @@ fn do_nmt(normalized: &mut NormalizedString) { }); } -#[derive(Default, Copy, Clone, Debug, StructDisplay)] +#[derive(Default, Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct Nmt; impl Normalizer for Nmt { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index ddd18b9b4..c241fcc4d 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,8 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use derive_more::Display; -use display_derive::StructDisplay; +use display_derive::Display; #[derive(Clone, Deserialize, Debug, Serialize, Display)] #[display( fmt = "Sequence([{}])", @@ -47,7 +46,7 @@ impl Normalizer for Sequence { } /// Lowercases the input -#[derive(Copy, Clone, Debug, StructDisplay)] +#[derive(Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct Lowercase; diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index eeaa0c315..3551a6a63 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -1,13 +1,13 @@ use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -use display_derive::StructDisplay; +use display_derive::Display; use unicode_categories::UnicodeCategories; fn is_bert_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct BertPreTokenizer; diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 852a7a3a1..0693449de 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; use crate::utils::SysRegex; -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. @@ -50,7 +50,7 @@ lazy_static! { /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. #[macro_rules_attribute(impl_serde_type!)] -#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index fb445203e..37428f52c 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,10 +1,10 @@ -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] #[non_exhaustive] #[macro_rules_attribute(impl_serde_type!)] pub struct CharDelimiterSplit { diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index 5fb2a5b41..393817157 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,10 +1,10 @@ -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Clone, Debug, PartialEq, Eq, Display)] /// Pre tokenizes the numbers into single tokens. If individual_digits is set /// to true, then all digits are splitted into individual tokens. #[non_exhaustive] diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index 4421f246b..b1cb01323 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,4 +1,4 @@ -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; @@ -9,7 +9,7 @@ fn is_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct Punctuation { #[serde(default = "default_split")] diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index a442a5884..fa0134b6b 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -2,7 +2,7 @@ use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::SysRegex; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use diff --git a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs index 7df5e0367..cffe5c47e 100644 --- a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs +++ b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs @@ -1,10 +1,10 @@ -use display_derive::StructDisplay; +use display_derive::Display; use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script}; use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Clone, Debug, PartialEq, Eq, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct UnicodeScripts; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 0bce6d178..12dd60346 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,4 +1,4 @@ -use display_derive::StructDisplay; +use display_derive::Display; use regex::Regex; use crate::tokenizer::{ @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Clone, Debug, PartialEq, Eq, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct Whitespace; @@ -29,7 +29,7 @@ impl PreTokenizer for Whitespace { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, StructDisplay)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] #[macro_rules_attribute(impl_serde_type!)] pub struct WhitespaceSplit; diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 642f5670a..c8c147b44 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,8 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use derive_more::Display; -use display_derive::StructDisplay; +use display_derive::Display; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -13,7 +12,7 @@ use std::collections::{HashMap, HashSet}; /// like: /// - Whether they should only match single words /// - Whether to include any whitespace on its left or right -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, StructDisplay)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Display)] pub struct AddedToken { /// The content of the added token pub content: String, diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index dfcb06432..f73208a51 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,6 +1,5 @@ use crate::tokenizer::{Encoding, Result}; -use derive_more::Display; -use display_derive::StructDisplay; +use display_derive::Display; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; @@ -21,7 +20,7 @@ impl std::convert::AsRef for TruncationDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize, StructDisplay)] +#[derive(Debug, Clone, Serialize, Deserialize, Display)] pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, From 998b2a33e6c13f6112ad5464662356e56a0f1ede Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Jun 2024 10:30:18 +0200 Subject: [PATCH 54/94] update my custom macro --- tokenizers/display_derive/src/lib.rs | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 2daa2548b..227fae9aa 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,7 +1,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, MetaList, NestedMeta}; #[proc_macro_derive(Display)] pub fn display_derive(input: TokenStream) -> TokenStream { @@ -70,20 +70,25 @@ pub fn display_derive(input: TokenStream) -> TokenStream { let variants = &data_enum.variants; let display_impls = variants.iter().map(|variant| { let ident = &variant.ident; - if let Some((_, meta)) = variant.attrs.iter().find(|(path, _)| path.is_ident("display")) { - if let Ok(Meta::List(MetaList { nested, .. })) = meta.parse_meta() { - let format_args = nested.iter().map(|nested_meta| { - if let NestedMeta::Lit(Lit::Str(s)) = nested_meta { - quote! { #s } - } else { - quote! { compile_error!("Invalid format argument"); } + if let Some(attr) = variant.attrs.iter().find(|attr| attr.path.is_ident("display")) { + if let Ok(Meta::List(meta_list)) = attr.parse_meta() { + let format_args = meta_list.nested.iter().map(|nested_meta| { + match nested_meta { + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("fmt") => { + if let syn::Lit::Str(s) = &nv.lit { + quote! { #s } + } else { + quote! { compile_error!("Invalid format argument"); } + } + } + _ => quote! { compile_error!("Invalid format argument"); }, } - }); + }).collect::>(); // Collect into a Vec quote! { impl std::fmt::Display for #name { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - Self::#ident(#format_args) => write!(f, "{}", stringify!(#ident)), + Self::#ident(#(#format_args),*) => write!(f, "{}", format_args!(#(#format_args),*)), _ => unreachable!(), } } @@ -113,7 +118,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { compile_error!("Unions are not supported for Display derive"); } - }; - } + } + }; TokenStream::from(expanded) } From 4df6cc209897c82f4d270239f20d0440fbfa193a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Jun 2024 10:33:16 +0200 Subject: [PATCH 55/94] replace derive more --- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/decoders/sequence.rs | 2 +- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/models/unigram/model.rs | 2 +- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 2 +- tokenizers/src/normalizers/bert.rs | 2 +- tokenizers/src/normalizers/mod.rs | 2 +- tokenizers/src/normalizers/replace.rs | 2 +- tokenizers/src/pre_tokenizers/metaspace.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 2 +- tokenizers/src/pre_tokenizers/sequence.rs | 2 +- tokenizers/src/processors/bert.rs | 2 +- tokenizers/src/processors/mod.rs | 2 +- tokenizers/src/processors/roberta.rs | 2 +- tokenizers/src/processors/sequence.rs | 2 +- tokenizers/src/processors/template.rs | 2 +- tokenizers/src/tokenizer/normalizer.rs | 2 +- tokenizers/src/utils/padding.rs | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 35cffbae8..1bc6b62bb 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -21,7 +21,7 @@ use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Clone, Debug, Display)] diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index a1863b224..ae9784f9c 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,7 +1,7 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 49a31211a..4e3b7d61b 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 95edc2b35..f00953cb7 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,7 +6,7 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; -use derive_more::Display; +use display_derive::Display; use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 10aeac106..09739e07d 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use derive_more::Display; +use display_derive::Display; use serde_json::Value; use std::collections::HashMap; use std::fs::File; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 7261ce111..8b8737e44 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use derive_more::Display; +use display_derive::Display; use std::{ borrow::Cow, collections::HashMap, diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index e58d028d6..1e8e6ebf8 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index e0f974346..fd097831b 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -15,7 +15,7 @@ pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; use crate::{NormalizedString, Normalizer}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; /// Wrapper for known Normalizers. diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index d79c141b9..41a316942 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,7 +2,7 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 6caf3e373..96fa66346 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -use derive_more::Display; +use display_derive::Display; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. #[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, Display)] diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 4dde24125..08166b355 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -22,7 +22,7 @@ use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; -use derive_more::Display; +use display_derive::Display; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] #[display(fmt="pre_tokenizers.{}")] diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 80190517e..94c30dcd7 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,7 +1,7 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index bf0dbad07..5a4ee43aa 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 0e3e92fc0..7e7e50c10 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -12,7 +12,7 @@ use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; use crate::processors::template::TemplateProcessing; use crate::{Encoding, PostProcessor, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, Display)] diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 08857adaf..b0d40c295 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,6 +1,6 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index aa829a383..d68bfc513 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -1,7 +1,7 @@ use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, Display)] diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index b28d0b693..0cbe04ab8 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -56,7 +56,7 @@ //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! use crate::{Encoding, PostProcessor, Result}; -use derive_more::Display; +use display_derive::Display; use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 46fa7b2ee..efd1b728a 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,6 +1,6 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index a42762a64..60b30786a 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,6 +1,6 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use derive_more::Display; +use display_derive::Display; use serde::{Deserialize, Serialize}; /// The various possible padding directions. From aefdc918e88d8eceff4c17019a8ec96d9590a3c4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Jun 2024 15:20:15 +0200 Subject: [PATCH 56/94] stash --- tokenizers/display_derive/src/lib.rs | 302 ++++++++++------ tokenizers/display_derive/src/vendored.rs | 406 ++++++++++++++++++++++ tokenizers/vendored.rs | 170 +++++++++ 3 files changed, 769 insertions(+), 109 deletions(-) create mode 100644 tokenizers/display_derive/src/vendored.rs create mode 100644 tokenizers/vendored.rs diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 227fae9aa..3c10771f2 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,124 +1,208 @@ extern crate proc_macro; use proc_macro::TokenStream; -use quote::quote; +use quote::{format_ident,quote}; use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, MetaList, NestedMeta}; +mod vendored; +use vendored::FmtAttribute; #[proc_macro_derive(Display)] -pub fn display_derive(input: TokenStream) -> TokenStream { +pub fn display_derive(input: TokenStream) -> syn::Result { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); - // Get the name of the struct - let name = &input.ident; - - // Generate code to match the struct's fields - let expanded = match input.data { - Data::Struct(data) => { - match data.fields { - Fields::Named(fields) => { - // If the struct has named fields - let field_names = fields.named.iter().map(|f| &f.ident); - let field_types = fields.named.iter().map(|f| &f.ty); - quote! { - impl std::fmt::Display for #name { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}(", stringify!(#name))?; - let mut first = true; - #( - if !first { - write!(f, ", ")?; - } - first = false; - - let field_value = &self.#field_names; - write!(f, "{}=", stringify!(#field_names))?; - if std::any::TypeId::of::<#field_types>() == std::any::TypeId::of::(){ - write!(f, "\"{}\"", field_value)?; - } else { - let s = format!("{}", field_value); - let mut chars = s.chars(); - let mut prefix = (&mut chars).take(100 - 1).collect::(); - if chars.next().is_some() { - prefix.push('…'); - } - write!(f, "{}", prefix)?; - } - )* - write!(f, ")") - } - } - } - }, - Fields::Unit => { - // If the struct has no fields - quote! { - impl std::fmt::Display for #name { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}()", stringify!(#name)) - } - } - } - }, - Fields::Unnamed(_) => { - quote! { - compile_error!("Unnamed fields for struct are not supported."); - } - }, + let attr_name = "display"; + let attrs = FmtAttributes::parse_attrs(&input.attrs, &attr_name)? + .map(Spanning::into_inner) + .unwrap_or_default(); + let trait_ident = format_ident!("display"); + let ident = &input.ident; + + let ctx = (&attrs, ident, &trait_ident, &attr_name); + let body = match &input.data { + syn::Data::Struct(s) => expand_struct(s, ctx), + syn::Data::Enum(e) => expand_enum(e, ctx), + syn::Data::Union(u) => return Err(syn::Error::new(u, format!("Union is not supported"))), + }?; + + Ok(quote! { + impl std::fmt::Display for #ident{ + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + #body } + } + }) +} + +/// Type alias for an expansion context: +/// - [`ContainerAttributes`]. +/// - Struct/enum/union [`syn::Ident`]. +/// - Derived trait [`syn::Ident`]. +/// - Attribute name [`syn::Ident`]. +/// +/// [`syn::Ident`]: struct@syn::Ident +type ExpansionCtx<'a> = ( + &'a FmtAttribute, + &'a syn::Ident, + &'a syn::Ident, + &'a syn::Ident, +); + +/// Expands a [`fmt::Display`]-like derive macro for the provided struct. +fn expand_struct( + s: &syn::DataStruct, + (attrs, ident, trait_ident, _): ExpansionCtx<'_>, +) -> syn::Result<(Vec, TokenStream)> { + let s = Expansion { + attrs, + fields: &s.fields, + trait_ident, + ident, + }; + let body = s.generate_body()?; + + let vars = s.fields.iter().enumerate().map(|(i, f)| { + let var = f.ident.clone().unwrap_or_else(|| format_ident!("_{i}")); + let member = f + .ident + .clone() + .map_or_else(|| syn::Member::Unnamed(i.into()), syn::Member::Named); + quote! { + let #var = &self.#member; + } + }); + + let body = quote! { + #( #vars )* + #body + }; + + Ok(body) +} + +/// Expands a [`fmt`]-like derive macro for the provided enum. +fn expand_enum( + e: &syn::DataEnum, + (attrs, _, trait_ident, attr_name): ExpansionCtx<'_>, +) -> syn::Result<(Vec, TokenStream)> { + if attrs.fmt.is_some() { + todo!("https://github.com/JelteF/derive_more/issues/142"); + } + + let match_arms = e.variants.iter().try_fold( + TokenStream::new, |variant| { + let attrs = ContainerAttributes::parse_attrs(&variant.attrs, attr_name)? + .map(Spanning::into_inner) + .unwrap_or_default(); + let ident = &variant.ident; + + let v = Expansion { + attrs: &attrs, + fields: &variant.fields, + trait_ident, + ident, + }; + let arm_body = v.generate_body()?; + + let fields_idents = + variant.fields.iter().enumerate().map(|(i, f)| { + f.ident.clone().unwrap_or_else(|| format_ident!("_{i}")) + }); + let matcher = match variant.fields { + syn::Fields::Named(_) => { + quote! { Self::#ident { #( #fields_idents ),* } } + } + syn::Fields::Unnamed(_) => { + quote! { Self::#ident ( #( #fields_idents ),* ) } + } + syn::Fields::Unit => quote! { Self::#ident }, + }; + + arms.extend([quote! { #matcher => { #arm_body }, }]); + + Ok::<_, syn::Error>(arms) }, - Data::Enum(ref data_enum) => { - let variants = &data_enum.variants; - let display_impls = variants.iter().map(|variant| { - let ident = &variant.ident; - if let Some(attr) = variant.attrs.iter().find(|attr| attr.path.is_ident("display")) { - if let Ok(Meta::List(meta_list)) = attr.parse_meta() { - let format_args = meta_list.nested.iter().map(|nested_meta| { - match nested_meta { - NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("fmt") => { - if let syn::Lit::Str(s) = &nv.lit { - quote! { #s } - } else { - quote! { compile_error!("Invalid format argument"); } - } - } - _ => quote! { compile_error!("Invalid format argument"); }, - } - }).collect::>(); // Collect into a Vec - quote! { - impl std::fmt::Display for #name { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::#ident(#(#format_args),*) => write!(f, "{}", format_args!(#(#format_args),*)), - _ => unreachable!(), - } - } - } - } - } else { - quote! { - compile_error!("Invalid display attribute format"); - } - } + )?; + + let body = match_arms + .is_empty() + .then(|| quote! { match *self {} }) + .unwrap_or_else(|| quote! { match self { #match_arms } }); + + Ok(body) +} + + +/// Helper struct to generate [`Display::fmt()`] implementation body and trait +/// bounds for a struct or an enum variant. +/// +/// [`Display::fmt()`]: fmt::Display::fmt() +#[derive(Debug)] +struct Expansion<'a> { + /// Derive macro [`ContainerAttributes`]. + attrs: &'a ContainerAttributes, + + /// Struct or enum [`syn::Ident`]. + /// + /// [`syn::Ident`]: struct@syn::Ident + ident: &'a syn::Ident, + + /// Struct or enum [`syn::Fields`]. + fields: &'a syn::Fields, + + /// [`fmt`] trait [`syn::Ident`]. + /// + /// [`syn::Ident`]: struct@syn::Ident + trait_ident: &'a syn::Ident, +} + +impl<'a> Expansion<'a> { + /// Generates [`Display::fmt()`] implementation for a struct or an enum variant. + /// + /// # Errors + /// + /// In case [`FmtAttribute`] is [`None`] and [`syn::Fields`] length is + /// greater than 1. + /// + /// [`Display::fmt()`]: fmt::Display::fmt() + /// [`FmtAttribute`]: super::FmtAttribute + fn generate_body(&self) -> syn::Result { + match &self.attrs.fmt { + Some(fmt) => { + Ok(if let Some((expr, trait_ident)) = fmt.transparent_call() { + quote! { core::fmt::#trait_ident::fmt(&(#expr), __derive_more_f) } } else { - quote! { - impl std::fmt::Display for #name { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", stringify!(#ident)) - } - } - } - } - }); - quote! { - #(#display_impls)* + quote! { core::write!(__derive_more_f, #fmt) } + }) } - }, - Data::Union(_) => { - quote! { - compile_error!("Unions are not supported for Display derive"); - + None if self.fields.is_empty() => { + let ident_str = self.ident.to_string(); + + Ok(quote! { + core::write!(__derive_more_f, #ident_str) + }) } + None if self.fields.len() == 1 => { + let field = self + .fields + .iter() + .next() + .unwrap_or_else(|| unreachable!("count() == 1")); + let ident = field.ident.clone().unwrap_or_else(|| format_ident!("_0")); + let trait_ident = self.trait_ident; + + Ok(quote! { + core::fmt::#trait_ident::fmt(#ident, __derive_more_f) + }) + } + _ => Err(syn::Error::new( + self.fields.span(), + format!( + "TODO ARTHUR! struct or enum variant with more than 1 field must have \ + `#[{}(\"...\", ...)]` attribute", + trait_name_to_attribute_name(self.trait_ident), + ), + )), } - }; - TokenStream::from(expanded) + } } + diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs new file mode 100644 index 000000000..c211937c2 --- /dev/null +++ b/tokenizers/display_derive/src/vendored.rs @@ -0,0 +1,406 @@ + +/// Representation of a [`fmt`]-like attribute. +/// +/// ```rust,ignore +/// #[("", )] +/// ``` +/// +/// [`fmt`]: std::fmt +#[derive(Debug)] +struct FmtAttribute { + /// Interpolation [`syn::LitStr`]. + /// + /// [`syn::LitStr`]: struct@syn::LitStr + lit: syn::LitStr, + + /// Optional [`token::Comma`]. + /// + /// [`token::Comma`]: struct@token::Comma + comma: Option, + + /// Interpolation arguments. + args: Punctuated, +} + +impl Parse for FmtAttribute { + fn parse(input: ParseStream<'_>) -> syn::Result { + Self::check_legacy_fmt(input)?; + + Ok(Self { + lit: input.parse()?, + comma: input + .peek(token::Comma) + .then(|| input.parse()) + .transpose()?, + args: input.parse_terminated(FmtArgument::parse, token::Comma)?, + }) + } +} + +impl attr::ParseMultiple for FmtAttribute {} + +impl ToTokens for FmtAttribute { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.lit.to_tokens(tokens); + self.comma.to_tokens(tokens); + self.args.to_tokens(tokens); + } +} + +impl FmtAttribute { + /// Checks whether this [`FmtAttribute`] can be replaced with a transparent delegation (calling + /// a formatting trait directly instead of interpolation syntax). + /// + /// If such transparent call is possible, the returns an [`Ident`] of the delegated trait and + /// the [`Expr`] to pass into the call, otherwise [`None`]. + /// + /// [`Ident`]: struct@syn::Ident + fn transparent_call(&self) -> Option<(Expr, syn::Ident)> { + // `FmtAttribute` is transparent when: + + // (1) There is exactly one formatting parameter. + let lit = self.lit.value(); + let param = + parsing::format(&lit).and_then(|(more, p)| more.is_empty().then_some(p))?; + + // (2) And the formatting parameter doesn't contain any modifiers. + if param + .spec + .map(|s| { + s.align.is_some() + || s.sign.is_some() + || s.alternate.is_some() + || s.zero_padding.is_some() + || s.width.is_some() + || s.precision.is_some() + || !s.ty.is_trivial() + }) + .unwrap_or_default() + { + return None; + } + + let expr = match param.arg { + // (3) And either exactly one positional argument is specified. + Some(parsing::Argument::Integer(_)) | None => (self.args.len() == 1) + .then(|| self.args.first()) + .flatten() + .map(|a| a.expr.clone()), + + // (4) Or the formatting parameter's name refers to some outer binding. + Some(parsing::Argument::Identifier(name)) if self.args.is_empty() => { + Some(format_ident!("{name}").into()) + } + + // (5) Or exactly one named argument is specified for the formatting parameter's name. + Some(parsing::Argument::Identifier(name)) => (self.args.len() == 1) + .then(|| self.args.first()) + .flatten() + .filter(|a| a.alias.as_ref().map(|a| a.0 == name).unwrap_or_default()) + .map(|a| a.expr.clone()), + }?; + + let trait_name = param + .spec + .map(|s| s.ty) + .unwrap_or(parsing::Type::Display) + .trait_name(); + + Some((expr, format_ident!("{trait_name}"))) + } + + /// Returns an [`Iterator`] over bounded [`syn::Type`]s (and correspondent trait names) by this + /// [`FmtAttribute`]. + fn bounded_types<'a>( + &'a self, + fields: &'a syn::Fields, + ) -> impl Iterator { + let placeholders = Placeholder::parse_fmt_string(&self.lit.value()); + + // We ignore unknown fields, as compiler will produce better error messages. + placeholders.into_iter().filter_map(move |placeholder| { + let name = match placeholder.arg { + Parameter::Named(name) => self + .args + .iter() + .find_map(|a| (a.alias()? == &name).then_some(&a.expr)) + .map_or(Some(name), |expr| expr.ident().map(ToString::to_string))?, + Parameter::Positional(i) => self + .args + .iter() + .nth(i) + .and_then(|a| a.expr.ident().filter(|_| a.alias.is_none()))? + .to_string(), + }; + + let unnamed = name.strip_prefix('_').and_then(|s| s.parse().ok()); + let ty = match (&fields, unnamed) { + (syn::Fields::Unnamed(f), Some(i)) => { + f.unnamed.iter().nth(i).map(|f| &f.ty) + } + (syn::Fields::Named(f), None) => f.named.iter().find_map(|f| { + f.ident.as_ref().filter(|s| **s == name).map(|_| &f.ty) + }), + _ => None, + }?; + + Some((ty, placeholder.trait_name)) + }) + } + + /// Errors in case legacy syntax is encountered: `fmt = "...", (arg),*`. + fn check_legacy_fmt(input: ParseStream<'_>) -> syn::Result<()> { + let fork = input.fork(); + + let path = fork + .parse::() + .and_then(|path| fork.parse::().map(|_| path)); + match path { + Ok(path) if path.is_ident("fmt") => (|| { + let args = fork + .parse_terminated( + >::parse, + token::Comma, + ) + .ok()? + .into_iter() + .enumerate() + .filter_map(|(i, arg)| match arg { + Either::Left(syn::Lit::Str(str)) => Some(if i == 0 { + format!("\"{}\"", str.value()) + } else { + str.value() + }), + Either::Right(ident) => Some(ident.to_string()), + _ => None, + }) + .collect::>(); + (!args.is_empty()).then_some(args) + })() + .map_or(Ok(()), |fmt| { + Err(syn::Error::new( + input.span(), + format!( + "legacy syntax, remove `fmt =` and use `{}` instead", + fmt.join(", "), + ), + )) + }), + Ok(_) | Err(_) => Ok(()), + } + } +} + +/// Representation of a [named parameter][1] (`identifier '=' expression`) in +/// in a [`FmtAttribute`]. +/// +/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters +#[derive(Debug)] +struct FmtArgument { + /// `identifier =` [`Ident`]. + /// + /// [`Ident`]: struct@syn::Ident + alias: Option<(syn::Ident, token::Eq)>, + + /// `expression` [`Expr`]. + expr: Expr, +} + +impl FmtArgument { + /// Returns an `identifier` of the [named parameter][1]. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters + fn alias(&self) -> Option<&syn::Ident> { + self.alias.as_ref().map(|(ident, _)| ident) + } +} + +impl Parse for FmtArgument { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self { + alias: (input.peek(syn::Ident) && input.peek2(token::Eq)) + .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) + .transpose()?, + expr: input.parse()?, + }) + } +} + +impl ToTokens for FmtArgument { + fn to_tokens(&self, tokens: &mut TokenStream) { + if let Some((ident, eq)) = &self.alias { + ident.to_tokens(tokens); + eq.to_tokens(tokens); + } + self.expr.to_tokens(tokens); + } +} + +/// Representation of a [parameter][1] used in a [`Placeholder`]. +/// +/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#formatting-parameters +#[derive(Debug, Eq, PartialEq)] +enum Parameter { + /// [Positional parameter][1]. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#positional-parameters + Positional(usize), + + /// [Named parameter][1]. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters + Named(String), +} + +impl<'a> From> for Parameter { + fn from(arg: parsing::Argument<'a>) -> Self { + match arg { + parsing::Argument::Integer(i) => Self::Positional(i), + parsing::Argument::Identifier(i) => Self::Named(i.to_owned()), + } + } +} + +/// Representation of a formatting placeholder. +#[derive(Debug, Eq, PartialEq)] +struct Placeholder { + /// Formatting argument (either named or positional) to be used by this placeholder. + arg: Parameter, + + /// [Width parameter][1], if present. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#width + width: Option, + + /// [Precision parameter][1], if present. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#precision + precision: Option, + + /// Name of [`std::fmt`] trait to be used for rendering this placeholder. + trait_name: &'static str, +} + +impl Placeholder { + /// Parses [`Placeholder`]s from the provided formatting string. + fn parse_fmt_string(s: &str) -> Vec { + let mut n = 0; + parsing::format_string(s) + .into_iter() + .flat_map(|f| f.formats) + .map(|format| { + let (maybe_arg, ty) = ( + format.arg, + format.spec.map(|s| s.ty).unwrap_or(parsing::Type::Display), + ); + let position = maybe_arg.map(Into::into).unwrap_or_else(|| { + // Assign "the next argument". + // https://doc.rust-lang.org/stable/std/fmt/index.html#positional-parameters + n += 1; + Parameter::Positional(n - 1) + }); + + Self { + arg: position, + width: format.spec.and_then(|s| match s.width { + Some(parsing::Count::Parameter(arg)) => Some(arg.into()), + _ => None, + }), + precision: format.spec.and_then(|s| match s.precision { + Some(parsing::Precision::Count(parsing::Count::Parameter( + arg, + ))) => Some(arg.into()), + _ => None, + }), + trait_name: ty.trait_name(), + } + }) + .collect() + } +} + +/// Representation of a [`fmt::Display`]-like derive macro attributes placed on a container (struct +/// or enum variant). +/// +/// ```rust,ignore +/// #[("", )] +/// #[(bound())] +/// ``` +/// +/// `#[(...)]` can be specified only once, while multiple `#[(bound(...))]` +/// are allowed. +/// +/// [`fmt::Display`]: std::fmt::Display +#[derive(Debug, Default)] +struct ContainerAttributes { + /// Interpolation [`FmtAttribute`]. + fmt: Option, + + /// Addition trait bounds. + bounds: BoundsAttribute, +} + +impl Parse for ContainerAttributes { + fn parse(input: ParseStream<'_>) -> syn::Result { + // We do check `FmtAttribute::check_legacy_fmt` eagerly here, because `Either` will swallow + // any error of the `Either::Left` if the `Either::Right` succeeds. + FmtAttribute::check_legacy_fmt(input)?; + >::parse(input).map(|v| match v { + Either::Left(fmt) => Self { + bounds: BoundsAttribute::default(), + fmt: Some(fmt), + }, + Either::Right(bounds) => Self { bounds, fmt: None }, + }) + } +} + +impl attr::ParseMultiple for ContainerAttributes { + fn merge_attrs( + prev: Spanning, + new: Spanning, + name: &syn::Ident, + ) -> syn::Result> { + let Spanning { + span: prev_span, + item: mut prev, + } = prev; + let Spanning { + span: new_span, + item: new, + } = new; + + if new.fmt.and_then(|n| prev.fmt.replace(n)).is_some() { + return Err(syn::Error::new( + new_span, + format!("multiple `#[{name}(\"...\", ...)]` attributes aren't allowed"), + )); + } + prev.bounds.0.extend(new.bounds.0); + + Ok(Spanning::new( + prev, + prev_span.join(new_span).unwrap_or(prev_span), + )) + } +} + +/// Matches the provided `trait_name` to appropriate [`FmtAttribute`]'s argument name. +fn trait_name_to_attribute_name(trait_name: T) -> &'static str +where + T: for<'a> PartialEq<&'a str>, +{ + match () { + _ if trait_name == "Binary" => "binary", + _ if trait_name == "Debug" => "debug", + _ if trait_name == "Display" => "display", + _ if trait_name == "LowerExp" => "lower_exp", + _ if trait_name == "LowerHex" => "lower_hex", + _ if trait_name == "Octal" => "octal", + _ if trait_name == "Pointer" => "pointer", + _ if trait_name == "UpperExp" => "upper_exp", + _ if trait_name == "UpperHex" => "upper_hex", + _ => unimplemented!(), + } +} + diff --git a/tokenizers/vendored.rs b/tokenizers/vendored.rs new file mode 100644 index 000000000..8f181e3f6 --- /dev/null +++ b/tokenizers/vendored.rs @@ -0,0 +1,170 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + spanned::Spanned as _, + token, +}; + +use crate::{ + parsing::Expr, + utils::{attr, Either, Spanning}, +}; + +/// Representation of a [`fmt`]-like attribute. +/// +/// ```rust,ignore +/// #[("", )] +/// ``` +/// +/// [`fmt`]: std::fmt +#[derive(Debug)] +pub struct FmtAttribute { + /// Interpolation [`syn::LitStr`]. + /// + /// [`syn::LitStr`]: struct@syn::LitStr + lit: syn::LitStr, + + /// Optional [`token::Comma`]. + /// + /// [`token::Comma`]: struct@token::Comma + comma: Option, + + /// Interpolation arguments. + args: Punctuated, +} + +impl Parse for FmtAttribute { + fn parse(input: ParseStream<'_>) -> syn::Result { + Self::check_legacy_fmt(input)?; + + Ok(Self { + lit: input.parse()?, + comma: input + .peek(token::Comma) + .then(|| input.parse()) + .transpose()?, + args: input.parse_terminated(FmtArgument::parse, token::Comma)?, + }) + } +} + +impl attr::ParseMultiple for FmtAttribute {} + +impl ToTokens for FmtAttribute { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.lit.to_tokens(tokens); + self.comma.to_tokens(tokens); + self.args.to_tokens(tokens); + } +} + +impl FmtAttribute { + /// Checks whether this [`FmtAttribute`] can be replaced with a transparent delegation (calling + /// a formatting trait directly instead of interpolation syntax). + /// + /// If such transparent call is possible, the returns an [`Ident`] of the delegated trait and + /// the [`Expr`] to pass into the call, otherwise [`None`]. + /// + /// [`Ident`]: struct@syn::Ident + fn transparent_call(&self) -> Option<(Expr, syn::Ident)> { + // `FmtAttribute` is transparent when: + + // (1) There is exactly one formatting parameter. + let lit = self.lit.value(); + let param = + parsing::format(&lit).and_then(|(more, p)| more.is_empty().then_some(p))?; + + // (2) And the formatting parameter doesn't contain any modifiers. + if param + .spec + .map(|s| { + s.align.is_some() + || s.sign.is_some() + || s.alternate.is_some() + || s.zero_padding.is_some() + || s.width.is_some() + || s.precision.is_some() + || !s.ty.is_trivial() + }) + .unwrap_or_default() + { + return None; + } + + let expr = match param.arg { + // (3) And either exactly one positional argument is specified. + Some(parsing::Argument::Integer(_)) | None => (self.args.len() == 1) + .then(|| self.args.first()) + .flatten() + .map(|a| a.expr.clone()), + + // (4) Or the formatting parameter's name refers to some outer binding. + Some(parsing::Argument::Identifier(name)) if self.args.is_empty() => { + Some(format_ident!("{name}").into()) + } + + // (5) Or exactly one named argument is specified for the formatting parameter's name. + Some(parsing::Argument::Identifier(name)) => (self.args.len() == 1) + .then(|| self.args.first()) + .flatten() + .filter(|a| a.alias.as_ref().map(|a| a.0 == name).unwrap_or_default()) + .map(|a| a.expr.clone()), + }?; + + let trait_name = param + .spec + .map(|s| s.ty) + .unwrap_or(parsing::Type::Display) + .trait_name(); + + Some((expr, format_ident!("{trait_name}"))) + } +} + +/// Representation of a [named parameter][1] (`identifier '=' expression`) in +/// in a [`FmtAttribute`]. +/// +/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters +#[derive(Debug)] +struct FmtArgument { + /// `identifier =` [`Ident`]. + /// + /// [`Ident`]: struct@syn::Ident + alias: Option<(syn::Ident, token::Eq)>, + + /// `expression` [`Expr`]. + expr: Expr, +} + +impl FmtArgument { + /// Returns an `identifier` of the [named parameter][1]. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters + fn alias(&self) -> Option<&syn::Ident> { + self.alias.as_ref().map(|(ident, _)| ident) + } +} + +impl Parse for FmtArgument { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self { + alias: (input.peek(syn::Ident) && input.peek2(token::Eq)) + .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) + .transpose()?, + expr: input.parse()?, + }) + } +} + +impl ToTokens for FmtArgument { + fn to_tokens(&self, tokens: &mut TokenStream) { + if let Some((ident, eq)) = &self.alias { + ident.to_tokens(tokens); + eq.to_tokens(tokens); + } + self.expr.to_tokens(tokens); + } +} + From f67af9cb0ede88d1c8b0152366fa59f3b47ecd56 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 09:58:07 +0200 Subject: [PATCH 57/94] updates --- tokenizers/display_derive/src/lib.rs | 8 +- tokenizers/display_derive/src/parsing.rs | 1311 +++++++++++++++++++++ tokenizers/display_derive/src/vendored.rs | 261 +--- tokenizers/vendored.rs | 170 --- 4 files changed, 1326 insertions(+), 424 deletions(-) create mode 100644 tokenizers/display_derive/src/parsing.rs delete mode 100644 tokenizers/vendored.rs diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 3c10771f2..633640659 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -34,7 +34,7 @@ pub fn display_derive(input: TokenStream) -> syn::Result { } /// Type alias for an expansion context: -/// - [`ContainerAttributes`]. +/// - [`FmtAttribute`]. /// - Struct/enum/union [`syn::Ident`]. /// - Derived trait [`syn::Ident`]. /// - Attribute name [`syn::Ident`]. @@ -90,7 +90,7 @@ fn expand_enum( let match_arms = e.variants.iter().try_fold( TokenStream::new, |variant| { - let attrs = ContainerAttributes::parse_attrs(&variant.attrs, attr_name)? + let attrs = FmtAttribute::parse_attrs(&variant.attrs, attr_name)? .map(Spanning::into_inner) .unwrap_or_default(); let ident = &variant.ident; @@ -138,8 +138,8 @@ fn expand_enum( /// [`Display::fmt()`]: fmt::Display::fmt() #[derive(Debug)] struct Expansion<'a> { - /// Derive macro [`ContainerAttributes`]. - attrs: &'a ContainerAttributes, + /// Derive macro [`FmtAttribute`]. + attrs: &'a FmtAttribute, /// Struct or enum [`syn::Ident`]. /// diff --git a/tokenizers/display_derive/src/parsing.rs b/tokenizers/display_derive/src/parsing.rs new file mode 100644 index 000000000..8a5d75277 --- /dev/null +++ b/tokenizers/display_derive/src/parsing.rs @@ -0,0 +1,1311 @@ +//! Parsing of [Rust `fmt` syntax][0]. +//! +//! [0]: std::fmt#syntax + +use std::iter; + +use unicode_xid::UnicodeXID as XID; + +/// Output of the [`format_string`] parser. +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct FormatString<'a> { + pub(crate) formats: Vec>, +} + +/// Output of the [`format`] parser. +/// +/// [`format`]: fn@format +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub(crate) struct Format<'a> { + pub(crate) arg: Option>, + pub(crate) spec: Option>, +} + +/// Output of the [`format_spec`] parser. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) struct FormatSpec<'a> { + /// Parsed `[[fill]`[`align`]`]`. + pub(crate) align: Option<(Option, Align)>, + + /// Parsed `[`[`sign`]`]`. + pub(crate) sign: Option, + + /// Parsed `['#']` (alternation). + pub(crate) alternate: Option, + + /// Parsed `['0']` (padding with zeros). + pub(crate) zero_padding: Option, + + /// Parsed `[width]`. + pub(crate) width: Option>, + + /// Parsed `['.' `[`precision`]`]`. + pub(crate) precision: Option>, + + /// Parsed [`type`]. + /// + /// [`type`]: type_ + pub(crate) ty: Type, +} + +/// Output of the [`argument`] parser. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Argument<'a> { + Integer(usize), + Identifier(&'a str), +} + +/// Output of the [`align`] parser. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Align { + Left, + Center, + Right, +} + +/// Output of the [`sign`] parser. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Sign { + Plus, + Minus, +} + +/// Type for the [`FormatSpec::alternate`]. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) struct Alternate; + +/// Type for the [`FormatSpec::zero_padding`]. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) struct ZeroPadding; + +/// Output of the [`precision`] parser. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Precision<'a> { + Count(Count<'a>), + Star, +} + +/// Output of the [`count`] parser. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Count<'a> { + Integer(usize), + Parameter(Parameter<'a>), +} + +/// Output of the [`type_`] parser. See [formatting traits][0] for more info. +/// +/// [0]: std::fmt#formatting-traits +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Type { + Display, + Debug, + LowerDebug, + UpperDebug, + Octal, + LowerHex, + UpperHex, + Pointer, + Binary, + LowerExp, + UpperExp, +} + +impl Type { + /// Returns trait name of this [`Type`]. + pub(crate) fn trait_name(&self) -> &'static str { + match self { + Self::Display => "Display", + Self::Debug | Self::LowerDebug | Self::UpperDebug => "Debug", + Self::Octal => "Octal", + Self::LowerHex => "LowerHex", + Self::UpperHex => "UpperHex", + Self::Pointer => "Pointer", + Self::Binary => "Binary", + Self::LowerExp => "LowerExp", + Self::UpperExp => "UpperExp", + } + } + + /// Indicates whether this [`Type`] represents a trivial trait call without any modifications. + pub(crate) fn is_trivial(&self) -> bool { + match self { + Self::Display + | Self::Debug + | Self::Octal + | Self::LowerHex + | Self::UpperHex + | Self::Pointer + | Self::Binary + | Self::LowerExp + | Self::UpperExp => true, + Self::LowerDebug | Self::UpperDebug => false, + } + } +} + +/// Type alias for the `fill` in the [`FormatSpec::align`]. +type Fill = char; + +/// Type alias for the [`FormatSpec::width`]. +type Width<'a> = Count<'a>; + +/// Output of the [`maybe_format`] parser. +type MaybeFormat<'a> = Option>; + +/// Output of the [`identifier`] parser. +type Identifier<'a> = &'a str; + +/// Output of the [`parameter`] parser. +type Parameter<'a> = Argument<'a>; + +/// [`str`] left to parse. +/// +/// [`str`]: prim@str +type LeftToParse<'a> = &'a str; + +/// Parses a `format_string` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`format_string`]` := `[`text`]` [`[`maybe_format text`]`] *` +/// +/// # Example +/// +/// ```text +/// Hello +/// Hello, {}! +/// {:?} +/// Hello {people}! +/// {} {} +/// {:04} +/// {par:-^#.0$?} +/// ``` +/// +/// # Return value +/// +/// - [`Some`] in case of successful parse. +/// - [`None`] otherwise (not all characters are consumed by underlying +/// parsers). +/// +/// [0]: std::fmt#syntax +pub(crate) fn format_string(input: &str) -> Option> { + let (mut input, _) = optional_result(text)(input); + + let formats = iter::repeat(()) + .scan(&mut input, |input, _| { + let (curr, format) = + alt(&mut [&mut maybe_format, &mut map(text, |(i, _)| (i, None))])( + input, + )?; + **input = curr; + Some(format) + }) + .flatten() + .collect(); + + // Should consume all tokens for a successful parse. + input.is_empty().then_some(FormatString { formats }) +} + +/// Parses a `maybe_format` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`maybe_format`]` := '{' '{' | '}' '}' | `[`format`] +/// +/// # Example +/// +/// ```text +/// {{ +/// }} +/// {:04} +/// {:#?} +/// {par:-^#.0$?} +/// ``` +/// +/// [`format`]: fn@format +/// [0]: std::fmt#syntax +fn maybe_format(input: &str) -> Option<(LeftToParse<'_>, MaybeFormat<'_>)> { + alt(&mut [ + &mut map(str("{{"), |i| (i, None)), + &mut map(str("}}"), |i| (i, None)), + &mut map(format, |(i, format)| (i, Some(format))), + ])(input) +} + +/// Parses a `format` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`format`]` := '{' [`[`argument`]`] [':' `[`format_spec`]`] '}'` +/// +/// # Example +/// +/// ```text +/// {par} +/// {:#?} +/// {par:-^#.0$?} +/// ``` +/// +/// [`format`]: fn@format +/// [0]: std::fmt#syntax +pub(crate) fn format(input: &str) -> Option<(LeftToParse<'_>, Format<'_>)> { + let input = char('{')(input)?; + + let (input, arg) = optional_result(argument)(input); + + let (input, spec) = map_or_else( + char(':'), + |i| Some((i, None)), + map(format_spec, |(i, s)| (i, Some(s))), + )(input)?; + + let input = char('}')(input)?; + + Some((input, Format { arg, spec })) +} + +/// Parses an `argument` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`argument`]` := `[`integer`]` | `[`identifier`] +/// +/// # Example +/// +/// ```text +/// 0 +/// ident +/// Минск +/// ``` +/// +/// [0]: std::fmt#syntax +fn argument(input: &str) -> Option<(LeftToParse<'_>, Argument)> { + alt(&mut [ + &mut map(identifier, |(i, ident)| (i, Argument::Identifier(ident))), + &mut map(integer, |(i, int)| (i, Argument::Integer(int))), + ])(input) +} + +/// Parses a `format_spec` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`format_spec`]` := [[fill]`[`align`]`][`[`sign`]`]['#']['0'][width]` +/// `['.' `[`precision`]`]`[`type`] +/// +/// # Example +/// +/// ```text +/// ^ +/// <^ +/// ->+#0width$.precision$x? +/// ``` +/// +/// [`type`]: type_ +/// [0]: std::fmt#syntax +fn format_spec(input: &str) -> Option<(LeftToParse<'_>, FormatSpec<'_>)> { + let (input, align) = optional_result(alt(&mut [ + &mut and_then(take_any_char, |(i, fill)| { + map(align, |(i, align)| (i, (Some(fill), align)))(i) + }), + &mut map(align, |(i, align)| (i, (None, align))), + ]))(input); + + let (input, sign) = optional_result(sign)(input); + + let (input, alternate) = optional_result(map(char('#'), |i| (i, Alternate)))(input); + + let (input, zero_padding) = optional_result(map( + try_seq(&mut [ + &mut char('0'), + &mut lookahead(check_char(|c| !matches!(c, '$'))), + ]), + |i| (i, ZeroPadding), + ))(input); + + let (input, width) = optional_result(count)(input); + + let (input, precision) = map_or_else( + char('.'), + |i| Some((i, None)), + map(precision, |(i, p)| (i, Some(p))), + )(input)?; + + let (input, ty) = type_(input)?; + + Some(( + input, + FormatSpec { + align, + sign, + alternate, + zero_padding, + width, + precision, + ty, + }, + )) +} + +/// Parses an `align` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`align`]` := '<' | '^' | '>'` +/// +/// # Example +/// +/// ```text +/// < +/// ^ +/// > +/// ``` +/// +/// [0]: std::fmt#syntax +fn align(input: &str) -> Option<(LeftToParse<'_>, Align)> { + alt(&mut [ + &mut map(char('<'), |i| (i, Align::Left)), + &mut map(char('^'), |i| (i, Align::Center)), + &mut map(char('>'), |i| (i, Align::Right)), + ])(input) +} + +/// Parses a `sign` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`sign`]` := '+' | '-'` +/// +/// # Example +/// +/// ```text +/// + +/// - +/// ``` +/// +/// [0]: std::fmt#syntax +fn sign(input: &str) -> Option<(LeftToParse<'_>, Sign)> { + alt(&mut [ + &mut map(char('+'), |i| (i, Sign::Plus)), + &mut map(char('-'), |i| (i, Sign::Minus)), + ])(input) +} + +/// Parses a `precision` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`precision`]` := `[`count`]` | '*'` +/// +/// # Example +/// +/// ```text +/// 0 +/// 42$ +/// par$ +/// * +/// ``` +/// +/// [0]: std::fmt#syntax +fn precision(input: &str) -> Option<(LeftToParse<'_>, Precision<'_>)> { + alt(&mut [ + &mut map(count, |(i, c)| (i, Precision::Count(c))), + &mut map(char('*'), |i| (i, Precision::Star)), + ])(input) +} + +/// Parses a `type` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`type`]` := '' | '?' | 'x?' | 'X?' | identifier` +/// +/// # Example +/// +/// All possible [`Type`]s. +/// +/// ```text +/// ? +/// x? +/// X? +/// o +/// x +/// X +/// p +/// b +/// e +/// E +/// ``` +/// +/// [`type`]: type_ +/// [0]: std::fmt#syntax +fn type_(input: &str) -> Option<(&str, Type)> { + alt(&mut [ + &mut map(str("x?"), |i| (i, Type::LowerDebug)), + &mut map(str("X?"), |i| (i, Type::UpperDebug)), + &mut map(char('?'), |i| (i, Type::Debug)), + &mut map(char('o'), |i| (i, Type::Octal)), + &mut map(char('x'), |i| (i, Type::LowerHex)), + &mut map(char('X'), |i| (i, Type::UpperHex)), + &mut map(char('p'), |i| (i, Type::Pointer)), + &mut map(char('b'), |i| (i, Type::Binary)), + &mut map(char('e'), |i| (i, Type::LowerExp)), + &mut map(char('E'), |i| (i, Type::UpperExp)), + &mut map(lookahead(char('}')), |i| (i, Type::Display)), + ])(input) +} + +/// Parses a `count` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`count`]` := `[`parameter`]` | `[`integer`] +/// +/// # Example +/// +/// ```text +/// 0 +/// 42$ +/// par$ +/// ``` +/// +/// [0]: std::fmt#syntax +fn count(input: &str) -> Option<(LeftToParse<'_>, Count<'_>)> { + alt(&mut [ + &mut map(parameter, |(i, p)| (i, Count::Parameter(p))), + &mut map(integer, |(i, int)| (i, Count::Integer(int))), + ])(input) +} + +/// Parses a `parameter` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// [`parameter`]` := `[`argument`]` '$'` +/// +/// # Example +/// +/// ```text +/// 42$ +/// par$ +/// ``` +/// +/// [0]: std::fmt#syntax +fn parameter(input: &str) -> Option<(LeftToParse<'_>, Parameter<'_>)> { + and_then(argument, |(i, arg)| map(char('$'), |i| (i, arg))(i))(input) +} + +/// Parses an `identifier` as defined in the [grammar spec][0]. +/// +/// # Grammar +/// +/// `IDENTIFIER_OR_KEYWORD : XID_Start XID_Continue* | _ XID_Continue+` +/// +/// See [rust reference][2] for more info. +/// +/// # Example +/// +/// ```text +/// identifier +/// Минск +/// ``` +/// +/// [0]: std::fmt#syntax +/// [2]: https://doc.rust-lang.org/reference/identifiers.html +fn identifier(input: &str) -> Option<(LeftToParse<'_>, Identifier<'_>)> { + map( + alt(&mut [ + &mut map( + check_char(XID::is_xid_start), + take_while0(check_char(XID::is_xid_continue)), + ), + &mut and_then(char('_'), take_while1(check_char(XID::is_xid_continue))), + ]), + |(i, _)| (i, &input[..(input.as_bytes().len() - i.as_bytes().len())]), + )(input) +} + +/// Parses an `integer` as defined in the [grammar spec][0]. +/// +/// [0]: std::fmt#syntax +fn integer(input: &str) -> Option<(LeftToParse<'_>, usize)> { + and_then( + take_while1(check_char(|c| c.is_ascii_digit())), + |(i, int)| int.parse().ok().map(|int| (i, int)), + )(input) +} + +/// Parses a `text` as defined in the [grammar spec][0]. +/// +/// [0]: std::fmt#syntax +fn text(input: &str) -> Option<(LeftToParse<'_>, &str)> { + take_until1(any_char, one_of("{}"))(input) +} + +type FallibleParser<'p> = &'p mut dyn FnMut(&str) -> Option<&str>; + +/// Tries to apply parsers in sequence. Returns [`None`] in case one of them +/// returned [`None`]. +fn try_seq<'p>( + parsers: &'p mut [FallibleParser<'p>], +) -> impl FnMut(&str) -> Option> + 'p { + move |input| parsers.iter_mut().try_fold(input, |i, p| (**p)(i)) +} + +/// Returns first successful parser or [`None`] in case all of them returned +/// [`None`]. +fn alt<'p, 'i, T: 'i>( + parsers: &'p mut [&'p mut dyn FnMut(&'i str) -> Option], +) -> impl FnMut(&'i str) -> Option + 'p { + move |input| parsers.iter_mut().find_map(|p| (**p)(input)) +} + +/// Maps output of the parser in case it returned [`Some`]. +fn map<'i, I: 'i, O: 'i>( + mut parser: impl FnMut(&'i str) -> Option, + mut f: impl FnMut(I) -> O, +) -> impl FnMut(&'i str) -> Option { + move |input| parser(input).map(&mut f) +} + +/// Maps output of the parser in case it returned [`Some`] or uses `default`. +fn map_or_else<'i, I: 'i, O: 'i>( + mut parser: impl FnMut(&'i str) -> Option, + mut default: impl FnMut(&'i str) -> O, + mut f: impl FnMut(I) -> O, +) -> impl FnMut(&'i str) -> O { + move |input| parser(input).map_or_else(|| default(input), &mut f) +} + +/// Returns [`None`] if the parser returned is [`None`], otherwise calls `f` +/// with the wrapped value and returns the result. +fn and_then<'i, I: 'i, O: 'i>( + mut parser: impl FnMut(&'i str) -> Option, + mut f: impl FnMut(I) -> Option, +) -> impl FnMut(&'i str) -> Option { + move |input| parser(input).and_then(&mut f) +} + +/// Checks whether `parser` is successful while not advancing the original +/// `input`. +fn lookahead( + mut parser: impl FnMut(&str) -> Option<&str>, +) -> impl FnMut(&str) -> Option> { + move |input| map(&mut parser, |_| input)(input) +} + +/// Makes underlying `parser` optional by returning the original `input` and +/// [`None`] in case it returned [`None`]. +fn optional_result<'i, T: 'i>( + mut parser: impl FnMut(&'i str) -> Option<(&'i str, T)>, +) -> impl FnMut(&'i str) -> (LeftToParse<'i>, Option) { + move |input: &str| { + map_or_else(&mut parser, |i| (i, None), |(i, c)| (i, Some(c)))(input) + } +} + +/// Parses while `parser` is successful. Never fails. +fn take_while0( + mut parser: impl FnMut(&str) -> Option<&str>, +) -> impl FnMut(&str) -> (LeftToParse<'_>, &str) { + move |input| { + let mut cur = input; + while let Some(step) = parser(cur) { + cur = step; + } + ( + cur, + &input[..(input.as_bytes().len() - cur.as_bytes().len())], + ) + } +} + +/// Parses while `parser` is successful. Returns [`None`] in case `parser` never +/// succeeded. +fn take_while1( + mut parser: impl FnMut(&str) -> Option<&str>, +) -> impl FnMut(&str) -> Option<(LeftToParse<'_>, &str)> { + move |input| { + let mut cur = parser(input)?; + while let Some(step) = parser(cur) { + cur = step; + } + Some(( + cur, + &input[..(input.as_bytes().len() - cur.as_bytes().len())], + )) + } +} + +/// Parses with `basic` while `until` returns [`None`]. Returns [`None`] in case +/// `until` succeeded initially or `basic` never succeeded. Doesn't consume +/// [`char`]s parsed by `until`. +/// +/// [`char`]: fn@char +fn take_until1( + mut basic: impl FnMut(&str) -> Option<&str>, + mut until: impl FnMut(&str) -> Option<&str>, +) -> impl FnMut(&str) -> Option<(LeftToParse<'_>, &str)> { + move |input: &str| { + if until(input).is_some() { + return None; + } + let mut cur = basic(input)?; + loop { + if until(cur).is_some() { + break; + } + let Some(b) = basic(cur) else { + break; + }; + cur = b; + } + + Some(( + cur, + &input[..(input.as_bytes().len() - cur.as_bytes().len())], + )) + } +} + +/// Checks whether `input` starts with `s`. +fn str(s: &str) -> impl FnMut(&str) -> Option> + '_ { + move |input| input.starts_with(s).then(|| &input[s.as_bytes().len()..]) +} + +/// Checks whether `input` starts with `c`. +fn char(c: char) -> impl FnMut(&str) -> Option> { + move |input| input.starts_with(c).then(|| &input[c.len_utf8()..]) +} + +/// Checks whether first [`char`] suits `check`. +/// +/// [`char`]: fn@char +fn check_char( + mut check: impl FnMut(char) -> bool, +) -> impl FnMut(&str) -> Option> { + move |input| { + input + .chars() + .next() + .and_then(|c| check(c).then(|| &input[c.len_utf8()..])) + } +} + +/// Checks whether first [`char`] of input is present in `chars`. +/// +/// [`char`]: fn@char +fn one_of(chars: &str) -> impl FnMut(&str) -> Option> + '_ { + move |input: &str| chars.chars().find_map(|c| char(c)(input)) +} + +/// Parses any [`char`]. +/// +/// [`char`]: fn@char +fn any_char(input: &str) -> Option> { + input.chars().next().map(|c| &input[c.len_utf8()..]) +} + +/// Parses any [`char`] and returns it. +/// +/// [`char`]: fn@char +fn take_any_char(input: &str) -> Option<(LeftToParse<'_>, char)> { + input.chars().next().map(|c| (&input[c.len_utf8()..], c)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn text() { + assert_eq!(format_string(""), Some(FormatString { formats: vec![] })); + assert_eq!( + format_string("test"), + Some(FormatString { formats: vec![] }), + ); + assert_eq!( + format_string("Минск"), + Some(FormatString { formats: vec![] }), + ); + assert_eq!(format_string("🦀"), Some(FormatString { formats: vec![] })); + } + + #[test] + fn argument() { + assert_eq!( + format_string("{}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: None, + }], + }), + ); + assert_eq!( + format_string("{0}"), + Some(FormatString { + formats: vec![Format { + arg: Some(Argument::Integer(0)), + spec: None, + }], + }), + ); + assert_eq!( + format_string("{par}"), + Some(FormatString { + formats: vec![Format { + arg: Some(Argument::Identifier("par")), + spec: None, + }], + }), + ); + assert_eq!( + format_string("{Минск}"), + Some(FormatString { + formats: vec![Format { + arg: Some(Argument::Identifier("Минск")), + spec: None, + }], + }), + ); + } + + #[test] + fn spec() { + assert_eq!( + format_string("{:}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:^}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((None, Align::Center)), + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:-<}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('-'), Align::Left)), + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{: <}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some(' '), Align::Left)), + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:^<}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('^'), Align::Left)), + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:+}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: Some(Sign::Plus), + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:^<-}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('^'), Align::Left)), + sign: Some(Sign::Minus), + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:#}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: Some(Alternate), + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:+#}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: Some(Sign::Plus), + alternate: Some(Alternate), + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:-<#}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('-'), Align::Left)), + sign: None, + alternate: Some(Alternate), + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:^<-#}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('^'), Align::Left)), + sign: Some(Sign::Minus), + alternate: Some(Alternate), + zero_padding: None, + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:0}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: Some(ZeroPadding), + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:#0}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: Some(Alternate), + zero_padding: Some(ZeroPadding), + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:-0}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: Some(Sign::Minus), + alternate: None, + zero_padding: Some(ZeroPadding), + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:^<0}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('^'), Align::Left)), + sign: None, + alternate: None, + zero_padding: Some(ZeroPadding), + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:^<+#0}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('^'), Align::Left)), + sign: Some(Sign::Plus), + alternate: Some(Alternate), + zero_padding: Some(ZeroPadding), + width: None, + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:1}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: Some(Count::Integer(1)), + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:1$}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: Some(Count::Parameter(Argument::Integer(1))), + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:par$}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: Some(Count::Parameter(Argument::Identifier("par"))), + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:-^-#0Минск$}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some('-'), Align::Center)), + sign: Some(Sign::Minus), + alternate: Some(Alternate), + zero_padding: Some(ZeroPadding), + width: Some(Count::Parameter(Argument::Identifier("Минск"))), + precision: None, + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:.*}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: Some(Precision::Star), + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:.0}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: Some(Precision::Count(Count::Integer(0))), + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:.0$}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: Some(Precision::Count(Count::Parameter( + Argument::Integer(0), + ))), + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:.par$}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: Some(Precision::Count(Count::Parameter( + Argument::Identifier("par"), + ))), + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{: >+#2$.par$}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some(' '), Align::Right)), + sign: Some(Sign::Plus), + alternate: Some(Alternate), + zero_padding: None, + width: Some(Count::Parameter(Argument::Integer(2))), + precision: Some(Precision::Count(Count::Parameter( + Argument::Identifier("par"), + ))), + ty: Type::Display, + }), + }], + }), + ); + assert_eq!( + format_string("{:x?}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::LowerDebug, + }), + }], + }), + ); + assert_eq!( + format_string("{:E}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: None, + zero_padding: None, + width: None, + precision: None, + ty: Type::UpperExp, + }), + }], + }), + ); + assert_eq!( + format_string("{: >+#par$.par$X?}"), + Some(FormatString { + formats: vec![Format { + arg: None, + spec: Some(FormatSpec { + align: Some((Some(' '), Align::Right)), + sign: Some(Sign::Plus), + alternate: Some(Alternate), + zero_padding: None, + width: Some(Count::Parameter(Argument::Identifier("par"))), + precision: Some(Precision::Count(Count::Parameter( + Argument::Identifier("par"), + ))), + ty: Type::UpperDebug, + }), + }], + }), + ); + } + + #[test] + fn full() { + assert_eq!( + format_string("prefix{{{0:#?}postfix{par:-^par$.a$}}}"), + Some(FormatString { + formats: vec![ + Format { + arg: Some(Argument::Integer(0)), + spec: Some(FormatSpec { + align: None, + sign: None, + alternate: Some(Alternate), + zero_padding: None, + width: None, + precision: None, + ty: Type::Debug, + }), + }, + Format { + arg: Some(Argument::Identifier("par")), + spec: Some(FormatSpec { + align: Some((Some('-'), Align::Center)), + sign: None, + alternate: None, + zero_padding: None, + width: Some(Count::Parameter(Argument::Identifier("par"))), + precision: Some(Precision::Count(Count::Parameter( + Argument::Identifier("a"), + ))), + ty: Type::Display, + }), + }, + ], + }), + ); + } + + #[test] + fn error() { + assert_eq!(format_string("{"), None); + assert_eq!(format_string("}"), None); + assert_eq!(format_string("{{}"), None); + assert_eq!(format_string("{:x?"), None); + assert_eq!(format_string("{:.}"), None); + assert_eq!(format_string("{:q}"), None); + assert_eq!(format_string("{:par}"), None); + assert_eq!(format_string("{⚙️}"), None); + } +} diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs index c211937c2..58e813ead 100644 --- a/tokenizers/display_derive/src/vendored.rs +++ b/tokenizers/display_derive/src/vendored.rs @@ -1,4 +1,14 @@ - +use syn::LitStr; +use proc_macro2::TokenStream; +use crate::parsing; +use quote::{format_ident, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + spanned::Spanned as _, + token, + Expr, +}; /// Representation of a [`fmt`]-like attribute. /// /// ```rust,ignore @@ -108,87 +118,6 @@ impl FmtAttribute { Some((expr, format_ident!("{trait_name}"))) } - - /// Returns an [`Iterator`] over bounded [`syn::Type`]s (and correspondent trait names) by this - /// [`FmtAttribute`]. - fn bounded_types<'a>( - &'a self, - fields: &'a syn::Fields, - ) -> impl Iterator { - let placeholders = Placeholder::parse_fmt_string(&self.lit.value()); - - // We ignore unknown fields, as compiler will produce better error messages. - placeholders.into_iter().filter_map(move |placeholder| { - let name = match placeholder.arg { - Parameter::Named(name) => self - .args - .iter() - .find_map(|a| (a.alias()? == &name).then_some(&a.expr)) - .map_or(Some(name), |expr| expr.ident().map(ToString::to_string))?, - Parameter::Positional(i) => self - .args - .iter() - .nth(i) - .and_then(|a| a.expr.ident().filter(|_| a.alias.is_none()))? - .to_string(), - }; - - let unnamed = name.strip_prefix('_').and_then(|s| s.parse().ok()); - let ty = match (&fields, unnamed) { - (syn::Fields::Unnamed(f), Some(i)) => { - f.unnamed.iter().nth(i).map(|f| &f.ty) - } - (syn::Fields::Named(f), None) => f.named.iter().find_map(|f| { - f.ident.as_ref().filter(|s| **s == name).map(|_| &f.ty) - }), - _ => None, - }?; - - Some((ty, placeholder.trait_name)) - }) - } - - /// Errors in case legacy syntax is encountered: `fmt = "...", (arg),*`. - fn check_legacy_fmt(input: ParseStream<'_>) -> syn::Result<()> { - let fork = input.fork(); - - let path = fork - .parse::() - .and_then(|path| fork.parse::().map(|_| path)); - match path { - Ok(path) if path.is_ident("fmt") => (|| { - let args = fork - .parse_terminated( - >::parse, - token::Comma, - ) - .ok()? - .into_iter() - .enumerate() - .filter_map(|(i, arg)| match arg { - Either::Left(syn::Lit::Str(str)) => Some(if i == 0 { - format!("\"{}\"", str.value()) - } else { - str.value() - }), - Either::Right(ident) => Some(ident.to_string()), - _ => None, - }) - .collect::>(); - (!args.is_empty()).then_some(args) - })() - .map_or(Ok(()), |fmt| { - Err(syn::Error::new( - input.span(), - format!( - "legacy syntax, remove `fmt =` and use `{}` instead", - fmt.join(", "), - ), - )) - }), - Ok(_) | Err(_) => Ok(()), - } - } } /// Representation of a [named parameter][1] (`identifier '=' expression`) in @@ -236,171 +165,3 @@ impl ToTokens for FmtArgument { } } -/// Representation of a [parameter][1] used in a [`Placeholder`]. -/// -/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#formatting-parameters -#[derive(Debug, Eq, PartialEq)] -enum Parameter { - /// [Positional parameter][1]. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#positional-parameters - Positional(usize), - - /// [Named parameter][1]. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters - Named(String), -} - -impl<'a> From> for Parameter { - fn from(arg: parsing::Argument<'a>) -> Self { - match arg { - parsing::Argument::Integer(i) => Self::Positional(i), - parsing::Argument::Identifier(i) => Self::Named(i.to_owned()), - } - } -} - -/// Representation of a formatting placeholder. -#[derive(Debug, Eq, PartialEq)] -struct Placeholder { - /// Formatting argument (either named or positional) to be used by this placeholder. - arg: Parameter, - - /// [Width parameter][1], if present. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#width - width: Option, - - /// [Precision parameter][1], if present. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#precision - precision: Option, - - /// Name of [`std::fmt`] trait to be used for rendering this placeholder. - trait_name: &'static str, -} - -impl Placeholder { - /// Parses [`Placeholder`]s from the provided formatting string. - fn parse_fmt_string(s: &str) -> Vec { - let mut n = 0; - parsing::format_string(s) - .into_iter() - .flat_map(|f| f.formats) - .map(|format| { - let (maybe_arg, ty) = ( - format.arg, - format.spec.map(|s| s.ty).unwrap_or(parsing::Type::Display), - ); - let position = maybe_arg.map(Into::into).unwrap_or_else(|| { - // Assign "the next argument". - // https://doc.rust-lang.org/stable/std/fmt/index.html#positional-parameters - n += 1; - Parameter::Positional(n - 1) - }); - - Self { - arg: position, - width: format.spec.and_then(|s| match s.width { - Some(parsing::Count::Parameter(arg)) => Some(arg.into()), - _ => None, - }), - precision: format.spec.and_then(|s| match s.precision { - Some(parsing::Precision::Count(parsing::Count::Parameter( - arg, - ))) => Some(arg.into()), - _ => None, - }), - trait_name: ty.trait_name(), - } - }) - .collect() - } -} - -/// Representation of a [`fmt::Display`]-like derive macro attributes placed on a container (struct -/// or enum variant). -/// -/// ```rust,ignore -/// #[("", )] -/// #[(bound())] -/// ``` -/// -/// `#[(...)]` can be specified only once, while multiple `#[(bound(...))]` -/// are allowed. -/// -/// [`fmt::Display`]: std::fmt::Display -#[derive(Debug, Default)] -struct ContainerAttributes { - /// Interpolation [`FmtAttribute`]. - fmt: Option, - - /// Addition trait bounds. - bounds: BoundsAttribute, -} - -impl Parse for ContainerAttributes { - fn parse(input: ParseStream<'_>) -> syn::Result { - // We do check `FmtAttribute::check_legacy_fmt` eagerly here, because `Either` will swallow - // any error of the `Either::Left` if the `Either::Right` succeeds. - FmtAttribute::check_legacy_fmt(input)?; - >::parse(input).map(|v| match v { - Either::Left(fmt) => Self { - bounds: BoundsAttribute::default(), - fmt: Some(fmt), - }, - Either::Right(bounds) => Self { bounds, fmt: None }, - }) - } -} - -impl attr::ParseMultiple for ContainerAttributes { - fn merge_attrs( - prev: Spanning, - new: Spanning, - name: &syn::Ident, - ) -> syn::Result> { - let Spanning { - span: prev_span, - item: mut prev, - } = prev; - let Spanning { - span: new_span, - item: new, - } = new; - - if new.fmt.and_then(|n| prev.fmt.replace(n)).is_some() { - return Err(syn::Error::new( - new_span, - format!("multiple `#[{name}(\"...\", ...)]` attributes aren't allowed"), - )); - } - prev.bounds.0.extend(new.bounds.0); - - Ok(Spanning::new( - prev, - prev_span.join(new_span).unwrap_or(prev_span), - )) - } -} - -/// Matches the provided `trait_name` to appropriate [`FmtAttribute`]'s argument name. -fn trait_name_to_attribute_name(trait_name: T) -> &'static str -where - T: for<'a> PartialEq<&'a str>, -{ - match () { - _ if trait_name == "Binary" => "binary", - _ if trait_name == "Debug" => "debug", - _ if trait_name == "Display" => "display", - _ if trait_name == "LowerExp" => "lower_exp", - _ if trait_name == "LowerHex" => "lower_hex", - _ if trait_name == "Octal" => "octal", - _ if trait_name == "Pointer" => "pointer", - _ if trait_name == "UpperExp" => "upper_exp", - _ if trait_name == "UpperHex" => "upper_hex", - _ => unimplemented!(), - } -} - diff --git a/tokenizers/vendored.rs b/tokenizers/vendored.rs deleted file mode 100644 index 8f181e3f6..000000000 --- a/tokenizers/vendored.rs +++ /dev/null @@ -1,170 +0,0 @@ -use proc_macro2::TokenStream; -use quote::{format_ident, ToTokens}; -use syn::{ - parse::{Parse, ParseStream}, - punctuated::Punctuated, - spanned::Spanned as _, - token, -}; - -use crate::{ - parsing::Expr, - utils::{attr, Either, Spanning}, -}; - -/// Representation of a [`fmt`]-like attribute. -/// -/// ```rust,ignore -/// #[("", )] -/// ``` -/// -/// [`fmt`]: std::fmt -#[derive(Debug)] -pub struct FmtAttribute { - /// Interpolation [`syn::LitStr`]. - /// - /// [`syn::LitStr`]: struct@syn::LitStr - lit: syn::LitStr, - - /// Optional [`token::Comma`]. - /// - /// [`token::Comma`]: struct@token::Comma - comma: Option, - - /// Interpolation arguments. - args: Punctuated, -} - -impl Parse for FmtAttribute { - fn parse(input: ParseStream<'_>) -> syn::Result { - Self::check_legacy_fmt(input)?; - - Ok(Self { - lit: input.parse()?, - comma: input - .peek(token::Comma) - .then(|| input.parse()) - .transpose()?, - args: input.parse_terminated(FmtArgument::parse, token::Comma)?, - }) - } -} - -impl attr::ParseMultiple for FmtAttribute {} - -impl ToTokens for FmtAttribute { - fn to_tokens(&self, tokens: &mut TokenStream) { - self.lit.to_tokens(tokens); - self.comma.to_tokens(tokens); - self.args.to_tokens(tokens); - } -} - -impl FmtAttribute { - /// Checks whether this [`FmtAttribute`] can be replaced with a transparent delegation (calling - /// a formatting trait directly instead of interpolation syntax). - /// - /// If such transparent call is possible, the returns an [`Ident`] of the delegated trait and - /// the [`Expr`] to pass into the call, otherwise [`None`]. - /// - /// [`Ident`]: struct@syn::Ident - fn transparent_call(&self) -> Option<(Expr, syn::Ident)> { - // `FmtAttribute` is transparent when: - - // (1) There is exactly one formatting parameter. - let lit = self.lit.value(); - let param = - parsing::format(&lit).and_then(|(more, p)| more.is_empty().then_some(p))?; - - // (2) And the formatting parameter doesn't contain any modifiers. - if param - .spec - .map(|s| { - s.align.is_some() - || s.sign.is_some() - || s.alternate.is_some() - || s.zero_padding.is_some() - || s.width.is_some() - || s.precision.is_some() - || !s.ty.is_trivial() - }) - .unwrap_or_default() - { - return None; - } - - let expr = match param.arg { - // (3) And either exactly one positional argument is specified. - Some(parsing::Argument::Integer(_)) | None => (self.args.len() == 1) - .then(|| self.args.first()) - .flatten() - .map(|a| a.expr.clone()), - - // (4) Or the formatting parameter's name refers to some outer binding. - Some(parsing::Argument::Identifier(name)) if self.args.is_empty() => { - Some(format_ident!("{name}").into()) - } - - // (5) Or exactly one named argument is specified for the formatting parameter's name. - Some(parsing::Argument::Identifier(name)) => (self.args.len() == 1) - .then(|| self.args.first()) - .flatten() - .filter(|a| a.alias.as_ref().map(|a| a.0 == name).unwrap_or_default()) - .map(|a| a.expr.clone()), - }?; - - let trait_name = param - .spec - .map(|s| s.ty) - .unwrap_or(parsing::Type::Display) - .trait_name(); - - Some((expr, format_ident!("{trait_name}"))) - } -} - -/// Representation of a [named parameter][1] (`identifier '=' expression`) in -/// in a [`FmtAttribute`]. -/// -/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters -#[derive(Debug)] -struct FmtArgument { - /// `identifier =` [`Ident`]. - /// - /// [`Ident`]: struct@syn::Ident - alias: Option<(syn::Ident, token::Eq)>, - - /// `expression` [`Expr`]. - expr: Expr, -} - -impl FmtArgument { - /// Returns an `identifier` of the [named parameter][1]. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters - fn alias(&self) -> Option<&syn::Ident> { - self.alias.as_ref().map(|(ident, _)| ident) - } -} - -impl Parse for FmtArgument { - fn parse(input: ParseStream) -> syn::Result { - Ok(Self { - alias: (input.peek(syn::Ident) && input.peek2(token::Eq)) - .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) - .transpose()?, - expr: input.parse()?, - }) - } -} - -impl ToTokens for FmtArgument { - fn to_tokens(&self, tokens: &mut TokenStream) { - if let Some((ident, eq)) = &self.alias { - ident.to_tokens(tokens); - eq.to_tokens(tokens); - } - self.expr.to_tokens(tokens); - } -} - From 4c3f37a93845e4e89fc14c2169ef68e8a8b0fb80 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 10:29:16 +0200 Subject: [PATCH 58/94] update display derive --- tokenizers/display_derive/Cargo.toml | 1 + tokenizers/display_derive/src/lib.rs | 25 +++++++++++++---- tokenizers/display_derive/src/parsing.rs | 34 +++++++++-------------- tokenizers/display_derive/src/vendored.rs | 15 ++++------ 4 files changed, 40 insertions(+), 35 deletions(-) diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml index 7d7697910..289d54dd1 100644 --- a/tokenizers/display_derive/Cargo.toml +++ b/tokenizers/display_derive/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" syn = "1.0" quote = "1.0" proc-macro2 = "1.0" +unicode-xid = "0.2.4" [lib] proc-macro = true diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 633640659..c4b8d5a5f 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,12 +1,13 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::{format_ident,quote}; -use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, MetaList, NestedMeta}; +use syn::{parse_macro_input, DeriveInput}; mod vendored; +mod parsing; use vendored::FmtAttribute; #[proc_macro_derive(Display)] -pub fn display_derive(input: TokenStream) -> syn::Result { +pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); @@ -89,12 +90,26 @@ fn expand_enum( } let match_arms = e.variants.iter().try_fold( - TokenStream::new, |variant| { - let attrs = FmtAttribute::parse_attrs(&variant.attrs, attr_name)? + (Vec::new(), TokenStream::new()), + |mut arms, variant| { + let attrs = ContainerAttributes::parse_attrs(&variant.attrs, attr_name)? .map(Spanning::into_inner) .unwrap_or_default(); let ident = &variant.ident; + if attrs.fmt.is_none() + && variant.fields.is_empty() + && attr_name != "display" + { + return Err(syn::Error::new( + e.variants.span(), + format!( + "implicit formatting of unit enum variant is supported only for `Display` \ + macro, use `#[{attr_name}(\"...\")]` to explicitly specify the formatting", + ), + )); + } + let v = Expansion { attrs: &attrs, fields: &variant.fields, @@ -121,7 +136,7 @@ fn expand_enum( Ok::<_, syn::Error>(arms) }, - )?; + )?; let body = match_arms .is_empty() diff --git a/tokenizers/display_derive/src/parsing.rs b/tokenizers/display_derive/src/parsing.rs index 8a5d75277..a74c3bc51 100644 --- a/tokenizers/display_derive/src/parsing.rs +++ b/tokenizers/display_derive/src/parsing.rs @@ -194,9 +194,7 @@ pub(crate) fn format_string(input: &str) -> Option> { let formats = iter::repeat(()) .scan(&mut input, |input, _| { let (curr, format) = - alt(&mut [&mut maybe_format, &mut map(text, |(i, _)| (i, None))])( - input, - )?; + alt(&mut [&mut maybe_format, &mut map(text, |(i, _)| (i, None))])(input)?; **input = curr; Some(format) }) @@ -600,9 +598,7 @@ fn lookahead( fn optional_result<'i, T: 'i>( mut parser: impl FnMut(&'i str) -> Option<(&'i str, T)>, ) -> impl FnMut(&'i str) -> (LeftToParse<'i>, Option) { - move |input: &str| { - map_or_else(&mut parser, |i| (i, None), |(i, c)| (i, Some(c)))(input) - } + move |input: &str| map_or_else(&mut parser, |i| (i, None), |(i, c)| (i, Some(c)))(input) } /// Parses while `parser` is successful. Never fails. @@ -682,9 +678,7 @@ fn char(c: char) -> impl FnMut(&str) -> Option> { /// Checks whether first [`char`] suits `check`. /// /// [`char`]: fn@char -fn check_char( - mut check: impl FnMut(char) -> bool, -) -> impl FnMut(&str) -> Option> { +fn check_char(mut check: impl FnMut(char) -> bool) -> impl FnMut(&str) -> Option> { move |input| { input .chars() @@ -1159,9 +1153,7 @@ mod tests { alternate: None, zero_padding: None, width: None, - precision: Some(Precision::Count(Count::Parameter( - Argument::Integer(0), - ))), + precision: Some(Precision::Count(Count::Parameter(Argument::Integer(0),))), ty: Type::Display, }), }], @@ -1178,9 +1170,9 @@ mod tests { alternate: None, zero_padding: None, width: None, - precision: Some(Precision::Count(Count::Parameter( - Argument::Identifier("par"), - ))), + precision: Some(Precision::Count(Count::Parameter(Argument::Identifier( + "par" + ),))), ty: Type::Display, }), }], @@ -1197,9 +1189,9 @@ mod tests { alternate: Some(Alternate), zero_padding: None, width: Some(Count::Parameter(Argument::Integer(2))), - precision: Some(Precision::Count(Count::Parameter( - Argument::Identifier("par"), - ))), + precision: Some(Precision::Count(Count::Parameter(Argument::Identifier( + "par" + ),))), ty: Type::Display, }), }], @@ -1250,9 +1242,9 @@ mod tests { alternate: Some(Alternate), zero_padding: None, width: Some(Count::Parameter(Argument::Identifier("par"))), - precision: Some(Precision::Count(Count::Parameter( - Argument::Identifier("par"), - ))), + precision: Some(Precision::Count(Count::Parameter(Argument::Identifier( + "par" + ),))), ty: Type::UpperDebug, }), }], diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs index 58e813ead..9799cbf81 100644 --- a/tokenizers/display_derive/src/vendored.rs +++ b/tokenizers/display_derive/src/vendored.rs @@ -1,14 +1,13 @@ -use syn::LitStr; -use proc_macro2::TokenStream; use crate::parsing; +use proc_macro2::TokenStream; use quote::{format_ident, ToTokens}; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, - spanned::Spanned as _, - token, + token, Expr, }; + /// Representation of a [`fmt`]-like attribute. /// /// ```rust,ignore @@ -17,7 +16,7 @@ use syn::{ /// /// [`fmt`]: std::fmt #[derive(Debug)] -struct FmtAttribute { +pub struct FmtAttribute { /// Interpolation [`syn::LitStr`]. /// /// [`syn::LitStr`]: struct@syn::LitStr @@ -42,7 +41,7 @@ impl Parse for FmtAttribute { .peek(token::Comma) .then(|| input.parse()) .transpose()?, - args: input.parse_terminated(FmtArgument::parse, token::Comma)?, + args: input.parse_terminated(FmtArgument::parse)?, }) } } @@ -70,8 +69,7 @@ impl FmtAttribute { // (1) There is exactly one formatting parameter. let lit = self.lit.value(); - let param = - parsing::format(&lit).and_then(|(more, p)| more.is_empty().then_some(p))?; + let param = parsing::format(&lit).and_then(|(more, p)| more.is_empty().then_some(p))?; // (2) And the formatting parameter doesn't contain any modifiers. if param @@ -164,4 +162,3 @@ impl ToTokens for FmtArgument { self.expr.to_tokens(tokens); } } - From 292475fc6b45739d468285f727e557514efa26e4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 10:47:51 +0200 Subject: [PATCH 59/94] blindly fix stuff --- tokenizers/display_derive/Cargo.toml | 1 - tokenizers/display_derive/src/lib.rs | 10 ++++------ tokenizers/display_derive/src/parsing.rs | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml index 289d54dd1..7d7697910 100644 --- a/tokenizers/display_derive/Cargo.toml +++ b/tokenizers/display_derive/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" syn = "1.0" quote = "1.0" proc-macro2 = "1.0" -unicode-xid = "0.2.4" [lib] proc-macro = true diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index c4b8d5a5f..4850588b1 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -10,10 +10,9 @@ use vendored::FmtAttribute; pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); - + return ; let attr_name = "display"; - let attrs = FmtAttributes::parse_attrs(&input.attrs, &attr_name)? - .map(Spanning::into_inner) + let attrs = FmtAttribute::parse_attrs(&input.attrs, &attr_name)? .unwrap_or_default(); let trait_ident = format_ident!("display"); let ident = &input.ident; @@ -92,8 +91,7 @@ fn expand_enum( let match_arms = e.variants.iter().try_fold( (Vec::new(), TokenStream::new()), |mut arms, variant| { - let attrs = ContainerAttributes::parse_attrs(&variant.attrs, attr_name)? - .map(Spanning::into_inner) + let attrs = FmtAttribute::parse_attrs(&variant.attrs, attr_name)? .unwrap_or_default(); let ident = &variant.ident; @@ -214,7 +212,7 @@ impl<'a> Expansion<'a> { format!( "TODO ARTHUR! struct or enum variant with more than 1 field must have \ `#[{}(\"...\", ...)]` attribute", - trait_name_to_attribute_name(self.trait_ident), + self.trait_ident, ), )), } diff --git a/tokenizers/display_derive/src/parsing.rs b/tokenizers/display_derive/src/parsing.rs index a74c3bc51..920037f65 100644 --- a/tokenizers/display_derive/src/parsing.rs +++ b/tokenizers/display_derive/src/parsing.rs @@ -4,7 +4,7 @@ use std::iter; -use unicode_xid::UnicodeXID as XID; +// use unicode_xid::UnicodeXID as XID; /// Output of the [`format_string`] parser. #[derive(Clone, Debug, Eq, PartialEq)] From 99cb0547bc967afec6c946edbfc2ac2efc17a522 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 10:51:00 +0200 Subject: [PATCH 60/94] maybe work --- tokenizers/display_derive/Cargo.toml | 1 + tokenizers/display_derive/src/parsing.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml index 7d7697910..289d54dd1 100644 --- a/tokenizers/display_derive/Cargo.toml +++ b/tokenizers/display_derive/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" syn = "1.0" quote = "1.0" proc-macro2 = "1.0" +unicode-xid = "0.2.4" [lib] proc-macro = true diff --git a/tokenizers/display_derive/src/parsing.rs b/tokenizers/display_derive/src/parsing.rs index 920037f65..a74c3bc51 100644 --- a/tokenizers/display_derive/src/parsing.rs +++ b/tokenizers/display_derive/src/parsing.rs @@ -4,7 +4,7 @@ use std::iter; -// use unicode_xid::UnicodeXID as XID; +use unicode_xid::UnicodeXID as XID; /// Output of the [`format_string`] parser. #[derive(Clone, Debug, Eq, PartialEq)] From 5c930e9a8c52815c6bbd01f58e4548ea9f983451 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 10:54:06 +0200 Subject: [PATCH 61/94] remove tests from vendored parsing --- tokenizers/display_derive/src/parsing.rs | 594 ----------------------- 1 file changed, 594 deletions(-) diff --git a/tokenizers/display_derive/src/parsing.rs b/tokenizers/display_derive/src/parsing.rs index a74c3bc51..f3fbca198 100644 --- a/tokenizers/display_derive/src/parsing.rs +++ b/tokenizers/display_derive/src/parsing.rs @@ -707,597 +707,3 @@ fn any_char(input: &str) -> Option> { fn take_any_char(input: &str) -> Option<(LeftToParse<'_>, char)> { input.chars().next().map(|c| (&input[c.len_utf8()..], c)) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn text() { - assert_eq!(format_string(""), Some(FormatString { formats: vec![] })); - assert_eq!( - format_string("test"), - Some(FormatString { formats: vec![] }), - ); - assert_eq!( - format_string("Минск"), - Some(FormatString { formats: vec![] }), - ); - assert_eq!(format_string("🦀"), Some(FormatString { formats: vec![] })); - } - - #[test] - fn argument() { - assert_eq!( - format_string("{}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: None, - }], - }), - ); - assert_eq!( - format_string("{0}"), - Some(FormatString { - formats: vec![Format { - arg: Some(Argument::Integer(0)), - spec: None, - }], - }), - ); - assert_eq!( - format_string("{par}"), - Some(FormatString { - formats: vec![Format { - arg: Some(Argument::Identifier("par")), - spec: None, - }], - }), - ); - assert_eq!( - format_string("{Минск}"), - Some(FormatString { - formats: vec![Format { - arg: Some(Argument::Identifier("Минск")), - spec: None, - }], - }), - ); - } - - #[test] - fn spec() { - assert_eq!( - format_string("{:}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:^}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((None, Align::Center)), - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:-<}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('-'), Align::Left)), - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{: <}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some(' '), Align::Left)), - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:^<}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('^'), Align::Left)), - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:+}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: Some(Sign::Plus), - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:^<-}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('^'), Align::Left)), - sign: Some(Sign::Minus), - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:#}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: Some(Alternate), - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:+#}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: Some(Sign::Plus), - alternate: Some(Alternate), - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:-<#}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('-'), Align::Left)), - sign: None, - alternate: Some(Alternate), - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:^<-#}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('^'), Align::Left)), - sign: Some(Sign::Minus), - alternate: Some(Alternate), - zero_padding: None, - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:0}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: Some(ZeroPadding), - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:#0}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: Some(Alternate), - zero_padding: Some(ZeroPadding), - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:-0}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: Some(Sign::Minus), - alternate: None, - zero_padding: Some(ZeroPadding), - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:^<0}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('^'), Align::Left)), - sign: None, - alternate: None, - zero_padding: Some(ZeroPadding), - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:^<+#0}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('^'), Align::Left)), - sign: Some(Sign::Plus), - alternate: Some(Alternate), - zero_padding: Some(ZeroPadding), - width: None, - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:1}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: Some(Count::Integer(1)), - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:1$}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: Some(Count::Parameter(Argument::Integer(1))), - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:par$}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: Some(Count::Parameter(Argument::Identifier("par"))), - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:-^-#0Минск$}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some('-'), Align::Center)), - sign: Some(Sign::Minus), - alternate: Some(Alternate), - zero_padding: Some(ZeroPadding), - width: Some(Count::Parameter(Argument::Identifier("Минск"))), - precision: None, - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:.*}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: Some(Precision::Star), - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:.0}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: Some(Precision::Count(Count::Integer(0))), - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:.0$}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: Some(Precision::Count(Count::Parameter(Argument::Integer(0),))), - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:.par$}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: Some(Precision::Count(Count::Parameter(Argument::Identifier( - "par" - ),))), - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{: >+#2$.par$}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some(' '), Align::Right)), - sign: Some(Sign::Plus), - alternate: Some(Alternate), - zero_padding: None, - width: Some(Count::Parameter(Argument::Integer(2))), - precision: Some(Precision::Count(Count::Parameter(Argument::Identifier( - "par" - ),))), - ty: Type::Display, - }), - }], - }), - ); - assert_eq!( - format_string("{:x?}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::LowerDebug, - }), - }], - }), - ); - assert_eq!( - format_string("{:E}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: None, - zero_padding: None, - width: None, - precision: None, - ty: Type::UpperExp, - }), - }], - }), - ); - assert_eq!( - format_string("{: >+#par$.par$X?}"), - Some(FormatString { - formats: vec![Format { - arg: None, - spec: Some(FormatSpec { - align: Some((Some(' '), Align::Right)), - sign: Some(Sign::Plus), - alternate: Some(Alternate), - zero_padding: None, - width: Some(Count::Parameter(Argument::Identifier("par"))), - precision: Some(Precision::Count(Count::Parameter(Argument::Identifier( - "par" - ),))), - ty: Type::UpperDebug, - }), - }], - }), - ); - } - - #[test] - fn full() { - assert_eq!( - format_string("prefix{{{0:#?}postfix{par:-^par$.a$}}}"), - Some(FormatString { - formats: vec![ - Format { - arg: Some(Argument::Integer(0)), - spec: Some(FormatSpec { - align: None, - sign: None, - alternate: Some(Alternate), - zero_padding: None, - width: None, - precision: None, - ty: Type::Debug, - }), - }, - Format { - arg: Some(Argument::Identifier("par")), - spec: Some(FormatSpec { - align: Some((Some('-'), Align::Center)), - sign: None, - alternate: None, - zero_padding: None, - width: Some(Count::Parameter(Argument::Identifier("par"))), - precision: Some(Precision::Count(Count::Parameter( - Argument::Identifier("a"), - ))), - ty: Type::Display, - }), - }, - ], - }), - ); - } - - #[test] - fn error() { - assert_eq!(format_string("{"), None); - assert_eq!(format_string("}"), None); - assert_eq!(format_string("{{}"), None); - assert_eq!(format_string("{:x?"), None); - assert_eq!(format_string("{:.}"), None); - assert_eq!(format_string("{:q}"), None); - assert_eq!(format_string("{:par}"), None); - assert_eq!(format_string("{⚙️}"), None); - } -} From f87bb97e0e94c9fb5548cdb4c597f23a58dbd25d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 11:53:32 +0200 Subject: [PATCH 62/94] update --- tokenizers/display_derive/src/lib.rs | 25 ++++++++++++----------- tokenizers/display_derive/src/vendored.rs | 4 ---- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 4850588b1..20aaceb08 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -8,29 +8,30 @@ use vendored::FmtAttribute; #[proc_macro_derive(Display)] pub fn display_derive(input: TokenStream) -> TokenStream { - // Parse the input tokens into a syntax tree - let input = parse_macro_input!(input as DeriveInput); - return ; + // Parse the parsed_input tokens into a syntax tree + let parsed_input = parse_macro_input!(input as DeriveInput); let attr_name = "display"; - let attrs = FmtAttribute::parse_attrs(&input.attrs, &attr_name)? - .unwrap_or_default(); + let attrs = syn::parse::(input).unwrap(); let trait_ident = format_ident!("display"); - let ident = &input.ident; + let ident = &parsed_input.ident; - let ctx = (&attrs, ident, &trait_ident, &attr_name); - let body = match &input.data { + let ctx = (&attrs, ident, &trait_ident, &trait_ident); + let body = match &parsed_input.data { syn::Data::Struct(s) => expand_struct(s, ctx), syn::Data::Enum(e) => expand_enum(e, ctx), - syn::Data::Union(u) => return Err(syn::Error::new(u, format!("Union is not supported"))), - }?; + syn::Data::Union(u) => { + let error = syn::Error::new_spanned(u.union_token, "Unions are not supported"); + return proc_macro::TokenStream::from(error.into_compile_error()); + } + }; - Ok(quote! { + quote! { impl std::fmt::Display for #ident{ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { #body } } - }) + }.into() } /// Type alias for an expansion context: diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs index 9799cbf81..a1cc967f8 100644 --- a/tokenizers/display_derive/src/vendored.rs +++ b/tokenizers/display_derive/src/vendored.rs @@ -15,7 +15,6 @@ use syn::{ /// ``` /// /// [`fmt`]: std::fmt -#[derive(Debug)] pub struct FmtAttribute { /// Interpolation [`syn::LitStr`]. /// @@ -33,7 +32,6 @@ pub struct FmtAttribute { impl Parse for FmtAttribute { fn parse(input: ParseStream<'_>) -> syn::Result { - Self::check_legacy_fmt(input)?; Ok(Self { lit: input.parse()?, @@ -46,7 +44,6 @@ impl Parse for FmtAttribute { } } -impl attr::ParseMultiple for FmtAttribute {} impl ToTokens for FmtAttribute { fn to_tokens(&self, tokens: &mut TokenStream) { @@ -122,7 +119,6 @@ impl FmtAttribute { /// in a [`FmtAttribute`]. /// /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters -#[derive(Debug)] struct FmtArgument { /// `identifier =` [`Ident`]. /// From c4b4f3cf418ad02c4bb858048fdd2d8177d402c8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Jun 2024 12:33:11 +0200 Subject: [PATCH 63/94] simplify some stuff --- tokenizers/display_derive/src/lib.rs | 240 +++++++-------------------- 1 file changed, 58 insertions(+), 182 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 20aaceb08..49d2fec10 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -10,213 +10,89 @@ use vendored::FmtAttribute; pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the parsed_input tokens into a syntax tree let parsed_input = parse_macro_input!(input as DeriveInput); - let attr_name = "display"; let attrs = syn::parse::(input).unwrap(); + // 1. If the attrs are not None, then we defer to this. + // Meaning we juste return quote!{ format!(#fmt, #attr)} let trait_ident = format_ident!("display"); let ident = &parsed_input.ident; - let ctx = (&attrs, ident, &trait_ident, &trait_ident); + // 2. We automatically parse let body = match &parsed_input.data { - syn::Data::Struct(s) => expand_struct(s, ctx), - syn::Data::Enum(e) => expand_enum(e, ctx), + syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident), + syn::Data::Enum(e) => generate_fmt_impl_for_enum(e, ident), syn::Data::Union(u) => { let error = syn::Error::new_spanned(u.union_token, "Unions are not supported"); return proc_macro::TokenStream::from(error.into_compile_error()); } }; - quote! { - impl std::fmt::Display for #ident{ + let expanded = quote! { + impl std::fmt::Display for #ident { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { #body } } - }.into() -} - -/// Type alias for an expansion context: -/// - [`FmtAttribute`]. -/// - Struct/enum/union [`syn::Ident`]. -/// - Derived trait [`syn::Ident`]. -/// - Attribute name [`syn::Ident`]. -/// -/// [`syn::Ident`]: struct@syn::Ident -type ExpansionCtx<'a> = ( - &'a FmtAttribute, - &'a syn::Ident, - &'a syn::Ident, - &'a syn::Ident, -); - -/// Expands a [`fmt::Display`]-like derive macro for the provided struct. -fn expand_struct( - s: &syn::DataStruct, - (attrs, ident, trait_ident, _): ExpansionCtx<'_>, -) -> syn::Result<(Vec, TokenStream)> { - let s = Expansion { - attrs, - fields: &s.fields, - trait_ident, - ident, }; - let body = s.generate_body()?; - let vars = s.fields.iter().enumerate().map(|(i, f)| { - let var = f.ident.clone().unwrap_or_else(|| format_ident!("_{i}")); - let member = f - .ident - .clone() - .map_or_else(|| syn::Member::Unnamed(i.into()), syn::Member::Named); + TokenStream::from(expanded) +} + +fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Ident) -> TokenStream { + let fields = &data_struct.fields; + let field_fmts = fields.iter().enumerate().map(|(i, field)| { + let field_name = match &field.ident { + Some(ident) => ident, + None => { + // If the field doesn't have a name, we generate a name based on its index + let index = syn::Index::from(i); + quote! { #index } + } + }; quote! { - let #var = &self.#member; + write!(f, "{}: {}", stringify!(#field_name), self.#field_name)?; } }); - - let body = quote! { - #( #vars )* - #body - }; - - Ok(body) + // Collect the mapped tokens into a TokenStream + field_fmts } -/// Expands a [`fmt`]-like derive macro for the provided enum. -fn expand_enum( - e: &syn::DataEnum, - (attrs, _, trait_ident, attr_name): ExpansionCtx<'_>, -) -> syn::Result<(Vec, TokenStream)> { - if attrs.fmt.is_some() { - todo!("https://github.com/JelteF/derive_more/issues/142"); - } - - let match_arms = e.variants.iter().try_fold( - (Vec::new(), TokenStream::new()), - |mut arms, variant| { - let attrs = FmtAttribute::parse_attrs(&variant.attrs, attr_name)? - .unwrap_or_default(); - let ident = &variant.ident; - - if attrs.fmt.is_none() - && variant.fields.is_empty() - && attr_name != "display" - { - return Err(syn::Error::new( - e.variants.span(), - format!( - "implicit formatting of unit enum variant is supported only for `Display` \ - macro, use `#[{attr_name}(\"...\")]` to explicitly specify the formatting", - ), - )); +fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> TokenStream { + let arms = data_enum.variants.iter().map(|variant| { + let variant_name = &variant.ident; + let variant_fmt = match &variant.fields { + syn::Fields::Unit => { + // If the variant has no fields, we just print its name + quote! { write!(f, "{}", stringify!(#variant_name))?; } } - - let v = Expansion { - attrs: &attrs, - fields: &variant.fields, - trait_ident, - ident, - }; - let arm_body = v.generate_body()?; - - let fields_idents = - variant.fields.iter().enumerate().map(|(i, f)| { - f.ident.clone().unwrap_or_else(|| format_ident!("_{i}")) + syn::Fields::Named(fields) => { + // If the variant has named fields, we print each field's name and value + let field_fmts = fields.named.iter().map(|field| { + let field_name = field.ident.as_ref().unwrap(); + quote! { + write!(f, "{}: {:?}", stringify!(#field_name), self.#field_name)?; + } }); - let matcher = match variant.fields { - syn::Fields::Named(_) => { - quote! { Self::#ident { #( #fields_idents ),* } } + quote! { + write!(f, "{} {{ ", stringify!(#variant_name))?; + #( #field_fmts )* + write!(f, " }}")?; } - syn::Fields::Unnamed(_) => { - quote! { Self::#ident ( #( #fields_idents ),* ) } - } - syn::Fields::Unit => quote! { Self::#ident }, - }; - - arms.extend([quote! { #matcher => { #arm_body }, }]); - - Ok::<_, syn::Error>(arms) - }, - )?; - - let body = match_arms - .is_empty() - .then(|| quote! { match *self {} }) - .unwrap_or_else(|| quote! { match self { #match_arms } }); - - Ok(body) -} - - -/// Helper struct to generate [`Display::fmt()`] implementation body and trait -/// bounds for a struct or an enum variant. -/// -/// [`Display::fmt()`]: fmt::Display::fmt() -#[derive(Debug)] -struct Expansion<'a> { - /// Derive macro [`FmtAttribute`]. - attrs: &'a FmtAttribute, - - /// Struct or enum [`syn::Ident`]. - /// - /// [`syn::Ident`]: struct@syn::Ident - ident: &'a syn::Ident, - - /// Struct or enum [`syn::Fields`]. - fields: &'a syn::Fields, - - /// [`fmt`] trait [`syn::Ident`]. - /// - /// [`syn::Ident`]: struct@syn::Ident - trait_ident: &'a syn::Ident, -} - -impl<'a> Expansion<'a> { - /// Generates [`Display::fmt()`] implementation for a struct or an enum variant. - /// - /// # Errors - /// - /// In case [`FmtAttribute`] is [`None`] and [`syn::Fields`] length is - /// greater than 1. - /// - /// [`Display::fmt()`]: fmt::Display::fmt() - /// [`FmtAttribute`]: super::FmtAttribute - fn generate_body(&self) -> syn::Result { - match &self.attrs.fmt { - Some(fmt) => { - Ok(if let Some((expr, trait_ident)) = fmt.transparent_call() { - quote! { core::fmt::#trait_ident::fmt(&(#expr), __derive_more_f) } - } else { - quote! { core::write!(__derive_more_f, #fmt) } - }) } - None if self.fields.is_empty() => { - let ident_str = self.ident.to_string(); - - Ok(quote! { - core::write!(__derive_more_f, #ident_str) - }) - } - None if self.fields.len() == 1 => { - let field = self - .fields - .iter() - .next() - .unwrap_or_else(|| unreachable!("count() == 1")); - let ident = field.ident.clone().unwrap_or_else(|| format_ident!("_0")); - let trait_ident = self.trait_ident; - - Ok(quote! { - core::fmt::#trait_ident::fmt(#ident, __derive_more_f) - }) + syn::Fields::Unnamed(fields) => { + // If the variant has unnamed fields, we print each field's value without names + let field_fmts = fields.unnamed.iter().map(|field| { + quote! { + write!(f, "{:?}, ", self.#field)?; + } + }); + quote! { + write!(f, "{}(", stringify!(#variant_name))?; + #( #field_fmts )* + write!(f, ")")?; + } } - _ => Err(syn::Error::new( - self.fields.span(), - format!( - "TODO ARTHUR! struct or enum variant with more than 1 field must have \ - `#[{}(\"...\", ...)]` attribute", - self.trait_ident, - ), - )), - } - } + }; + quote! { #ident::#variant_name => { #variant_fmt } } + }); + arms } - From e712079b07942c5cecc6b3fbac711d9fc17acc50 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 13 Jun 2024 17:50:26 +0200 Subject: [PATCH 64/94] current status, not bad but not soooooo good --- tokenizers/display_derive/src/lib.rs | 108 +++++++++++++--------- tokenizers/display_derive/src/vendored.rs | 6 +- tokenizers/src/decoders/byte_fallback.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 6 +- 4 files changed, 73 insertions(+), 49 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 49d2fec10..ee72d0e21 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,31 +1,42 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::{format_ident,quote}; -use syn::{parse_macro_input, DeriveInput}; +use syn::{parse_macro_input, DeriveInput, Meta, MetaNameValue, Lit}; mod vendored; mod parsing; use vendored::FmtAttribute; -#[proc_macro_derive(Display)] +#[proc_macro_derive(Display, attributes(display))] pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the parsed_input tokens into a syntax tree let parsed_input = parse_macro_input!(input as DeriveInput); - let attrs = syn::parse::(input).unwrap(); + // let attrs = syn::parse::(input).unwrap(); + let mut fmt = quote!{}; + for attr in parsed_input.attrs{ + if attr.path.is_ident("display"){ + println!("attrs: {:?}", attr.path.get_ident()); + fmt = quote!{ write!(f, "display(fmt = '', ...) is not supported yet!")}; + } + } + // 1. If the attrs are not None, then we defer to this. // Meaning we juste return quote!{ format!(#fmt, #attr)} let trait_ident = format_ident!("display"); let ident = &parsed_input.ident; - + + let body = if fmt.is_empty() { // 2. We automatically parse - let body = match &parsed_input.data { - syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident), - syn::Data::Enum(e) => generate_fmt_impl_for_enum(e, ident), - syn::Data::Union(u) => { - let error = syn::Error::new_spanned(u.union_token, "Unions are not supported"); - return proc_macro::TokenStream::from(error.into_compile_error()); + match &parsed_input.data { + syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident), + syn::Data::Enum(e) => generate_fmt_impl_for_enum(e, ident), + syn::Data::Union(u) => { + let error = syn::Error::new_spanned(u.union_token, "Unions are not supported"); + return proc_macro::TokenStream::from(error.into_compile_error()); + } } + } else { + fmt }; - let expanded = quote! { impl std::fmt::Display for #ident { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -37,26 +48,45 @@ pub fn display_derive(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } -fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Ident) -> TokenStream { +fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Ident) -> proc_macro2::TokenStream { + + // return quote!{ + // write!(f, "automatic print") + // }; let fields = &data_struct.fields; - let field_fmts = fields.iter().enumerate().map(|(i, field)| { - let field_name = match &field.ident { - Some(ident) => ident, - None => { - // If the field doesn't have a name, we generate a name based on its index - let index = syn::Index::from(i); - quote! { #index } + + // Extract field names and types + let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); + let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); + + quote! { + write!(f, "{}(", stringify!(#ident))?; + let mut first = true; + #( + if !first { + write!(f, ", ")?; } - }; - quote! { - write!(f, "{}: {}", stringify!(#field_name), self.#field_name)?; - } - }); - // Collect the mapped tokens into a TokenStream - field_fmts + first = false; + + let field_value = &self.#field_names; + write!(f, "{}=", stringify!(#field_names))?; + if std::any::TypeId::of::<#field_types>() == std::any::TypeId::of::() { + write!(f, "\"{}\"", field_value)?; + } else { + let s = format!("{}", field_value); + let mut chars = s.chars(); + let mut prefix = (&mut chars).take(100 - 1).collect::(); + if chars.next().is_some() { + prefix.push('…'); + } + write!(f, "{}", prefix)?; + } + )* + write!(f, ")") + } } -fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> TokenStream { +fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> proc_macro2::TokenStream { let arms = data_enum.variants.iter().map(|variant| { let variant_name = &variant.ident; let variant_fmt = match &variant.fields { @@ -69,7 +99,7 @@ fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> let field_fmts = fields.named.iter().map(|field| { let field_name = field.ident.as_ref().unwrap(); quote! { - write!(f, "{}: {:?}", stringify!(#field_name), self.#field_name)?; + write!(f, "{}: {}", stringify!(#field_name), self.#field_name)?; } }); quote! { @@ -78,21 +108,13 @@ fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> write!(f, " }}")?; } } - syn::Fields::Unnamed(fields) => { - // If the variant has unnamed fields, we print each field's value without names - let field_fmts = fields.unnamed.iter().map(|field| { - quote! { - write!(f, "{:?}, ", self.#field)?; - } - }); - quote! { - write!(f, "{}(", stringify!(#variant_name))?; - #( #field_fmts )* - write!(f, ")")?; - } - } + syn::Fields::Unnamed(_) => quote! { write!(f, "__UNAMED__")} }; - quote! { #ident::#variant_name => { #variant_fmt } } + quote! { #variant_name => {#variant_fmt} } }); - arms + quote! { + match *self { + #(#arms),* + } + } } diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs index a1cc967f8..231858ed0 100644 --- a/tokenizers/display_derive/src/vendored.rs +++ b/tokenizers/display_derive/src/vendored.rs @@ -93,9 +93,9 @@ impl FmtAttribute { .map(|a| a.expr.clone()), // (4) Or the formatting parameter's name refers to some outer binding. - Some(parsing::Argument::Identifier(name)) if self.args.is_empty() => { - Some(format_ident!("{name}").into()) - } + // Some(parsing::Argument::Identifier(name)) if self.args.is_empty() => { + // Some(format_ident!("{trait_name}").into()) + // } // (5) Or exactly one named argument is specified for the formatting parameter's name. Some(parsing::Argument::Identifier(name)) => (self.args.len() == 1) diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index 69817c1b4..9c02acf2a 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -2,7 +2,7 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; use display_derive::Display; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone Debug, Serialize, Default, Display)] +#[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 35ac22971..d471755f5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -17,7 +17,8 @@ use std::{ ops::{Deref, DerefMut}, path::{Path, PathBuf}, }; - +extern crate rayon; +use rayon::current_thread_index; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; @@ -838,7 +839,7 @@ where EncodeInput::Single(s1) => (s1, None), EncodeInput::Dual(s1, s2) => (s1, Some(s2)), }; - + println!("thread id: {:?}", current_thread_index()); // Encode each sequence let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Byte)?; let pair_encoding = pair @@ -939,6 +940,7 @@ where word_idx: Option, offsets_type: OffsetType, ) -> Result { + println!("do tokenizer {:?}", current_thread_index()); let mut pretokenized: PreTokenizedString = pretokenized.into(); pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?; pretokenized.into_encoding(word_idx, type_id, offsets_type) From 554013683a93e0aef907685273d48cdfd3394c13 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Jun 2024 09:45:00 +0200 Subject: [PATCH 65/94] is this a good start? --- tokenizers/display_derive/src/lib.rs | 45 ++++++++++++++-------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index ee72d0e21..90474bbe7 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -37,6 +37,8 @@ pub fn display_derive(input: TokenStream) -> TokenStream { } else { fmt }; + + println!("body: {:?}", body.to_string()); let expanded = quote! { impl std::fmt::Display for #ident { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -89,32 +91,31 @@ fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Iden fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> proc_macro2::TokenStream { let arms = data_enum.variants.iter().map(|variant| { let variant_name = &variant.ident; - let variant_fmt = match &variant.fields { + let formatted_output = match &variant.fields { syn::Fields::Unit => { - // If the variant has no fields, we just print its name + // Unit variant: just stringify the variant name + quote! { #ident::#variant_name => {write!(f, "{}", stringify!(#variant_name))?; }} + }, + syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + // Tuple variant with one field + quote! { #ident::#variant_name(ref single) => {write!(f, "{}", single)?;} } + }, + syn::Fields::Named(fields) if fields.named.len() == 1 => { + // Tuple variant with one named field + let field_name = fields.named[0].ident.as_ref().unwrap(); // Assuming it's named + quote! { #ident::#variant_name{..}=>{ write!(f, "{}({})", stringify!(self.#field_name)?);} } + }, + _ => { + // Default case: stringify the variant name quote! { write!(f, "{}", stringify!(#variant_name))?; } } - syn::Fields::Named(fields) => { - // If the variant has named fields, we print each field's name and value - let field_fmts = fields.named.iter().map(|field| { - let field_name = field.ident.as_ref().unwrap(); - quote! { - write!(f, "{}: {}", stringify!(#field_name), self.#field_name)?; - } - }); - quote! { - write!(f, "{} {{ ", stringify!(#variant_name))?; - #( #field_fmts )* - write!(f, " }}")?; - } - } - syn::Fields::Unnamed(_) => quote! { write!(f, "__UNAMED__")} }; - quote! { #variant_name => {#variant_fmt} } + formatted_output }); quote! { - match *self { - #(#arms),* - } - } + match *self { + #(#arms)* + } + Ok(()) + } } From ba03c166edb1a4c298578a374b50bbc78d485554 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Jun 2024 20:35:33 +0200 Subject: [PATCH 66/94] small changes --- tokenizers/display_derive/src/vendored.rs | 13 +++++++------ tokenizers/src/decoders/byte_fallback.rs | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs index 231858ed0..eccc5263a 100644 --- a/tokenizers/display_derive/src/vendored.rs +++ b/tokenizers/display_derive/src/vendored.rs @@ -1,11 +1,12 @@ +use std::{env::VarError, error::Error}; + use crate::parsing; use proc_macro2::TokenStream; use quote::{format_ident, ToTokens}; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, - token, - Expr, + token, Expr, }; /// Representation of a [`fmt`]-like attribute. @@ -32,19 +33,19 @@ pub struct FmtAttribute { impl Parse for FmtAttribute { fn parse(input: ParseStream<'_>) -> syn::Result { - - Ok(Self { + let attribute = Self { lit: input.parse()?, comma: input .peek(token::Comma) .then(|| input.parse()) .transpose()?, args: input.parse_terminated(FmtArgument::parse)?, - }) + }; + println!("Parsing FMTAttribute, {}, ",attribute.lit.token().to_string()); + Ok(attribute) } } - impl ToTokens for FmtAttribute { fn to_tokens(&self, tokens: &mut TokenStream) { self.lit.to_tokens(tokens); diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index 9c02acf2a..8c88199de 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use monostate::MustBe; use display_derive::Display; +use monostate::MustBe; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` From d0e741bb00a8df1c664eb717ae5f4b92a94ba579 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Jun 2024 22:04:12 +0200 Subject: [PATCH 67/94] format does not work yet --- tokenizers/display_derive/src/lib.rs | 55 +++++++++++++++----------- tokenizers/src/tokenizer/mod.rs | 20 ++++++++-- tokenizers/src/tokenizer/normalizer.rs | 1 - 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 90474bbe7..7f82d9bb6 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,31 +1,35 @@ extern crate proc_macro; use proc_macro::TokenStream; -use quote::{format_ident,quote}; -use syn::{parse_macro_input, DeriveInput, Meta, MetaNameValue, Lit}; -mod vendored; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse_macro_input, DeriveInput, Lit, Meta, MetaNameValue}; mod parsing; +mod vendored; use vendored::FmtAttribute; #[proc_macro_derive(Display, attributes(display))] -pub fn display_derive(input: TokenStream) -> TokenStream { +pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the parsed_input tokens into a syntax tree + // let attrs = syn::parse::(input.clone()); + // // Handle the Result from the parsing step + // let attrs = match attrs { + // Ok(attrs) => attrs, + // Err(_) => return TokenStream::new(), // Handle error case appropriately + // }; let parsed_input = parse_macro_input!(input as DeriveInput); - // let attrs = syn::parse::(input).unwrap(); - let mut fmt = quote!{}; - for attr in parsed_input.attrs{ - if attr.path.is_ident("display"){ - println!("attrs: {:?}", attr.path.get_ident()); - fmt = quote!{ write!(f, "display(fmt = '', ...) is not supported yet!")}; + let mut fmt = quote! {}; + for attr in parsed_input.attrs { + if attr.path.is_ident("display") { + fmt = quote! { write!(f, "display(fmt = '', ...)")}; } } - - // 1. If the attrs are not None, then we defer to this. - // Meaning we juste return quote!{ format!(#fmt, #attr)} - let trait_ident = format_ident!("display"); + + // 1. If the attrs are not None, then we defer to this. + // Meaning we juste return quote!{ format!(#fmt, #attr)} + let trait_ident: syn::Ident = format_ident!("display"); let ident = &parsed_input.ident; - + let body = if fmt.is_empty() { - // 2. We automatically parse + // 2. We automatically parse match &parsed_input.data { syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident), syn::Data::Enum(e) => generate_fmt_impl_for_enum(e, ident), @@ -50,13 +54,12 @@ pub fn display_derive(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } -fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Ident) -> proc_macro2::TokenStream { - - // return quote!{ - // write!(f, "automatic print") - // }; +fn generate_fmt_impl_for_struct( + data_struct: &syn::DataStruct, + ident: &syn::Ident, +) -> proc_macro2::TokenStream { let fields = &data_struct.fields; - + // Extract field names and types let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); @@ -73,6 +76,7 @@ fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Iden let field_value = &self.#field_names; write!(f, "{}=", stringify!(#field_names))?; if std::any::TypeId::of::<#field_types>() == std::any::TypeId::of::() { + println!("We have a string!"); write!(f, "\"{}\"", field_value)?; } else { let s = format!("{}", field_value); @@ -88,7 +92,10 @@ fn generate_fmt_impl_for_struct(data_struct: &syn::DataStruct, ident: &syn::Iden } } -fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> proc_macro2::TokenStream { +fn generate_fmt_impl_for_enum( + data_enum: &syn::DataEnum, + ident: &syn::Ident, +) -> proc_macro2::TokenStream { let arms = data_enum.variants.iter().map(|variant| { let variant_name = &variant.ident; let formatted_output = match &variant.fields { @@ -112,6 +119,8 @@ fn generate_fmt_impl_for_enum(data_enum: &syn::DataEnum, ident: &syn::Ident) -> }; formatted_output }); + + println!("printing ident: {}", ident.to_string()); quote! { match *self { #(#arms)* diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index d471755f5..fd5631755 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -18,10 +18,10 @@ use std::{ path::{Path, PathBuf}, }; extern crate rayon; -use rayon::current_thread_index; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use rayon::current_thread_index; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -551,17 +551,17 @@ where None => "None".to_string(), }; let truncation_str = match &self.truncation { - Some(t) => format!("{}", t), + Some(t) => format!("{:?}", t), None => "None".to_string(), }; let padding_str = match &self.padding { - Some(p) => format!("{}", p), + Some(p) => format!("{:?}", p), None => "None".to_string(), }; write!( f, - "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={}, truncation={}, padding={})", + "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={:?}, truncation={}, padding={})", normalizer_str, pre_tokenizer_str, self.model, @@ -1355,3 +1355,15 @@ where Ok(()) } } + +#[cfg(test)] +mod tests { + use super::Tokenizer; + + #[cfg(feature = "http")] + #[test] + fn test_from_pretrained() { + let tok = Tokenizer::from_pretrained("Qwen/Qwen2-7B-Instruct".to_string(), None); + println!("ROCK!") + } +} diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index efd1b728a..ff44bfe56 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -79,7 +79,6 @@ where /// - MergedWithNext => `[ "the", "-final", "-", "-countdown" ]` /// - Contiguous => `[ "the", "-", "final", "--", "countdown" ]` #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Display)] -#[display(fmt = "{}")] pub enum SplitDelimiterBehavior { Removed, Isolated, From 19afb669deb812cf13825ed1170fed1e5cf2aeee Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 16 Jun 2024 14:14:39 +0200 Subject: [PATCH 68/94] some cleanup of unnecessary things --- bindings/python/grep | 0 bindings/python/src/processors.rs | 2 +- tokenizers/display_derive/src/fmt_parsing.rs | 108 +++ tokenizers/display_derive/src/lib.rs | 38 +- tokenizers/display_derive/src/parsing.rs | 709 ------------------- tokenizers/display_derive/src/vendored.rs | 161 ----- tokenizers/src/pre_tokenizers/split.rs | 2 +- 7 files changed, 129 insertions(+), 891 deletions(-) create mode 100644 bindings/python/grep create mode 100644 tokenizers/display_derive/src/fmt_parsing.rs delete mode 100644 tokenizers/display_derive/src/parsing.rs delete mode 100644 tokenizers/display_derive/src/vendored.rs diff --git a/bindings/python/grep b/bindings/python/grep new file mode 100644 index 000000000..e69de29bb diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 1aa55f76e..130440b55 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -144,7 +144,7 @@ impl PyPostProcessor { Ok(format!("{}", &self)) } - fn __repr__(&self) -> PyResult{ + fn __repr__(&self) -> PyResult { Ok(format!("{}", &self)) } } diff --git a/tokenizers/display_derive/src/fmt_parsing.rs b/tokenizers/display_derive/src/fmt_parsing.rs new file mode 100644 index 000000000..b9ebc2c06 --- /dev/null +++ b/tokenizers/display_derive/src/fmt_parsing.rs @@ -0,0 +1,108 @@ +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + token, Expr, +}; + +/// Representation of a [`fmt`]-like attribute. +/// +/// ```rust,ignore +/// #[("", )] +/// ``` +/// +/// [`fmt`]: std::fmt +pub struct FmtAttribute { + /// Interpolation [`syn::LitStr`]. + /// + /// [`syn::LitStr`]: struct@syn::LitStr + lit: syn::LitStr, + + /// Optional [`token::Comma`]. + /// + /// [`token::Comma`]: struct@token::Comma + comma: Option, + + /// Interpolation arguments. + args: Punctuated, +} + +impl Parse for FmtAttribute { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let _ident: syn::Ident = input + .parse() + .map_err(|_| syn::Error::new(input.span(), "Expected 'fmt' argument"))?; + input + .parse::() + .map_err(|_| syn::Error::new(input.span(), "Expected '=' after 'fmt'"))?; + + let attribute = Self { + lit: input.parse()?, + comma: input + .peek(token::Comma) + .then(|| input.parse()) + .transpose()?, + args: input.parse_terminated::(FmtArgument::parse)?, + }; + println!( + "Parsed successfully!, {:?}\n parsed arguments: {}", + attribute.lit.token().to_string(), + attribute.args.to_token_stream(), + ); + Ok(attribute) + } +} + +impl ToTokens for FmtAttribute { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.lit.to_tokens(tokens); + self.comma.to_tokens(tokens); + self.args.to_tokens(tokens); + } +} + +/// Representation of a [named parameter][1] (`identifier '=' expression`) in +/// in a [`FmtAttribute`]. +/// This should be used in `[display(fmt="", alias=alias, expr)]`. +/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters +struct FmtArgument { + /// `identifier =` [`Ident`]. + /// + /// [`Ident`]: struct@syn::Ident + alias: Option<(syn::Ident, token::Eq)>, + + /// `expression` [`Expr`]. + expr: Expr, +} + +impl FmtArgument { + /// Returns an `identifier` of the [named parameter][1]. + /// + /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters + fn alias(&self) -> Option<&syn::Ident> { + self.alias.as_ref().map(|(ident, _)| ident) + } +} + +impl Parse for FmtArgument { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self { + alias: (input.peek(syn::Ident) && input.peek2(token::Eq)) + .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) + .transpose()?, + expr: input.parse()?, + }) + } +} + +impl ToTokens for FmtArgument { + fn to_tokens(&self, tokens: &mut TokenStream) { + if let Some((ident, eq)) = &self.alias { + quote!(self . #ident).to_tokens(tokens); + eq.to_tokens(tokens); + } + self.expr.to_tokens(tokens) + } +} diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 7f82d9bb6..80e0e538e 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,31 +1,31 @@ extern crate proc_macro; use proc_macro::TokenStream; -use quote::{format_ident, quote, ToTokens}; -use syn::{parse_macro_input, DeriveInput, Lit, Meta, MetaNameValue}; -mod parsing; -mod vendored; -use vendored::FmtAttribute; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +mod fmt_parsing; +use fmt_parsing::FmtAttribute; #[proc_macro_derive(Display, attributes(display))] pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the parsed_input tokens into a syntax tree - // let attrs = syn::parse::(input.clone()); - // // Handle the Result from the parsing step - // let attrs = match attrs { - // Ok(attrs) => attrs, - // Err(_) => return TokenStream::new(), // Handle error case appropriately - // }; let parsed_input = parse_macro_input!(input as DeriveInput); - let mut fmt = quote! {}; - for attr in parsed_input.attrs { - if attr.path.is_ident("display") { - fmt = quote! { write!(f, "display(fmt = '', ...)")}; - } - } + // Find the `display` attribute + let display_attr = parsed_input + .attrs + .iter() + .find(|attr| attr.path.is_ident("display")); + let fmt = if let Some(attr) = display_attr { + match attr.parse_args::() { + Ok(display_macro) => quote! { write!(f, #display_macro) }, + Err(e) => return e.to_compile_error().into(), + } + } else { + quote! {} + }; // 1. If the attrs are not None, then we defer to this. // Meaning we juste return quote!{ format!(#fmt, #attr)} - let trait_ident: syn::Ident = format_ident!("display"); let ident = &parsed_input.ident; let body = if fmt.is_empty() { @@ -42,7 +42,6 @@ pub fn display_derive(input: TokenStream) -> TokenStream { fmt }; - println!("body: {:?}", body.to_string()); let expanded = quote! { impl std::fmt::Display for #ident { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -51,6 +50,7 @@ pub fn display_derive(input: TokenStream) -> TokenStream { } }; + println!("Generated body: \n{}\n", expanded); TokenStream::from(expanded) } diff --git a/tokenizers/display_derive/src/parsing.rs b/tokenizers/display_derive/src/parsing.rs deleted file mode 100644 index f3fbca198..000000000 --- a/tokenizers/display_derive/src/parsing.rs +++ /dev/null @@ -1,709 +0,0 @@ -//! Parsing of [Rust `fmt` syntax][0]. -//! -//! [0]: std::fmt#syntax - -use std::iter; - -use unicode_xid::UnicodeXID as XID; - -/// Output of the [`format_string`] parser. -#[derive(Clone, Debug, Eq, PartialEq)] -pub(crate) struct FormatString<'a> { - pub(crate) formats: Vec>, -} - -/// Output of the [`format`] parser. -/// -/// [`format`]: fn@format -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub(crate) struct Format<'a> { - pub(crate) arg: Option>, - pub(crate) spec: Option>, -} - -/// Output of the [`format_spec`] parser. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) struct FormatSpec<'a> { - /// Parsed `[[fill]`[`align`]`]`. - pub(crate) align: Option<(Option, Align)>, - - /// Parsed `[`[`sign`]`]`. - pub(crate) sign: Option, - - /// Parsed `['#']` (alternation). - pub(crate) alternate: Option, - - /// Parsed `['0']` (padding with zeros). - pub(crate) zero_padding: Option, - - /// Parsed `[width]`. - pub(crate) width: Option>, - - /// Parsed `['.' `[`precision`]`]`. - pub(crate) precision: Option>, - - /// Parsed [`type`]. - /// - /// [`type`]: type_ - pub(crate) ty: Type, -} - -/// Output of the [`argument`] parser. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Argument<'a> { - Integer(usize), - Identifier(&'a str), -} - -/// Output of the [`align`] parser. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Align { - Left, - Center, - Right, -} - -/// Output of the [`sign`] parser. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Sign { - Plus, - Minus, -} - -/// Type for the [`FormatSpec::alternate`]. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) struct Alternate; - -/// Type for the [`FormatSpec::zero_padding`]. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) struct ZeroPadding; - -/// Output of the [`precision`] parser. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Precision<'a> { - Count(Count<'a>), - Star, -} - -/// Output of the [`count`] parser. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Count<'a> { - Integer(usize), - Parameter(Parameter<'a>), -} - -/// Output of the [`type_`] parser. See [formatting traits][0] for more info. -/// -/// [0]: std::fmt#formatting-traits -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Type { - Display, - Debug, - LowerDebug, - UpperDebug, - Octal, - LowerHex, - UpperHex, - Pointer, - Binary, - LowerExp, - UpperExp, -} - -impl Type { - /// Returns trait name of this [`Type`]. - pub(crate) fn trait_name(&self) -> &'static str { - match self { - Self::Display => "Display", - Self::Debug | Self::LowerDebug | Self::UpperDebug => "Debug", - Self::Octal => "Octal", - Self::LowerHex => "LowerHex", - Self::UpperHex => "UpperHex", - Self::Pointer => "Pointer", - Self::Binary => "Binary", - Self::LowerExp => "LowerExp", - Self::UpperExp => "UpperExp", - } - } - - /// Indicates whether this [`Type`] represents a trivial trait call without any modifications. - pub(crate) fn is_trivial(&self) -> bool { - match self { - Self::Display - | Self::Debug - | Self::Octal - | Self::LowerHex - | Self::UpperHex - | Self::Pointer - | Self::Binary - | Self::LowerExp - | Self::UpperExp => true, - Self::LowerDebug | Self::UpperDebug => false, - } - } -} - -/// Type alias for the `fill` in the [`FormatSpec::align`]. -type Fill = char; - -/// Type alias for the [`FormatSpec::width`]. -type Width<'a> = Count<'a>; - -/// Output of the [`maybe_format`] parser. -type MaybeFormat<'a> = Option>; - -/// Output of the [`identifier`] parser. -type Identifier<'a> = &'a str; - -/// Output of the [`parameter`] parser. -type Parameter<'a> = Argument<'a>; - -/// [`str`] left to parse. -/// -/// [`str`]: prim@str -type LeftToParse<'a> = &'a str; - -/// Parses a `format_string` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`format_string`]` := `[`text`]` [`[`maybe_format text`]`] *` -/// -/// # Example -/// -/// ```text -/// Hello -/// Hello, {}! -/// {:?} -/// Hello {people}! -/// {} {} -/// {:04} -/// {par:-^#.0$?} -/// ``` -/// -/// # Return value -/// -/// - [`Some`] in case of successful parse. -/// - [`None`] otherwise (not all characters are consumed by underlying -/// parsers). -/// -/// [0]: std::fmt#syntax -pub(crate) fn format_string(input: &str) -> Option> { - let (mut input, _) = optional_result(text)(input); - - let formats = iter::repeat(()) - .scan(&mut input, |input, _| { - let (curr, format) = - alt(&mut [&mut maybe_format, &mut map(text, |(i, _)| (i, None))])(input)?; - **input = curr; - Some(format) - }) - .flatten() - .collect(); - - // Should consume all tokens for a successful parse. - input.is_empty().then_some(FormatString { formats }) -} - -/// Parses a `maybe_format` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`maybe_format`]` := '{' '{' | '}' '}' | `[`format`] -/// -/// # Example -/// -/// ```text -/// {{ -/// }} -/// {:04} -/// {:#?} -/// {par:-^#.0$?} -/// ``` -/// -/// [`format`]: fn@format -/// [0]: std::fmt#syntax -fn maybe_format(input: &str) -> Option<(LeftToParse<'_>, MaybeFormat<'_>)> { - alt(&mut [ - &mut map(str("{{"), |i| (i, None)), - &mut map(str("}}"), |i| (i, None)), - &mut map(format, |(i, format)| (i, Some(format))), - ])(input) -} - -/// Parses a `format` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`format`]` := '{' [`[`argument`]`] [':' `[`format_spec`]`] '}'` -/// -/// # Example -/// -/// ```text -/// {par} -/// {:#?} -/// {par:-^#.0$?} -/// ``` -/// -/// [`format`]: fn@format -/// [0]: std::fmt#syntax -pub(crate) fn format(input: &str) -> Option<(LeftToParse<'_>, Format<'_>)> { - let input = char('{')(input)?; - - let (input, arg) = optional_result(argument)(input); - - let (input, spec) = map_or_else( - char(':'), - |i| Some((i, None)), - map(format_spec, |(i, s)| (i, Some(s))), - )(input)?; - - let input = char('}')(input)?; - - Some((input, Format { arg, spec })) -} - -/// Parses an `argument` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`argument`]` := `[`integer`]` | `[`identifier`] -/// -/// # Example -/// -/// ```text -/// 0 -/// ident -/// Минск -/// ``` -/// -/// [0]: std::fmt#syntax -fn argument(input: &str) -> Option<(LeftToParse<'_>, Argument)> { - alt(&mut [ - &mut map(identifier, |(i, ident)| (i, Argument::Identifier(ident))), - &mut map(integer, |(i, int)| (i, Argument::Integer(int))), - ])(input) -} - -/// Parses a `format_spec` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`format_spec`]` := [[fill]`[`align`]`][`[`sign`]`]['#']['0'][width]` -/// `['.' `[`precision`]`]`[`type`] -/// -/// # Example -/// -/// ```text -/// ^ -/// <^ -/// ->+#0width$.precision$x? -/// ``` -/// -/// [`type`]: type_ -/// [0]: std::fmt#syntax -fn format_spec(input: &str) -> Option<(LeftToParse<'_>, FormatSpec<'_>)> { - let (input, align) = optional_result(alt(&mut [ - &mut and_then(take_any_char, |(i, fill)| { - map(align, |(i, align)| (i, (Some(fill), align)))(i) - }), - &mut map(align, |(i, align)| (i, (None, align))), - ]))(input); - - let (input, sign) = optional_result(sign)(input); - - let (input, alternate) = optional_result(map(char('#'), |i| (i, Alternate)))(input); - - let (input, zero_padding) = optional_result(map( - try_seq(&mut [ - &mut char('0'), - &mut lookahead(check_char(|c| !matches!(c, '$'))), - ]), - |i| (i, ZeroPadding), - ))(input); - - let (input, width) = optional_result(count)(input); - - let (input, precision) = map_or_else( - char('.'), - |i| Some((i, None)), - map(precision, |(i, p)| (i, Some(p))), - )(input)?; - - let (input, ty) = type_(input)?; - - Some(( - input, - FormatSpec { - align, - sign, - alternate, - zero_padding, - width, - precision, - ty, - }, - )) -} - -/// Parses an `align` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`align`]` := '<' | '^' | '>'` -/// -/// # Example -/// -/// ```text -/// < -/// ^ -/// > -/// ``` -/// -/// [0]: std::fmt#syntax -fn align(input: &str) -> Option<(LeftToParse<'_>, Align)> { - alt(&mut [ - &mut map(char('<'), |i| (i, Align::Left)), - &mut map(char('^'), |i| (i, Align::Center)), - &mut map(char('>'), |i| (i, Align::Right)), - ])(input) -} - -/// Parses a `sign` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`sign`]` := '+' | '-'` -/// -/// # Example -/// -/// ```text -/// + -/// - -/// ``` -/// -/// [0]: std::fmt#syntax -fn sign(input: &str) -> Option<(LeftToParse<'_>, Sign)> { - alt(&mut [ - &mut map(char('+'), |i| (i, Sign::Plus)), - &mut map(char('-'), |i| (i, Sign::Minus)), - ])(input) -} - -/// Parses a `precision` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`precision`]` := `[`count`]` | '*'` -/// -/// # Example -/// -/// ```text -/// 0 -/// 42$ -/// par$ -/// * -/// ``` -/// -/// [0]: std::fmt#syntax -fn precision(input: &str) -> Option<(LeftToParse<'_>, Precision<'_>)> { - alt(&mut [ - &mut map(count, |(i, c)| (i, Precision::Count(c))), - &mut map(char('*'), |i| (i, Precision::Star)), - ])(input) -} - -/// Parses a `type` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`type`]` := '' | '?' | 'x?' | 'X?' | identifier` -/// -/// # Example -/// -/// All possible [`Type`]s. -/// -/// ```text -/// ? -/// x? -/// X? -/// o -/// x -/// X -/// p -/// b -/// e -/// E -/// ``` -/// -/// [`type`]: type_ -/// [0]: std::fmt#syntax -fn type_(input: &str) -> Option<(&str, Type)> { - alt(&mut [ - &mut map(str("x?"), |i| (i, Type::LowerDebug)), - &mut map(str("X?"), |i| (i, Type::UpperDebug)), - &mut map(char('?'), |i| (i, Type::Debug)), - &mut map(char('o'), |i| (i, Type::Octal)), - &mut map(char('x'), |i| (i, Type::LowerHex)), - &mut map(char('X'), |i| (i, Type::UpperHex)), - &mut map(char('p'), |i| (i, Type::Pointer)), - &mut map(char('b'), |i| (i, Type::Binary)), - &mut map(char('e'), |i| (i, Type::LowerExp)), - &mut map(char('E'), |i| (i, Type::UpperExp)), - &mut map(lookahead(char('}')), |i| (i, Type::Display)), - ])(input) -} - -/// Parses a `count` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`count`]` := `[`parameter`]` | `[`integer`] -/// -/// # Example -/// -/// ```text -/// 0 -/// 42$ -/// par$ -/// ``` -/// -/// [0]: std::fmt#syntax -fn count(input: &str) -> Option<(LeftToParse<'_>, Count<'_>)> { - alt(&mut [ - &mut map(parameter, |(i, p)| (i, Count::Parameter(p))), - &mut map(integer, |(i, int)| (i, Count::Integer(int))), - ])(input) -} - -/// Parses a `parameter` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// [`parameter`]` := `[`argument`]` '$'` -/// -/// # Example -/// -/// ```text -/// 42$ -/// par$ -/// ``` -/// -/// [0]: std::fmt#syntax -fn parameter(input: &str) -> Option<(LeftToParse<'_>, Parameter<'_>)> { - and_then(argument, |(i, arg)| map(char('$'), |i| (i, arg))(i))(input) -} - -/// Parses an `identifier` as defined in the [grammar spec][0]. -/// -/// # Grammar -/// -/// `IDENTIFIER_OR_KEYWORD : XID_Start XID_Continue* | _ XID_Continue+` -/// -/// See [rust reference][2] for more info. -/// -/// # Example -/// -/// ```text -/// identifier -/// Минск -/// ``` -/// -/// [0]: std::fmt#syntax -/// [2]: https://doc.rust-lang.org/reference/identifiers.html -fn identifier(input: &str) -> Option<(LeftToParse<'_>, Identifier<'_>)> { - map( - alt(&mut [ - &mut map( - check_char(XID::is_xid_start), - take_while0(check_char(XID::is_xid_continue)), - ), - &mut and_then(char('_'), take_while1(check_char(XID::is_xid_continue))), - ]), - |(i, _)| (i, &input[..(input.as_bytes().len() - i.as_bytes().len())]), - )(input) -} - -/// Parses an `integer` as defined in the [grammar spec][0]. -/// -/// [0]: std::fmt#syntax -fn integer(input: &str) -> Option<(LeftToParse<'_>, usize)> { - and_then( - take_while1(check_char(|c| c.is_ascii_digit())), - |(i, int)| int.parse().ok().map(|int| (i, int)), - )(input) -} - -/// Parses a `text` as defined in the [grammar spec][0]. -/// -/// [0]: std::fmt#syntax -fn text(input: &str) -> Option<(LeftToParse<'_>, &str)> { - take_until1(any_char, one_of("{}"))(input) -} - -type FallibleParser<'p> = &'p mut dyn FnMut(&str) -> Option<&str>; - -/// Tries to apply parsers in sequence. Returns [`None`] in case one of them -/// returned [`None`]. -fn try_seq<'p>( - parsers: &'p mut [FallibleParser<'p>], -) -> impl FnMut(&str) -> Option> + 'p { - move |input| parsers.iter_mut().try_fold(input, |i, p| (**p)(i)) -} - -/// Returns first successful parser or [`None`] in case all of them returned -/// [`None`]. -fn alt<'p, 'i, T: 'i>( - parsers: &'p mut [&'p mut dyn FnMut(&'i str) -> Option], -) -> impl FnMut(&'i str) -> Option + 'p { - move |input| parsers.iter_mut().find_map(|p| (**p)(input)) -} - -/// Maps output of the parser in case it returned [`Some`]. -fn map<'i, I: 'i, O: 'i>( - mut parser: impl FnMut(&'i str) -> Option, - mut f: impl FnMut(I) -> O, -) -> impl FnMut(&'i str) -> Option { - move |input| parser(input).map(&mut f) -} - -/// Maps output of the parser in case it returned [`Some`] or uses `default`. -fn map_or_else<'i, I: 'i, O: 'i>( - mut parser: impl FnMut(&'i str) -> Option, - mut default: impl FnMut(&'i str) -> O, - mut f: impl FnMut(I) -> O, -) -> impl FnMut(&'i str) -> O { - move |input| parser(input).map_or_else(|| default(input), &mut f) -} - -/// Returns [`None`] if the parser returned is [`None`], otherwise calls `f` -/// with the wrapped value and returns the result. -fn and_then<'i, I: 'i, O: 'i>( - mut parser: impl FnMut(&'i str) -> Option, - mut f: impl FnMut(I) -> Option, -) -> impl FnMut(&'i str) -> Option { - move |input| parser(input).and_then(&mut f) -} - -/// Checks whether `parser` is successful while not advancing the original -/// `input`. -fn lookahead( - mut parser: impl FnMut(&str) -> Option<&str>, -) -> impl FnMut(&str) -> Option> { - move |input| map(&mut parser, |_| input)(input) -} - -/// Makes underlying `parser` optional by returning the original `input` and -/// [`None`] in case it returned [`None`]. -fn optional_result<'i, T: 'i>( - mut parser: impl FnMut(&'i str) -> Option<(&'i str, T)>, -) -> impl FnMut(&'i str) -> (LeftToParse<'i>, Option) { - move |input: &str| map_or_else(&mut parser, |i| (i, None), |(i, c)| (i, Some(c)))(input) -} - -/// Parses while `parser` is successful. Never fails. -fn take_while0( - mut parser: impl FnMut(&str) -> Option<&str>, -) -> impl FnMut(&str) -> (LeftToParse<'_>, &str) { - move |input| { - let mut cur = input; - while let Some(step) = parser(cur) { - cur = step; - } - ( - cur, - &input[..(input.as_bytes().len() - cur.as_bytes().len())], - ) - } -} - -/// Parses while `parser` is successful. Returns [`None`] in case `parser` never -/// succeeded. -fn take_while1( - mut parser: impl FnMut(&str) -> Option<&str>, -) -> impl FnMut(&str) -> Option<(LeftToParse<'_>, &str)> { - move |input| { - let mut cur = parser(input)?; - while let Some(step) = parser(cur) { - cur = step; - } - Some(( - cur, - &input[..(input.as_bytes().len() - cur.as_bytes().len())], - )) - } -} - -/// Parses with `basic` while `until` returns [`None`]. Returns [`None`] in case -/// `until` succeeded initially or `basic` never succeeded. Doesn't consume -/// [`char`]s parsed by `until`. -/// -/// [`char`]: fn@char -fn take_until1( - mut basic: impl FnMut(&str) -> Option<&str>, - mut until: impl FnMut(&str) -> Option<&str>, -) -> impl FnMut(&str) -> Option<(LeftToParse<'_>, &str)> { - move |input: &str| { - if until(input).is_some() { - return None; - } - let mut cur = basic(input)?; - loop { - if until(cur).is_some() { - break; - } - let Some(b) = basic(cur) else { - break; - }; - cur = b; - } - - Some(( - cur, - &input[..(input.as_bytes().len() - cur.as_bytes().len())], - )) - } -} - -/// Checks whether `input` starts with `s`. -fn str(s: &str) -> impl FnMut(&str) -> Option> + '_ { - move |input| input.starts_with(s).then(|| &input[s.as_bytes().len()..]) -} - -/// Checks whether `input` starts with `c`. -fn char(c: char) -> impl FnMut(&str) -> Option> { - move |input| input.starts_with(c).then(|| &input[c.len_utf8()..]) -} - -/// Checks whether first [`char`] suits `check`. -/// -/// [`char`]: fn@char -fn check_char(mut check: impl FnMut(char) -> bool) -> impl FnMut(&str) -> Option> { - move |input| { - input - .chars() - .next() - .and_then(|c| check(c).then(|| &input[c.len_utf8()..])) - } -} - -/// Checks whether first [`char`] of input is present in `chars`. -/// -/// [`char`]: fn@char -fn one_of(chars: &str) -> impl FnMut(&str) -> Option> + '_ { - move |input: &str| chars.chars().find_map(|c| char(c)(input)) -} - -/// Parses any [`char`]. -/// -/// [`char`]: fn@char -fn any_char(input: &str) -> Option> { - input.chars().next().map(|c| &input[c.len_utf8()..]) -} - -/// Parses any [`char`] and returns it. -/// -/// [`char`]: fn@char -fn take_any_char(input: &str) -> Option<(LeftToParse<'_>, char)> { - input.chars().next().map(|c| (&input[c.len_utf8()..], c)) -} diff --git a/tokenizers/display_derive/src/vendored.rs b/tokenizers/display_derive/src/vendored.rs deleted file mode 100644 index eccc5263a..000000000 --- a/tokenizers/display_derive/src/vendored.rs +++ /dev/null @@ -1,161 +0,0 @@ -use std::{env::VarError, error::Error}; - -use crate::parsing; -use proc_macro2::TokenStream; -use quote::{format_ident, ToTokens}; -use syn::{ - parse::{Parse, ParseStream}, - punctuated::Punctuated, - token, Expr, -}; - -/// Representation of a [`fmt`]-like attribute. -/// -/// ```rust,ignore -/// #[("", )] -/// ``` -/// -/// [`fmt`]: std::fmt -pub struct FmtAttribute { - /// Interpolation [`syn::LitStr`]. - /// - /// [`syn::LitStr`]: struct@syn::LitStr - lit: syn::LitStr, - - /// Optional [`token::Comma`]. - /// - /// [`token::Comma`]: struct@token::Comma - comma: Option, - - /// Interpolation arguments. - args: Punctuated, -} - -impl Parse for FmtAttribute { - fn parse(input: ParseStream<'_>) -> syn::Result { - let attribute = Self { - lit: input.parse()?, - comma: input - .peek(token::Comma) - .then(|| input.parse()) - .transpose()?, - args: input.parse_terminated(FmtArgument::parse)?, - }; - println!("Parsing FMTAttribute, {}, ",attribute.lit.token().to_string()); - Ok(attribute) - } -} - -impl ToTokens for FmtAttribute { - fn to_tokens(&self, tokens: &mut TokenStream) { - self.lit.to_tokens(tokens); - self.comma.to_tokens(tokens); - self.args.to_tokens(tokens); - } -} - -impl FmtAttribute { - /// Checks whether this [`FmtAttribute`] can be replaced with a transparent delegation (calling - /// a formatting trait directly instead of interpolation syntax). - /// - /// If such transparent call is possible, the returns an [`Ident`] of the delegated trait and - /// the [`Expr`] to pass into the call, otherwise [`None`]. - /// - /// [`Ident`]: struct@syn::Ident - fn transparent_call(&self) -> Option<(Expr, syn::Ident)> { - // `FmtAttribute` is transparent when: - - // (1) There is exactly one formatting parameter. - let lit = self.lit.value(); - let param = parsing::format(&lit).and_then(|(more, p)| more.is_empty().then_some(p))?; - - // (2) And the formatting parameter doesn't contain any modifiers. - if param - .spec - .map(|s| { - s.align.is_some() - || s.sign.is_some() - || s.alternate.is_some() - || s.zero_padding.is_some() - || s.width.is_some() - || s.precision.is_some() - || !s.ty.is_trivial() - }) - .unwrap_or_default() - { - return None; - } - - let expr = match param.arg { - // (3) And either exactly one positional argument is specified. - Some(parsing::Argument::Integer(_)) | None => (self.args.len() == 1) - .then(|| self.args.first()) - .flatten() - .map(|a| a.expr.clone()), - - // (4) Or the formatting parameter's name refers to some outer binding. - // Some(parsing::Argument::Identifier(name)) if self.args.is_empty() => { - // Some(format_ident!("{trait_name}").into()) - // } - - // (5) Or exactly one named argument is specified for the formatting parameter's name. - Some(parsing::Argument::Identifier(name)) => (self.args.len() == 1) - .then(|| self.args.first()) - .flatten() - .filter(|a| a.alias.as_ref().map(|a| a.0 == name).unwrap_or_default()) - .map(|a| a.expr.clone()), - }?; - - let trait_name = param - .spec - .map(|s| s.ty) - .unwrap_or(parsing::Type::Display) - .trait_name(); - - Some((expr, format_ident!("{trait_name}"))) - } -} - -/// Representation of a [named parameter][1] (`identifier '=' expression`) in -/// in a [`FmtAttribute`]. -/// -/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters -struct FmtArgument { - /// `identifier =` [`Ident`]. - /// - /// [`Ident`]: struct@syn::Ident - alias: Option<(syn::Ident, token::Eq)>, - - /// `expression` [`Expr`]. - expr: Expr, -} - -impl FmtArgument { - /// Returns an `identifier` of the [named parameter][1]. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters - fn alias(&self) -> Option<&syn::Ident> { - self.alias.as_ref().map(|(ident, _)| ident) - } -} - -impl Parse for FmtArgument { - fn parse(input: ParseStream) -> syn::Result { - Ok(Self { - alias: (input.peek(syn::Ident) && input.peek2(token::Eq)) - .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) - .transpose()?, - expr: input.parse()?, - }) - } -} - -impl ToTokens for FmtArgument { - fn to_tokens(&self, tokens: &mut TokenStream) { - if let Some((ident, eq)) = &self.alias { - ident.to_tokens(tokens); - eq.to_tokens(tokens); - } - self.expr.to_tokens(tokens); - } -} diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index fa0134b6b..3343582f8 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -29,7 +29,7 @@ impl From<&str> for SplitPattern { #[serde(tag = "type")] #[display( fmt = "Split(patter={}, regex={:?}, behavior={}, invert={})", - "pattern", + pattern, regex, behavior, invert From 9559dea88e0cb797d247bdee2bff8292e21c4897 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 16 Jun 2024 16:35:40 +0200 Subject: [PATCH 69/94] nit --- tokenizers/src/pre_tokenizers/split.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 3343582f8..a3128bdf2 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -7,7 +7,6 @@ use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, Display)] -#[display(fmt = "{}")] pub enum SplitPattern { String(String), Regex(String), From e53f4cac59c1db78b810a05cc932b5d3b582b50e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 16 Jun 2024 18:43:11 +0200 Subject: [PATCH 70/94] current status --- tokenizers/display_derive/src/fmt_parsing.rs | 27 ++++-- tokenizers/display_derive/src/lib.rs | 92 ++++++++++---------- 2 files changed, 69 insertions(+), 50 deletions(-) diff --git a/tokenizers/display_derive/src/fmt_parsing.rs b/tokenizers/display_derive/src/fmt_parsing.rs index b9ebc2c06..006f63dcf 100644 --- a/tokenizers/display_derive/src/fmt_parsing.rs +++ b/tokenizers/display_derive/src/fmt_parsing.rs @@ -4,7 +4,7 @@ use quote::ToTokens; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, - token, Expr, + token, Attribute, Expr, }; /// Representation of a [`fmt`]-like attribute. @@ -18,7 +18,7 @@ pub struct FmtAttribute { /// Interpolation [`syn::LitStr`]. /// /// [`syn::LitStr`]: struct@syn::LitStr - lit: syn::LitStr, + pub lit: syn::LitStr, /// Optional [`token::Comma`]. /// @@ -26,7 +26,7 @@ pub struct FmtAttribute { comma: Option, /// Interpolation arguments. - args: Punctuated, + pub args: Punctuated, } impl Parse for FmtAttribute { @@ -67,11 +67,11 @@ impl ToTokens for FmtAttribute { /// in a [`FmtAttribute`]. /// This should be used in `[display(fmt="", alias=alias, expr)]`. /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters -struct FmtArgument { +pub struct FmtArgument { /// `identifier =` [`Ident`]. /// /// [`Ident`]: struct@syn::Ident - alias: Option<(syn::Ident, token::Eq)>, + pub alias: Option<(syn::Ident, token::Eq)>, /// `expression` [`Expr`]. expr: Expr, @@ -106,3 +106,20 @@ impl ToTokens for FmtArgument { self.expr.to_tokens(tokens) } } + +pub fn find_display_attribute(attrs: &[Attribute]) -> Option { + let display_attr = attrs.iter().find(|attr| attr.path.is_ident("display")); + + let attr: Option = if let Some(attr) = display_attr { + match attr.parse_args::() { + Ok(display_macro) => Some(display_macro), + Err(e) => { + e.to_compile_error(); + None + } + } + } else { + None + }; + attr +} diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 80e0e538e..315135297 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -1,45 +1,31 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, DeriveInput}; +use syn::{parse_macro_input, stringify_punct, DeriveInput}; mod fmt_parsing; -use fmt_parsing::FmtAttribute; +use fmt_parsing::{find_display_attribute, FmtAttribute}; #[proc_macro_derive(Display, attributes(display))] pub fn display_derive(input: TokenStream) -> TokenStream { // Parse the parsed_input tokens into a syntax tree let parsed_input = parse_macro_input!(input as DeriveInput); // Find the `display` attribute - let display_attr = parsed_input - .attrs - .iter() - .find(|attr| attr.path.is_ident("display")); - - let fmt = if let Some(attr) = display_attr { - match attr.parse_args::() { - Ok(display_macro) => quote! { write!(f, #display_macro) }, - Err(e) => return e.to_compile_error().into(), - } - } else { - quote! {} - }; + let attr = find_display_attribute(&parsed_input.attrs); // 1. If the attrs are not None, then we defer to this. // Meaning we juste return quote!{ format!(#fmt, #attr)} let ident = &parsed_input.ident; - let body = if fmt.is_empty() { + let body = { // 2. We automatically parse match &parsed_input.data { - syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident), + syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident, &attr), syn::Data::Enum(e) => generate_fmt_impl_for_enum(e, ident), syn::Data::Union(u) => { let error = syn::Error::new_spanned(u.union_token, "Unions are not supported"); return proc_macro::TokenStream::from(error.into_compile_error()); } } - } else { - fmt }; let expanded = quote! { @@ -57,41 +43,57 @@ pub fn display_derive(input: TokenStream) -> TokenStream { fn generate_fmt_impl_for_struct( data_struct: &syn::DataStruct, ident: &syn::Ident, + attrs: &Option, ) -> proc_macro2::TokenStream { let fields = &data_struct.fields; - // Extract field names and types - let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); - let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); - - quote! { - write!(f, "{}(", stringify!(#ident))?; - let mut first = true; - #( - if !first { - write!(f, ", ")?; - } - first = false; + // Generate field formatting expressions + let field_formats: Vec<_> = fields + .iter() + .map(|f| { + let field_name = &f.ident; + let fmts = find_display_attribute(&f.attrs); - let field_value = &self.#field_names; - write!(f, "{}=", stringify!(#field_names))?; - if std::any::TypeId::of::<#field_types>() == std::any::TypeId::of::() { - println!("We have a string!"); - write!(f, "\"{}\"", field_value)?; + if let Some(attr) = attrs { + if attr.args.is_empty() { + // If there is a prefix and no args, use fmts if it exists + if let Some(fmt) = fmts { + // Combine prefix and fmts + quote! { + write!(f, "{}{}", #fmt.lit.value(), #fmt.args.to_string())?; + } + } else { + // If no fmts, write just the field value + quote! { + write!(f, "{}", self.#field_name)?; + } + } + } else { + // If there are args to the attribute, use attr.lit and attr.args exclusively + quote! { + write!(f, "{}{}", #attr.lit.value(), #attr.args.to_string())?; + } + } } else { - let s = format!("{}", field_value); - let mut chars = s.chars(); - let mut prefix = (&mut chars).take(100 - 1).collect::(); - if chars.next().is_some() { - prefix.push('…'); + // If there is no attribute, print everything directly + quote! { + write!(f, "{}", self.#field_name)?; } - write!(f, "{}", prefix)?; } - )* - write!(f, ")") + }) + .collect(); + + // Generate the final implementation of Display trait for the struct + quote! { + impl std::fmt::Display for #ident { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}(", stringify!(#ident))?; + #(#field_formats)* + write!(f, ")") + } + } } } - fn generate_fmt_impl_for_enum( data_enum: &syn::DataEnum, ident: &syn::Ident, From 18238dd6cb509b038f02a46048ddf6f07f4f2255 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 16 Jun 2024 18:50:06 +0200 Subject: [PATCH 71/94] let's just go with this no it's not optimal but I need to go --- tokenizers/display_derive/src/lib.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 315135297..2114eafec 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -46,7 +46,8 @@ fn generate_fmt_impl_for_struct( attrs: &Option, ) -> proc_macro2::TokenStream { let fields = &data_struct.fields; - + // TODO I am stuck here for now hehe. + // Basically we need to produce the body that will be used. // Generate field formatting expressions let field_formats: Vec<_> = fields .iter() @@ -85,13 +86,9 @@ fn generate_fmt_impl_for_struct( // Generate the final implementation of Display trait for the struct quote! { - impl std::fmt::Display for #ident { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}(", stringify!(#ident))?; - #(#field_formats)* - write!(f, ")") - } - } + write!(f, "{}(", stringify!(#ident))?; + #field_formats + write!(f, ")") } } fn generate_fmt_impl_for_enum( From 269ff217a5cc200e97bb700ffc829ef2585d4b16 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 17 Jun 2024 09:58:42 +0200 Subject: [PATCH 72/94] update --- tokenizers/display_derive/src/lib.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs index 2114eafec..a60486f0c 100644 --- a/tokenizers/display_derive/src/lib.rs +++ b/tokenizers/display_derive/src/lib.rs @@ -49,6 +49,8 @@ fn generate_fmt_impl_for_struct( // TODO I am stuck here for now hehe. // Basically we need to produce the body that will be used. // Generate field formatting expressions + // Will write more later on + let result_stream = quote! { write!(f, "{}", stringify!(#ident))}; let field_formats: Vec<_> = fields .iter() .map(|f| { @@ -66,7 +68,7 @@ fn generate_fmt_impl_for_struct( } else { // If no fmts, write just the field value quote! { - write!(f, "{}", self.#field_name)?; + write!(f, "{}={}", stringify!(#field_name), self.#field_name)?; } } } else { @@ -76,9 +78,9 @@ fn generate_fmt_impl_for_struct( } } } else { - // If there is no attribute, print everything directly + // If there is no attribute, print the default quote! { - write!(f, "{}", self.#field_name)?; + write!(f, "{}={}", stringify!(#field_name), self.#field_name)?; } } }) From 93ad5939a1a69a57cd6d5e1d0285ad2142458429 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 15 Jul 2024 10:39:40 +0200 Subject: [PATCH 73/94] update --- tokenizers/Cargo.toml | 2 +- tokenizers/display_derive/Cargo.toml | 13 -- tokenizers/display_derive/src/fmt_parsing.rs | 125 ----------------- tokenizers/display_derive/src/lib.rs | 131 ------------------ tokenizers/src/decoders/bpe.rs | 2 +- tokenizers/src/decoders/byte_fallback.rs | 2 +- tokenizers/src/decoders/ctc.rs | 2 +- tokenizers/src/decoders/fuse.rs | 2 +- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/decoders/sequence.rs | 2 +- tokenizers/src/decoders/strip.rs | 2 +- tokenizers/src/decoders/wordpiece.rs | 2 +- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/models/unigram/model.rs | 2 +- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 2 +- tokenizers/src/normalizers/bert.rs | 2 +- tokenizers/src/normalizers/mod.rs | 2 +- tokenizers/src/normalizers/prepend.rs | 2 +- tokenizers/src/normalizers/replace.rs | 2 +- tokenizers/src/normalizers/strip.rs | 2 +- tokenizers/src/normalizers/unicode.rs | 2 +- tokenizers/src/normalizers/utils.rs | 2 +- tokenizers/src/pre_tokenizers/bert.rs | 2 +- tokenizers/src/pre_tokenizers/byte_level.rs | 2 +- tokenizers/src/pre_tokenizers/delimiter.rs | 2 +- tokenizers/src/pre_tokenizers/digits.rs | 2 +- tokenizers/src/pre_tokenizers/metaspace.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 2 +- tokenizers/src/pre_tokenizers/punctuation.rs | 2 +- tokenizers/src/pre_tokenizers/sequence.rs | 2 +- tokenizers/src/pre_tokenizers/split.rs | 2 +- .../unicode_scripts/pre_tokenizer.rs | 2 +- tokenizers/src/pre_tokenizers/whitespace.rs | 2 +- tokenizers/src/processors/bert.rs | 2 +- tokenizers/src/processors/mod.rs | 2 +- tokenizers/src/processors/roberta.rs | 2 +- tokenizers/src/processors/sequence.rs | 2 +- tokenizers/src/processors/template.rs | 2 +- tokenizers/src/tokenizer/added_vocabulary.rs | 2 +- tokenizers/src/tokenizer/normalizer.rs | 2 +- tokenizers/src/utils/padding.rs | 2 +- tokenizers/src/utils/truncation.rs | 2 +- 43 files changed, 40 insertions(+), 309 deletions(-) delete mode 100644 tokenizers/display_derive/Cargo.toml delete mode 100644 tokenizers/display_derive/src/fmt_parsing.rs delete mode 100644 tokenizers/display_derive/src/lib.rs diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 492a956d6..80e042b62 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -63,7 +63,7 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" -display_derive = { path = "display_derive" } +pyo3_special_method_derive = "0.3.0" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/display_derive/Cargo.toml b/tokenizers/display_derive/Cargo.toml deleted file mode 100644 index 289d54dd1..000000000 --- a/tokenizers/display_derive/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "display_derive" -version = "0.1.0" -edition = "2021" - -[dependencies] -syn = "1.0" -quote = "1.0" -proc-macro2 = "1.0" -unicode-xid = "0.2.4" - -[lib] -proc-macro = true diff --git a/tokenizers/display_derive/src/fmt_parsing.rs b/tokenizers/display_derive/src/fmt_parsing.rs deleted file mode 100644 index 006f63dcf..000000000 --- a/tokenizers/display_derive/src/fmt_parsing.rs +++ /dev/null @@ -1,125 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; -use quote::ToTokens; -use syn::{ - parse::{Parse, ParseStream}, - punctuated::Punctuated, - token, Attribute, Expr, -}; - -/// Representation of a [`fmt`]-like attribute. -/// -/// ```rust,ignore -/// #[("", )] -/// ``` -/// -/// [`fmt`]: std::fmt -pub struct FmtAttribute { - /// Interpolation [`syn::LitStr`]. - /// - /// [`syn::LitStr`]: struct@syn::LitStr - pub lit: syn::LitStr, - - /// Optional [`token::Comma`]. - /// - /// [`token::Comma`]: struct@token::Comma - comma: Option, - - /// Interpolation arguments. - pub args: Punctuated, -} - -impl Parse for FmtAttribute { - fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { - let _ident: syn::Ident = input - .parse() - .map_err(|_| syn::Error::new(input.span(), "Expected 'fmt' argument"))?; - input - .parse::() - .map_err(|_| syn::Error::new(input.span(), "Expected '=' after 'fmt'"))?; - - let attribute = Self { - lit: input.parse()?, - comma: input - .peek(token::Comma) - .then(|| input.parse()) - .transpose()?, - args: input.parse_terminated::(FmtArgument::parse)?, - }; - println!( - "Parsed successfully!, {:?}\n parsed arguments: {}", - attribute.lit.token().to_string(), - attribute.args.to_token_stream(), - ); - Ok(attribute) - } -} - -impl ToTokens for FmtAttribute { - fn to_tokens(&self, tokens: &mut TokenStream) { - self.lit.to_tokens(tokens); - self.comma.to_tokens(tokens); - self.args.to_tokens(tokens); - } -} - -/// Representation of a [named parameter][1] (`identifier '=' expression`) in -/// in a [`FmtAttribute`]. -/// This should be used in `[display(fmt="", alias=alias, expr)]`. -/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters -pub struct FmtArgument { - /// `identifier =` [`Ident`]. - /// - /// [`Ident`]: struct@syn::Ident - pub alias: Option<(syn::Ident, token::Eq)>, - - /// `expression` [`Expr`]. - expr: Expr, -} - -impl FmtArgument { - /// Returns an `identifier` of the [named parameter][1]. - /// - /// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters - fn alias(&self) -> Option<&syn::Ident> { - self.alias.as_ref().map(|(ident, _)| ident) - } -} - -impl Parse for FmtArgument { - fn parse(input: ParseStream) -> syn::Result { - Ok(Self { - alias: (input.peek(syn::Ident) && input.peek2(token::Eq)) - .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) - .transpose()?, - expr: input.parse()?, - }) - } -} - -impl ToTokens for FmtArgument { - fn to_tokens(&self, tokens: &mut TokenStream) { - if let Some((ident, eq)) = &self.alias { - quote!(self . #ident).to_tokens(tokens); - eq.to_tokens(tokens); - } - self.expr.to_tokens(tokens) - } -} - -pub fn find_display_attribute(attrs: &[Attribute]) -> Option { - let display_attr = attrs.iter().find(|attr| attr.path.is_ident("display")); - - let attr: Option = if let Some(attr) = display_attr { - match attr.parse_args::() { - Ok(display_macro) => Some(display_macro), - Err(e) => { - e.to_compile_error(); - None - } - } - } else { - None - }; - attr -} diff --git a/tokenizers/display_derive/src/lib.rs b/tokenizers/display_derive/src/lib.rs deleted file mode 100644 index a60486f0c..000000000 --- a/tokenizers/display_derive/src/lib.rs +++ /dev/null @@ -1,131 +0,0 @@ -extern crate proc_macro; -use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, stringify_punct, DeriveInput}; - -mod fmt_parsing; -use fmt_parsing::{find_display_attribute, FmtAttribute}; - -#[proc_macro_derive(Display, attributes(display))] -pub fn display_derive(input: TokenStream) -> TokenStream { - // Parse the parsed_input tokens into a syntax tree - let parsed_input = parse_macro_input!(input as DeriveInput); - // Find the `display` attribute - let attr = find_display_attribute(&parsed_input.attrs); - // 1. If the attrs are not None, then we defer to this. - // Meaning we juste return quote!{ format!(#fmt, #attr)} - let ident = &parsed_input.ident; - - let body = { - // 2. We automatically parse - match &parsed_input.data { - syn::Data::Struct(s) => generate_fmt_impl_for_struct(s, ident, &attr), - syn::Data::Enum(e) => generate_fmt_impl_for_enum(e, ident), - syn::Data::Union(u) => { - let error = syn::Error::new_spanned(u.union_token, "Unions are not supported"); - return proc_macro::TokenStream::from(error.into_compile_error()); - } - } - }; - - let expanded = quote! { - impl std::fmt::Display for #ident { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - #body - } - } - }; - - println!("Generated body: \n{}\n", expanded); - TokenStream::from(expanded) -} - -fn generate_fmt_impl_for_struct( - data_struct: &syn::DataStruct, - ident: &syn::Ident, - attrs: &Option, -) -> proc_macro2::TokenStream { - let fields = &data_struct.fields; - // TODO I am stuck here for now hehe. - // Basically we need to produce the body that will be used. - // Generate field formatting expressions - // Will write more later on - let result_stream = quote! { write!(f, "{}", stringify!(#ident))}; - let field_formats: Vec<_> = fields - .iter() - .map(|f| { - let field_name = &f.ident; - let fmts = find_display_attribute(&f.attrs); - - if let Some(attr) = attrs { - if attr.args.is_empty() { - // If there is a prefix and no args, use fmts if it exists - if let Some(fmt) = fmts { - // Combine prefix and fmts - quote! { - write!(f, "{}{}", #fmt.lit.value(), #fmt.args.to_string())?; - } - } else { - // If no fmts, write just the field value - quote! { - write!(f, "{}={}", stringify!(#field_name), self.#field_name)?; - } - } - } else { - // If there are args to the attribute, use attr.lit and attr.args exclusively - quote! { - write!(f, "{}{}", #attr.lit.value(), #attr.args.to_string())?; - } - } - } else { - // If there is no attribute, print the default - quote! { - write!(f, "{}={}", stringify!(#field_name), self.#field_name)?; - } - } - }) - .collect(); - - // Generate the final implementation of Display trait for the struct - quote! { - write!(f, "{}(", stringify!(#ident))?; - #field_formats - write!(f, ")") - } -} -fn generate_fmt_impl_for_enum( - data_enum: &syn::DataEnum, - ident: &syn::Ident, -) -> proc_macro2::TokenStream { - let arms = data_enum.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let formatted_output = match &variant.fields { - syn::Fields::Unit => { - // Unit variant: just stringify the variant name - quote! { #ident::#variant_name => {write!(f, "{}", stringify!(#variant_name))?; }} - }, - syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { - // Tuple variant with one field - quote! { #ident::#variant_name(ref single) => {write!(f, "{}", single)?;} } - }, - syn::Fields::Named(fields) if fields.named.len() == 1 => { - // Tuple variant with one named field - let field_name = fields.named[0].ident.as_ref().unwrap(); // Assuming it's named - quote! { #ident::#variant_name{..}=>{ write!(f, "{}({})", stringify!(self.#field_name)?);} } - }, - _ => { - // Default case: stringify the variant name - quote! { write!(f, "{}", stringify!(#variant_name))?; } - } - }; - formatted_output - }); - - println!("printing ident: {}", ident.to_string()); - quote! { - match *self { - #(#arms)* - } - Ok(()) - } -} diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 5636c6524..430086615 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Decoder, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Display)] /// Allows decoding Original BPE by joining all the tokens and then replacing diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index 8c88199de..04bd691bc 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use display_derive::Display; use monostate::MustBe; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index f96e71f3e..bac298e8a 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,7 +1,7 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; -use display_derive::Display; use itertools::Itertools; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, Display)] diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index b91485eec..583fe4645 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize, Default, Display)] /// Fuse simply fuses all tokens into one big string. diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 1bc6b62bb..83f8ea510 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -21,7 +21,7 @@ use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Clone, Debug, Display)] diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index ae9784f9c..f30e74510 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,7 +1,7 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index 93d085a45..581bb4b48 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] /// Strip is a simple trick which converts tokens looking like `<0x61>` diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index c8bd57c06..acf293931 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Display)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 4e3b7d61b..59cb8146d 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index f00953cb7..06b244d5a 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,7 +6,7 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 09739e07d..f6e85852d 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde_json::Value; use std::collections::HashMap; use std::fs::File; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 8b8737e44..0f3247ee9 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use std::{ borrow::Cow, collections::HashMap, diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 1e8e6ebf8..962393ef0 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index fd097831b..0b1076de2 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -15,7 +15,7 @@ pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; use crate::{NormalizedString, Normalizer}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; /// Wrapper for known Normalizers. diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index a9e6ded60..5b5f2379e 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize, Display)] diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 41a316942..6c3f44127 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,7 +2,7 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index ef298cc03..abbb53976 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; #[derive(Copy, Clone, Debug, Deserialize, Serialize, Display)] diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 8cdfcf1dd..05a4df678 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; #[derive(Default, Copy, Clone, Debug, Display)] #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index c241fcc4d..14e692e69 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use display_derive::Display; #[derive(Clone, Deserialize, Debug, Serialize, Display)] +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; #[display( fmt = "Sequence([{}])", "normalizers.iter().fold(String::new(), |mut acc, d| { diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index 3551a6a63..6e9efd3cb 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use unicode_categories::UnicodeCategories; fn is_bert_punc(x: char) -> bool { diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 0693449de..3232083e1 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; use crate::utils::SysRegex; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index 37428f52c..d857f190b 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,4 +1,4 @@ -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index 393817157..e943b8e5e 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,4 +1,4 @@ -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 96fa66346..3c38f1675 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. #[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, Display)] diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 08166b355..a036dccc9 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -22,7 +22,7 @@ use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] #[display(fmt="pre_tokenizers.{}")] diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index b1cb01323..664154e5a 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,4 +1,4 @@ -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 94c30dcd7..de9e14e8a 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,7 +1,7 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index a3128bdf2..82a7ad86b 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -2,7 +2,7 @@ use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::SysRegex; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use diff --git a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs index cffe5c47e..b693f2f1c 100644 --- a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs +++ b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs @@ -1,4 +1,4 @@ -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script}; use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result}; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 12dd60346..40b5fba04 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,4 +1,4 @@ -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use regex::Regex; use crate::tokenizer::{ diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index 5a4ee43aa..aed0b7e68 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 7e7e50c10..6beecdc64 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -12,7 +12,7 @@ use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; use crate::processors::template::TemplateProcessing; use crate::{Encoding, PostProcessor, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, Display)] diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index b0d40c295..a7e79eaa9 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,6 +1,6 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index d68bfc513..1c86572de 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -1,7 +1,7 @@ use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, Display)] diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 0cbe04ab8..58e9edc63 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -56,8 +56,8 @@ //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! use crate::{Encoding, PostProcessor, Result}; -use display_derive::Display; use itertools::Itertools; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index c8c147b44..2f3b2702d 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,7 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index ff44bfe56..507606300 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,6 +1,6 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 60b30786a..35a9560fa 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,6 +1,6 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; /// The various possible padding directions. diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index f73208a51..95dfb5fa2 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Encoding, Result}; -use display_derive::Display; +use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; From 3aa0138f93342a9ac36c54da66016417cf2b07b0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 19 Jul 2024 21:09:14 +0200 Subject: [PATCH 74/94] derive auto display --- bindings/python/Cargo.toml | 3 ++- bindings/python/src/decoders.rs | 10 ++++---- bindings/python/src/models.rs | 5 ++-- bindings/python/src/normalizers.rs | 10 ++++---- bindings/python/src/pre_tokenizers.rs | 23 ++++--------------- bindings/python/src/processors.rs | 4 ++-- bindings/python/src/tokenizer.rs | 4 ++-- tokenizers/src/decoders/bpe.rs | 4 ++-- tokenizers/src/decoders/byte_fallback.rs | 6 ++--- tokenizers/src/decoders/ctc.rs | 4 ++-- tokenizers/src/decoders/fuse.rs | 6 ++--- tokenizers/src/decoders/mod.rs | 6 ++--- tokenizers/src/decoders/sequence.rs | 4 ++-- tokenizers/src/decoders/strip.rs | 4 ++-- tokenizers/src/decoders/wordpiece.rs | 4 ++-- tokenizers/src/models/bpe/model.rs | 2 +- tokenizers/src/models/mod.rs | 4 ++-- tokenizers/src/models/unigram/model.rs | 17 ++------------ tokenizers/src/models/wordlevel/mod.rs | 16 +++---------- tokenizers/src/models/wordpiece/mod.rs | 18 ++++----------- tokenizers/src/normalizers/bert.rs | 11 ++------- tokenizers/src/normalizers/mod.rs | 8 +++---- tokenizers/src/normalizers/prepend.rs | 4 ++-- tokenizers/src/normalizers/replace.rs | 11 +++------ tokenizers/src/normalizers/strip.rs | 6 ++--- tokenizers/src/normalizers/unicode.rs | 12 +++++----- tokenizers/src/normalizers/utils.rs | 16 +++---------- tokenizers/src/pre_tokenizers/bert.rs | 4 ++-- tokenizers/src/pre_tokenizers/byte_level.rs | 4 ++-- tokenizers/src/pre_tokenizers/delimiter.rs | 4 ++-- tokenizers/src/pre_tokenizers/digits.rs | 4 ++-- tokenizers/src/pre_tokenizers/metaspace.rs | 13 ++++------- tokenizers/src/pre_tokenizers/mod.rs | 6 ++--- tokenizers/src/pre_tokenizers/punctuation.rs | 4 ++-- tokenizers/src/pre_tokenizers/sequence.rs | 14 ++--------- tokenizers/src/pre_tokenizers/split.rs | 13 +++-------- .../unicode_scripts/pre_tokenizer.rs | 4 ++-- tokenizers/src/pre_tokenizers/whitespace.rs | 6 ++--- tokenizers/src/processors/bert.rs | 5 ++-- tokenizers/src/processors/mod.rs | 6 ++--- tokenizers/src/processors/roberta.rs | 11 ++------- tokenizers/src/processors/sequence.rs | 14 ++--------- tokenizers/src/processors/template.rs | 5 ++-- tokenizers/src/tokenizer/added_vocabulary.rs | 11 +++++---- tokenizers/src/tokenizer/normalizer.rs | 4 ++-- tokenizers/src/utils/padding.rs | 18 ++++----------- tokenizers/src/utils/truncation.rs | 8 +++---- 47 files changed, 131 insertions(+), 249 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 14050874d..1b0fde555 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,12 +14,13 @@ serde = { version = "1.0", features = [ "rc", "derive" ]} serde_json = "1.0" libc = "0.2" env_logger = "0.11" -pyo3 = { version = "0.21" } numpy = "0.21" ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" +pyo3 = { version = "0.21.2", features = ["multiple-pymethods"] } +pyo3_special_method_derive = "0.3" [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 3692fd9f0..6944e8c74 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, RwLock}; use crate::pre_tokenizers::from_string; use crate::utils::PyPattern; -use derive_more::Display; +use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -29,7 +29,7 @@ use super::error::ToPyResult; /// This class is not supposed to be instantiated directly. Instead, any implementation of /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] -#[derive(Clone, Deserialize, Serialize, Display)] +#[derive(Clone, Deserialize, Serialize, AutoDisplay)] pub struct PyDecoder { #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, @@ -487,7 +487,7 @@ impl PySequenceDecoder { } } -#[derive(Clone, Display)] +#[derive(Clone, AutoDisplay)] pub(crate) struct CustomDecoder { pub inner: PyObject, } @@ -540,12 +540,10 @@ impl<'de> Deserialize<'de> for CustomDecoder { } } -#[derive(Clone, Deserialize, Serialize, Display)] +#[derive(Clone, Deserialize, Serialize, AutoDisplay)] #[serde(untagged)] pub(crate) enum PyDecoderWrapper { - #[display(fmt = "{}", "_0.as_ref().read().unwrap().inner")] Custom(Arc>), - #[display(fmt = "{}", "_0.as_ref().read().unwrap()")] Wrapped(Arc>), } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index cdef40735..33c370fd5 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, RwLock}; use super::error::{deprecation_warning, ToPyResult}; use crate::token::PyToken; use crate::trainers::PyTrainer; -use derive_more::Display; +use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -25,8 +25,7 @@ use tokenizers as tk; /// /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] -#[derive(Clone, Serialize, Deserialize, Display)] -#[display(fmt = "{}", "model.as_ref().read().unwrap()")] +#[derive(Clone, Serialize, Deserialize, AutoDisplay)] pub struct PyModel { #[serde(flatten)] pub model: Arc>, diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 3bc1a6e21..f21642718 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, RwLock}; use crate::error::ToPyResult; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; -use derive_more::Display; +use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -43,7 +43,7 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize, Display, Debug)] +#[derive(Clone, Serialize, Deserialize, AutoDisplay, Debug)] pub struct PyNormalizer { #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, @@ -505,7 +505,7 @@ impl PyReplace { } } -#[derive(Debug, Clone, Display)] +#[derive(Debug, Clone, AutoDisplay)] pub(crate) struct CustomNormalizer { inner: PyObject, } @@ -548,12 +548,10 @@ impl<'de> Deserialize<'de> for CustomNormalizer { } } -#[derive(Debug, Clone, Deserialize, Display)] +#[derive(Debug, Clone, Deserialize, AutoDisplay)] #[serde(untagged)] pub(crate) enum PyNormalizerWrapper { - #[display(fmt = "{}", "_0.inner")] Custom(CustomNormalizer), - #[display(fmt = "{}", "_0")] Wrapped(NormalizerWrapper), } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 3c1366c75..b71d2dbb7 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -23,7 +23,7 @@ use tokenizers as tk; use super::error::ToPyResult; use super::utils::*; -use derive_more::Display; +use pyo3_special_method_derive::AutoDisplay; /// Base class for all pre-tokenizers /// /// This class is not supposed to be instantiated directly. Instead, any implementation of a @@ -34,7 +34,7 @@ use derive_more::Display; name = "PreTokenizer", subclass )] -#[derive(Clone, Serialize, Deserialize, Display)] +#[derive(Clone, Serialize, Deserialize, Str, Repr, Dir, Dict)] pub struct PyPreTokenizer { #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, @@ -595,7 +595,7 @@ impl PyUnicodeScripts { } } -#[derive(Clone, Display)] +#[derive(Clone, AutoDisplay)] pub(crate) struct CustomPreTokenizer { inner: PyObject, } @@ -639,7 +639,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { } } -#[derive(Clone, Deserialize, Display)] +#[derive(Clone, Deserialize, AutoDisplay)] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Custom(CustomPreTokenizer), @@ -658,23 +658,10 @@ impl Serialize for PyPreTokenizerWrapper { } } -#[derive(Clone, Deserialize, Display)] +#[derive(Clone, Deserialize, AutoDisplay)] #[serde(untagged)] pub(crate) enum PyPreTokenizerTypeWrapper { - #[display( - fmt = "[{}]", - "_0.iter() - .map(|d| d.as_ref().read().unwrap().to_string()) - .fold(String::new(), |mut acc, s| { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&s); - acc - })" - )] Sequence(Vec>>), - #[display(fmt = "{}", "_0.as_ref().read().unwrap()")] Single(Arc>), } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 130440b55..3341e68b7 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::encoding::PyEncoding; use crate::error::ToPyResult; -use derive_more::Display; +use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -27,7 +27,7 @@ use tokenizers as tk; name = "PostProcessor", subclass )] -#[derive(Clone, Deserialize, Serialize, Display)] +#[derive(Clone, Deserialize, Serialize, AutoDisplay)] pub struct PyPostProcessor { #[serde(flatten)] pub processor: Arc, diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index f5c04d4d6..5976f6214 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -10,7 +10,7 @@ use super::pre_tokenizers::PyPreTokenizer; use super::trainers::PyTrainer; use crate::processors::PyPostProcessor; use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; -use derive_more::Display; +use pyo3_special_method_derive::AutoDisplay; use numpy::{npyffi, PyArray1}; use pyo3::class::basic::CompareOp; use pyo3::exceptions; @@ -462,7 +462,7 @@ type Tokenizer = TokenizerImpl` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token #[non_exhaustive] -#[display(fmt = "ByteFallback")] +#[auto_display(fmt = "ByteFallback")] pub struct ByteFallback { #[serde(rename = "type")] type_: MustBe!("ByteFallback"), diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index bac298e8a..d53fc003a 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,10 +1,10 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; use itertools::Itertools; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize, Display)] +#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] /// The CTC (Connectionist Temporal Classification) decoder takes care /// of sanitizing a list of inputs token. /// Due to some alignement problem the output of some models can come diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index 583fe4645..a75977493 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,14 +1,14 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize, Default, Display)] +#[derive(Clone, Debug, Serialize, Deserialize, Default, AutoDisplay)] /// Fuse simply fuses all tokens into one big string. /// It's usually the last decoding step anyway, but this /// decoder exists incase some decoders need to happen after that /// step #[non_exhaustive] -#[display(fmt = "Fuse")] +#[auto_display(fmt = "Fuse")] pub struct Fuse { #[serde(rename = "type")] type_: MustBe!("Fuse"), diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 83f8ea510..233c33e30 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -21,11 +21,11 @@ use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Clone, Debug, Display)] -#[display(fmt="decoders.{}")] +#[derive(Serialize, Deserialize, Clone, Debug, AutoDisplay)] +#[auto_display(fmt="decoders.{}")] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index f30e74510..20b48aa3a 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,11 +1,11 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, Debug, Display)] +#[derive(Clone, Debug, AutoDisplay)] #[display( fmt = "Sequence([{}])", "decoders.iter().map(|d| d.to_string()).fold( String::new(), |mut acc, s|{ diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index 581bb4b48..01a35fc75 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Default, Display)] +#[derive(Deserialize, Clone, Debug, Serialize, Default, AutoDisplay)] /// Strip is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index acf293931..69d168d4e 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Display)] +#[derive(Deserialize, Clone, Debug, Serialize, AutoDisplay)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens /// back into a readable string. #[serde(tag = "type")] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index f0fae9694..749f7461f 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -248,7 +248,7 @@ impl std::fmt::Debug for BPE { } } -impl std::fmt::Display for BPE { +impl std::fmt::AutoDisplay for BPE { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut vocab_vec: Vec<_> = self.vocab.iter().collect(); vocab_vec.sort_by_key(|&(_, v)| v); diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 59cb8146d..4806a3f53 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; @@ -58,7 +58,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { } } -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, Display)] +#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, AutoDisplay)] #[serde(untagged)] pub enum ModelWrapper { BPE(BPE), diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 06b244d5a..9915ce4d5 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,7 +6,7 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; @@ -15,20 +15,7 @@ type TokenMap = HashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. -#[derive(Display)] -#[display( - fmt = "Unigram(vocab={{{}, ...}}, unk_id={}, bos_id={}, eos_id={})", - "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&format!(\"\'{}\': {}\", key, value)); - acc -})", - "unk_id.unwrap()", - bos_id, - eos_id -)] +#[derive(AutoDisplay)] pub struct Unigram { token_to_ids: TokenMap, pub(crate) vocab: Vocab, diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index f6e85852d..0ad0412ad 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde_json::Value; use std::collections::HashMap; use std::fs::File; @@ -94,19 +94,9 @@ impl WordLevelBuilder { } } -#[derive(PartialEq, Clone, Eq, Display)] -#[display( - fmt = "WordLevel(vocab={{{}, ...}}, unk_token={})", - "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&format!(\"\'{}\': {}\", key, value)); - acc -})", - unk_token -)] +#[derive(PartialEq, Clone, Eq, AutoDisplay)] pub struct WordLevel { + #[auto_display] vocab: HashMap, vocab_r: HashMap, pub unk_token: String, diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 0f3247ee9..ffd8f9b05 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use std::{ borrow::Cow, collections::HashMap, @@ -119,24 +119,14 @@ impl WordPieceBuilder { /// A /// [WordPiece](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf) /// model. -#[derive(Clone, PartialEq, Eq, Display)] -#[display( - fmt = "WordPiece(vocab={}, unk_token={}, continuing_subword_prefix={:?})", - "vocab.iter().take(5).fold(String::new(), |mut acc, (key, value)| { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&format!(\"\'{}\': {}\", key, value)); - acc - })", - unk_token, - continuing_subword_prefix -)] +#[derive(Clone, PartialEq, Eq, AutoDisplay)] pub struct WordPiece { + #[auto_display] vocab: Vocab, vocab_r: VocabR, pub unk_token: String, pub continuing_subword_prefix: String, + #[auto_display(skip)] pub max_input_chars_per_word: usize, } diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 962393ef0..255642d30 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace @@ -47,14 +47,7 @@ fn is_chinese_char(c: char) -> bool { ) } -#[derive(Copy, Clone, Debug, Deserialize, Serialize, Display)] -#[display( - fmt = "BertNormalizer(clean_text={}, handle_chinese_chars={}, strip_accents={:?}, lower_case={})", - clean_text, - handle_chinese_chars, - strip_accents, - lowercase -)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct BertNormalizer { diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 0b1076de2..3c268820a 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -15,13 +15,13 @@ pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; use crate::{NormalizedString, Normalizer}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; /// Wrapper for known Normalizers. -#[derive(Clone, Debug, Deserialize, Serialize, Display)] +#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] #[serde(untagged)] -#[display(fmt = "normalizers.{}")] +#[auto_display(fmt = "normalizers.{}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), @@ -33,7 +33,7 @@ pub enum NormalizerWrapper { Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), - #[display(fmt = "Precompiled()")] + #[auto_display(fmt = "Precompiled()")] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index 5b5f2379e..27b0a3b50 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Deserialize, Serialize, Display)] +#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] pub struct Prepend { pub prepend: String, diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 6c3f44127..0bbedae51 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,7 +2,7 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] @@ -42,16 +42,11 @@ impl std::convert::TryFrom for Replace { /// This normalizer will take a `pattern` (for now only a String) /// and replace every occurrence with `content`. -#[derive(Debug, Serialize, Deserialize, Display)] +#[derive(Debug, Serialize, Deserialize, AutoDisplay)] #[serde(tag = "type", try_from = "ReplaceDeserializer")] -#[display( - fmt = "Replace(pattern={:?}, content=\"{}\", regex={:?}", - pattern, - content, - regex -)] pub struct Replace { pattern: ReplacePattern, + #[auto_display] content: String, #[serde(skip)] regex: SysRegex, diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index abbb53976..78d517862 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,9 +1,9 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; -#[derive(Copy, Clone, Debug, Deserialize, Serialize, Display)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct Strip { @@ -43,7 +43,7 @@ impl Normalizer for Strip { // This normalizer removes combining marks from a normalized string // It's different from unidecode as it does not attempt to modify // non ascii languages. -#[derive(Copy, Clone, Debug, Display)] +#[derive(Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct StripAccents; diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 05a4df678..a203db700 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; -#[derive(Default, Copy, Clone, Debug, Display)] +#[derive(Default, Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFD; impl Normalizer for NFD { @@ -12,7 +12,7 @@ impl Normalizer for NFD { } } -#[derive(Default, Copy, Clone, Debug, Display)] +#[derive(Default, Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKD; impl Normalizer for NFKD { @@ -22,7 +22,7 @@ impl Normalizer for NFKD { } } -#[derive(Default, Copy, Clone, Debug, Display)] +#[derive(Default, Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFC; impl Normalizer for NFC { @@ -32,7 +32,7 @@ impl Normalizer for NFC { } } -#[derive(Default, Copy, Clone, Debug, Display)] +#[derive(Default, Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKC; impl Normalizer for NFKC { @@ -73,7 +73,7 @@ fn do_nmt(normalized: &mut NormalizedString) { }); } -#[derive(Default, Copy, Clone, Debug, Display)] +#[derive(Default, Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Nmt; impl Normalizer for Nmt { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index 14e692e69..fd9f15f25 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,18 +3,8 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Deserialize, Debug, Serialize, Display)] -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; -#[display( - fmt = "Sequence([{}])", - "normalizers.iter().fold(String::new(), |mut acc, d| { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&d.to_string()); - acc -})" -)] +use pyo3_special_method_derive::AutoDisplay; +#[derive(Clone, Deserialize, Debug, Serialize, AutoDisplay)] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. @@ -46,7 +36,7 @@ impl Normalizer for Sequence { } /// Lowercases the input -#[derive(Copy, Clone, Debug, Display)] +#[derive(Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Lowercase; diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index 6e9efd3cb..5a8f1ebc0 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -1,13 +1,13 @@ use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use unicode_categories::UnicodeCategories; fn is_bert_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct BertPreTokenizer; diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 3232083e1..e31378b46 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; use crate::utils::SysRegex; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. @@ -50,7 +50,7 @@ lazy_static! { /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. #[macro_rules_attribute(impl_serde_type!)] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index d857f190b..fcfd13aee 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,10 +1,10 @@ -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] #[non_exhaustive] #[macro_rules_attribute(impl_serde_type!)] pub struct CharDelimiterSplit { diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index e943b8e5e..3cb3326b4 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,10 +1,10 @@ -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, Display)] +#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] /// Pre tokenizes the numbers into single tokens. If individual_digits is set /// to true, then all digits are splitted into individual tokens. #[non_exhaustive] diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 3c38f1675..83744eed2 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. -#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, Display)] +#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, AutoDisplay)] #[serde(rename_all = "snake_case")] pub enum PrependScheme { /// Specifies that the scheme should be prepended only once, on the first split. @@ -13,16 +13,11 @@ pub enum PrependScheme { Always, } -#[derive(Debug, Clone, PartialEq, Serialize, Eq, Display)] +#[derive(Debug, Clone, PartialEq, Serialize, Eq, AutoDisplay)] /// Replaces all the whitespaces by the provided meta character and then /// splits on this character #[serde(tag = "type")] -#[display( - fmt = "Metaspace(replacement='{}', prepend_scheme={:?}, split={})", - replacement, - "prepend_scheme.to_string().to_lowercase()", - split -)] + pub struct Metaspace { replacement: char, pub prepend_scheme: PrependScheme, diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index a036dccc9..617442d9b 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -22,10 +22,10 @@ use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Display)] -#[display(fmt="pre_tokenizers.{}")] +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, AutoDisplay)] +#[auto_display(fmt="pre_tokenizers.{}")] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index 664154e5a..a61ce16f0 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; @@ -9,7 +9,7 @@ fn is_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Punctuation { #[serde(default = "default_split")] diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index de9e14e8a..8f5c591b5 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,21 +1,11 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, Debug, PartialEq, Display)] -#[display( - fmt = "Seqence([{}])", - "pretokenizers.iter().fold(String::new(), |mut acc, p| { - if !acc.is_empty(){ - acc.push_str(\", \") - } - acc.push_str(&p.to_string()); - acc - })" -)] +#[derive(Clone, Debug, PartialEq, AutoDisplay)] pub struct Sequence { pretokenizers: Vec, } diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 82a7ad86b..2edcd723b 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -2,11 +2,11 @@ use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::SysRegex; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, Display)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] pub enum SplitPattern { String(String), Regex(String), @@ -24,15 +24,8 @@ impl From<&str> for SplitPattern { } } -#[derive(Debug, Serialize, Display)] +#[derive(Debug, Serialize, AutoDisplay)] #[serde(tag = "type")] -#[display( - fmt = "Split(patter={}, regex={:?}, behavior={}, invert={})", - pattern, - regex, - behavior, - invert -)] pub struct Split { pattern: SplitPattern, #[serde(skip)] diff --git a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs index b693f2f1c..405810a95 100644 --- a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs +++ b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs @@ -1,10 +1,10 @@ -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script}; use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, Display)] +#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct UnicodeScripts; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 40b5fba04..3d57a76f5 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay use regex::Regex; use crate::tokenizer::{ @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, Display)] +#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Whitespace; @@ -29,7 +29,7 @@ impl PreTokenizer for Whitespace { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct WhitespaceSplit; diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index aed0b7e68..8f0ef327c 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,12 +1,11 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Display)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, AutoDisplay)] #[serde(tag = "type")] -#[display(fmt = "BertProcessing(sep={:?}, cls={:?})", sep, cls)] pub struct BertProcessing { sep: (String, u32), cls: (String, u32), diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 6beecdc64..742a58e7a 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -12,11 +12,11 @@ use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; use crate::processors::template::TemplateProcessing; use crate::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, Display)] -#[display(fmt="processors.{}")] +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, AutoDisplay)] +#[auto_display(fmt="processors.{}")] #[serde(untagged)] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index a7e79eaa9..248858ec7 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,18 +1,11 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Display)] -#[display( - fmt = "RobertaProcessing(sep={:?}, cls={:?}, trim_offsets={}, add_prefix_space={}", - sep, - cls, - trim_offsets, - add_prefix_space -)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, AutoDisplay)] #[serde(tag = "type")] pub struct RobertaProcessing { sep: (String, u32), diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 1c86572de..39b34d08e 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -1,20 +1,10 @@ use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, Debug, PartialEq, Eq, Display)] -#[display( - fmt = "Sequence([{}])", - "processors.iter().fold(String::new(), |mut acc, p| { - if !acc.is_empty() { - acc.push_str(\", \"); - } - acc.push_str(&p.to_string()); - acc -})" -)] +#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] pub struct Sequence { processors: Vec, } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 58e9edc63..3cf8cc932 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -57,7 +57,7 @@ //! use crate::{Encoding, PostProcessor, Result}; use itertools::Itertools; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; @@ -332,9 +332,8 @@ impl From> for Tokens { /// .unwrap(); /// ``` /// -#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, Display)] +#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, AutoDisplay)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] -#[display(fmt = "TemplateProcessing(single={:?}, pair={:?})", single, pair)] #[builder(build_fn(validate = "Self::validate"))] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 2f3b2702d..b9edb282b 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,7 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet}; /// like: /// - Whether they should only match single words /// - Whether to include any whitespace on its left or right -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Display)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, AutoDisplay)] pub struct AddedToken { /// The content of the added token pub content: String, @@ -139,14 +139,15 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// were to add new tokens after this training process, we couldn't make sure the merges pairs /// exist as required. /// -#[derive(Clone, Debug, Display)] -#[display(fmt="AddedVocabulary(added_tokens_map_r={{{}, ...}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| {if let Some(token) = added_tokens_map_r.get(&key){if !acc.is_empty(){acc.push_str(\", \");}acc.push_str(&format!(\"\n\t{}: {}\", key, &token.to_string()));}acc})", encode_special_tokens)] +#[derive(Clone, Debug, AutoDisplay)] +#[auto_display(fmt="AddedVocabulary(added_tokens_map_r={{{}, ...}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| {if let Some(token) = added_tokens_map_r.get(&key){if !acc.is_empty(){acc.push_str(\", \");}acc.push_str(&format!(\"\n\t{}: {}\", key, &token.to_string()));}acc})", encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. added_tokens_map: HashMap, /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. + #[auto_display] added_tokens_map_r: HashMap, /// Contains only the classic AddedToken, in the specific order the user gave them. @@ -156,6 +157,7 @@ pub struct AddedVocabulary { /// A Set, containing all the special token for easy access while decoding. This let's /// us remove them easily with an O(1) complexity. + #[auto_display] special_tokens_set: HashSet, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens @@ -164,6 +166,7 @@ pub struct AddedVocabulary { split_normalized_trie: MatchingSet, /// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them + #[auto_display] encode_special_tokens: bool, } diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 507606300..125a8f880 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,6 +1,6 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; @@ -78,7 +78,7 @@ where /// - MergedWithPrevious => `[ "the-", "final-", "-", "countdown" ]` /// - MergedWithNext => `[ "the", "-final", "-", "-countdown" ]` /// - Contiguous => `[ "the", "-", "final", "--", "countdown" ]` -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Display)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] pub enum SplitDelimiterBehavior { Removed, Isolated, diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 35a9560fa..bc1445070 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,10 +1,10 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; /// The various possible padding directions. -#[derive(Debug, Clone, Copy, Serialize, Deserialize, Display)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, AutoDisplay)] pub enum PaddingDirection { Left, Right, @@ -19,16 +19,8 @@ impl std::convert::AsRef for PaddingDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize, Display)] -#[display( - fmt = "strategy={}, direction={}, pad_to_multiple_of={}, pad_id={}, pad_type_id={}, pad_token={}", - strategy, - direction, - "pad_to_multiple_of.unwrap()", - pad_id, - pad_type_id, - pad_token -)] +#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] +#[auto_display(fmt = "")] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, @@ -51,7 +43,7 @@ impl Default for PaddingParams { } } -#[derive(Debug, Clone, Serialize, Deserialize, Display)] +#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] pub enum PaddingStrategy { BatchLongest, Fixed(usize), diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 95dfb5fa2..9cec977ca 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,10 +1,10 @@ use crate::tokenizer::{Encoding, Result}; -use pyo3_special_method_derive::{Dict, Dir, Getattr, Repr, Str}; +use pyo3_special_method_derive::AutoDisplay; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default, Display)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default, AutoDisplay)] pub enum TruncationDirection { Left, #[default] @@ -20,7 +20,7 @@ impl std::convert::AsRef for TruncationDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize, Display)] +#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, @@ -50,7 +50,7 @@ pub enum TruncationError { SequenceTooShort, } -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Display)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, AutoAutoDisplay)] pub enum TruncationStrategy { LongestFirst, OnlyFirst, From 011340b341394a275bd40741ed81be3eb450c9e9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 19 Jul 2024 21:14:19 +0200 Subject: [PATCH 75/94] nit --- tokenizers/Cargo.toml | 1 + tokenizers/src/models/bpe/model.rs | 2 +- tokenizers/src/pre_tokenizers/whitespace.rs | 2 +- tokenizers/src/utils/truncation.rs | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 80e042b62..1cbea3d5b 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -64,6 +64,7 @@ getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" pyo3_special_method_derive = "0.3.0" +pyo3_special_method_derive_lib = "0.3.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 749f7461f..f0fae9694 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -248,7 +248,7 @@ impl std::fmt::Debug for BPE { } } -impl std::fmt::AutoDisplay for BPE { +impl std::fmt::Display for BPE { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut vocab_vec: Vec<_> = self.vocab.iter().collect(); vocab_vec.sort_by_key(|&(_, v)| v); diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 3d57a76f5..efd0f9202 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::AutoDisplay +use pyo3_special_method_derive::AutoDisplay; use regex::Regex; use crate::tokenizer::{ diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 9cec977ca..5cf47464d 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -50,7 +50,7 @@ pub enum TruncationError { SequenceTooShort, } -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, AutoAutoDisplay)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] pub enum TruncationStrategy { LongestFirst, OnlyFirst, From 3fc31d00616e0a1cd4e112b1f0522d1ac13d6440 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 19 Jul 2024 21:15:36 +0200 Subject: [PATCH 76/94] nice --- tokenizers/src/decoders/sequence.rs | 10 ---------- tokenizers/src/utils/padding.rs | 1 + tokenizers/src/utils/truncation.rs | 1 + 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index 20b48aa3a..f4e6cfb20 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -6,16 +6,6 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, AutoDisplay)] -#[display( - fmt = "Sequence([{}])", - "decoders.iter().map(|d| d.to_string()).fold( String::new(), |mut acc, s|{ - if !acc.is_empty(){ - acc.push_str(\", \"); - } - acc.push_str(&s); - acc - })" -)] pub struct Sequence { decoders: Vec, } diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index bc1445070..3e311f9b5 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -24,6 +24,7 @@ impl std::convert::AsRef for PaddingDirection { pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, + #[auto_display(skip)] // usize not supported yet pub pad_to_multiple_of: Option, pub pad_id: u32, pub pad_type_id: u32, diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 5cf47464d..541df0e7f 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -24,6 +24,7 @@ impl std::convert::AsRef for TruncationDirection { pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, + #[auto_display(skip)] pub max_length: usize, pub strategy: TruncationStrategy, pub stride: usize, From 51d3f61ca41474c2ba49559464e7558bde859702 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 19 Jul 2024 22:11:20 +0200 Subject: [PATCH 77/94] updates --- bindings/python/Cargo.toml | 2 +- tokenizers/Cargo.toml | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 1b0fde555..4000de28d 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -20,7 +20,7 @@ onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" pyo3 = { version = "0.21.2", features = ["multiple-pymethods"] } -pyo3_special_method_derive = "0.3" +pyo3_special_method_derive = "0.4" [dependencies.tokenizers] path = "../../tokenizers" diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 1cbea3d5b..b186b214d 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -63,8 +63,7 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" -pyo3_special_method_derive = "0.3.0" -pyo3_special_method_derive_lib = "0.3.1" +pyo3_special_method_derive = "0.4" [features] default = ["progressbar", "onig", "esaxx_fast"] From 951b6e6464b563bf82a19ab08a9c3461cbb10741 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 19 Jul 2024 22:14:21 +0200 Subject: [PATCH 78/94] deos --- bindings/python/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 4000de28d..4b77bb977 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -19,7 +19,7 @@ ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" -pyo3 = { version = "0.21.2", features = ["multiple-pymethods"] } +pyo3 = { version = "0.22", features = ["multiple-pymethods"] } pyo3_special_method_derive = "0.4" [dependencies.tokenizers] @@ -27,7 +27,7 @@ path = "../../tokenizers" [dev-dependencies] tempfile = "3.10" -pyo3 = { version = "0.21", features = ["auto-initialize"] } +pyo3 = { version = "0.22", features = ["auto-initialize"] } [features] defaut = ["pyo3/extension-module"] From 2048c02efb60953024ad4cdbb8477d6d72b6b02b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 19 Jul 2024 22:18:22 +0200 Subject: [PATCH 79/94] fix build --- tokenizers/src/normalizers/mod.rs | 1 - tokenizers/src/tokenizer/mod.rs | 9 +-------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index af968ef47..6122967d9 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -18,7 +18,6 @@ use serde::{Deserialize, Serialize}; use crate::{NormalizedString, Normalizer}; use pyo3_special_method_derive::AutoDisplay; -use serde::{Deserialize, Serialize}; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 789f99134..0cec8f38c 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1347,13 +1347,6 @@ where #[cfg(test)] mod tests { use super::Tokenizer; - - #[cfg(feature = "http")] - #[test] - fn test_from_pretrained() { - let tok = Tokenizer::from_pretrained("Qwen/Qwen2-7B-Instruct".to_string(), None); - println!("ROCK!") - use crate::AddedToken; use crate::Tokenizer; @@ -1407,4 +1400,4 @@ mod tests { let decoded = tokenizer.decode(encoded.get_ids(), false); assert_eq!(decoded.unwrap(), "Hey! how is this token: д") } -} +} \ No newline at end of file From 104fe0cb2b9711176fea888b46017697d0ed32cd Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sat, 20 Jul 2024 01:28:01 -0400 Subject: [PATCH 80/94] Use pyo3 smd v0.21 (#1574) --- bindings/python/Cargo.toml | 6 +++--- bindings/python/src/decoders.rs | 18 ++++++++++++++++-- bindings/python/src/models.rs | 2 +- bindings/python/src/normalizers.rs | 2 +- bindings/python/src/pre_tokenizers.rs | 10 +--------- bindings/python/src/processors.rs | 2 +- bindings/python/src/tokenizer.rs | 10 +--------- tokenizers/Cargo.toml | 2 +- tokenizers/src/decoders/bpe.rs | 2 +- tokenizers/src/decoders/byte_fallback.rs | 2 +- tokenizers/src/decoders/ctc.rs | 2 +- tokenizers/src/decoders/fuse.rs | 2 +- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/decoders/sequence.rs | 2 +- tokenizers/src/decoders/strip.rs | 2 +- tokenizers/src/decoders/wordpiece.rs | 2 +- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/models/unigram/model.rs | 2 +- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 2 +- tokenizers/src/normalizers/bert.rs | 2 +- tokenizers/src/normalizers/mod.rs | 2 +- tokenizers/src/normalizers/prepend.rs | 2 +- tokenizers/src/normalizers/replace.rs | 2 +- tokenizers/src/normalizers/strip.rs | 2 +- tokenizers/src/normalizers/unicode.rs | 2 +- tokenizers/src/normalizers/utils.rs | 2 +- tokenizers/src/pre_tokenizers/bert.rs | 2 +- tokenizers/src/pre_tokenizers/byte_level.rs | 2 +- tokenizers/src/pre_tokenizers/delimiter.rs | 2 +- tokenizers/src/pre_tokenizers/digits.rs | 2 +- tokenizers/src/pre_tokenizers/metaspace.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 2 +- tokenizers/src/pre_tokenizers/punctuation.rs | 2 +- tokenizers/src/pre_tokenizers/sequence.rs | 2 +- tokenizers/src/pre_tokenizers/split.rs | 2 +- .../unicode_scripts/pre_tokenizer.rs | 2 +- tokenizers/src/pre_tokenizers/whitespace.rs | 2 +- tokenizers/src/processors/bert.rs | 2 +- tokenizers/src/processors/mod.rs | 2 +- tokenizers/src/processors/roberta.rs | 2 +- tokenizers/src/processors/sequence.rs | 2 +- tokenizers/src/processors/template.rs | 2 +- tokenizers/src/tokenizer/added_vocabulary.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 1 - tokenizers/src/tokenizer/normalizer.rs | 2 +- tokenizers/src/utils/padding.rs | 2 +- tokenizers/src/utils/truncation.rs | 2 +- 48 files changed, 64 insertions(+), 67 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 4b77bb977..5fae356fc 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -19,15 +19,15 @@ ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" -pyo3 = { version = "0.22", features = ["multiple-pymethods"] } -pyo3_special_method_derive = "0.4" +pyo3 = { version = "0.21", features = ["multiple-pymethods"] } +pyo3_special_method_derive_0_21 = "0.4" [dependencies.tokenizers] path = "../../tokenizers" [dev-dependencies] tempfile = "3.10" -pyo3 = { version = "0.22", features = ["auto-initialize"] } +pyo3 = { version = "0.21", features = ["auto-initialize"] } [features] defaut = ["pyo3/extension-module"] diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 6944e8c74..f182a6f56 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -2,10 +2,12 @@ use std::sync::{Arc, RwLock}; use crate::pre_tokenizers::from_string; use crate::utils::PyPattern; -use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::PyDebug; +use pyo3_special_method_derive_0_21::PyDisplay; use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::decoders::bpe::BPEDecoder; @@ -487,11 +489,23 @@ impl PySequenceDecoder { } } -#[derive(Clone, AutoDisplay)] +#[derive(Clone)] pub(crate) struct CustomDecoder { pub inner: PyObject, } +impl PyDisplay for CustomDecoder { + fn fmt_display(&self) -> String { + "CustomDecoder()".to_string() + } +} + +impl PyDebug for CustomDecoder { + fn fmt_debug(&self) -> String { + "CustomDecoder()".to_string() + } +} + impl CustomDecoder { pub(crate) fn new(inner: PyObject) -> Self { CustomDecoder { inner } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 33c370fd5..fb6022018 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -5,10 +5,10 @@ use std::sync::{Arc, RwLock}; use super::error::{deprecation_warning, ToPyResult}; use crate::token::PyToken; use crate::trainers::PyTrainer; -use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; use tk::models::unigram::Unigram; diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 2c117d297..2ec437ddd 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -2,10 +2,10 @@ use std::sync::{Arc, RwLock}; use crate::error::ToPyResult; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; -use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index b71d2dbb7..cb400bdba 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -23,7 +23,7 @@ use tokenizers as tk; use super::error::ToPyResult; use super::utils::*; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; /// Base class for all pre-tokenizers /// /// This class is not supposed to be instantiated directly. Instead, any implementation of a @@ -181,14 +181,6 @@ impl PyPreTokenizer { .map(|(s, o, _)| (s.to_owned(), o)) .collect()) } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.pretok)) - } - - fn __repr__(&self) -> PyResult { - Ok(format!("{}", self.pretok)) - } } macro_rules! getter { diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 3341e68b7..2783c9cdb 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use crate::encoding::PyEncoding; use crate::error::ToPyResult; -use pyo3_special_method_derive::AutoDisplay; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; use tk::processors::byte_level::ByteLevel; diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 5976f6214..18f368b43 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -10,13 +10,13 @@ use super::pre_tokenizers::PyPreTokenizer; use super::trainers::PyTrainer; use crate::processors::PyPostProcessor; use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; -use pyo3_special_method_derive::AutoDisplay; use numpy::{npyffi, PyArray1}; use pyo3::class::basic::CompareOp; use pyo3::exceptions; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; +use pyo3_special_method_derive_0_21::AutoDisplay; use std::collections::BTreeMap; use tk::models::bpe::BPE; use tk::tokenizer::{ @@ -1409,14 +1409,6 @@ impl PyTokenizer { fn set_decoder(&mut self, decoder: PyRef) { self.tokenizer.with_decoder(decoder.clone()); } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.tokenizer)) - } - - fn __repr__(&self) -> PyResult { - Ok(format!("{}", self.tokenizer)) - } } #[cfg(test)] diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index b186b214d..3bb69911e 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -63,7 +63,7 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" -pyo3_special_method_derive = "0.4" +pyo3_special_method_derive_0_21 = "0.4" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 0bd807808..933df0023 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, AutoDisplay)] /// Allows decoding Original BPE by joining all the tokens and then replacing diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index 653992be9..b62510eb1 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Default, AutoDisplay)] /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index d53fc003a..f56966a5d 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,7 +1,7 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; use itertools::Itertools; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index a75977493..ff7346e27 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize, Default, AutoDisplay)] /// Fuse simply fuses all tokens into one big string. diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 233c33e30..0dc524760 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -21,7 +21,7 @@ use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Clone, Debug, AutoDisplay)] diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index f4e6cfb20..aef8de5ef 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,7 +1,7 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index 01a35fc75..f47fa8ed3 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Default, AutoDisplay)] /// Strip is a simple trick which converts tokens looking like `<0x61>` diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 69d168d4e..5a5ba86c3 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, AutoDisplay)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 4806a3f53..977a6fd78 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 9915ce4d5..705d9983e 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,7 +6,7 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 0ad0412ad..40e62c147 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde_json::Value; use std::collections::HashMap; use std::fs::File; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index ffd8f9b05..d877b72ba 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use std::{ borrow::Cow, collections::HashMap, diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 255642d30..63805f676 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 6122967d9..efeeafabd 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -17,7 +17,7 @@ pub use crate::normalizers::utils::{Lowercase, Sequence}; use serde::{Deserialize, Serialize}; use crate::{NormalizedString, Normalizer}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index 27b0a3b50..cd7f047af 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 0bbedae51..85e670551 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,7 +2,7 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 78d517862..453a8c18a 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; #[derive(Copy, Clone, Debug, Deserialize, Serialize, AutoDisplay)] diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index a203db700..80a5dead9 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; #[derive(Default, Copy, Clone, Debug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index fd9f15f25..b2ba9d069 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; #[derive(Clone, Deserialize, Debug, Serialize, AutoDisplay)] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index 5a8f1ebc0..b50f4b118 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use unicode_categories::UnicodeCategories; fn is_bert_punc(x: char) -> bool { diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 0b37cf243..7749afc18 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; use crate::utils::SysRegex; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index fcfd13aee..25e5fefd9 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index 3cb3326b4..65c7c30bb 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 83744eed2..a0e9ddc80 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. #[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, AutoDisplay)] diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 617442d9b..53a01de91 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -22,7 +22,7 @@ use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, AutoDisplay)] #[auto_display(fmt="pre_tokenizers.{}")] diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index a61ce16f0..fab237586 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 8f5c591b5..b215eba8d 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,7 +1,7 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 2edcd723b..eaa3a16e7 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -2,7 +2,7 @@ use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::SysRegex; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use diff --git a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs index 405810a95..7fa905a39 100644 --- a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs +++ b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script}; use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result}; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index efd0f9202..d7f044a0e 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use regex::Regex; use crate::tokenizer::{ diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index 8f0ef327c..eb7bffd00 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 742a58e7a..399080086 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -12,7 +12,7 @@ use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; use crate::processors::template::TemplateProcessing; use crate::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, AutoDisplay)] diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 248858ec7..f74fb5009 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,6 +1,6 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 39b34d08e..d0f5db6a7 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -1,7 +1,7 @@ use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 3cf8cc932..61bcd30ca 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -57,7 +57,7 @@ //! use crate::{Encoding, PostProcessor, Result}; use itertools::Itertools; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index adae8d4c7..72c8306e5 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,7 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 0cec8f38c..c653b2654 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1348,7 +1348,6 @@ where mod tests { use super::Tokenizer; use crate::AddedToken; - use crate::Tokenizer; #[cfg(feature = "http")] #[test] diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 125a8f880..912734274 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,6 +1,6 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 3e311f9b5..018f3eb3b 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,6 +1,6 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; /// The various possible padding directions. diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 541df0e7f..863e7785d 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Encoding, Result}; -use pyo3_special_method_derive::AutoDisplay; +use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; From 7db61094fd388acf15824ab4bd8dec4285d8f29f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 21 Jul 2024 09:51:50 +0200 Subject: [PATCH 81/94] stash commit, wanna make sure this is recorded --- bindings/python/py_src/tokenizers/__init__.pyi | 9 +++++++++ bindings/python/src/normalizers.rs | 13 ++++--------- bindings/python/src/pre_tokenizers.rs | 6 ++++-- bindings/python/src/tokenizer.rs | 7 ++++--- tokenizers/src/pre_tokenizers/mod.rs | 4 ++-- tokenizers/src/pre_tokenizers/sequence.rs | 3 ++- 6 files changed, 25 insertions(+), 17 deletions(-) diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 3ecef4089..06b8621e0 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -980,6 +980,15 @@ class Tokenizer: """ pass + def get_added_tokens_decoder(self): + """ + Get the underlying vocabulary + + Returns: + :obj:`Dict[int, AddedToken]`: The vocabulary + """ + pass + def get_vocab(self, with_added_tokens=True): """ Get the underlying vocabulary diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 2ec437ddd..d4a4bbb69 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -5,7 +5,7 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay, Str, Dir, Dict}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ @@ -43,7 +43,7 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize, AutoDisplay, Debug)] +#[derive(Clone, Serialize, Deserialize, Str, Debug, Dir, Dict)] pub struct PyNormalizer { #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, @@ -169,13 +169,6 @@ impl PyNormalizer { ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?; Ok(normalized.get().to_owned()) } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.normalizer)) - } - fn __repr__(&self) -> PyResult { - Ok(format!("{}", self.normalizer)) - } } macro_rules! getter { @@ -484,6 +477,7 @@ impl PyNmt { /// Precompiled normalizer /// Don't use manually it is used for compatiblity for SentencePiece. #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Precompiled")] +#[derive(Str)] pub struct PyPrecompiled {} #[pymethods] impl PyPrecompiled { @@ -522,6 +516,7 @@ impl PyReplace { #[derive(Debug, Clone, AutoDisplay)] pub(crate) struct CustomNormalizer { + #[auto_display] inner: PyObject, } impl CustomNormalizer { diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index cb400bdba..fc06f4ce2 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -34,10 +34,10 @@ use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; name = "PreTokenizer", subclass )] -#[derive(Clone, Serialize, Deserialize, Str, Repr, Dir, Dict)] +#[derive(Clone, Serialize, Deserialize, Str, Dir)] pub struct PyPreTokenizer { #[serde(flatten)] - pub(crate) pretok: PyPreTokenizerTypeWrapper, + pub pretok: PyPreTokenizerTypeWrapper, } impl PyPreTokenizer { @@ -425,6 +425,8 @@ impl PyPunctuation { /// This pre-tokenizer composes other pre_tokenizers and applies them in sequence #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Sequence")] +#[derive(AutoDisplay)] +#[auto_display(fmt = "{self.inner}{}")] pub struct PySequence {} #[pymethods] impl PySequence { diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 18f368b43..a3a8a4ebb 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -16,7 +16,7 @@ use pyo3::exceptions; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{Str, Dict, Dir, Repr, AutoDisplay}; use std::collections::BTreeMap; use tk::models::bpe::BPE; use tk::tokenizer::{ @@ -462,7 +462,7 @@ type Tokenizer = TokenizerImpl BTreeMap { let mut sorted_map = BTreeMap::new(); diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 53a01de91..eb4b20d43 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -22,9 +22,9 @@ use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, AutoDisplay)] +#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDisplay, AutoDebug)] #[auto_display(fmt="pre_tokenizers.{}")] #[serde(untagged)] pub enum PreTokenizerWrapper { diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index b215eba8d..07267fe59 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -7,7 +7,8 @@ use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] #[derive(Clone, Debug, PartialEq, AutoDisplay)] pub struct Sequence { - pretokenizers: Vec, + #[auto_display] + pub pretokenizers: Vec, } impl Sequence { From c7cd92759ae1ddf6ca28d6488e0dd7e7b8644d7c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 25 Jul 2024 18:17:06 +0200 Subject: [PATCH 82/94] what works a bit ? --- bindings/python/Cargo.toml | 2 +- bindings/python/src/decoders.rs | 28 ++--- bindings/python/src/models.rs | 10 +- bindings/python/src/normalizers.rs | 4 +- bindings/python/src/pre_tokenizers.rs | 8 +- bindings/python/src/processors.rs | 1 + bindings/python/src/tokenizer.rs | 1 + bindings/python/src/utils/normalization.rs | 1 - tokenizers/Cargo.toml | 2 +- tokenizers/src/decoders/byte_fallback.rs | 2 +- tokenizers/src/decoders/fuse.rs | 2 +- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/models/bpe/model.rs | 102 ++++++++++--------- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 4 +- tokenizers/src/normalizers/byte_level.rs | 4 +- tokenizers/src/normalizers/mod.rs | 4 +- tokenizers/src/normalizers/replace.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 3 +- tokenizers/src/pre_tokenizers/sequence.rs | 4 +- tokenizers/src/processors/mod.rs | 5 +- tokenizers/src/processors/template.rs | 14 +-- tokenizers/src/tokenizer/added_vocabulary.rs | 7 +- tokenizers/src/tokenizer/mod.rs | 38 +++---- tokenizers/src/utils/padding.rs | 2 - tokenizers/src/utils/truncation.rs | 2 +- 27 files changed, 125 insertions(+), 133 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 5fae356fc..9607112d0 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -20,7 +20,7 @@ onig = { version = "6.4", default-features = false } itertools = "0.12" derive_more = "0.99.17" pyo3 = { version = "0.21", features = ["multiple-pymethods"] } -pyo3_special_method_derive_0_21 = "0.4" +pyo3_special_method_derive_0_21 = {path = "../../../pyo3-special-method-derive/pyo3_special_method_derive_0_21"} [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index f182a6f56..2bb1a23f2 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -7,7 +7,7 @@ use pyo3::prelude::*; use pyo3::types::*; use pyo3_special_method_derive_0_21::AutoDisplay; use pyo3_special_method_derive_0_21::PyDebug; -use pyo3_special_method_derive_0_21::PyDisplay; +use pyo3_special_method_derive_0_21::Str; use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::decoders::bpe::BPEDecoder; @@ -31,9 +31,11 @@ use super::error::ToPyResult; /// This class is not supposed to be instantiated directly. Instead, any implementation of /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] -#[derive(Clone, Deserialize, Serialize, AutoDisplay)] +#[derive(Clone, Deserialize, Serialize, Str)] +#[format(fmt="")] pub struct PyDecoder { #[serde(flatten)] + #[format(fmt="{}")] pub(crate) decoder: PyDecoderWrapper, } @@ -117,14 +119,6 @@ impl PyDecoder { fn decode(&self, tokens: Vec) -> PyResult { ToPyResult(self.decoder.decode(tokens)).into() } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.decoder)) - } - - fn __repr__(&self) -> PyResult { - Ok(format!("{}", self.decoder)) - } } macro_rules! getter { @@ -489,22 +483,12 @@ impl PySequenceDecoder { } } -#[derive(Clone)] +#[derive(Clone, AutoDisplay)] pub(crate) struct CustomDecoder { + #[format(skip)] pub inner: PyObject, } -impl PyDisplay for CustomDecoder { - fn fmt_display(&self) -> String { - "CustomDecoder()".to_string() - } -} - -impl PyDebug for CustomDecoder { - fn fmt_debug(&self) -> String { - "CustomDecoder()".to_string() - } -} impl CustomDecoder { pub(crate) fn new(inner: PyObject) -> Self { diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index fb6022018..b3f1a77b3 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -9,6 +9,7 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::Str; use serde::{Deserialize, Serialize}; use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; use tk::models::unigram::Unigram; @@ -25,9 +26,10 @@ use tokenizers as tk; /// /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] -#[derive(Clone, Serialize, Deserialize, AutoDisplay)] +#[derive(Clone, Serialize, Deserialize, Str)] pub struct PyModel { #[serde(flatten)] + #[format(fmt="{}")] pub model: Arc>, } @@ -220,12 +222,6 @@ impl PyModel { fn get_trainer(&self, py: Python<'_>) -> PyResult { PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py) } - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.model.read().unwrap())) - } - fn __repr__(&self) -> PyResult { - Ok(format!("{}", self.model.read().unwrap())) - } } /// An implementation of the BPE (Byte-Pair Encoding) algorithm diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index d4a4bbb69..5eed43a8d 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -478,7 +478,9 @@ impl PyNmt { /// Don't use manually it is used for compatiblity for SentencePiece. #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Precompiled")] #[derive(Str)] +#[format("PreCompiled")] pub struct PyPrecompiled {} + #[pymethods] impl PyPrecompiled { #[new] @@ -516,7 +518,7 @@ impl PyReplace { #[derive(Debug, Clone, AutoDisplay)] pub(crate) struct CustomNormalizer { - #[auto_display] + #[format(fmt="Custom Normalizer")] inner: PyObject, } impl CustomNormalizer { diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index fc06f4ce2..792624d1a 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -37,7 +37,8 @@ use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; #[derive(Clone, Serialize, Deserialize, Str, Dir)] pub struct PyPreTokenizer { #[serde(flatten)] - pub pretok: PyPreTokenizerTypeWrapper, + #[format(fmt = "{}")] + pretok: PyPreTokenizerTypeWrapper, } impl PyPreTokenizer { @@ -426,7 +427,7 @@ impl PyPunctuation { /// This pre-tokenizer composes other pre_tokenizers and applies them in sequence #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Sequence")] #[derive(AutoDisplay)] -#[auto_display(fmt = "{self.inner}{}")] +#[format(fmt = "Sequence.{}")] pub struct PySequence {} #[pymethods] impl PySequence { @@ -637,6 +638,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Custom(CustomPreTokenizer), + #[format(fmt = "wrapped:{}")] Wrapped(PreTokenizerWrapper), } @@ -655,7 +657,9 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize, AutoDisplay)] #[serde(untagged)] pub(crate) enum PyPreTokenizerTypeWrapper { + #[format(fmt = "{}")] Sequence(Vec>>), + #[format(fmt = "{}")] Single(Arc>), } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 2783c9cdb..0d234871a 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -28,6 +28,7 @@ use tokenizers as tk; subclass )] #[derive(Clone, Deserialize, Serialize, AutoDisplay)] +#[format(fmt="post processor: {}.{}")] pub struct PyPostProcessor { #[serde(flatten)] pub processor: Arc, diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index a3a8a4ebb..a6ad1e396 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -464,6 +464,7 @@ type Tokenizer = TokenizerImpl) -> std::fmt::Result { - let mut vocab_vec: Vec<_> = self.vocab.iter().collect(); - vocab_vec.sort_by_key(|&(_, v)| v); - vocab_vec.truncate(5); - - let vocab_str: String = vocab_vec - .iter() - .map(|(k, v)| format!("'{}':{}", k, v)) - .collect::>() - .join(", "); - - let mut merges_vec: Vec<_> = self.merges.iter().collect(); - merges_vec.truncate(5); - merges_vec.sort_by_key(|&(_, v)| v); - - let merges_str: String = merges_vec - .iter() - .map(|((id1, id2), _)| { - ( - self.vocab_r - .get(id1) - .cloned() - .unwrap_or_else(|| id1.to_string()), - self.vocab_r - .get(id2) - .cloned() - .unwrap_or_else(|| id2.to_string()), - ) - }) - .map(|(id1, id2)| format!("('{}', '{}')", id1, id2)) - .collect::>() - .join(", "); - - write!( - f, - "BPE(vocab={{{}, ...}}, merges=[{:?}, ...], dropout={:?}, unk_token={:?}, continuing_subword_prefix={:?}, end_of_word_suffix={:?}, fuse_unk={}, byte_fallback={}, ignore_merges={})", - vocab_str, - merges_str, - self.dropout, - self.unk_token, - self.continuing_subword_prefix, - self.end_of_word_suffix, - self.fuse_unk, - self.byte_fallback, - self.ignore_merges - ) - } -} +// impl std::fmt::Display for BPE { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// let mut vocab_vec: Vec<_> = self.vocab.iter().collect(); +// vocab_vec.sort_by_key(|&(_, v)| v); +// vocab_vec.truncate(5); + +// let vocab_str: String = vocab_vec +// .iter() +// .map(|(k, v)| format!("'{}':{}", k, v)) +// .collect::>() +// .join(", "); + +// let mut merges_vec: Vec<_> = self.merges.iter().collect(); +// merges_vec.truncate(5); +// merges_vec.sort_by_key(|&(_, v)| v); + +// let merges_str: String = merges_vec +// .iter() +// .map(|((id1, id2), _)| { +// ( +// self.vocab_r +// .get(id1) +// .cloned() +// .unwrap_or_else(|| id1.to_string()), +// self.vocab_r +// .get(id2) +// .cloned() +// .unwrap_or_else(|| id2.to_string()), +// ) +// }) +// .map(|(id1, id2)| format!("('{}', '{}')", id1, id2)) +// .collect::>() +// .join(", "); + +// write!( +// f, +// "BPE(vocab={{{}, ...}}, merges=[{:?}, ...], dropout={:?}, unk_token={:?}, continuing_subword_prefix={:?}, end_of_word_suffix={:?}, fuse_unk={}, byte_fallback={}, ignore_merges={})", +// vocab_str, +// merges_str, +// self.dropout, +// self.unk_token, +// self.continuing_subword_prefix, +// self.end_of_word_suffix, +// self.fuse_unk, +// self.byte_fallback, +// self.ignore_merges +// ) +// } +// } impl Default for BPE { fn default() -> Self { diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 977a6fd78..9f7a9101d 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDisplay}; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 40e62c147..18a50a761 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -96,7 +96,7 @@ impl WordLevelBuilder { #[derive(PartialEq, Clone, Eq, AutoDisplay)] pub struct WordLevel { - #[auto_display] + #[format] vocab: HashMap, vocab_r: HashMap, pub unk_token: String, diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index d877b72ba..079533ba0 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -121,12 +121,12 @@ impl WordPieceBuilder { /// model. #[derive(Clone, PartialEq, Eq, AutoDisplay)] pub struct WordPiece { - #[auto_display] + #[format] vocab: Vocab, vocab_r: VocabR, pub unk_token: String, pub continuing_subword_prefix: String, - #[auto_display(skip)] + #[format(skip)] pub max_input_chars_per_word: usize, } diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index 42c7fa510..44d5b83b1 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -2,8 +2,8 @@ use crate::processors::byte_level::bytes_char; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; - -#[derive(Clone, Debug, Deserialize, Serialize)] +use pyo3_special_method_derive_0_21::AutoDisplay; +#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] pub struct ByteLevel {} diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index efeeafabd..b813a3ce3 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -22,7 +22,7 @@ use pyo3_special_method_derive_0_21::AutoDisplay; /// Wrapper for known Normalizers. #[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] #[serde(untagged)] -#[auto_display(fmt = "normalizers.{}")] +#[format(fmt = "normalizers.{}")] pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), @@ -34,7 +34,7 @@ pub enum NormalizerWrapper { Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), - #[auto_display(fmt = "Precompiled()")] + #[format(skip)] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 85e670551..dda4a331a 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -46,7 +46,7 @@ impl std::convert::TryFrom for Replace { #[serde(tag = "type", try_from = "ReplaceDeserializer")] pub struct Replace { pattern: ReplacePattern, - #[auto_display] + #[format] content: String, #[serde(skip)] regex: SysRegex, diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index eb4b20d43..8baf874f8 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -24,8 +24,7 @@ use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; -#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDisplay, AutoDebug)] -#[auto_display(fmt="pre_tokenizers.{}")] +#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDisplay)] #[serde(untagged)] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 07267fe59..f379fc894 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -5,9 +5,9 @@ use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, Debug, PartialEq, AutoDisplay)] +#[derive(Clone, PartialEq, AutoDisplay)] pub struct Sequence { - #[auto_display] + #[format] pub pretokenizers: Vec, } diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 399080086..6b6ed8acd 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -16,14 +16,17 @@ use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, AutoDisplay)] -#[auto_display(fmt="processors.{}")] +#[format(fmt="")] #[serde(untagged)] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) Roberta(RobertaProcessing), Bert(BertProcessing), + #[format(fmt="{}.{}")] ByteLevel(ByteLevel), + #[format(fmt="{}")] Template(TemplateProcessing), + #[format(fmt="{}")] Sequence(Sequence), } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 61bcd30ca..56842e6b0 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -63,7 +63,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; use std::result::Result as StdResult; /// Represents any sequences received as input of the PostProcessor -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] pub enum Sequence { /// This is the first sequence, the one that is always specified A, @@ -91,7 +91,7 @@ pub enum Sequence { /// /// [`SpecialToken`]: struct.SpecialToken.html /// -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] pub enum Piece { Sequence { id: Sequence, type_id: u32 }, SpecialToken { id: String, type_id: u32 }, @@ -249,7 +249,7 @@ impl SpecialToken { /// /// [`Piece`]: enum.Piece.html /// -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] #[serde(transparent)] pub struct Template(Vec); @@ -335,15 +335,17 @@ impl From> for Tokens { #[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, AutoDisplay)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] #[builder(build_fn(validate = "Self::validate"))] +#[format(fmt="TemplateProcessing: ")] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] - single: Template, + pub single: Template, #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")] - pair: Template, + pub pair: Template, #[builder(setter(skip), default = "self.default_added(true)")] #[serde(skip)] - added_single: usize, + pub added_single: usize, #[builder(setter(skip), default = "self.default_added(false)")] + #[format] #[serde(skip)] added_pair: usize, #[builder(setter(into), default)] diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 72c8306e5..f23a70ac2 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -140,14 +140,13 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug, AutoDisplay)] -#[auto_display(fmt="AddedVocabulary(added_tokens_map_r={{{}, ...}}, encode_special_tokens={})", "&(0..=5).fold(String::new(), |mut acc, key| {if let Some(token) = added_tokens_map_r.get(&key){if !acc.is_empty(){acc.push_str(\", \");}acc.push_str(&format!(\"\n\t{}: {}\", key, &token.to_string()));}acc})", encode_special_tokens)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. added_tokens_map: HashMap, /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. - #[auto_display] + #[format] added_tokens_map_r: HashMap, /// Contains only the classic AddedToken, in the specific order the user gave them. @@ -157,7 +156,7 @@ pub struct AddedVocabulary { /// A Set, containing all the special token for easy access while decoding. This let's /// us remove them easily with an O(1) complexity. - #[auto_display] + #[format] special_tokens_set: HashSet, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens @@ -166,7 +165,7 @@ pub struct AddedVocabulary { split_normalized_trie: MatchingSet, /// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them - #[auto_display] + #[format(fmt = "{}")] encode_special_tokens: bool, } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index c653b2654..73f47a893 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -21,6 +21,7 @@ extern crate rayon; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDisplay}; use rayon::current_thread_index; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -405,8 +406,9 @@ where } } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Clone, AutoDisplay)] pub struct Tokenizer( + #[format(fmt = "tokenizer = {}")] TokenizerImpl< ModelWrapper, NormalizerWrapper, @@ -525,46 +527,45 @@ pub struct TokenizerImpl { padding: Option, } -impl std::fmt::Display for TokenizerImpl +impl PyDisplay for TokenizerImpl where - M: std::fmt::Display, - N: std::fmt::Display, - PT: std::fmt::Display, - PP: std::fmt::Display, - D: std::fmt::Display, + M: PyDisplay, + N: PyDisplay, + PT: PyDisplay, + PP: PyDisplay, + D: PyDisplay, { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt_display(&self) -> std::string::String { let normalizer_str = match &self.normalizer { - Some(n) => format!("{}", n), + Some(n) => format!("{}", n.fmt_display()), None => "None".to_string(), }; let pre_tokenizer_str = match &self.pre_tokenizer { - Some(pt) => format!("{}", pt), + Some(pt) => format!("{}", pt.fmt_display()), None => "None".to_string(), }; let post_processor_str = match &self.post_processor { - Some(pp) => format!("{}", pp), + Some(pp) => format!("{}", pp.fmt_display()), None => "None".to_string(), }; let decoder_str = match &self.decoder { - Some(d) => format!("{}", d), + Some(d) => format!("{}", d.fmt_display()), None => "None".to_string(), }; let truncation_str = match &self.truncation { - Some(t) => format!("{:?}", t), + Some(t) => format!("{}", t.fmt_display()), None => "None".to_string(), }; let padding_str = match &self.padding { - Some(p) => format!("{:?}", p), + Some(p) => format!("{}", p.fmt_display()), None => "None".to_string(), }; - write!( - f, + format!( "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={:?}, truncation={}, padding={})", normalizer_str, pre_tokenizer_str, - self.model, + self.model.fmt_display(), post_processor_str, decoder_str, self.added_vocabulary, @@ -1399,4 +1400,5 @@ mod tests { let decoded = tokenizer.decode(encoded.get_ids(), false); assert_eq!(decoded.unwrap(), "Hey! how is this token: д") } -} \ No newline at end of file +} + diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 018f3eb3b..6481fc096 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -20,11 +20,9 @@ impl std::convert::AsRef for PaddingDirection { } #[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] -#[auto_display(fmt = "")] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, - #[auto_display(skip)] // usize not supported yet pub pad_to_multiple_of: Option, pub pad_id: u32, pub pad_type_id: u32, diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 863e7785d..5780f33bf 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -24,7 +24,7 @@ impl std::convert::AsRef for TruncationDirection { pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, - #[auto_display(skip)] + #[format(skip)] pub max_length: usize, pub strategy: TruncationStrategy, pub stride: usize, From e4cf65a838a5c882a79ff09184905fb429eb9639 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 25 Jul 2024 19:07:30 +0200 Subject: [PATCH 83/94] update --- tokenizers/src/pre_tokenizers/mod.rs | 3 ++- tokenizers/src/pre_tokenizers/sequence.rs | 2 +- tokenizers/src/processors/template.rs | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 8baf874f8..2de3625e1 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -24,8 +24,9 @@ use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; -#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDisplay)] +#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDisplay, Debug)] #[serde(untagged)] +#[format(fmt = "pre_tokenizers.{}")] pub enum PreTokenizerWrapper { BertPreTokenizer(BertPreTokenizer), ByteLevel(ByteLevel), diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index f379fc894..c57076d9f 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -5,7 +5,7 @@ use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, PartialEq, AutoDisplay)] +#[derive(Clone, PartialEq, AutoDisplay, Debug)] pub struct Sequence { #[format] pub pretokenizers: Vec, diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 56842e6b0..f30da61ae 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -335,7 +335,7 @@ impl From> for Tokens { #[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, AutoDisplay)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] #[builder(build_fn(validate = "Self::validate"))] -#[format(fmt="TemplateProcessing: ")] +#[format(fmt = "TemplateProcessing: {}")] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] pub single: Template, From 39ffc28dd3d1cd0cedae0fc8a93ff5d619526249 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 27 Jul 2024 09:23:26 +0200 Subject: [PATCH 84/94] fix tokenizer's wrapping --- bindings/python/src/pre_tokenizers.rs | 3 ++- bindings/python/src/tokenizer.rs | 5 +++-- tokenizers/src/tokenizer/mod.rs | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 792624d1a..6187f2cf7 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,9 +35,10 @@ use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; subclass )] #[derive(Clone, Serialize, Deserialize, Str, Dir)] +#[format("")] // don't format the Py wrapper pub struct PyPreTokenizer { #[serde(flatten)] - #[format(fmt = "{}")] + #[format(fmt = "{}")] // format only pretok, not pretok = pretok: PyPreTokenizerTypeWrapper, } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index a6ad1e396..b389a22e3 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -16,7 +16,7 @@ use pyo3::exceptions; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::{Str, Dict, Dir, Repr, AutoDisplay}; +use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; use std::collections::BTreeMap; use tk::models::bpe::BPE; use tk::tokenizer::{ @@ -463,8 +463,9 @@ type Tokenizer = TokenizerImpl Date: Sat, 27 Jul 2024 09:29:00 +0200 Subject: [PATCH 85/94] fix normalizer display --- bindings/python/src/normalizers.rs | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 5eed43a8d..41b518381 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -44,8 +44,10 @@ impl PyNormalizedStringMut<'_> { /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] #[derive(Clone, Serialize, Deserialize, Str, Debug, Dir, Dict)] +#[format(fmt="{}")] pub struct PyNormalizer { #[serde(flatten)] + #[format(fmt="{}")] pub(crate) normalizer: PyNormalizerTypeWrapper, } @@ -562,8 +564,11 @@ impl<'de> Deserialize<'de> for CustomNormalizer { #[derive(Debug, Clone, Deserialize, AutoDisplay)] #[serde(untagged)] +#[format(fmt="{}")] pub(crate) enum PyNormalizerWrapper { + #[format(fmt="{}")] Custom(CustomNormalizer), + #[format(fmt="{}")] Wrapped(NormalizerWrapper), } @@ -579,32 +584,14 @@ impl Serialize for PyNormalizerWrapper { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, AutoDisplay)] #[serde(untagged)] +#[format(fmt="{}")] pub(crate) enum PyNormalizerTypeWrapper { Sequence(Vec>>), Single(Arc>), } -// Implement the Display trait for PyNormalizerTypeWrapper -impl std::fmt::Display for PyNormalizerTypeWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - PyNormalizerTypeWrapper::Sequence(ref decoders) => { - for decoder in decoders { - let decoder = decoder.read().unwrap(); - writeln!(f, "{}", decoder)?; - } - Ok(()) - } - PyNormalizerTypeWrapper::Single(ref decoder) => { - let decoder = decoder.read().unwrap(); - write!(f, "{}", decoder) - } - } - } -} - impl Serialize for PyNormalizerTypeWrapper { fn serialize(&self, serializer: S) -> Result where From c436b23f25bfa4a779c52c516b4f71e160373369 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 27 Jul 2024 10:47:57 +0200 Subject: [PATCH 86/94] fix! --- bindings/python/src/decoders.rs | 5 ++++- bindings/python/src/models.rs | 1 + bindings/python/src/pre_tokenizers.rs | 7 +++++-- bindings/python/src/processors.rs | 3 ++- tokenizers/src/models/bpe/model.rs | 1 - tokenizers/src/models/mod.rs | 1 + tokenizers/src/pre_tokenizers/mod.rs | 1 + tokenizers/src/tokenizer/added_vocabulary.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 4 ++-- 9 files changed, 17 insertions(+), 8 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 2bb1a23f2..b21bdbe43 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -32,7 +32,7 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize, Str)] -#[format(fmt="")] +#[format(fmt="{}")] pub struct PyDecoder { #[serde(flatten)] #[format(fmt="{}")] @@ -540,8 +540,11 @@ impl<'de> Deserialize<'de> for CustomDecoder { #[derive(Clone, Deserialize, Serialize, AutoDisplay)] #[serde(untagged)] +#[format(fmt = "{}")] pub(crate) enum PyDecoderWrapper { + #[format(fmt = "{}")] Custom(Arc>), + #[format(fmt = "{}")] Wrapped(Arc>), } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index b3f1a77b3..7a4cdc19d 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -27,6 +27,7 @@ use tokenizers as tk; /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] #[derive(Clone, Serialize, Deserialize, Str)] +#[format(fmt = "{}")] pub struct PyModel { #[serde(flatten)] #[format(fmt="{}")] diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 6187f2cf7..c90a6d1b2 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,7 +35,7 @@ use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; subclass )] #[derive(Clone, Serialize, Deserialize, Str, Dir)] -#[format("")] // don't format the Py wrapper +#[format(fmt = "{}")] // don't format the Py wrapper pub struct PyPreTokenizer { #[serde(flatten)] #[format(fmt = "{}")] // format only pretok, not pretok = @@ -637,9 +637,11 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { #[derive(Clone, Deserialize, AutoDisplay)] #[serde(untagged)] +#[format(fmt = "{}")] pub(crate) enum PyPreTokenizerWrapper { + #[format(fmt = "{}")] Custom(CustomPreTokenizer), - #[format(fmt = "wrapped:{}")] + #[format(fmt = "{}")] Wrapped(PreTokenizerWrapper), } @@ -657,6 +659,7 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize, AutoDisplay)] #[serde(untagged)] +#[format(fmt = "{}")] pub(crate) enum PyPreTokenizerTypeWrapper { #[format(fmt = "{}")] Sequence(Vec>>), diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 0d234871a..9ff9d0328 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -28,9 +28,10 @@ use tokenizers as tk; subclass )] #[derive(Clone, Deserialize, Serialize, AutoDisplay)] -#[format(fmt="post processor: {}.{}")] +#[format(fmt="{}")] pub struct PyPostProcessor { #[serde(flatten)] + #[format(fmt="{}")] pub processor: Arc, } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 53328d3cc..24233281f 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -208,7 +208,6 @@ impl BpeBuilder { #[derive(PartialEq, AutoDisplay)] pub struct BPE { /// The vocabulary assigns a number to each token. - #[format(skip)] pub(crate) vocab: Vocab, /// Reversed vocabulary, to rebuild sentences. pub(crate) vocab_r: VocabR, diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 9f7a9101d..bfbb69eab 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -60,6 +60,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { #[derive(Deserialize, Serialize, Debug, PartialEq, Clone, AutoDisplay)] #[serde(untagged)] +#[format(fmt = "{}")] // TODO by default this should define the finale render {}, {} {} . or {}{}{} pub enum ModelWrapper { BPE(BPE), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 2de3625e1..b7388e7aa 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -33,6 +33,7 @@ pub enum PreTokenizerWrapper { Delimiter(CharDelimiterSplit), Metaspace(Metaspace), Whitespace(Whitespace), + #[format(fmt="{}")] Sequence(Sequence), Split(Split), Punctuation(Punctuation), diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index f23a70ac2..fb6e288d2 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -165,7 +165,7 @@ pub struct AddedVocabulary { split_normalized_trie: MatchingSet, /// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them - #[format(fmt = "{}")] + #[format] encode_special_tokens: bool, } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 0ebdc9c09..3f5684e21 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -561,13 +561,13 @@ where }; format!( - "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={:?}, truncation={}, padding={})", + "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={}, truncation={}, padding={})", normalizer_str, pre_tokenizer_str, self.model.fmt_display(), post_processor_str, decoder_str, - self.added_vocabulary, + self.added_vocabulary.fmt_display(), truncation_str, padding_str ) From e5b059fc11e5a8aa798c48994ec1a0a928458444 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 27 Jul 2024 19:01:59 +0200 Subject: [PATCH 87/94] final touch? --- tokenizers/src/processors/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 6b6ed8acd..c21ef6abd 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -16,13 +16,15 @@ use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, AutoDisplay)] -#[format(fmt="")] +#[format(fmt="{}")] #[serde(untagged)] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) + #[format(fmt="{}")] Roberta(RobertaProcessing), + #[format(fmt="{}")] Bert(BertProcessing), - #[format(fmt="{}.{}")] + #[format(fmt="{}")] ByteLevel(ByteLevel), #[format(fmt="{}")] Template(TemplateProcessing), From ff825a7f176dadb42fbd687ca26de54aa44ef019 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 28 Jul 2024 09:58:23 +0200 Subject: [PATCH 88/94] full autodebug --- bindings/python/src/decoders.rs | 10 ++--- bindings/python/src/models.rs | 7 ++-- bindings/python/src/normalizers.rs | 24 +++++------ bindings/python/src/pre_tokenizers.rs | 12 +++--- bindings/python/src/processors.rs | 16 ++----- bindings/python/src/tokenizer.rs | 2 +- tokenizers/src/decoders/bpe.rs | 4 +- tokenizers/src/decoders/byte_fallback.rs | 4 +- tokenizers/src/decoders/ctc.rs | 4 +- tokenizers/src/decoders/fuse.rs | 4 +- tokenizers/src/decoders/mod.rs | 3 +- tokenizers/src/decoders/sequence.rs | 4 +- tokenizers/src/decoders/strip.rs | 4 +- tokenizers/src/decoders/wordpiece.rs | 4 +- tokenizers/src/models/bpe/model.rs | 9 +++- tokenizers/src/models/mod.rs | 4 +- tokenizers/src/models/unigram/model.rs | 8 +++- tokenizers/src/models/wordlevel/mod.rs | 8 +++- tokenizers/src/models/wordpiece/mod.rs | 8 +++- tokenizers/src/normalizers/bert.rs | 4 +- tokenizers/src/normalizers/byte_level.rs | 4 +- tokenizers/src/normalizers/mod.rs | 4 +- tokenizers/src/normalizers/prepend.rs | 4 +- tokenizers/src/normalizers/replace.rs | 4 +- tokenizers/src/normalizers/strip.rs | 6 +-- tokenizers/src/normalizers/unicode.rs | 12 +++--- tokenizers/src/normalizers/utils.rs | 6 +-- tokenizers/src/pre_tokenizers/bert.rs | 4 +- tokenizers/src/pre_tokenizers/byte_level.rs | 4 +- tokenizers/src/pre_tokenizers/delimiter.rs | 4 +- tokenizers/src/pre_tokenizers/digits.rs | 4 +- tokenizers/src/pre_tokenizers/metaspace.rs | 7 ++-- tokenizers/src/pre_tokenizers/mod.rs | 4 +- tokenizers/src/pre_tokenizers/punctuation.rs | 4 +- tokenizers/src/pre_tokenizers/sequence.rs | 4 +- tokenizers/src/pre_tokenizers/split.rs | 4 +- .../unicode_scripts/pre_tokenizer.rs | 4 +- tokenizers/src/pre_tokenizers/whitespace.rs | 6 +-- tokenizers/src/processors/bert.rs | 4 +- tokenizers/src/processors/mod.rs | 4 +- tokenizers/src/processors/roberta.rs | 4 +- tokenizers/src/processors/sequence.rs | 4 +- tokenizers/src/processors/template.rs | 6 +-- tokenizers/src/tokenizer/added_vocabulary.rs | 8 ++-- tokenizers/src/tokenizer/mod.rs | 42 ++++++++++++++----- tokenizers/src/utils/padding.rs | 8 ++-- tokenizers/src/utils/truncation.rs | 8 ++-- 47 files changed, 174 insertions(+), 146 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index b21bdbe43..478e9faf1 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -5,9 +5,7 @@ use crate::utils::PyPattern; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::AutoDisplay; -use pyo3_special_method_derive_0_21::PyDebug; -use pyo3_special_method_derive_0_21::Str; +use pyo3_special_method_derive_0_21::{AutoDisplay, Repr, Str, AutoDebug}; use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::decoders::bpe::BPEDecoder; @@ -31,7 +29,7 @@ use super::error::ToPyResult; /// This class is not supposed to be instantiated directly. Instead, any implementation of /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] -#[derive(Clone, Deserialize, Serialize, Str)] +#[derive(Clone, Deserialize, Serialize, Str, Repr)] #[format(fmt="{}")] pub struct PyDecoder { #[serde(flatten)] @@ -483,7 +481,7 @@ impl PySequenceDecoder { } } -#[derive(Clone, AutoDisplay)] +#[derive(Clone, AutoDisplay, AutoDebug)] pub(crate) struct CustomDecoder { #[format(skip)] pub inner: PyObject, @@ -538,7 +536,7 @@ impl<'de> Deserialize<'de> for CustomDecoder { } } -#[derive(Clone, Deserialize, Serialize, AutoDisplay)] +#[derive(Clone, Deserialize, Serialize, AutoDisplay, AutoDebug)] #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyDecoderWrapper { diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 7a4cdc19d..ba12bb7ab 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -8,8 +8,7 @@ use crate::trainers::PyTrainer; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::AutoDisplay; -use pyo3_special_method_derive_0_21::Str; +use pyo3_special_method_derive_0_21::{Repr, Str}; use serde::{Deserialize, Serialize}; use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; use tk::models::unigram::Unigram; @@ -26,11 +25,11 @@ use tokenizers as tk; /// /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] -#[derive(Clone, Serialize, Deserialize, Str)] +#[derive(Clone, Serialize, Deserialize, Str, Repr)] #[format(fmt = "{}")] pub struct PyModel { #[serde(flatten)] - #[format(fmt="{}")] + #[format(fmt = "{}")] pub model: Arc>, } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 41b518381..edc803bec 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -5,7 +5,7 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::{AutoDisplay, Str, Dir, Dict}; +use pyo3_special_method_derive_0_21::{AutoDisplay,AutoDebug, Dict, Dir, Repr, Str}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ @@ -43,11 +43,11 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize, Str, Debug, Dir, Dict)] -#[format(fmt="{}")] +#[derive(Clone, Serialize, Deserialize, Str, Repr, Dir, Dict)] +#[format(fmt = "{}")] pub struct PyNormalizer { #[serde(flatten)] - #[format(fmt="{}")] + #[format(fmt = "{}")] pub(crate) normalizer: PyNormalizerTypeWrapper, } @@ -518,9 +518,9 @@ impl PyReplace { } } -#[derive(Debug, Clone, AutoDisplay)] +#[derive(AutoDebug, Clone, AutoDisplay)] pub(crate) struct CustomNormalizer { - #[format(fmt="Custom Normalizer")] + #[format(fmt = "Custom Normalizer")] inner: PyObject, } impl CustomNormalizer { @@ -562,13 +562,13 @@ impl<'de> Deserialize<'de> for CustomNormalizer { } } -#[derive(Debug, Clone, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Clone, Deserialize, AutoDisplay)] #[serde(untagged)] -#[format(fmt="{}")] +#[format(fmt = "{}")] pub(crate) enum PyNormalizerWrapper { - #[format(fmt="{}")] + #[format(fmt = "{}")] Custom(CustomNormalizer), - #[format(fmt="{}")] + #[format(fmt = "{}")] Wrapped(NormalizerWrapper), } @@ -584,9 +584,9 @@ impl Serialize for PyNormalizerWrapper { } } -#[derive(Debug, Clone, Deserialize, AutoDisplay)] +#[derive(Clone, Deserialize, AutoDisplay, AutoDebug)] #[serde(untagged)] -#[format(fmt="{}")] +#[format(fmt = "{}")] pub(crate) enum PyNormalizerTypeWrapper { Sequence(Vec>>), Single(Arc>), diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index c90a6d1b2..e492df094 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -23,7 +23,7 @@ use tokenizers as tk; use super::error::ToPyResult; use super::utils::*; -use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, Dict, Dir, Repr, Str}; /// Base class for all pre-tokenizers /// /// This class is not supposed to be instantiated directly. Instead, any implementation of a @@ -34,11 +34,11 @@ use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; name = "PreTokenizer", subclass )] -#[derive(Clone, Serialize, Deserialize, Str, Dir)] +#[derive(Clone, Serialize, Deserialize, Str, Repr, Dir, Dict)] #[format(fmt = "{}")] // don't format the Py wrapper pub struct PyPreTokenizer { #[serde(flatten)] - #[format(fmt = "{}")] // format only pretok, not pretok = + #[format(fmt = "{}")] // format only pretok, not pretok = pretok: PyPreTokenizerTypeWrapper, } @@ -591,7 +591,7 @@ impl PyUnicodeScripts { } } -#[derive(Clone, AutoDisplay)] +#[derive(Clone, AutoDisplay, AutoDebug)] pub(crate) struct CustomPreTokenizer { inner: PyObject, } @@ -635,7 +635,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { } } -#[derive(Clone, Deserialize, AutoDisplay)] +#[derive(Clone, Deserialize, AutoDisplay, AutoDebug)] #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyPreTokenizerWrapper { @@ -657,7 +657,7 @@ impl Serialize for PyPreTokenizerWrapper { } } -#[derive(Clone, Deserialize, AutoDisplay)] +#[derive(Clone, Deserialize, AutoDisplay, AutoDebug)] #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyPreTokenizerTypeWrapper { diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 9ff9d0328..bf568d7f4 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -6,7 +6,7 @@ use crate::error::ToPyResult; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay, Str, Repr}; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; use tk::processors::byte_level::ByteLevel; @@ -27,11 +27,11 @@ use tokenizers as tk; name = "PostProcessor", subclass )] -#[derive(Clone, Deserialize, Serialize, AutoDisplay)] -#[format(fmt="{}")] +#[derive(Clone, Deserialize, Serialize, Str, Repr)] +#[format(fmt = "{}")] pub struct PyPostProcessor { #[serde(flatten)] - #[format(fmt="{}")] + #[format(fmt = "{}")] pub processor: Arc, } @@ -141,14 +141,6 @@ impl PyPostProcessor { .into_py()?; Ok(final_encoding.into()) } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", &self)) - } - - fn __repr__(&self) -> PyResult { - Ok(format!("{}", &self)) - } } /// This post-processor takes care of adding the special tokens needed by diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index b389a22e3..cd3bfa682 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -462,7 +462,7 @@ type Tokenizer = TokenizerImpl` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index f56966a5d..bfabf223b 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,10 +1,10 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; use itertools::Itertools; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Clone, Serialize, Deserialize, AutoDisplay)] /// The CTC (Connectionist Temporal Classification) decoder takes care /// of sanitizing a list of inputs token. /// Due to some alignement problem the output of some models can come diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index 9a538009d..43636f8c8 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; use monostate::MustBe; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize, Default, AutoDisplay)] +#[derive(Clone, AutoDebug, Serialize, Deserialize, Default, AutoDisplay)] /// Fuse simply fuses all tokens into one big string. /// It's usually the last decoding step anyway, but this /// decoder exists incase some decoders need to happen after that diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 3f1a34b00..aa369572c 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -21,10 +21,11 @@ use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; +use pyo3_special_method_derive_0_21::AutoDebug; use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Clone, Debug, AutoDisplay)] +#[derive(Serialize, Deserialize, Clone, AutoDebug, AutoDisplay)] #[format(fmt="decoders.{}")] #[serde(untagged)] pub enum DecoderWrapper { diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index aef8de5ef..4fd57a97e 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,11 +1,11 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, Debug, AutoDisplay)] +#[derive(Clone, AutoDebug, AutoDisplay)] pub struct Sequence { decoders: Vec, } diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index f47fa8ed3..fee40b4e1 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, Default, AutoDisplay)] +#[derive(Deserialize, Clone, AutoDebug, Serialize, Default, AutoDisplay)] /// Strip is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 5a5ba86c3..f7b3aacde 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Debug, Serialize, AutoDisplay)] +#[derive(Deserialize, Clone, AutoDebug, Serialize, AutoDisplay)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens /// back into a readable string. #[serde(tag = "type")] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 24233281f..4c13798b4 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,7 +2,7 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::iter::ResultShunt; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDebug}; use serde_json::Value; use std::borrow::Cow; use std::{ @@ -248,7 +248,6 @@ impl std::fmt::Debug for BPE { .finish() } } - // impl std::fmt::Display for BPE { // fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // let mut vocab_vec: Vec<_> = self.vocab.iter().collect(); @@ -299,6 +298,12 @@ impl std::fmt::Debug for BPE { // } // } +// That is the only annouying part, explicit implementation. We can have PyDebugOnly. +impl PyDebug for BPE { + fn fmt_debug(&self) -> std::string::String { + format!("{:?}", self) + } +} impl Default for BPE { fn default() -> Self { Self::builder().build().unwrap() diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index bfbb69eab..c1c722568 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use pyo3_special_method_derive_0_21::{AutoDisplay, PyDisplay}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, PyDisplay}; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; @@ -58,7 +58,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { } } -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, AutoDisplay)] +#[derive(Deserialize, Serialize, AutoDebug, PartialEq, Clone, AutoDisplay)] #[serde(untagged)] #[format(fmt = "{}")] // TODO by default this should define the finale render {}, {} {} . or {}{}{} pub enum ModelWrapper { diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 705d9983e..d810a4cfd 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,7 +6,7 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDebug}; use std::collections::HashMap; use std::convert::TryInto; use std::fs::read_to_string; @@ -66,7 +66,11 @@ impl std::fmt::Debug for Unigram { .finish() } } - +impl PyDebug for Unigram { + fn fmt_debug(&self) -> std::string::String { + format!("{:?}", self) + } +} static K_UNK_PENALTY: f64 = 10.0; #[derive(thiserror::Error, Debug)] diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 18a50a761..8dd2335d5 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay,PyDebug}; use serde_json::Value; use std::collections::HashMap; use std::fs::File; @@ -110,7 +110,11 @@ impl std::fmt::Debug for WordLevel { .finish() } } - +impl PyDebug for WordLevel { + fn fmt_debug(&self) -> std::string::String { + format!("{:?}", self) + } +} impl WordLevel { pub fn builder() -> WordLevelBuilder { WordLevelBuilder::new() diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 079533ba0..732aafb5e 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay,PyDebug}; use std::{ borrow::Cow, collections::HashMap, @@ -140,7 +140,11 @@ impl std::fmt::Debug for WordPiece { .finish() } } - +impl PyDebug for WordPiece { + fn fmt_debug(&self) -> std::string::String { + format!("{:?}", self) + } +} impl Default for WordPiece { fn default() -> Self { Self { diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 63805f676..89399b022 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDisplay,AutoDebug}; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace @@ -47,7 +47,7 @@ fn is_chinese_char(c: char) -> bool { ) } -#[derive(Copy, Clone, Debug, Deserialize, Serialize, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct BertNormalizer { diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index 44d5b83b1..d01f713cc 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -2,8 +2,8 @@ use crate::processors::byte_level::bytes_char; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; -use pyo3_special_method_derive_0_21::AutoDisplay; -#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; +#[derive(Clone, AutoDebug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] pub struct ByteLevel {} diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index b813a3ce3..639dc8ffa 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -17,10 +17,10 @@ pub use crate::normalizers::utils::{Lowercase, Sequence}; use serde::{Deserialize, Serialize}; use crate::{NormalizedString, Normalizer}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; /// Wrapper for known Normalizers. -#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] +#[derive(Clone, Deserialize, Serialize, AutoDisplay, AutoDebug)] #[serde(untagged)] #[format(fmt = "normalizers.{}")] pub enum NormalizerWrapper { diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index cd7f047af..936f9006a 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)] +#[derive(Clone, AutoDebug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] pub struct Prepend { pub prepend: String, diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index dda4a331a..f371f743c 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,7 +2,7 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] @@ -42,7 +42,7 @@ impl std::convert::TryFrom for Replace { /// This normalizer will take a `pattern` (for now only a String) /// and replace every occurrence with `content`. -#[derive(Debug, Serialize, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Serialize, Deserialize, AutoDisplay)] #[serde(tag = "type", try_from = "ReplaceDeserializer")] pub struct Replace { pattern: ReplacePattern, diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 453a8c18a..e597975b9 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,9 +1,9 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; -#[derive(Copy, Clone, Debug, Deserialize, Serialize, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] #[non_exhaustive] pub struct Strip { @@ -43,7 +43,7 @@ impl Normalizer for Strip { // This normalizer removes combining marks from a normalized string // It's different from unidecode as it does not attempt to modify // non ascii languages. -#[derive(Copy, Clone, Debug, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct StripAccents; diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 80a5dead9..47498498e 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; -#[derive(Default, Copy, Clone, Debug, AutoDisplay)] +#[derive(Default, Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFD; impl Normalizer for NFD { @@ -12,7 +12,7 @@ impl Normalizer for NFD { } } -#[derive(Default, Copy, Clone, Debug, AutoDisplay)] +#[derive(Default, Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKD; impl Normalizer for NFKD { @@ -22,7 +22,7 @@ impl Normalizer for NFKD { } } -#[derive(Default, Copy, Clone, Debug, AutoDisplay)] +#[derive(Default, Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFC; impl Normalizer for NFC { @@ -32,7 +32,7 @@ impl Normalizer for NFC { } } -#[derive(Default, Copy, Clone, Debug, AutoDisplay)] +#[derive(Default, Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct NFKC; impl Normalizer for NFKC { @@ -73,7 +73,7 @@ fn do_nmt(normalized: &mut NormalizedString) { }); } -#[derive(Default, Copy, Clone, Debug, AutoDisplay)] +#[derive(Default, Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Nmt; impl Normalizer for Nmt { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index b2ba9d069..3b4c01933 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; -#[derive(Clone, Deserialize, Debug, Serialize, AutoDisplay)] +use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; +#[derive(Clone, Deserialize, AutoDebug, Serialize, AutoDisplay)] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. @@ -36,7 +36,7 @@ impl Normalizer for Sequence { } /// Lowercases the input -#[derive(Copy, Clone, Debug, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Lowercase; diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index b50f4b118..0030f785d 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -1,13 +1,13 @@ use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use unicode_categories::UnicodeCategories; fn is_bert_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct BertPreTokenizer; diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 7749afc18..577c2ad58 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; use crate::utils::SysRegex; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; /// Converts bytes to unicode characters. @@ -50,7 +50,7 @@ lazy_static! { /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. #[macro_rules_attribute(impl_serde_type!)] -#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index 25e5fefd9..2f81d4eeb 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,10 +1,10 @@ -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[non_exhaustive] #[macro_rules_attribute(impl_serde_type!)] pub struct CharDelimiterSplit { diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index 65c7c30bb..5fb76a6e4 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,10 +1,10 @@ -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] /// Pre tokenizes the numbers into single tokens. If individual_digits is set /// to true, then all digits are splitted into individual tokens. #[non_exhaustive] diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index a0e9ddc80..f0b1ba28c 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,8 +1,8 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. -#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, AutoDisplay)] +#[derive(AutoDebug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy, AutoDisplay)] #[serde(rename_all = "snake_case")] pub enum PrependScheme { /// Specifies that the scheme should be prepended only once, on the first split. @@ -13,11 +13,10 @@ pub enum PrependScheme { Always, } -#[derive(Debug, Clone, PartialEq, Serialize, Eq, AutoDisplay)] +#[derive(AutoDebug, Clone, PartialEq, Serialize, Eq, AutoDisplay)] /// Replaces all the whitespaces by the provided meta character and then /// splits on this character #[serde(tag = "type")] - pub struct Metaspace { replacement: char, pub prepend_scheme: PrependScheme, diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index b7388e7aa..788a1c92e 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -24,7 +24,7 @@ use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; -#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDisplay, Debug)] +#[derive(Deserialize, Serialize, Clone, PartialEq, AutoDebug, AutoDisplay)] #[serde(untagged)] #[format(fmt = "pre_tokenizers.{}")] pub enum PreTokenizerWrapper { @@ -33,7 +33,7 @@ pub enum PreTokenizerWrapper { Delimiter(CharDelimiterSplit), Metaspace(Metaspace), Whitespace(Whitespace), - #[format(fmt="{}")] + #[format(fmt = "{}")] Sequence(Sequence), Split(Split), Punctuation(Punctuation), diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index fab237586..dbfd1b29a 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; @@ -9,7 +9,7 @@ fn is_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Punctuation { #[serde(default = "default_split")] diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index c57076d9f..98a6a06c1 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,11 +1,11 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, PartialEq, AutoDisplay, Debug)] +#[derive(Clone, PartialEq, AutoDisplay, AutoDebug)] pub struct Sequence { #[format] pub pretokenizers: Vec, diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index eaa3a16e7..55b9b3d2a 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -2,7 +2,7 @@ use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::SysRegex; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Deserializer, Serialize}; /// Represents the different patterns that `Split` can use @@ -24,7 +24,7 @@ impl From<&str> for SplitPattern { } } -#[derive(Debug, Serialize, AutoDisplay)] +#[derive(AutoDebug, Serialize, AutoDisplay)] #[serde(tag = "type")] pub struct Split { pattern: SplitPattern, diff --git a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs index 7fa905a39..df4c2e794 100644 --- a/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs +++ b/tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs @@ -1,10 +1,10 @@ -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script}; use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result}; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct UnicodeScripts; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index d7f044a0e..cd38ee445 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,4 +1,4 @@ -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use regex::Regex; use crate::tokenizer::{ @@ -6,7 +6,7 @@ use crate::tokenizer::{ }; use crate::utils::macro_rules_attribute; -#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct Whitespace; @@ -29,7 +29,7 @@ impl PreTokenizer for Whitespace { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Copy, Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] pub struct WhitespaceSplit; diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index eb7bffd00..9fd1c91e6 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,10 +1,10 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Serialize, Deserialize, Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] #[serde(tag = "type")] pub struct BertProcessing { sep: (String, u32), diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index c21ef6abd..c68651fa8 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -12,10 +12,10 @@ use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; use crate::processors::template::TemplateProcessing; use crate::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq, AutoDisplay)] +#[derive(Serialize, Deserialize, PartialEq, AutoDebug, Clone, Eq, AutoDisplay)] #[format(fmt="{}")] #[serde(untagged)] pub enum PostProcessorWrapper { diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index f74fb5009..1dc52d1e2 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,11 +1,11 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, AutoDisplay)] +#[derive(Serialize, Deserialize, AutoDebug, Clone, PartialEq, Eq, AutoDisplay)] #[serde(tag = "type")] pub struct RobertaProcessing { sep: (String, u32), diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index d0f5db6a7..6be62720d 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -1,10 +1,10 @@ use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; #[macro_rules_attribute(impl_serde_type!)] -#[derive(Clone, Debug, PartialEq, Eq, AutoDisplay)] +#[derive(Clone, AutoDebug, PartialEq, Eq, AutoDisplay)] pub struct Sequence { processors: Vec, } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index f30da61ae..2ef977041 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -57,7 +57,7 @@ //! use crate::{Encoding, PostProcessor, Result}; use itertools::Itertools; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; @@ -249,7 +249,7 @@ impl SpecialToken { /// /// [`Piece`]: enum.Piece.html /// -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] +#[derive(AutoDebug, Clone, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] #[serde(transparent)] pub struct Template(Vec); @@ -332,7 +332,7 @@ impl From> for Tokens { /// .unwrap(); /// ``` /// -#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, AutoDisplay)] +#[derive(AutoDebug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq, AutoDisplay)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] #[builder(build_fn(validate = "Self::validate"))] #[format(fmt = "TemplateProcessing: {}")] diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index fb6e288d2..03c569aff 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,7 +2,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet}; /// like: /// - Whether they should only match single words /// - Whether to include any whitespace on its left or right -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, AutoDisplay)] +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, AutoDisplay, AutoDebug)] pub struct AddedToken { /// The content of the added token pub content: String, @@ -139,14 +139,14 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// were to add new tokens after this training process, we couldn't make sure the merges pairs /// exist as required. /// -#[derive(Clone, Debug, AutoDisplay)] +#[derive(Clone, AutoDisplay, AutoDebug)] pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. added_tokens_map: HashMap, /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. - #[format] + #[format(fmt = "added_token_decoder={}")] added_tokens_map_r: HashMap, /// Contains only the classic AddedToken, in the specific order the user gave them. diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 3f5684e21..71fa75e62 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -21,11 +21,10 @@ extern crate rayon; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; -use pyo3_special_method_derive_0_21::{AutoDisplay, PyDisplay}; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDebug, PyDisplay}; use rayon::current_thread_index; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; - mod added_vocabulary; mod encoding; pub mod normalizer; @@ -509,7 +508,7 @@ impl DerefMut for Tokenizer { pub struct TruncationParamError(String); /// A `Tokenizer` is capable of encoding/decoding any text. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct TokenizerImpl { // Tokenizer parts normalizer: Option, @@ -526,6 +525,29 @@ pub struct TokenizerImpl { padding: Option, } +impl PyDebug for TokenizerImpl +where + M: PyDebug, + N: PyDebug, + PT: PyDebug, + PP: PyDebug, + D: PyDebug, +{ + fn fmt_debug(&self) -> std::string::String { + format!( + "Tokenizer(normalizer={}, pre_tokenizer={}, model={}, post_processor={}, decoder={}, added_tokens_decoder={}, truncation={}, padding={})", + self.normalizer.fmt_debug(), + self.pre_tokenizer.fmt_debug(), + self.model.fmt_debug(), + self.post_processor.fmt_debug(), + self.decoder.fmt_debug(), + self.added_vocabulary.fmt_debug(), + self.truncation.fmt_debug(), + self.padding.fmt_debug() + ) + } +} + impl PyDisplay for TokenizerImpl where M: PyDisplay, @@ -536,27 +558,27 @@ where { fn fmt_display(&self) -> std::string::String { let normalizer_str = match &self.normalizer { - Some(n) => format!("{}", n.fmt_display()), + Some(n) => n.fmt_display().to_string(), None => "None".to_string(), }; let pre_tokenizer_str = match &self.pre_tokenizer { - Some(pt) => format!("{}", pt.fmt_display()), + Some(pt) => pt.fmt_display().to_string(), None => "None".to_string(), }; let post_processor_str = match &self.post_processor { - Some(pp) => format!("{}", pp.fmt_display()), + Some(pp) => pp.fmt_display().to_string(), None => "None".to_string(), }; let decoder_str = match &self.decoder { - Some(d) => format!("{}", d.fmt_display()), + Some(d) => d.fmt_display().to_string(), None => "None".to_string(), }; let truncation_str = match &self.truncation { - Some(t) => format!("{}", t.fmt_display()), + Some(t) => t.fmt_display().to_string(), None => "None".to_string(), }; let padding_str = match &self.padding { - Some(p) => format!("{}", p.fmt_display()), + Some(p) => p.fmt_display().to_string(), None => "None".to_string(), }; @@ -573,6 +595,7 @@ where ) } } + impl TokenizerImpl where M: Model, @@ -1400,4 +1423,3 @@ mod tests { assert_eq!(decoded.unwrap(), "Hey! how is this token: д") } } - diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 6481fc096..318951398 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,10 +1,10 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; /// The various possible padding directions. -#[derive(Debug, Clone, Copy, Serialize, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Clone, Copy, Serialize, Deserialize, AutoDisplay)] pub enum PaddingDirection { Left, Right, @@ -19,7 +19,7 @@ impl std::convert::AsRef for PaddingDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Clone, Serialize, Deserialize, AutoDisplay)] pub struct PaddingParams { pub strategy: PaddingStrategy, pub direction: PaddingDirection, @@ -42,7 +42,7 @@ impl Default for PaddingParams { } } -#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Clone, Serialize, Deserialize, AutoDisplay)] pub enum PaddingStrategy { BatchLongest, Fixed(usize), diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 5780f33bf..b11766bba 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,10 +1,10 @@ use crate::tokenizer::{Encoding, Result}; -use pyo3_special_method_derive_0_21::AutoDisplay; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default, AutoDisplay)] +#[derive(AutoDebug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default, AutoDisplay)] pub enum TruncationDirection { Left, #[default] @@ -20,7 +20,7 @@ impl std::convert::AsRef for TruncationDirection { } } -#[derive(Debug, Clone, Serialize, Deserialize, AutoDisplay)] +#[derive(AutoDebug, Clone, Serialize, Deserialize, AutoDisplay)] pub struct TruncationParams { #[serde(default)] pub direction: TruncationDirection, @@ -51,7 +51,7 @@ pub enum TruncationError { SequenceTooShort, } -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] +#[derive(AutoDebug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, AutoDisplay)] pub enum TruncationStrategy { LongestFirst, OnlyFirst, From c30df0c2b797085182dce759169dfcf45dd9b6bc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 28 Jul 2024 10:07:31 +0200 Subject: [PATCH 89/94] remove dict and dir as it's gonna be a bit more involved --- bindings/python/src/processors.rs | 2 +- bindings/python/src/tokenizer.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index bf568d7f4..2ce1abd84 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -6,7 +6,7 @@ use crate::error::ToPyResult; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::{AutoDisplay, Str, Repr}; +use pyo3_special_method_derive_0_21::{Repr, Str}; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; use tk::processors::byte_level::ByteLevel; diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index cd3bfa682..953ae0be0 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -16,7 +16,7 @@ use pyo3::exceptions; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str}; +use pyo3_special_method_derive_0_21::{Repr, Str}; use std::collections::BTreeMap; use tk::models::bpe::BPE; use tk::tokenizer::{ @@ -462,11 +462,11 @@ type Tokenizer = TokenizerImpl Date: Tue, 30 Jul 2024 11:06:03 +0200 Subject: [PATCH 90/94] remove pub where it is not necessary --- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/normalizers/mod.rs | 4 ++-- tokenizers/src/normalizers/replace.rs | 3 ++- tokenizers/src/processors/template.rs | 8 +++++--- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index c1c722568..ec9479847 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, PyDisplay}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 639dc8ffa..cbd98a07e 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -17,7 +17,7 @@ pub use crate::normalizers::utils::{Lowercase, Sequence}; use serde::{Deserialize, Serialize}; use crate::{NormalizedString, Normalizer}; -use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; /// Wrapper for known Normalizers. #[derive(Clone, Deserialize, Serialize, AutoDisplay, AutoDebug)] @@ -34,7 +34,7 @@ pub enum NormalizerWrapper { Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), - #[format(skip)] + #[format(fmt = "Precompiled")] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index f371f743c..df6d4c005 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -5,7 +5,7 @@ use crate::utils::SysRegex; use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +#[derive(AutoDebug, AutoDisplay, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum ReplacePattern { String(String), Regex(String), @@ -49,6 +49,7 @@ pub struct Replace { #[format] content: String, #[serde(skip)] + #[format(skip)] regex: SysRegex, } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 2ef977041..11d2471b2 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -338,12 +338,14 @@ impl From> for Tokens { #[format(fmt = "TemplateProcessing: {}")] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] - pub single: Template, + #[format] + single: Template, #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")] - pub pair: Template, + #[format] + pair: Template, #[builder(setter(skip), default = "self.default_added(true)")] #[serde(skip)] - pub added_single: usize, + added_single: usize, #[builder(setter(skip), default = "self.default_added(false)")] #[format] #[serde(skip)] From a99c6457e662d82968c79dd733bfb1bf52016e01 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 30 Jul 2024 11:39:44 +0200 Subject: [PATCH 91/94] fmt = --- bindings/python/src/normalizers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index edc803bec..6444f75eb 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -480,7 +480,7 @@ impl PyNmt { /// Don't use manually it is used for compatiblity for SentencePiece. #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Precompiled")] #[derive(Str)] -#[format("PreCompiled")] +#[format(fmt="PreCompiled")] pub struct PyPrecompiled {} #[pymethods] From 902247028b55afff3abcf561513a4687ad3dafae Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 2 Aug 2024 12:25:16 +0200 Subject: [PATCH 92/94] formating --- bindings/python/py_src/tokenizers/__init__.pyi | 10 ---------- bindings/python/src/decoders.rs | 7 +++---- bindings/python/src/normalizers.rs | 4 ++-- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/models/mod.rs | 2 +- tokenizers/src/models/wordlevel/mod.rs | 2 +- tokenizers/src/models/wordpiece/mod.rs | 2 +- tokenizers/src/normalizers/bert.rs | 2 +- tokenizers/src/normalizers/byte_level.rs | 2 +- tokenizers/src/normalizers/strip.rs | 2 +- tokenizers/src/normalizers/unicode.rs | 2 +- tokenizers/src/normalizers/utils.rs | 2 +- tokenizers/src/pre_tokenizers/mod.rs | 1 - tokenizers/src/processors/mod.rs | 7 +------ 14 files changed, 15 insertions(+), 32 deletions(-) diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 06b8621e0..5dbc665dc 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -725,16 +725,6 @@ class Tokenizer: """ pass - @property - def added_tokens_decoder(self): - """ - Get the underlying vocabulary - - Returns: - :obj:`Dict[int, AddedToken]`: The vocabulary - """ - pass - def decode(self, ids, skip_special_tokens=True): """ Decode the given list of ids back to a string diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 478e9faf1..781477994 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -5,7 +5,7 @@ use crate::utils::PyPattern; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::{AutoDisplay, Repr, Str, AutoDebug}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, Repr, Str}; use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::decoders::bpe::BPEDecoder; @@ -30,10 +30,10 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize, Str, Repr)] -#[format(fmt="{}")] +#[format(fmt = "{}")] pub struct PyDecoder { #[serde(flatten)] - #[format(fmt="{}")] + #[format(fmt = "{}")] pub(crate) decoder: PyDecoderWrapper, } @@ -487,7 +487,6 @@ pub(crate) struct CustomDecoder { pub inner: PyObject, } - impl CustomDecoder { pub(crate) fn new(inner: PyObject) -> Self { CustomDecoder { inner } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 6444f75eb..f9a784cad 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -5,7 +5,7 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use pyo3_special_method_derive_0_21::{AutoDisplay,AutoDebug, Dict, Dir, Repr, Str}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, Dict, Dir, Repr, Str}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ @@ -480,7 +480,7 @@ impl PyNmt { /// Don't use manually it is used for compatiblity for SentencePiece. #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Precompiled")] #[derive(Str)] -#[format(fmt="PreCompiled")] +#[format(fmt = "PreCompiled")] pub struct PyPrecompiled {} #[pymethods] diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index aa369572c..a0a270536 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -26,7 +26,7 @@ use pyo3_special_method_derive_0_21::AutoDisplay; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Clone, AutoDebug, AutoDisplay)] -#[format(fmt="decoders.{}")] +#[format(fmt = "decoders.{}")] #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index ec9479847..68146045f 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -60,7 +60,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { #[derive(Deserialize, Serialize, AutoDebug, PartialEq, Clone, AutoDisplay)] #[serde(untagged)] -#[format(fmt = "{}")] // TODO by default this should define the finale render {}, {} {} . or {}{}{} +#[format(fmt = "models.{}")] // TODO by default this should define the finale render {}, {} {} . or {}{}{} pub enum ModelWrapper { BPE(BPE), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 8dd2335d5..2cf9057a2 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive_0_21::{AutoDisplay,PyDebug}; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDebug}; use serde_json::Value; use std::collections::HashMap; use std::fs::File; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 732aafb5e..a3b2997ce 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use pyo3_special_method_derive_0_21::{AutoDisplay,PyDebug}; +use pyo3_special_method_derive_0_21::{AutoDisplay, PyDebug}; use std::{ borrow::Cow, collections::HashMap, diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 89399b022..9cf9a5e2b 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use pyo3_special_method_derive_0_21::{AutoDisplay,AutoDebug}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index d01f713cc..95f38dd96 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -1,8 +1,8 @@ use crate::processors::byte_level::bytes_char; use crate::tokenizer::{NormalizedString, Normalizer, Result}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; -use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; #[derive(Clone, AutoDebug, Deserialize, Serialize, AutoDisplay)] #[serde(tag = "type")] pub struct ByteLevel {} diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index e597975b9..ec3b83a68 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; use unicode_normalization_alignments::char::is_combining_mark; #[derive(Copy, Clone, AutoDebug, Deserialize, Serialize, AutoDisplay)] diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 47498498e..4a2498722 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; #[derive(Default, Copy, Clone, AutoDebug, AutoDisplay)] #[macro_rules_attribute(impl_serde_type!)] diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index 3b4c01933..77359e608 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use pyo3_special_method_derive_0_21::{AutoDebug,AutoDisplay}; +use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; #[derive(Clone, Deserialize, AutoDebug, Serialize, AutoDisplay)] #[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 788a1c92e..1b35a3971 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -33,7 +33,6 @@ pub enum PreTokenizerWrapper { Delimiter(CharDelimiterSplit), Metaspace(Metaspace), Whitespace(Whitespace), - #[format(fmt = "{}")] Sequence(Sequence), Split(Split), Punctuation(Punctuation), diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index c68651fa8..266a23051 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -16,19 +16,14 @@ use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, AutoDebug, Clone, Eq, AutoDisplay)] -#[format(fmt="{}")] +#[format(fmt = "processors.{}")] #[serde(untagged)] pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) - #[format(fmt="{}")] Roberta(RobertaProcessing), - #[format(fmt="{}")] Bert(BertProcessing), - #[format(fmt="{}")] ByteLevel(ByteLevel), - #[format(fmt="{}")] Template(TemplateProcessing), - #[format(fmt="{}")] Sequence(Sequence), } From 64b8df02412544c67344f802675e424544205af7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 2 Aug 2024 12:27:57 +0200 Subject: [PATCH 93/94] remove non needed fm --- bindings/python/src/decoders.rs | 3 --- bindings/python/src/models.rs | 1 - bindings/python/src/normalizers.rs | 1 - bindings/python/src/pre_tokenizers.rs | 1 - bindings/python/src/processors.rs | 1 - bindings/python/src/tokenizer.rs | 1 - 6 files changed, 8 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 781477994..c6eed6f28 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -33,7 +33,6 @@ use super::error::ToPyResult; #[format(fmt = "{}")] pub struct PyDecoder { #[serde(flatten)] - #[format(fmt = "{}")] pub(crate) decoder: PyDecoderWrapper, } @@ -539,9 +538,7 @@ impl<'de> Deserialize<'de> for CustomDecoder { #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyDecoderWrapper { - #[format(fmt = "{}")] Custom(Arc>), - #[format(fmt = "{}")] Wrapped(Arc>), } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index ba12bb7ab..02d838378 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -29,7 +29,6 @@ use tokenizers as tk; #[format(fmt = "{}")] pub struct PyModel { #[serde(flatten)] - #[format(fmt = "{}")] pub model: Arc>, } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index f9a784cad..ac3714ee4 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -47,7 +47,6 @@ impl PyNormalizedStringMut<'_> { #[format(fmt = "{}")] pub struct PyNormalizer { #[serde(flatten)] - #[format(fmt = "{}")] pub(crate) normalizer: PyNormalizerTypeWrapper, } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index e492df094..740471787 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -38,7 +38,6 @@ use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, Dict, Dir, Repr, S #[format(fmt = "{}")] // don't format the Py wrapper pub struct PyPreTokenizer { #[serde(flatten)] - #[format(fmt = "{}")] // format only pretok, not pretok = pretok: PyPreTokenizerTypeWrapper, } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 2ce1abd84..aee5a142b 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -31,7 +31,6 @@ use tokenizers as tk; #[format(fmt = "{}")] pub struct PyPostProcessor { #[serde(flatten)] - #[format(fmt = "{}")] pub processor: Arc, } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 953ae0be0..68cf2100d 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -465,7 +465,6 @@ type Tokenizer = TokenizerImpl Date: Fri, 2 Aug 2024 13:06:35 +0200 Subject: [PATCH 94/94] so we only need format when the visibility is not pub but pub(crate) --- bindings/python/src/decoders.rs | 1 + bindings/python/src/normalizers.rs | 5 ++--- bindings/python/src/pre_tokenizers.rs | 5 +---- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index c6eed6f28..c5d2365e5 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -33,6 +33,7 @@ use super::error::ToPyResult; #[format(fmt = "{}")] pub struct PyDecoder { #[serde(flatten)] + #[format] pub(crate) decoder: PyDecoderWrapper, } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index ac3714ee4..0d8aa0edd 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -43,10 +43,11 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize, Str, Repr, Dir, Dict)] +#[derive(Clone, Serialize, Deserialize, Str, Repr, Dir)] #[format(fmt = "{}")] pub struct PyNormalizer { #[serde(flatten)] + #[format] pub(crate) normalizer: PyNormalizerTypeWrapper, } @@ -565,9 +566,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer { #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyNormalizerWrapper { - #[format(fmt = "{}")] Custom(CustomNormalizer), - #[format(fmt = "{}")] Wrapped(NormalizerWrapper), } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 740471787..43787d053 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -38,6 +38,7 @@ use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, Dict, Dir, Repr, S #[format(fmt = "{}")] // don't format the Py wrapper pub struct PyPreTokenizer { #[serde(flatten)] + #[format] pretok: PyPreTokenizerTypeWrapper, } @@ -638,9 +639,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyPreTokenizerWrapper { - #[format(fmt = "{}")] Custom(CustomPreTokenizer), - #[format(fmt = "{}")] Wrapped(PreTokenizerWrapper), } @@ -660,9 +659,7 @@ impl Serialize for PyPreTokenizerWrapper { #[serde(untagged)] #[format(fmt = "{}")] pub(crate) enum PyPreTokenizerTypeWrapper { - #[format(fmt = "{}")] Sequence(Vec>>), - #[format(fmt = "{}")] Single(Arc>), }