Skip to content

Commit

Permalink
make PyErrState thread-safe (#4671)
Browse files Browse the repository at this point in the history
* make `PyErrState` thread-safe

* fix clippy

* add test of reentrancy, fix deadlock

* newsfragment

* fix MSRV

* fix nightly build

* Update err_state.rs

---------

Co-authored-by: Nathan Goldbaum <[email protected]>
  • Loading branch information
davidhewitt and ngoldbaum authored Nov 5, 2024
1 parent d45e0bd commit 9f955e4
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 26 deletions.
1 change: 1 addition & 0 deletions newsfragments/4671.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `PyErr` internals thread-safe.
163 changes: 137 additions & 26 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::cell::UnsafeCell;
use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};

use crate::{
exceptions::{PyBaseException, PyTypeError},
Expand All @@ -11,15 +15,18 @@ use crate::{
pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
//
// The state is temporarily removed from the PyErr during normalization, to avoid
// concurrent modifications.
normalized: Once,
// Guard against re-entrancy when normalizing the exception state.
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}

// The inner value is only accessed through ways that require the gil is held.
// Safety: The inner value is protected by locking to ensure that only the normalized state is
// handed out as a reference.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}
#[cfg(feature = "nightly")]
unsafe impl crate::marker::Ungil for PyErrState {}

impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Expand Down Expand Up @@ -48,17 +55,22 @@ impl PyErrState {

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(PyErrStateInner::Normalized(n)) = unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
return n;
if self.normalized.is_completed() {
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}

self.make_normalized(py)
Expand All @@ -69,25 +81,47 @@ impl PyErrState {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
// - State is set to None during the normalization process, so that a second
// concurrent normalization attempt will panic before changing anything.

// FIXME: this needs to be rewritten to deal with free-threaded Python
// see https://github.com/PyO3/pyo3/issues/4584
// Guard against re-entrant normalization, because `Once` does not provide
// re-entrancy guarantees.
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
assert!(
!(*thread == std::thread::current().id()),
"Re-entrant normalization of PyErrState detected"
);
}

let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
// avoid deadlock of `.call_once` with the GIL
py.allow_threads(|| {
self.normalized.call_once(|| {
self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());

// Safety: no other thread can access the inner value while we are normalizing it.
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

let normalized_state =
Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));

// Safety: no other thread can access the inner value while we are normalizing it.
unsafe {
*self.inner.get() = Some(normalized_state);
}
})
});

unsafe {
let self_state = &mut *self.inner.get();
*self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
match self_state {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
}
}
Expand Down Expand Up @@ -321,3 +355,80 @@ fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
}
}
}

#[cfg(test)]
mod tests {

use crate::{
exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
};

#[test]
#[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
fn test_reentrant_normalization() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct RecursiveArgs;

impl PyErrArguments for RecursiveArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// .value(py) triggers normalization
ERR.get(py)
.expect("is set just below")
.value(py)
.clone()
.into()
}
}

Python::with_gil(|py| {
ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
ERR.get(py).expect("is set just above").value(py);
})
}

#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn test_no_deadlock_thread_switch() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct GILSwitchArgs;

impl PyErrArguments for GILSwitchArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// releasing the GIL potentially allows for other threads to deadlock
// with the normalization going on here
py.allow_threads(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
});
py.None()
}
}

Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());

// Let many threads attempt to read the normalized value at the same time
let handles = (0..10)
.map(|_| {
std::thread::spawn(|| {
Python::with_gil(|py| {
ERR.get(py).expect("is set just above").value(py);
});
})
})
.collect::<Vec<_>>();

for handle in handles {
handle.join().unwrap();
}

// We should never have deadlocked, and should be able to run
// this assertion
Python::with_gil(|py| {
assert!(ERR
.get(py)
.expect("is set above")
.is_instance_of::<PyValueError>(py))
});
}
}

0 comments on commit 9f955e4

Please sign in to comment.