-
Notifications
You must be signed in to change notification settings - Fork 819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add display capabilities to tokenizers objects #1542
Changes from all commits
61804d9
a56da5f
f1a6a97
4a49530
f4af616
88630dc
b9d44da
a90ec22
2224275
4d9204e
4c2aca1
904ce70
6413810
fda66f5
20c9fc4
86c77b6
a429642
3cec010
8d77286
e48cd3a
27576e5
0d9a452
35373de
1c6d272
df51116
4b4b833
59a89c9
ac9b849
a3f7439
cf5b6f3
5d33243
b73c43d
3e16df7
b214d77
0654831
a15e3cc
6023192
ebf1258
477a9b5
93a1e63
7591f2b
f50e4e0
4f15052
85c7b69
0a16ca0
35d442d
a3cc764
5b20fa7
15f877e
9c45e8f
4a34870
e0d35e0
fe95add
2770099
3d0eb0a
11a3601
f6fa136
2a54482
998b2a3
4df6cc2
a9c6c61
aefdc91
f67af9c
4c3f37a
292475f
99cb054
5c930e9
f87bb97
c4b4f3c
e712079
5540136
ba03c16
d0e741b
19afb66
9559dea
e53f4ca
18238dd
269ff21
e799602
93ad593
3aa0138
011340b
3fc31d0
51d3f61
acb8196
951b6e6
c2a320c
2048c02
104fe0c
7db6109
c7cd927
e4cf65a
39ffc28
0a3bb18
c436b23
e5b059f
ff825a7
c30df0c
b78e11c
a99c645
9022470
64b8df0
27cad45
ceabef3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ use crate::utils::PyPattern; | |
use pyo3::exceptions; | ||
use pyo3::prelude::*; | ||
use pyo3::types::*; | ||
use pyo3_special_method_derive_0_21::{AutoDebug, AutoDisplay, Repr, Str}; | ||
use serde::de::Error; | ||
use serde::{Deserialize, Deserializer, Serialize, Serializer}; | ||
use tk::decoders::bpe::BPEDecoder; | ||
|
@@ -28,9 +29,11 @@ use super::error::ToPyResult; | |
/// This class is not supposed to be instantiated directly. Instead, any implementation of | ||
/// a Decoder will return an instance of this class when instantiated. | ||
#[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] | ||
#[derive(Clone, Deserialize, Serialize)] | ||
#[derive(Clone, Deserialize, Serialize, Str, Repr)] | ||
#[format(fmt = "{}")] | ||
pub struct PyDecoder { | ||
#[serde(flatten)] | ||
#[format] | ||
pub(crate) decoder: PyDecoderWrapper, | ||
} | ||
|
||
|
@@ -478,9 +481,10 @@ impl PySequenceDecoder { | |
} | ||
} | ||
|
||
#[derive(Clone)] | ||
#[derive(Clone, AutoDisplay, AutoDebug)] | ||
pub(crate) struct CustomDecoder { | ||
inner: PyObject, | ||
#[format(skip)] | ||
pub inner: PyObject, | ||
Comment on lines
+486
to
+487
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not implemented yet so skipping for now |
||
} | ||
|
||
impl CustomDecoder { | ||
|
@@ -531,8 +535,9 @@ impl<'de> Deserialize<'de> for CustomDecoder { | |
} | ||
} | ||
|
||
#[derive(Clone, Deserialize, Serialize)] | ||
#[derive(Clone, Deserialize, Serialize, AutoDisplay, AutoDebug)] | ||
#[serde(untagged)] | ||
#[format(fmt = "{}")] | ||
pub(crate) enum PyDecoderWrapper { | ||
Custom(Arc<RwLock<CustomDecoder>>), | ||
Wrapped(Arc<RwLock<DecoderWrapper>>), | ||
Comment on lines
+540
to
543
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will directly display |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,23 @@ | ||
use std::collections::{hash_map::DefaultHasher, HashMap}; | ||
use std::hash::{Hash, Hasher}; | ||
|
||
use super::decoders::PyDecoder; | ||
use super::encoding::PyEncoding; | ||
use super::error::{PyError, ToPyResult}; | ||
use super::models::PyModel; | ||
use super::normalizers::PyNormalizer; | ||
use super::pre_tokenizers::PyPreTokenizer; | ||
use super::trainers::PyTrainer; | ||
use crate::processors::PyPostProcessor; | ||
use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; | ||
use numpy::{npyffi, PyArray1}; | ||
use pyo3::class::basic::CompareOp; | ||
use pyo3::exceptions; | ||
use pyo3::intern; | ||
use pyo3::prelude::*; | ||
use pyo3::types::*; | ||
use pyo3_special_method_derive_0_21::{Repr, Str}; | ||
use std::collections::BTreeMap; | ||
use tk::models::bpe::BPE; | ||
use tk::tokenizer::{ | ||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, | ||
|
@@ -15,17 +26,6 @@ use tk::tokenizer::{ | |
use tk::utils::iter::ResultShunt; | ||
use tokenizers as tk; | ||
|
||
use super::decoders::PyDecoder; | ||
use super::encoding::PyEncoding; | ||
use super::error::{PyError, ToPyResult}; | ||
use super::models::PyModel; | ||
use super::normalizers::PyNormalizer; | ||
use super::pre_tokenizers::PyPreTokenizer; | ||
use super::trainers::PyTrainer; | ||
use crate::processors::PyPostProcessor; | ||
use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; | ||
use std::collections::BTreeMap; | ||
|
||
/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`. | ||
/// It can have special options that defines the way it should behave. | ||
/// | ||
|
@@ -462,9 +462,10 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc | |
/// The core algorithm that this :obj:`Tokenizer` should be using. | ||
/// | ||
#[pyclass(dict, module = "tokenizers", name = "Tokenizer")] | ||
#[derive(Clone)] | ||
#[derive(Clone, Str, Repr)] | ||
#[format(fmt = "{}")] | ||
pub struct PyTokenizer { | ||
tokenizer: Tokenizer, | ||
pub tokenizer: Tokenizer, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a requirement |
||
} | ||
|
||
impl PyTokenizer { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ fancy-regex = { version = "0.13", optional = true} | |
getrandom = { version = "0.2.10" } | ||
esaxx-rs = { version = "0.1.10", default-features = false, features=[]} | ||
monostate = "0.1.12" | ||
pyo3_special_method_derive_0_21 = {path = "../../pyo3-special-method-derive/pyo3_special_method_derive_0_21"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not forget to remove |
||
|
||
[features] | ||
default = ["progressbar", "onig", "esaxx_fast"] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
visibility here forces us to add format