Skip to content

Commit

Permalink
PyO3 0.21.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Apr 15, 2024
1 parent 914576f commit e7e7bad
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 79 deletions.
14 changes: 7 additions & 7 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ name = "tokenizers"
crate-type = ["cdylib"]

[dependencies]
rayon = "1.8"
rayon = "1.10"
serde = { version = "1.0", features = [ "rc", "derive" ]}
serde_json = "1.0"
libc = "0.2"
env_logger = "0.10.0"
pyo3 = { version = "0.20" }
numpy = "0.20.0"
env_logger = "0.11"
pyo3 = { version = "0.21" }
numpy = "0.21"
ndarray = "0.15"
onig = { version = "6.4", default-features = false }
itertools = "0.11"
itertools = "0.12"

[dependencies.tokenizers]
version = "0.16.0-dev.0"
path = "../../tokenizers"

[dev-dependencies]
tempfile = "3.8"
pyo3 = { version = "0.20", features = ["auto-initialize"] }
tempfile = "3.10"
pyo3 = { version = "0.21", features = ["auto-initialize"] }

[features]
defaut = ["pyo3/extension-module"]
13 changes: 6 additions & 7 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::sync::{Arc, RwLock};

use crate::pre_tokenizers::from_string;
use crate::utils::PyChar;
use crate::utils::PyPattern;
use pyo3::exceptions;
use pyo3::prelude::*;
Expand Down Expand Up @@ -318,8 +317,8 @@ impl PyMetaspaceDec {
}

#[setter]
fn set_replacement(self_: PyRef<Self>, replacement: PyChar) {
setter!(self_, Metaspace, @set_replacement, replacement.0);
fn set_replacement(self_: PyRef<Self>, replacement: char) {
setter!(self_, Metaspace, @set_replacement, replacement);
}

#[getter]
Expand Down Expand Up @@ -352,16 +351,16 @@ impl PyMetaspaceDec {
}

#[new]
#[pyo3(signature = (replacement = PyChar('▁'), prepend_scheme = String::from("always"), split = true), text_signature = "(self, replacement = \"\", prepend_scheme = \"always\", split = True)")]
#[pyo3(signature = (replacement = '▁', prepend_scheme = String::from("always"), split = true), text_signature = "(self, replacement = \"\", prepend_scheme = \"always\", split = True)")]
fn new(
replacement: PyChar,
replacement: char,
prepend_scheme: String,
split: bool,
) -> PyResult<(Self, PyDecoder)> {
let prepend_scheme = from_string(prepend_scheme)?;
Ok((
PyMetaspaceDec {},
Metaspace::new(replacement.0, prepend_scheme, split).into(),
Metaspace::new(replacement, prepend_scheme, split).into(),
))
}
}
Expand Down Expand Up @@ -602,7 +601,7 @@ mod test {
Python::with_gil(|py| {
let py_dec = PyDecoder::new(Metaspace::default().into());
let py_meta = py_dec.get_as_subtype(py).unwrap();
assert_eq!("Metaspace", py_meta.as_ref(py).get_type().name().unwrap());
assert_eq!("Metaspace", py_meta.as_ref(py).get_type().qualname().unwrap());
})
}

Expand Down
10 changes: 5 additions & 5 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,14 @@ macro_rules! setter {
}

