diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 1a59d4a20..3c873a447 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -34,6 +34,17 @@ pub struct PyPostProcessor { pub processor: Arc>, } +impl From for PyPostProcessor +where + I: Into, +{ + fn from(processor: I) -> Self { + PyPostProcessor { + processor: Arc::new(RwLock::new(processor.into())), // Wrap the PostProcessorWrapper in Arc> + } + } +} + impl PyPostProcessor { pub fn new(processor: Arc>) -> Self { PyPostProcessor { processor } @@ -508,9 +519,21 @@ impl PySequence { } fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { - match &self_.as_ref().processor.read().unwrap().clone() { - PostProcessorWrapper::Sequence(inner) => match inner.get(index) { - Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.clone()))).get_as_subtype(py), + + let super_ = self_.as_ref(); + let mut wrapper = super_.processor.write().unwrap(); + // if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper { + // match post.get(index) { + // Some(item) => PyPostProcessor::new(Arc::clone(item)).get_as_subtype(py), + // _ => Err(PyErr::new::( + // "Index not found", + // )), + // } + // } + + match *wrapper { + PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) { + Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py), _ => Err(PyErr::new::( "Index not found", )), @@ -520,6 +543,30 @@ impl PySequence { )), } } + + fn __setitem__(self_: PyRefMut<'_, Self>, py: Python<'_>, index: usize, value: PyRef<'_, PyPostProcessor>) -> PyResult<()> { + let super_ = self_.as_ref(); + let mut wrapper = super_.processor.write().unwrap(); + let value = value.processor.read().unwrap().clone(); + match *wrapper { + PostProcessorWrapper::Sequence(ref mut inner) => { + // Convert the Py into the appropriate Rust type + // Ensure we can set an item at the given index + if index < inner.get_processors().len() { + inner.set_mut(index, value); // Assuming you want to wrap the new item in Arc + + Ok(()) + } else { + Err(PyErr::new::( + "Index out of bounds", + )) + } + }, + _ => Err(PyErr::new::( + "This processor is not a Sequence, it does not support __setitem__", + )), + } + } } /// Processors Module diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 8d273252a..b9fdbb4dd 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -21,6 +21,18 @@ impl Sequence { pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> { self.processors.get_mut(index) } + + pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) { + self.processors[index as usize] = post_proc; + } + + pub fn get_processors(&self) -> &[PostProcessorWrapper] { + &self.processors + } + + pub fn get_processors_mut(&mut self) -> &mut [PostProcessorWrapper] { + &mut self.processors + } } impl PostProcessor for Sequence { diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 8c9e88145..be9df6b36 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -362,6 +362,7 @@ impl TemplateProcessing { pub fn set_single(&mut self, single: Template) { println!("Setting single to: {:?}", single); // Debugging output self.single = single; + println!("Single is now {:?}", self.single); } // Getter for `pair`