Skip to content

Commit

Permalink
wrap PyClassObjectContents to prevent it from leaking outside pyo3
Browse files Browse the repository at this point in the history
  • Loading branch information
mbway committed Nov 3, 2024
1 parent ccce96d commit de2c2a3
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use crate::{
coroutine::{cancel::ThrowCallback, Coroutine},
instance::Bound,
pycell::impl_::{InternalPyClassObjectLayout, PyClassBorrowChecker},
pycell::impl_::{PyClassObjectLayout, PyClassBorrowChecker},
pyclass::boolean_struct::False,
types::{PyAnyMethods, PyString},
IntoPyObject, Py, PyAny, PyClass, PyErr, PyResult, Python,
Expand Down
2 changes: 1 addition & 1 deletion src/impl_/pycell.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Externally-accessible implementation of pycell
pub use crate::pycell::impl_::{
GetBorrowChecker, PyClassMutability, PyClassObjectBase, PyClassObjectLayout,
GetBorrowChecker, PyClassMutability, PyClassObjectBase, PyClassObjectBaseLayout,
PyStaticClassObject, PyVariableClassObject, PyVariableClassObjectBase,
};
10 changes: 5 additions & 5 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use crate::{
ffi,
impl_::{
freelist::FreeList,
pycell::{GetBorrowChecker, PyClassMutability, PyClassObjectLayout},
pycell::{GetBorrowChecker, PyClassMutability, PyClassObjectBaseLayout},
pyclass_init::PyObjectInit,
pymethods::{PyGetterDef, PyMethodDefType},
},
pycell::{impl_::InternalPyClassObjectLayout, PyBorrowError},
pycell::{impl_::PyClassObjectLayout, PyBorrowError},
types::{any::PyAnyMethods, PyBool},
Borrowed, BoundObject, Py, PyAny, PyClass, PyErr, PyRef, PyResult, PyTypeInfo, Python,
};
Expand Down Expand Up @@ -170,7 +170,7 @@ pub trait PyClassImpl: Sized + 'static {
const IS_SEQUENCE: bool = false;

/// Description of how this class is laid out in memory
type Layout: InternalPyClassObjectLayout<Self>;
type Layout: PyClassObjectLayout<Self>;

/// Base class
type BaseType: PyTypeInfo + PyClassBaseType;
Expand Down Expand Up @@ -1137,7 +1137,7 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
)
)]
pub trait PyClassBaseType: Sized {
type LayoutAsBase: PyClassObjectLayout<Self>;
type LayoutAsBase: PyClassObjectBaseLayout<Self>;
type BaseNativeType;
type Initializer: PyObjectInit<Self>;
type PyClassMutability: PyClassMutability;
Expand Down Expand Up @@ -1549,7 +1549,7 @@ where
let class_ptr = obj.cast::<<ClassT as PyClassImpl>::Layout>();
// Safety: the object `obj` must have the layout `ClassT::Layout`
let class_obj = unsafe { &mut *class_ptr };
let contents = class_obj.contents_mut() as *mut PyClassObjectContents<ClassT>;
let contents = (&mut class_obj.contents_mut().0) as *mut PyClassObjectContents<ClassT>;
(contents.cast::<u8>(), offset)
}
#[cfg(not(Py_3_12))]
Expand Down
8 changes: 6 additions & 2 deletions src/impl_/pyclass_init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::impl_::pyclass::PyClassImpl;
use crate::internal::get_slot::TP_ALLOC;
use crate::pycell::impl_::{InternalPyClassObjectLayout, PyClassObjectContents};
use crate::pycell::impl_::{
PyClassObjectContents, PyClassObjectLayout, WrappedPyClassObjectContents,
};
use crate::types::PyType;
use crate::{ffi, Borrowed, PyErr, PyResult, Python};
use crate::{ffi::PyTypeObject, sealed::Sealed, type_object::PyTypeInfo};
use std::marker::PhantomData;

