Skip to content

Commit

Permalink
what works a bit ?
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jul 25, 2024
1 parent 7db6109 commit c7cd927
Show file tree
Hide file tree
Showing 27 changed files with 125 additions and 133 deletions.
2 changes: 1 addition & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ onig = { version = "6.4", default-features = false }
itertools = "0.12"
derive_more = "0.99.17"
pyo3 = { version = "0.21", features = ["multiple-pymethods"] }
pyo3_special_method_derive_0_21 = "0.4"
pyo3_special_method_derive_0_21 = {path = "../../../pyo3-special-method-derive/pyo3_special_method_derive_0_21"}

[dependencies.tokenizers]
path = "../../tokenizers"
Expand Down
28 changes: 6 additions & 22 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use pyo3::prelude::*;
use pyo3::types::*;
use pyo3_special_method_derive_0_21::AutoDisplay;
use pyo3_special_method_derive_0_21::PyDebug;
use pyo3_special_method_derive_0_21::PyDisplay;
use pyo3_special_method_derive_0_21::Str;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::decoders::bpe::BPEDecoder;
Expand All @@ -31,9 +31,11 @@ use super::error::ToPyResult;
/// This class is not supposed to be instantiated directly. Instead, any implementation of
/// a Decoder will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)]
#[derive(Clone, Deserialize, Serialize, AutoDisplay)]
#[derive(Clone, Deserialize, Serialize, Str)]
#[format(fmt="")]
pub struct PyDecoder {
#[serde(flatten)]
#[format(fmt="{}")]
pub(crate) decoder: PyDecoderWrapper,
}

Expand Down Expand Up @@ -117,14 +119,6 @@ impl PyDecoder {
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
}

fn __str__(&self) -> PyResult<String> {
Ok(format!("{}", self.decoder))
}

fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.decoder))
}
}

macro_rules! getter {
Expand Down Expand Up @@ -489,22 +483,12 @@ impl PySequenceDecoder {
}
}

#[derive(Clone)]
#[derive(Clone, AutoDisplay)]
pub(crate) struct CustomDecoder {
#[format(skip)]
pub inner: PyObject,
}

impl PyDisplay for CustomDecoder {
fn fmt_display(&self) -> String {
"CustomDecoder()".to_string()
}
}

impl PyDebug for CustomDecoder {
fn fmt_debug(&self) -> String {
"CustomDecoder()".to_string()
}
}

impl CustomDecoder {
pub(crate) fn new(inner: PyObject) -> Self {
Expand Down
10 changes: 3 additions & 7 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3_special_method_derive_0_21::AutoDisplay;
use pyo3_special_method_derive_0_21::Str;
use serde::{Deserialize, Serialize};
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
use tk::models::unigram::Unigram;
Expand All @@ -25,9 +26,10 @@ use tokenizers as tk;
///
/// This class cannot be constructed directly. Please use one of the concrete models.
#[pyclass(module = "tokenizers.models", name = "Model", subclass)]
#[derive(Clone, Serialize, Deserialize, AutoDisplay)]
#[derive(Clone, Serialize, Deserialize, Str)]
pub struct PyModel {
#[serde(flatten)]
#[format(fmt="{}")]
pub model: Arc<RwLock<ModelWrapper>>,
}

Expand Down Expand Up @@ -220,12 +222,6 @@ impl PyModel {
fn get_trainer(&self, py: Python<'_>) -> PyResult<PyObject> {
PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py)
}
fn __str__(&self) -> PyResult<String> {
Ok(format!("{}", self.model.read().unwrap()))
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.model.read().unwrap()))
}
}

