Skip to content

Commit

Permalink
current updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 14, 2024
1 parent 5182653 commit 01d0b29
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 19 deletions.
121 changes: 103 additions & 18 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::convert::TryInto;
use std::sync::Arc;
use std::sync::RwLock;

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;

use std::ops::DerefMut;
use crate::encoding::PyEncoding;
use crate::error::ToPyResult;
use serde::{Deserialize, Serialize};
Expand All @@ -30,17 +31,17 @@ use tokenizers as tk;
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyPostProcessor {
pub processor: Arc<PostProcessorWrapper>,
pub processor: Arc<RwLock<PostProcessorWrapper>>,
}

impl PyPostProcessor {
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
pub fn new(processor: Arc<RwLock<PostProcessorWrapper>>) -> Self {
PyPostProcessor { processor }
}

pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
let base = self.clone();
Ok(match self.processor.as_ref() {
Ok(match self.processor.read().unwrap().clone() {
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
PostProcessorWrapper::Roberta(_) => {
Expand All @@ -56,23 +57,23 @@ impl PyPostProcessor {

impl PostProcessor for PyPostProcessor {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
self.processor.read().unwrap().added_tokens(is_pair)
}

fn process_encodings(
&self,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Vec<Encoding>> {
self.processor
self.processor.read().unwrap()
.process_encodings(encodings, add_special_tokens)
}
}

#[pymethods]
impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
let data = serde_json::to_string(&self.processor).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PostProcessor: {}",
e
Expand Down Expand Up @@ -106,7 +107,7 @@ impl PyPostProcessor {
/// :obj:`int`: The number of tokens to add
#[pyo3(text_signature = "(self, is_pair)")]
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
self.processor.read().unwrap().added_tokens(is_pair)
}

/// Post-process the given encodings, generating the final one
Expand All @@ -131,7 +132,7 @@ impl PyPostProcessor {
pair: Option<&PyEncoding>,
add_special_tokens: bool,
) -> PyResult<PyEncoding> {
let final_encoding = ToPyResult(self.processor.process(
let final_encoding = ToPyResult(self.processor.read().unwrap().process(
encoding.encoding.clone(),
pair.map(|e| e.encoding.clone()),
add_special_tokens,
Expand All @@ -151,6 +152,42 @@ impl PyPostProcessor {
}
}

macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
if let PostProcessorWrapper::$variant(ref post) = *super_.processor.read().unwrap() {
let output = post.$($name)+;
return format!("{:?}", output)
} else {
unreachable!()
}
}};
}

macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self;
if let PostProcessorWrapper::$variant(ref mut post) = super_.processor.as_ref() {
post.$name = $value;
}
}};
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
let super_ = &$self.as_ref();
match &super_.processor.as_ref() {
PostProcessorWrapper::$variant(post_variant) => post_variant.$name($value),
_ => unreachable!(),
}

{
if let Some(PostProcessorWrapper::$variant(post_variant)) =
Arc::get_mut(&mut super_.processor)
{
post_variant.$name($value);
}
};
};};
}

/// This post-processor takes care of adding the special tokens needed by
/// a Bert model:
///
Expand All @@ -172,7 +209,7 @@ impl PyBertProcessing {
fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) {
(
PyBertProcessing {},
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())),
PyPostProcessor::new(Arc::new(RwLock::new(BertProcessing::new(sep, cls).into()))),
)
}

Expand Down Expand Up @@ -222,7 +259,7 @@ impl PyRobertaProcessing {
.add_prefix_space(add_prefix_space);
(
PyRobertaProcessing {},
PyPostProcessor::new(Arc::new(proc.into())),
PyPostProcessor::new(Arc::new(RwLock::new(proc.into()))),
)
}

Expand Down Expand Up @@ -257,7 +294,7 @@ impl PyByteLevel {

(
PyByteLevel {},
PyPostProcessor::new(Arc::new(byte_level.into())),
PyPostProcessor::new(Arc::new(RwLock::new(byte_level.into()))),
)
}
}
Expand Down Expand Up @@ -421,9 +458,43 @@ impl PyTemplateProcessing {

Ok((
PyTemplateProcessing {},
PyPostProcessor::new(Arc::new(processor.into())),
PyPostProcessor::new(Arc::new(RwLock::new(processor.into()))),
))
}

#[getter]
fn get_single(self_: PyRef<Self>) -> String{
getter!(self_, Template, get_single())
}