pub unsafe fn initialize_with_default<T: PyClassImpl + Default>(obj: *mut ffi::PyObject) {
let contents = T::Layout::contents_uninitialised(obj);
(*contents).write(PyClassObjectContents::new(T::default()));
(*contents).write(WrappedPyClassObjectContents(PyClassObjectContents::new(
T::default(),
)));
}

/// Initializer for Python types.
Expand Down
10 changes: 5 additions & 5 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::callback::IntoPyCallbackOutput;
use crate::impl_::panic::PanicTrap;
use crate::impl_::pycell::PyClassObjectLayout;
use crate::impl_::pycell::PyClassObjectBaseLayout;
use crate::internal::get_slot::{get_slot, TP_BASE, TP_CLEAR, TP_TRAVERSE};
use crate::pycell::impl_::{InternalPyClassObjectLayout, PyClassBorrowChecker as _};
use crate::pycell::impl_::{PyClassBorrowChecker as _, PyClassObjectLayout};
use crate::pycell::{PyBorrowError, PyBorrowMutError};
use crate::pyclass::boolean_struct::False;
use crate::types::any::PyAnyMethods;
Expand Down Expand Up @@ -310,8 +310,8 @@ where
if class_object.check_threadsafe().is_ok()
// ... and we cannot traverse a type which might be being mutated by a Rust thread
&& class_object.borrow_checker().try_borrow().is_ok() {
struct TraverseGuard<'a, U: PyClassImpl, V: InternalPyClassObjectLayout<U>>(&'a V, PhantomData<U>);
impl<U: PyClassImpl, V: InternalPyClassObjectLayout<U>> Drop for TraverseGuard<'_, U, V> {
struct TraverseGuard<'a, U: PyClassImpl, V: PyClassObjectLayout<U>>(&'a V, PhantomData<U>);
impl<U: PyClassImpl, V: PyClassObjectLayout<U>> Drop for TraverseGuard<'_, U, V> {
fn drop(&mut self) {
self.0.borrow_checker().release_borrow()
}
Expand All @@ -320,7 +320,7 @@ where
// `.try_borrow()` above created a borrow, we need to release it when we're done
// traversing the object. This allows us to read `instance` safely.
let _guard = TraverseGuard(class_object, PhantomData);
let instance = &*class_object.contents().value.get();
let instance = &*class_object.contents().0.value.get();

let visit = PyVisit { visit, arg, _guard: PhantomData };

Expand Down
2 changes: 1 addition & 1 deletion src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::conversion::IntoPyObject;
use crate::err::{self, PyErr, PyResult};
use crate::impl_::pyclass::PyClassImpl;
use crate::internal_tricks::ptr_from_ref;
use crate::pycell::impl_::InternalPyClassObjectLayout;
use crate::pycell::impl_::PyClassObjectLayout;
use crate::pycell::{PyBorrowError, PyBorrowMutError};
use crate::pyclass::boolean_struct::{False, True};
use crate::types::{any::PyAnyMethods, string::PyStringMethods, typeobject::PyTypeMethods};
Expand Down
2 changes: 1 addition & 1 deletion src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};

pub(crate) mod impl_;
use impl_::{InternalPyClassObjectLayout, PyClassBorrowChecker, PyClassObjectLayout};
use impl_::{PyClassBorrowChecker, PyClassObjectBaseLayout, PyClassObjectLayout};

/// A wrapper type for an immutably borrowed value from a [`Bound<'py, T>`].
///
Expand Down
96 changes: 62 additions & 34 deletions src/pycell/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ pub trait GetBorrowChecker<T: PyClassImpl> {

impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
fn borrow_checker(class_object: &T::Layout) -> &BorrowChecker {
&class_object.contents().borrow_checker
&class_object.contents().0.borrow_checker
}
}

impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
fn borrow_checker(class_object: &T::Layout) -> &EmptySlot {
&class_object.contents().borrow_checker
&class_object.contents().0.borrow_checker
}
}

Expand Down Expand Up @@ -226,7 +226,7 @@ pub struct PyVariableClassObjectBase {

unsafe impl<T> PyLayout<T> for PyVariableClassObjectBase {}

impl<T: PyTypeInfo> PyClassObjectLayout<T> for PyVariableClassObjectBase {
impl<T: PyTypeInfo> PyClassObjectBaseLayout<T> for PyVariableClassObjectBase {
fn ensure_threadsafe(&self) {}
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
Ok(())
Expand All @@ -237,7 +237,7 @@ impl<T: PyTypeInfo> PyClassObjectLayout<T> for PyVariableClassObjectBase {
}

#[doc(hidden)]
pub trait PyClassObjectLayout<T>: PyLayout<T> {
pub trait PyClassObjectBaseLayout<T>: PyLayout<T> {
fn ensure_threadsafe(&self);
fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
/// Implementation of tp_dealloc.
Expand All @@ -247,6 +247,30 @@ pub trait PyClassObjectLayout<T>: PyLayout<T> {
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject);
}

/// Allow [PyClassObjectLayout] to have public visibility without leaking the structure of [PyClassObjectContents].
#[doc(hidden)]
#[repr(transparent)]
pub struct WrappedPyClassObjectContents<T: PyClassImpl>(pub(crate) PyClassObjectContents<T>);

impl<'a, T: PyClassImpl> From<&'a PyClassObjectContents<T>>
for &'a WrappedPyClassObjectContents<T>
{
fn from(value: &'a PyClassObjectContents<T>) -> &'a WrappedPyClassObjectContents<T> {
// Safety: Wrapped struct must use repr(transparent)
unsafe { std::mem::transmute(value) }
}
}

impl<'a, T: PyClassImpl> From<&'a mut PyClassObjectContents<T>>
for &'a mut WrappedPyClassObjectContents<T>
{
fn from(value: &'a mut PyClassObjectContents<T>) -> &'a mut WrappedPyClassObjectContents<T> {
// Safety: Wrapped struct must use repr(transparent)
unsafe { std::mem::transmute(value) }
}
}

/// Functionality required for creating and managing the memory associated with a pyclass annotated struct.
#[doc(hidden)]
#[cfg_attr(
all(diagnostic_namespace),
Expand All @@ -256,18 +280,18 @@ pub trait PyClassObjectLayout<T>: PyLayout<T> {
note = "the python version being built against influences which layouts are valid",
)
)]
pub trait InternalPyClassObjectLayout<T: PyClassImpl>: PyClassObjectLayout<T> {
pub trait PyClassObjectLayout<T: PyClassImpl>: PyClassObjectBaseLayout<T> {
/// Obtain a pointer to the contents of an uninitialized PyObject of this type
/// Safety: the provided object must have the layout that the implementation is expecting
unsafe fn contents_uninitialised(
obj: *mut ffi::PyObject,
) -> *mut MaybeUninit<PyClassObjectContents<T>>;
) -> *mut MaybeUninit<WrappedPyClassObjectContents<T>>;

fn get_ptr(&self) -> *mut T;

fn contents(&self) -> &PyClassObjectContents<T>;
fn contents(&self) -> &WrappedPyClassObjectContents<T>;

fn contents_mut(&mut self) -> &mut PyClassObjectContents<T>;
fn contents_mut(&mut self) -> &mut WrappedPyClassObjectContents<T>;

fn ob_base(&self) -> &<T::BaseType as PyClassBaseType>::LayoutAsBase;

Expand All @@ -287,7 +311,7 @@ pub trait InternalPyClassObjectLayout<T: PyClassImpl>: PyClassObjectLayout<T> {
fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker;
}

impl<T, U> PyClassObjectLayout<T> for PyClassObjectBase<U>
impl<T, U> PyClassObjectBaseLayout<T> for PyClassObjectBase<U>
where
U: PySizedLayout<T>,
T: PyTypeInfo,
Expand All @@ -301,6 +325,10 @@ where
}
}

/// Implementation of tp_dealloc.
/// # Safety
/// - obj must be a valid pointer to an instance of the type at `type_ptr` or a subclass.
/// - obj must not be used after this call (as it will be freed).
unsafe fn tp_dealloc(py: Python<'_>, obj: *mut ffi::PyObject, type_ptr: *mut ffi::PyTypeObject) {
// FIXME: there is potentially subtle issues here if the base is overwritten
// at runtime? To be investigated.
Expand Down Expand Up @@ -376,14 +404,14 @@ pub struct PyStaticClassObject<T: PyClassImpl> {
contents: PyClassObjectContents<T>,
}

impl<T: PyClassImpl> InternalPyClassObjectLayout<T> for PyStaticClassObject<T> {
impl<T: PyClassImpl> PyClassObjectLayout<T> for PyStaticClassObject<T> {
unsafe fn contents_uninitialised(
obj: *mut ffi::PyObject,
) -> *mut MaybeUninit<PyClassObjectContents<T>> {
) -> *mut MaybeUninit<WrappedPyClassObjectContents<T>> {
#[repr(C)]
struct PartiallyInitializedClassObject<T: PyClassImpl> {
_ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
contents: MaybeUninit<PyClassObjectContents<T>>,
contents: MaybeUninit<WrappedPyClassObjectContents<T>>,
}
let obj: *mut PartiallyInitializedClassObject<T> = obj.cast();
addr_of_mut!((*obj).contents)
Expand All @@ -397,12 +425,12 @@ impl<T: PyClassImpl> InternalPyClassObjectLayout<T> for PyStaticClassObject<T> {
&self.ob_base
}

fn contents(&self) -> &PyClassObjectContents<T> {
&self.contents
fn contents(&self) -> &WrappedPyClassObjectContents<T> {
(&self.contents).into()
}

fn contents_mut(&mut self) -> &mut PyClassObjectContents<T> {
&mut self.contents
fn contents_mut(&mut self) -> &mut WrappedPyClassObjectContents<T> {
(&mut self.contents).into()
}

/// used to set PyType_Spec::basicsize
Expand Down Expand Up @@ -453,9 +481,9 @@ impl<T: PyClassImpl> InternalPyClassObjectLayout<T> for PyStaticClassObject<T> {
unsafe impl<T: PyClassImpl> PyLayout<T> for PyStaticClassObject<T> {}
impl<T: PyClass> PySizedLayout<T> for PyStaticClassObject<T> {}

impl<T: PyClassImpl> PyClassObjectLayout<T> for PyStaticClassObject<T>
impl<T: PyClassImpl> PyClassObjectBaseLayout<T> for PyStaticClassObject<T>
where
<T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>,
<T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectBaseLayout<T::BaseType>,
{
fn ensure_threadsafe(&self) {
self.contents.thread_checker.ensure();
Expand All @@ -470,7 +498,7 @@ where
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
// Safety: Python only calls tp_dealloc when no references to the object remain.
let class_object = &mut *(slf.cast::<T::Layout>());
class_object.contents_mut().dealloc(py, slf);
class_object.contents_mut().0.dealloc(py, slf);
<T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf)
}
}
Expand All @@ -482,41 +510,41 @@ pub struct PyVariableClassObject<T: PyClassImpl> {

impl<T: PyClassImpl> PyVariableClassObject<T> {
#[cfg(Py_3_12)]
fn get_contents_of_obj(obj: *mut ffi::PyObject) -> *mut PyClassObjectContents<T> {
fn get_contents_of_obj(obj: *mut ffi::PyObject) -> *mut WrappedPyClassObjectContents<T> {
// https://peps.python.org/pep-0697/
let type_obj = unsafe { ffi::Py_TYPE(obj) };
let pointer = unsafe { ffi::PyObject_GetTypeData(obj, type_obj) };
pointer as *mut PyClassObjectContents<T>
pointer as *mut WrappedPyClassObjectContents<T>
}

#[cfg(Py_3_12)]
fn get_contents_ptr(&self) -> *mut PyClassObjectContents<T> {
fn get_contents_ptr(&self) -> *mut WrappedPyClassObjectContents<T> {
Self::get_contents_of_obj(self as *const PyVariableClassObject<T> as *mut ffi::PyObject)
}
}

#[cfg(Py_3_12)]
impl<T: PyClassImpl> InternalPyClassObjectLayout<T> for PyVariableClassObject<T> {
impl<T: PyClassImpl> PyClassObjectLayout<T> for PyVariableClassObject<T> {
unsafe fn contents_uninitialised(
obj: *mut ffi::PyObject,
) -> *mut MaybeUninit<PyClassObjectContents<T>> {
Self::get_contents_of_obj(obj) as *mut MaybeUninit<PyClassObjectContents<T>>
) -> *mut MaybeUninit<WrappedPyClassObjectContents<T>> {
Self::get_contents_of_obj(obj) as *mut MaybeUninit<WrappedPyClassObjectContents<T>>
}

fn get_ptr(&self) -> *mut T {
self.contents().value.get()
self.contents().0.value.get()
}

fn ob_base(&self) -> &<T::BaseType as PyClassBaseType>::LayoutAsBase {
&self.ob_base
}

fn contents(&self) -> &PyClassObjectContents<T> {
unsafe { (self.get_contents_ptr() as *const PyClassObjectContents<T>).as_ref() }
fn contents(&self) -> &WrappedPyClassObjectContents<T> {
unsafe { self.get_contents_ptr().cast_const().as_ref() }
.expect("should be able to cast PyClassObjectContents pointer")
}

fn contents_mut(&mut self) -> &mut PyClassObjectContents<T> {
fn contents_mut(&mut self) -> &mut WrappedPyClassObjectContents<T> {
unsafe { self.get_contents_ptr().as_mut() }
.expect("should be able to cast PyClassObjectContents pointer")
}
Expand Down Expand Up @@ -560,24 +588,24 @@ impl<T: PyClassImpl> InternalPyClassObjectLayout<T> for PyVariableClassObject<T>
unsafe impl<T: PyClassImpl> PyLayout<T> for PyVariableClassObject<T> {}

#[cfg(Py_3_12)]
impl<T: PyClassImpl> PyClassObjectLayout<T> for PyVariableClassObject<T>
impl<T: PyClassImpl> PyClassObjectBaseLayout<T> for PyVariableClassObject<T>
where
<T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>,
<T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectBaseLayout<T::BaseType>,
{
fn ensure_threadsafe(&self) {
self.contents().thread_checker.ensure();
self.contents().0.thread_checker.ensure();
self.ob_base.ensure_threadsafe();
}
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
if !self.contents().thread_checker.check() {
if !self.contents().0.thread_checker.check() {
return Err(PyBorrowError { _private: () });
}
self.ob_base.check_threadsafe()
}
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
// Safety: Python only calls tp_dealloc when no references to the object remain.
let class_object = &mut *(slf.cast::<T::Layout>());
class_object.contents_mut().dealloc(py, slf);
class_object.contents_mut().0.dealloc(py, slf);
<T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/pyclass/create_type_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
trampoline::trampoline,
},
internal_tricks::ptr_from_ref,
pycell::impl_::InternalPyClassObjectLayout,
pycell::impl_::PyClassObjectLayout,
types::{typeobject::PyTypeMethods, PyType},
Py, PyClass, PyResult, PyTypeInfo, Python,
};
Expand Down
Loading

0 comments on commit de2c2a3

Please sign in to comment.