/// An implementation of the BPE (Byte-Pair Encoding) algorithm
Expand Down
4 changes: 3 additions & 1 deletion bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ impl PyNmt {
/// Don't use manually it is used for compatiblity for SentencePiece.
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Precompiled")]
#[derive(Str)]
#[format("PreCompiled")]
pub struct PyPrecompiled {}

#[pymethods]
impl PyPrecompiled {
#[new]
Expand Down Expand Up @@ -516,7 +518,7 @@ impl PyReplace {

#[derive(Debug, Clone, AutoDisplay)]
pub(crate) struct CustomNormalizer {
#[auto_display]
#[format(fmt="Custom Normalizer")]
inner: PyObject,
}
impl CustomNormalizer {
Expand Down
8 changes: 6 additions & 2 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ use pyo3_special_method_derive_0_21::{AutoDisplay, Dict, Dir, Repr, Str};
#[derive(Clone, Serialize, Deserialize, Str, Dir)]
pub struct PyPreTokenizer {
#[serde(flatten)]
pub pretok: PyPreTokenizerTypeWrapper,
#[format(fmt = "{}")]
pretok: PyPreTokenizerTypeWrapper,
}

impl PyPreTokenizer {
Expand Down Expand Up @@ -426,7 +427,7 @@ impl PyPunctuation {
/// This pre-tokenizer composes other pre_tokenizers and applies them in sequence
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Sequence")]
#[derive(AutoDisplay)]
#[auto_display(fmt = "{self.inner}{}")]
#[format(fmt = "Sequence.{}")]
pub struct PySequence {}
#[pymethods]
impl PySequence {
Expand Down Expand Up @@ -637,6 +638,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer {
#[serde(untagged)]
pub(crate) enum PyPreTokenizerWrapper {
Custom(CustomPreTokenizer),
#[format(fmt = "wrapped:{}")]
Wrapped(PreTokenizerWrapper),
}

Expand All @@ -655,7 +657,9 @@ impl Serialize for PyPreTokenizerWrapper {
#[derive(Clone, Deserialize, AutoDisplay)]
#[serde(untagged)]
pub(crate) enum PyPreTokenizerTypeWrapper {
#[format(fmt = "{}")]
Sequence(Vec<Arc<RwLock<PyPreTokenizerWrapper>>>),
#[format(fmt = "{}")]
Single(Arc<RwLock<PyPreTokenizerWrapper>>),
}

Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use tokenizers as tk;
subclass
)]
#[derive(Clone, Deserialize, Serialize, AutoDisplay)]
#[format(fmt="post processor: {}.{}")]
pub struct PyPostProcessor {
#[serde(flatten)]
pub processor: Arc<PostProcessorWrapper>,
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
#[pyclass(dict, module = "tokenizers", name = "Tokenizer")]
#[derive(Clone, Str, Dict, Dir)]
pub struct PyTokenizer {
#[format(fmt="{}")]
tokenizer: Tokenizer,
}

Expand Down
1 change: 0 additions & 1 deletion bindings/python/src/utils/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use pyo3::prelude::*;
use pyo3::types::*;
use tk::normalizer::{char_to_bytes, NormalizedString, Range, SplitDelimiterBehavior};
use tk::pattern::Pattern;

/// Represents a Pattern as used by `NormalizedString`
#[derive(Clone, FromPyObject)]
pub enum PyPattern {
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fancy-regex = { version = "0.13", optional = true}
getrandom = { version = "0.2.10" }
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.12"
pyo3_special_method_derive_0_21 = "0.4"
pyo3_special_method_derive_0_21 = {path = "../../pyo3-special-method-derive/pyo3_special_method_derive_0_21"}

[features]
default = ["progressbar", "onig", "esaxx_fast"]
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/decoders/byte_fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
/// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get � instead for each inconvertable byte token
#[non_exhaustive]
#[auto_display(fmt = "ByteFallback")]
#[format(fmt = "ByteFallback")]
pub struct ByteFallback {
#[serde(rename = "type")]
type_: MustBe!("ByteFallback"),
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/decoders/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
/// decoder exists incase some decoders need to happen after that
/// step
#[non_exhaustive]
#[auto_display(fmt = "Fuse")]
#[format(fmt = "Fuse")]
pub struct Fuse {
#[serde(rename = "type")]
type_: MustBe!("Fuse"),
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/decoders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use pyo3_special_method_derive_0_21::AutoDisplay;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Clone, Debug, AutoDisplay)]
#[auto_display(fmt="decoders.{}")]
#[format(fmt="decoders.{}")]
#[serde(untagged)]
pub enum DecoderWrapper {
BPE(BPEDecoder),
Expand Down
102 changes: 52 additions & 50 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY};
use crate::utils::iter::ResultShunt;
use pyo3_special_method_derive_0_21::AutoDisplay;
use serde_json::Value;
use std::borrow::Cow;
use std::{
Expand Down Expand Up @@ -204,9 +205,10 @@ impl BpeBuilder {
}

/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
#[derive(PartialEq)]
#[derive(PartialEq, AutoDisplay)]
pub struct BPE {
/// The vocabulary assigns a number to each token.
#[format(skip)]
pub(crate) vocab: Vocab,
/// Reversed vocabulary, to rebuild sentences.
pub(crate) vocab_r: VocabR,
Expand Down Expand Up @@ -248,55 +250,55 @@ impl std::fmt::Debug for BPE {
}
}

impl std::fmt::Display for BPE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut vocab_vec: Vec<_> = self.vocab.iter().collect();
vocab_vec.sort_by_key(|&(_, v)| v);
vocab_vec.truncate(5);

let vocab_str: String = vocab_vec
.iter()
.map(|(k, v)| format!("'{}':{}", k, v))
.collect::<Vec<_>>()
.join(", ");

let mut merges_vec: Vec<_> = self.merges.iter().collect();
merges_vec.truncate(5);
merges_vec.sort_by_key(|&(_, v)| v);

let merges_str: String = merges_vec
.iter()
.map(|((id1, id2), _)| {
(
self.vocab_r
.get(id1)
.cloned()
.unwrap_or_else(|| id1.to_string()),
self.vocab_r
.get(id2)
.cloned()
.unwrap_or_else(|| id2.to_string()),
)
})
.map(|(id1, id2)| format!("('{}', '{}')", id1, id2))
.collect::<Vec<_>>()
.join(", ");

write!(
f,
"BPE(vocab={{{}, ...}}, merges=[{:?}, ...], dropout={:?}, unk_token={:?}, continuing_subword_prefix={:?}, end_of_word_suffix={:?}, fuse_unk={}, byte_fallback={}, ignore_merges={})",
vocab_str,
merges_str,
self.dropout,
self.unk_token,
self.continuing_subword_prefix,
self.end_of_word_suffix,
self.fuse_unk,
self.byte_fallback,
self.ignore_merges
)
}
}
// impl std::fmt::Display for BPE {
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// let mut vocab_vec: Vec<_> = self.vocab.iter().collect();
// vocab_vec.sort_by_key(|&(_, v)| v);
// vocab_vec.truncate(5);

// let vocab_str: String = vocab_vec
// .iter()
// .map(|(k, v)| format!("'{}':{}", k, v))
// .collect::<Vec<_>>()
// .join(", ");

// let mut merges_vec: Vec<_> = self.merges.iter().collect();
// merges_vec.truncate(5);
// merges_vec.sort_by_key(|&(_, v)| v);

// let merges_str: String = merges_vec
// .iter()
// .map(|((id1, id2), _)| {
// (
// self.vocab_r
// .get(id1)
// .cloned()
// .unwrap_or_else(|| id1.to_string()),
// self.vocab_r
// .get(id2)
// .cloned()
// .unwrap_or_else(|| id2.to_string()),
// )
// })
// .map(|(id1, id2)| format!("('{}', '{}')", id1, id2))
// .collect::<Vec<_>>()
// .join(", ");

// write!(
// f,
// "BPE(vocab={{{}, ...}}, merges=[{:?}, ...], dropout={:?}, unk_token={:?}, continuing_subword_prefix={:?}, end_of_word_suffix={:?}, fuse_unk={}, byte_fallback={}, ignore_merges={})",
// vocab_str,
// merges_str,
// self.dropout,
// self.unk_token,
// self.continuing_subword_prefix,
// self.end_of_word_suffix,
// self.fuse_unk,
// self.byte_fallback,
// self.ignore_merges
// )
// }
// }

impl Default for BPE {
fn default() -> Self {
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod wordpiece;
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use pyo3_special_method_derive_0_21::AutoDisplay;
use pyo3_special_method_derive_0_21::{AutoDisplay, PyDisplay};
use serde::{Deserialize, Serialize, Serializer};

use crate::models::bpe::{BpeTrainer, BPE};
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/models/wordlevel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl WordLevelBuilder {

#[derive(PartialEq, Clone, Eq, AutoDisplay)]
pub struct WordLevel {
#[auto_display]
#[format]
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
pub unk_token: String,
Expand Down
4 changes: 2 additions & 2 deletions tokenizers/src/models/wordpiece/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ impl WordPieceBuilder {
/// model.
#[derive(Clone, PartialEq, Eq, AutoDisplay)]
pub struct WordPiece {
#[auto_display]
#[format]
vocab: Vocab,
vocab_r: VocabR,
pub unk_token: String,
pub continuing_subword_prefix: String,
#[auto_display(skip)]
#[format(skip)]
pub max_input_chars_per_word: usize,
}

Expand Down
4 changes: 2 additions & 2 deletions tokenizers/src/normalizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::processors::byte_level::bytes_char;
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};

#[derive(Clone, Debug, Deserialize, Serialize)]
use pyo3_special_method_derive_0_21::AutoDisplay;
#[derive(Clone, Debug, Deserialize, Serialize, AutoDisplay)]
#[serde(tag = "type")]
pub struct ByteLevel {}

Expand Down
Loading

0 comments on commit c7cd927

Please sign in to comment.