Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make PyErrState thread-safe #4671

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::cell::UnsafeCell;
use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};

use crate::{
exceptions::{PyBaseException, PyTypeError},
Expand All @@ -11,13 +15,14 @@ use crate::{
pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
//
// The state is temporarily removed from the PyErr during normalization, to avoid
// concurrent modifications.
normalized: Once,
// Guard against re-entrancy when normalizing the exception state.
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}

// The inner value is only accessed through ways that require the gil is held.
// Safety: The inner value is protected by locking to ensure that only the normalized state is
// handed out as a reference.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}

Expand Down Expand Up @@ -48,17 +53,22 @@ impl PyErrState {

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(PyErrStateInner::Normalized(n)) = unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
return n;
if self.normalized.is_completed() {
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}

self.make_normalized(py)
Expand All @@ -69,25 +79,43 @@ impl PyErrState {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
// - State is set to None during the normalization process, so that a second
// concurrent normalization attempt will panic before changing anything.

// FIXME: this needs to be rewritten to deal with free-threaded Python
// see https://github.com/PyO3/pyo3/issues/4584
// Guard against re-entrant normalization, because `Once` does not provide
// re-entrancy guarantees.
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
assert!(
!(*thread == std::thread::current().id()),
"Re-entrant normalization of PyErrState detected"
);
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
}

let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
self.normalized.call_once(|| {
self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());

// Safety: no other thread can access the inner value while we are normalizing it.
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

unsafe {
let self_state = &mut *self.inner.get();
*self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
match self_state {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
let normalized_state = PyErrStateInner::Normalized(state.normalize(py));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only spot where there might be a deadlock is here, if normalize somehow leads to arbitrary Python code execution.

Is that possible? If not I think it deserves a comment explaining why.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it can deadlock, I'm not sure what we can do, since at this point we haven't actually constructed any Python objects yet and we only have a handle to an FnOnce that knows how to construct them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great observation; I've added a wrapping call to py.allow_threads before potentially blocking on the Once, which I think avoids the deadlock (I pushed a test which did deadlock before that change).

// Safety: no other thread can access the inner value while we are normalizing it.
unsafe {
*self.inner.get() = Some(normalized_state);
}
});

match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
}
}
Expand Down
Loading