#[setter]
fn set_single(self_:PyRefMut<Self>, single: PyTemplate) {
let template: Template = Template::from(single);

let super_ = &self_.into_super();

// Acquire a write lock on the processor
let binding = super_.processor.clone(); // Clone the Arc
let mut write_lock = match binding.write() { // Make this mutable
Ok(lock) => lock,
Err(e) => {
eprintln!("Failed to acquire write lock: {:?}", e);
return; // Handle lock acquisition failure appropriately
}
};

// Use deref_mut to get a mutable reference and match against the PostProcessorWrapper type
match write_lock.deref_mut() {
PostProcessorWrapper::Template(value) => {
println!("Created template single : {template:?}");
value.set_single(template.clone());
},
_ => {
eprintln!("Processor is not of type PostProcessorWrapper::Template");
}
}

}
}

/// Sequence Processor
Expand All @@ -441,19 +512,33 @@ impl PySequence {
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
for n in processors_py.iter() {
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
let processor = processor.processor.as_ref();
let processor = processor.processor.write().unwrap();
processors.push(processor.clone());
}
let sequence_processor = Sequence::new(processors);
(
PySequence {},
PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))),
PyPostProcessor::new(Arc::new(RwLock::new(PostProcessorWrapper::Sequence(sequence_processor)))),
)
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().processor.read().unwrap().clone() {
PostProcessorWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.clone()))).get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"This processor is not a Sequence, it does not support __getitem__",
)),
}
}
}

/// Processors Module
Expand Down Expand Up @@ -481,9 +566,9 @@ mod test {
#[test]
fn get_subtype() {
Python::with_gil(|py| {
let py_proc = PyPostProcessor::new(Arc::new(
let py_proc = PyPostProcessor::new(Arc::new(RwLock::new(
BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(),
));
)));
let py_bert = py_proc.get_as_subtype(py).unwrap();
assert_eq!(
"BertProcessing",
Expand All @@ -499,7 +584,7 @@ mod test {
let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap();
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();

let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper));
let py_processing = PyPostProcessor::new(Arc::new(RwLock::new(rs_wrapper)));
let py_ser = serde_json::to_string(&py_processing).unwrap();
assert_eq!(py_ser, rs_processing_ser);
assert_eq!(py_ser, rs_wrapper_ser);
Expand Down
8 changes: 8 additions & 0 deletions tokenizers/src/processors/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ impl Sequence {
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
Self { processors }
}

pub fn get(&self, index: usize) -> Option<& PostProcessorWrapper> {
self.processors.get(index as usize)
}

pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
self.processors.get_mut(index)
}
}

impl PostProcessor for Sequence {
Expand Down
56 changes: 55 additions & 1 deletion tokenizers/src/processors/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ impl From<HashMap<String, SpecialToken>> for Tokens {
#[builder(build_fn(validate = "Self::validate"))]
pub struct TemplateProcessing {
#[builder(try_setter, default = "\"$0\".try_into().unwrap()")]
single: Template,
pub single: Template,
#[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")]
pair: Template,
#[builder(setter(skip), default = "self.default_added(true)")]
Expand All @@ -351,6 +351,60 @@ pub struct TemplateProcessing {
special_tokens: Tokens,
}


impl TemplateProcessing {
// Getter for `single`
pub fn get_single(& self) -> Template {
self.single.clone()
}

// Setter for `single`
pub fn set_single(&mut self, single: Template) {
println!("Setting single to: {:?}", single); // Debugging output
self.single = single;
}

// Getter for `pair`
pub fn get_pair(&self) -> &Template {
&self.pair
}

// Setter for `pair`
pub fn set_pair(&mut self, pair: Template) {
self.pair = pair;
}

// Getter for `added_single`
pub fn get_added_single(&self) -> usize {
self.added_single
}

// Setter for `added_single`
pub fn set_added_single(&mut self, added_single: usize) {
self.added_single = added_single;
}

// Getter for `added_pair`
pub fn get_added_pair(&self) -> usize {
self.added_pair
}

// Setter for `added_pair`
pub fn set_added_pair(&mut self, added_pair: usize) {
self.added_pair = added_pair;
}

// Getter for `special_tokens`
pub fn get_special_tokens(&self) -> &Tokens {
&self.special_tokens
}

// Setter for `special_tokens`
pub fn set_special_tokens(&mut self, special_tokens: Tokens) {
self.special_tokens = special_tokens;
}
}

impl From<&str> for TemplateProcessingBuilderError {
fn from(e: &str) -> Self {
e.to_string().into()
Expand Down

0 comments on commit 01d0b29

Please sign in to comment.