Skip to content

Commit

Permalink
set_item works, but `tokenizer._tokenizer.post_processor[1].single = …
Browse files Browse the repository at this point in the history
…["$0", "</s>"]` does not !
  • Loading branch information
ArthurZucker committed Oct 14, 2024
1 parent 1d67a76 commit 488a570
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
53 changes: 50 additions & 3 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ pub struct PyPostProcessor {
pub processor: Arc<RwLock<PostProcessorWrapper>>,
}

impl<I> From<I> for PyPostProcessor
where
I: Into<PostProcessorWrapper>,
{
fn from(processor: I) -> Self {
PyPostProcessor {
processor: Arc::new(RwLock::new(processor.into())), // Wrap the PostProcessorWrapper in Arc<RwLock<>>
}
}
}

impl PyPostProcessor {
pub fn new(processor: Arc<RwLock<PostProcessorWrapper>>) -> Self {
PyPostProcessor { processor }
Expand Down Expand Up @@ -508,9 +519,21 @@ impl PySequence {
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().processor.read().unwrap().clone() {
PostProcessorWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.clone()))).get_as_subtype(py),

let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
// if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper {
// match post.get(index) {
// Some(item) => PyPostProcessor::new(Arc::clone(item)).get_as_subtype(py),
// _ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
// "Index not found",
// )),
// }
// }

match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
Expand All @@ -520,6 +543,30 @@ impl PySequence {
)),
}
}

fn __setitem__(self_: PyRefMut<'_, Self>, py: Python<'_>, index: usize, value: PyRef<'_, PyPostProcessor>) -> PyResult<()> {
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
let value = value.processor.read().unwrap().clone();
match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => {
// Convert the Py<PyAny> into the appropriate Rust type
// Ensure we can set an item at the given index
if index < inner.get_processors().len() {
inner.set_mut(index, value); // Assuming you want to wrap the new item in Arc<RwLock>

Ok(())
} else {
Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index out of bounds",
))
}
},
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"This processor is not a Sequence, it does not support __setitem__",
)),
}
}
}

/// Processors Module
Expand Down
12 changes: 12 additions & 0 deletions tokenizers/src/processors/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ impl Sequence {
pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
self.processors.get_mut(index)
}

pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) {
self.processors[index as usize] = post_proc;
}

pub fn get_processors(&self) -> &[PostProcessorWrapper] {
&self.processors
}

pub fn get_processors_mut(&mut self) -> &mut [PostProcessorWrapper] {
&mut self.processors
}
}

impl PostProcessor for Sequence {
Expand Down
1 change: 1 addition & 0 deletions tokenizers/src/processors/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ impl TemplateProcessing {
pub fn set_single(&mut self, single: Template) {
println!("Setting single to: {:?}", single); // Debugging output
self.single = single;
println!("Single is now {:?}", self.single);
}

// Getter for `pair`
Expand Down

0 comments on commit 488a570

Please sign in to comment.