From 7d3fad0cc09b1856177cbd028e95169b37a9109b Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Tue, 10 Dec 2024 15:04:58 -0700 Subject: [PATCH] implement locked iteration for PyList --- src/types/list.rs | 403 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 370 insertions(+), 33 deletions(-) diff --git a/src/types/list.rs b/src/types/list.rs index af2b557cba9..8fd3552ee68 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -179,7 +179,9 @@ pub trait PyListMethods<'py>: crate::sealed::Sealed { /// # Safety /// /// Caller must verify that the index is within the bounds of the list. - #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] + /// On the free-threaded build, caller must verify they have exclusive access to the list + /// via a lock or by holding the innermost critical section on the list. + #[cfg(not(any(Py_LIMITED_API)))] unsafe fn get_item_unchecked(&self, index: usize) -> Bound<'py, PyAny>; /// Takes the slice `self[low:high]` and returns it as a new list. @@ -239,6 +241,17 @@ pub trait PyListMethods<'py>: crate::sealed::Sealed { /// Returns an iterator over this list's items. fn iter(&self) -> BoundListIterator<'py>; + /// Iterates over the contents of this list while holding a critical section on the list. + /// This is useful when the GIL is disabled and the list is shared between threads. + /// It is not guaranteed that the list will not be modified during iteration when the + /// closure calls arbitrary Python code that releases the critical section held by the + /// iterator. Otherwise, the list will not be modified during iteration. + /// + /// This is equivalent to for_each if the GIL is enabled. + fn locked_for_each(&self, closure: F) -> PyResult<()> + where + F: Fn(Bound<'py, PyAny>) -> PyResult<()>; + /// Sorts the list in-place. Equivalent to the Python expression `l.sort()`. fn sort(&self) -> PyResult<()>; @@ -302,7 +315,7 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { /// # Safety /// /// Caller must verify that the index is within the bounds of the list. - #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] + #[cfg(not(Py_LIMITED_API))] unsafe fn get_item_unchecked(&self, index: usize) -> Bound<'py, PyAny> { // PyList_GET_ITEM return borrowed ptr; must make owned for safety (see #890). ffi::PyList_GET_ITEM(self.as_ptr(), index as Py_ssize_t) @@ -440,6 +453,14 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { BoundListIterator::new(self.clone()) } + /// Returns an iterator that holds a critical section on the list. + fn locked_for_each(&self, closure: F) -> PyResult<()> + where + F: Fn(Bound<'py, PyAny>) -> PyResult<()>, + { + crate::sync::with_critical_section(self, || self.iter().try_for_each(closure)) + } + /// Sorts the list in-place. Equivalent to the Python expression `l.sort()`. fn sort(&self) -> PyResult<()> { err::error_on_minusone(self.py(), unsafe { ffi::PyList_Sort(self.as_ptr()) }) @@ -465,8 +486,74 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { /// Used by `PyList::iter()`. pub struct BoundListIterator<'py> { list: Bound<'py, PyList>, - index: usize, - length: usize, + inner: ListIterImpl, +} + +enum ListIterImpl { + ListIter { index: usize, length: usize }, +} + +impl ListIterImpl { + #[inline] + /// Safety: the list should be locked with a critical section on the free-threaded build + /// and otherwise not shared between threads when the GIL is released. + unsafe fn next_unchecked<'py>( + &mut self, + list: &Bound<'py, PyList>, + ) -> Option> { + match self { + Self::ListIter { index, length, .. } => { + let length = (*length).min(list.len()); + let my_index = *index; + + if *index < length { + let item = unsafe { list.get_item_unchecked(my_index) }; + *index += 1; + Some(item) + } else { + None + } + } + } + } + + #[inline] + unsafe fn next_back_unchecked<'py>( + &mut self, + list: &Bound<'py, PyList>, + ) -> Option> { + match self { + Self::ListIter { index, length, .. } => { + let current_length = (*length).min(list.len()); + + if *index < current_length { + let item = unsafe { list.get_item_unchecked(current_length - 1) }; + *length = current_length - 1; + Some(item) + } else { + None + } + } + } + } + + #[inline] + fn len(&self) -> usize { + match self { + Self::ListIter { index, length, .. } => length.saturating_sub(*index), + } + } + + #[cfg(Py_GIL_DISABLED)] + #[inline] + fn with_critical_section(&mut self, list: &Bound<'_, PyList>, f: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + match self { + Self::ListIter { .. } => crate::sync::with_critical_section(list, || f(self)), + } + } } impl<'py> BoundListIterator<'py> { @@ -474,18 +561,9 @@ impl<'py> BoundListIterator<'py> { let length: usize = list.len(); BoundListIterator { list, - index: 0, - length, + inner: ListIterImpl::ListIter { index: 0, length }, } } - - unsafe fn get_item(&self, index: usize) -> Bound<'py, PyAny> { - #[cfg(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED))] - let item = self.list.get_item(index).expect("list.get failed"); - #[cfg(not(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED)))] - let item = self.list.get_item_unchecked(index); - item - } } impl<'py> Iterator for BoundListIterator<'py> { @@ -493,14 +571,16 @@ impl<'py> Iterator for BoundListIterator<'py> { #[inline] fn next(&mut self) -> Option { - let length = self.length.min(self.list.len()); - - if self.index < length { - let item = unsafe { self.get_item(self.index) }; - self.index += 1; - Some(item) - } else { - None + #[cfg(Py_GIL_DISABLED)] + { + self.inner + .with_critical_section(&self.list, |inner| unsafe { + inner.next_unchecked(&self.list) + }) + } + #[cfg(not(Py_GIL_DISABLED))] + { + unsafe { self.inner.next_unchecked(&self.list) } } } @@ -509,26 +589,164 @@ impl<'py> Iterator for BoundListIterator<'py> { let len = self.len(); (len, Some(len)) } + + #[inline] + #[cfg(Py_GIL_DISABLED)] + fn fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + self.inner.with_critical_section(&self.list, |inner| { + let mut accum = init; + while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + accum = f(accum, x); + } + accum + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn all(&mut self, mut f: F) -> bool + where + Self: Sized, + F: FnMut(Self::Item) -> bool, + { + self.inner.with_critical_section(&self.list, |inner| { + while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + if !f(x) { + return false; + } + } + true + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn any(&mut self, mut f: F) -> bool + where + Self: Sized, + F: FnMut(Self::Item) -> bool, + { + self.inner.with_critical_section(&self.list, |inner| { + while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + if f(x) { + return true; + } + } + false + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn find