#[derive(FromPyObject)]
enum PyVocab<'a> {
enum PyVocab {
Vocab(Vocab),
Filename(&'a str),
Filename(String),
}
#[derive(FromPyObject)]
enum PyMerges<'a> {
enum PyMerges {
Merges(Merges),
Filename(&'a str),
Filename(String),
}

#[pymethods]
Expand Down Expand Up @@ -870,7 +870,7 @@ mod test {
Python::with_gil(|py| {
let py_model = PyModel::from(BPE::default());
let py_bpe = py_model.get_as_subtype(py).unwrap();
assert_eq!("BPE", py_bpe.as_ref(py).get_type().name().unwrap());
assert_eq!("BPE", py_bpe.as_ref(py).get_type().qualname().unwrap());
})
}

Expand Down
6 changes: 3 additions & 3 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,10 @@ impl PyPrecompiled {
#[new]
#[pyo3(text_signature = "(self, precompiled_charsmap)")]
fn new(py_precompiled_charsmap: &PyBytes) -> PyResult<(Self, PyNormalizer)> {
let precompiled_charsmap: &[u8] = FromPyObject::extract(py_precompiled_charsmap)?;
let precompiled_charsmap: Vec<u8> = FromPyObject::extract(py_precompiled_charsmap)?;
Ok((
PyPrecompiled {},
Precompiled::from(precompiled_charsmap)
Precompiled::from(&precompiled_charsmap)
.map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to build Precompiled normalizer: {}",
Expand Down Expand Up @@ -667,7 +667,7 @@ mod test {
Python::with_gil(|py| {
let py_norm = PyNormalizer::new(NFC.into());
let py_nfc = py_norm.get_as_subtype(py).unwrap();
assert_eq!("NFC", py_nfc.as_ref(py).get_type().name().unwrap());
assert_eq!("NFC", py_nfc.as_ref(py).get_type().qualname().unwrap());
})
}

Expand Down
20 changes: 10 additions & 10 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,16 @@ impl PyCharDelimiterSplit {
}

#[setter]
fn set_delimiter(self_: PyRef<Self>, delimiter: PyChar) {
setter!(self_, Delimiter, delimiter, delimiter.0);
fn set_delimiter(self_: PyRef<Self>, delimiter: char) {
setter!(self_, Delimiter, delimiter, delimiter);
}

#[new]
#[pyo3(text_signature = None)]
pub fn new(delimiter: PyChar) -> PyResult<(Self, PyPreTokenizer)> {
pub fn new(delimiter: char) -> PyResult<(Self, PyPreTokenizer)> {
Ok((
PyCharDelimiterSplit {},
CharDelimiterSplit::new(delimiter.0).into(),
CharDelimiterSplit::new(delimiter).into(),
))
}

Expand Down Expand Up @@ -490,8 +490,8 @@ impl PyMetaspace {
}

#[setter]
fn set_replacement(self_: PyRef<Self>, replacement: PyChar) {
setter!(self_, Metaspace, @set_replacement, replacement.0);
fn set_replacement(self_: PyRef<Self>, replacement: char) {
setter!(self_, Metaspace, @set_replacement, replacement);
}

#[getter]
Expand Down Expand Up @@ -524,15 +524,15 @@ impl PyMetaspace {
}

#[new]
#[pyo3(signature = (replacement = PyChar('▁'), prepend_scheme=String::from("always"), split=true), text_signature = "(self, replacement=\"_\", prepend_scheme=\"always\", split=True)")]
#[pyo3(signature = (replacement = '▁', prepend_scheme=String::from("always"), split=true), text_signature = "(self, replacement=\"_\", prepend_scheme=\"always\", split=True)")]
fn new(
replacement: PyChar,
replacement: char,
prepend_scheme: String,
split: bool,
) -> PyResult<(Self, PyPreTokenizer)> {
// Create a new Metaspace instance
let prepend_scheme = from_string(prepend_scheme)?;
let new_instance: Metaspace = Metaspace::new(replacement.0, prepend_scheme, split);
let new_instance: Metaspace = Metaspace::new(replacement, prepend_scheme, split);
Ok((PyMetaspace {}, new_instance.into()))
}
}
Expand Down Expand Up @@ -754,7 +754,7 @@ mod test {
Python::with_gil(|py| {
let py_norm = PyPreTokenizer::new(Whitespace {}.into());
let py_wsp = py_norm.get_as_subtype(py).unwrap();
assert_eq!("Whitespace", py_wsp.as_ref(py).get_type().name().unwrap());
assert_eq!("Whitespace", py_wsp.as_ref(py).get_type().qualname().unwrap());
})
}

Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ impl FromPyObject<'_> for PyTemplate {
Ok(Self(
s.try_into().map_err(exceptions::PyValueError::new_err)?,
))
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
} else if let Ok(s) = ob.extract::<Vec<String>>() {
Ok(Self(
s.try_into().map_err(exceptions::PyValueError::new_err)?,
))
Expand Down Expand Up @@ -474,7 +474,7 @@ mod test {
let py_bert = py_proc.get_as_subtype(py).unwrap();
assert_eq!(
"BertProcessing",
py_bert.as_ref(py).get_type().name().unwrap()
py_bert.as_ref(py).get_type().qualname().unwrap()
);
})
}
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
return Ok(Self(seq.into()));
}
if let Ok(s) = ob.downcast::<PyList>() {
if let Ok(seq) = s.extract::<Vec<&str>>() {
if let Ok(seq) = s.extract::<Vec<String>>() {
return Ok(Self(seq.into()));
}
}
if let Ok(s) = ob.downcast::<PyTuple>() {
if let Ok(seq) = s.extract::<Vec<&str>>() {
if let Ok(seq) = s.extract::<Vec<String>>() {
return Ok(Self(seq.into()));
}
}
Expand Down
16 changes: 8 additions & 8 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::sync::{Arc, RwLock};

use crate::models::PyModel;
use crate::tokenizer::PyAddedToken;
use crate::utils::PyChar;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Serialize};
use tk::models::TrainerWrapper;
use tk::Trainer;
use std::collections::HashSet;
use tokenizers as tk;

/// Base class for all trainers
Expand Down Expand Up @@ -269,12 +269,12 @@ impl PyBpeTrainer {
}

#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: HashSet<char>) {
setter!(
self_,
BpeTrainer,
initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
alphabet
);
}

Expand Down Expand Up @@ -473,12 +473,12 @@ impl PyWordPieceTrainer {
}

#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: HashSet<char>) {
setter!(
self_,
WordPieceTrainer,
@set_initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
alphabet
);
}

Expand Down Expand Up @@ -801,12 +801,12 @@ impl PyUnigramTrainer {
}

#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: HashSet<char>) {
setter!(
self_,
UnigramTrainer,
initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
alphabet
);
}

Expand Down Expand Up @@ -893,7 +893,7 @@ mod tests {
Python::with_gil(|py| {
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
let py_bpe = py_trainer.get_as_subtype(py).unwrap();
assert_eq!("BpeTrainer", py_bpe.as_ref(py).get_type().name().unwrap());
assert_eq!("BpeTrainer", py_bpe.as_ref(py).get_type().qualname().unwrap());
})
}
}
22 changes: 0 additions & 22 deletions bindings/python/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};

Expand All @@ -14,25 +11,6 @@ pub use normalization::*;
pub use pretokenization::*;
pub use regex::*;

// PyChar
// This type is a temporary hack to accept `char` as argument
// To be removed once https://github.com/PyO3/pyo3/pull/1282 has been released
pub struct PyChar(pub char);

impl FromPyObject<'_> for PyChar {
fn extract(obj: &PyAny) -> PyResult<Self> {
let s = <PyString as PyTryFrom<'_>>::try_from(obj)?.to_str()?;
let mut iter = s.chars();
if let (Some(ch), None) = (iter.next(), iter.next()) {
Ok(Self(ch))
} else {
Err(exceptions::PyValueError::new_err(
"expected a string of length 1",
))
}
}
}