(&mut self, mut predicate: P) -> Option + where + Self: Sized, + P: FnMut(&Self::Item) -> bool, + { + self.inner.with_critical_section(&self.list, |inner| { + while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + if predicate(&x) { + return Some(x); + } + } + None + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn find_map(&mut self, mut f: F) -> Option + where + Self: Sized, + F: FnMut(Self::Item) -> Option, + { + self.inner.with_critical_section(&self.list, |inner| { + while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + if let found @ Some(_) = f(x) { + return found; + } + } + None + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn position

(&mut self, mut predicate: P) -> Option + where + Self: Sized, + P: FnMut(Self::Item) -> bool, + { + self.inner.with_critical_section(&self.list, |inner| { + let mut acc = 0; + while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + if predicate(x) { + return Some(acc); + } + acc += 1; + } + None + }) + } } impl DoubleEndedIterator for BoundListIterator<'_> { #[inline] fn next_back(&mut self) -> Option { - let length = self.length.min(self.list.len()); - - if self.index < length { - let item = unsafe { self.get_item(length - 1) }; - self.length = length - 1; - Some(item) - } else { - None + #[cfg(Py_GIL_DISABLED)] + { + self.inner + .with_critical_section(&self.list, |inner| unsafe { + inner.next_back_unchecked(&self.list) + }) } + #[cfg(not(Py_GIL_DISABLED))] + { + unsafe { self.inner.next_back_unchecked(&self.list) } + } + } + + #[inline] + #[cfg(Py_GIL_DISABLED)] + fn rfold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + self.inner.with_critical_section(&self.list, |inner| { + let mut accum = init; + while let Some(x) = unsafe { inner.next_back_unchecked(&self.list) } { + accum = f(accum, x); + } + accum + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, feature = "nightly"))] + fn try_rfold(&mut self, init: B, mut f: F) -> R + where + Self: Sized, + F: FnMut(B, Self::Item) -> R, + R: std::ops::Try, + { + self.inner.with_critical_section(&self.list, |inner| { + let mut accum = init; + while let Some(x) = unsafe { inner.next_back_unchecked(&self.list) } { + accum = f(accum, x)? + } + R::from_output(accum) + }) } } impl ExactSizeIterator for BoundListIterator<'_> { fn len(&self) -> usize { - self.length.saturating_sub(self.index) + self.inner.len() } } @@ -558,7 +776,7 @@ mod tests { use crate::types::list::PyListMethods; use crate::types::sequence::PySequenceMethods; use crate::types::{PyList, PyTuple}; - use crate::{ffi, IntoPyObject, Python}; + use crate::{ffi, IntoPyObject, PyResult, Python}; #[test] fn test_new() { @@ -748,6 +966,125 @@ mod tests { }); } + #[test] + fn test_iter_all() { + Python::with_gil(|py| { + let list = PyList::new(py, [true, true, true]).unwrap(); + assert!(list.iter().all(|x| x.extract::().unwrap())); + + let list = PyList::new(py, [true, false, true]).unwrap(); + assert!(!list.iter().all(|x| x.extract::().unwrap())); + }); + } + + #[test] + fn test_iter_any() { + Python::with_gil(|py| { + let list = PyList::new(py, [true, true, true]).unwrap(); + assert!(list.iter().any(|x| x.extract::().unwrap())); + + let list = PyList::new(py, [true, false, true]).unwrap(); + assert!(list.iter().any(|x| x.extract::().unwrap())); + + let list = PyList::new(py, [false, false, false]).unwrap(); + assert!(!list.iter().any(|x| x.extract::().unwrap())); + }); + } + + #[test] + fn test_iter_find() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, ["hello", "world"]).unwrap(); + assert_eq!( + Some("world".to_string()), + list.iter() + .find(|v| v.extract::().unwrap() == "world") + .map(|v| v.extract::().unwrap()) + ); + assert_eq!( + None, + list.iter() + .find(|v| v.extract::().unwrap() == "foobar") + .map(|v| v.extract::().unwrap()) + ); + }); + } + + #[test] + fn test_iter_position() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, ["hello", "world"]).unwrap(); + assert_eq!( + Some(1), + list.iter() + .position(|v| v.extract::().unwrap() == "world") + ); + assert_eq!( + None, + list.iter() + .position(|v| v.extract::().unwrap() == "foobar") + ); + }); + } + + #[test] + fn test_iter_fold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .fold(0, |acc, v| acc + v.extract::().unwrap()); + assert_eq!(sum, 6); + }); + } + + #[test] + fn test_iter_rfold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .rfold(0, |acc, v| acc + v.extract::().unwrap()); + assert_eq!(sum, 6); + }); + } + + #[test] + fn test_iter_try_fold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .try_fold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .unwrap(); + assert_eq!(sum, 6); + + let list = PyList::new(py, ["foo", "bar"]).unwrap(); + assert!(list + .iter() + .try_fold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .is_err()); + }); + } + + #[test] + fn test_iter_try_rfold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .try_rfold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .unwrap(); + assert_eq!(sum, 6); + + let list = PyList::new(py, ["foo", "bar"]).unwrap(); + assert!(list + .iter() + .try_rfold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .is_err()); + }); + } + #[test] fn test_into_iter() { Python::with_gil(|py| { @@ -877,7 +1214,7 @@ mod tests { }); } - #[cfg(not(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED)))] + #[cfg(not(any(Py_LIMITED_API, PyPy)))] #[test] fn test_list_get_item_unchecked_sanity() { Python::with_gil(|py| {