// RefMut utils

pub trait DestroyPtr {
Expand Down
14 changes: 7 additions & 7 deletions bindings/python/src/utils/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ use tk::pattern::Pattern;

/// Represents a Pattern as used by `NormalizedString`
#[derive(Clone, FromPyObject)]
pub enum PyPattern<'p> {
pub enum PyPattern {
#[pyo3(annotation = "str")]
Str(&'p str),
Str(String),
#[pyo3(annotation = "tokenizers.Regex")]
Regex(Py<PyRegex>),
// TODO: Add the compatibility for Fn(char) -> bool
}

impl Pattern for PyPattern<'_> {
impl Pattern for PyPattern {
fn find_matches(&self, inside: &str) -> tk::Result<Vec<(tk::Offsets, bool)>> {
match self {
PyPattern::Str(s) => {
Expand All @@ -35,17 +35,17 @@ impl Pattern for PyPattern<'_> {
}
}

impl From<PyPattern<'_>> for tk::normalizers::replace::ReplacePattern {
fn from(pattern: PyPattern<'_>) -> Self {
impl From<PyPattern> for tk::normalizers::replace::ReplacePattern {
fn from(pattern: PyPattern) -> Self {
match pattern {
PyPattern::Str(s) => Self::String(s.to_owned()),
PyPattern::Regex(r) => Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.clone())),
}
}
}

impl From<PyPattern<'_>> for tk::pre_tokenizers::split::SplitPattern {
fn from(pattern: PyPattern<'_>) -> Self {
impl From<PyPattern> for tk::pre_tokenizers::split::SplitPattern {
fn from(pattern: PyPattern) -> Self {
match pattern {
PyPattern::Str(s) => Self::String(s.to_owned()),
PyPattern::Regex(r) => Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.clone())),
Expand Down
Loading

0 comments on commit e7e7bad

Please sign in to comment.