Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide an explicit ComObject<T> type that represents a heap-allocated COM object #3043

Merged
merged 13 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion crates/libs/core/src/as_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,15 @@ pub trait AsImpl<T> {
///
/// The caller needs to ensure that `self` is actually implemented by the
/// implementation `T`.
unsafe fn as_impl(&self) -> &T;
unsafe fn as_impl(&self) -> &T {
self.as_impl_ptr().as_ref()
}

/// Returns a pointer to the implementation object.
///
/// # Safety
///
/// The caller needs to ensure that `self` is actually implemented by the
/// implementation `T`.
unsafe fn as_impl_ptr(&self) -> core::ptr::NonNull<T>;
}
281 changes: 281 additions & 0 deletions crates/libs/core/src/com_object.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef};
use core::borrow::Borrow;
use core::ops::Deref;
use core::ptr::NonNull;

/// Identifies types that can be placed in [`ComObject`].
///
/// This trait links types that can be placed in `ComObject` with the types generated by the
/// `#[implement]` macro. The `#[implement]` macro generates implementations of this trait.
/// The generated types contain the vtable layouts and refcount-related fields for the COM
/// object implementation.
///
/// This trait is an implementation detail of the Windows crates.
/// User code should not deal directly with this trait.
///
/// This trait is sort of the reverse of [`IUnknownImpl`]. This trait allows user code to use
/// `ComObject<T>] instead of `ComObject<T_Impl>`.
pub trait ComObjectInner {
/// The generated `<foo>_Impl` type (aka the "boxed" type or "outer" type).
type Outer: IUnknownImpl<Impl = Self>;
}

/// Describes the COM interfaces implemented by a specific COM object.
///
/// The `#[implement]` macro generates implementations of this trait. Implementations are attached
/// to the "outer" types generated by `#[implemented]`, e.g. the `MyApp_Impl` type. Each
sivadeilra marked this conversation as resolved.
Show resolved Hide resolved
/// implementation knows how to locate the interface-specific field within `MyApp_Impl`.
///
/// This trait is an implementation detail of the Windows crates.
/// User code should not deal directly with this trait.
pub trait ComObjectInterface<I: Interface> {
/// Gets a borrowed interface that is implemented by `T`.
fn as_interface_ref(&self) -> InterfaceRef<'_, I>;
}

/// A counted pointer to a type that implements COM interfaces, where the object has been
/// placed in the heap (boxed).
///
/// This type exists so that you can place an object into the heap and query for COM interfaces,
/// without losing the safe reference to the implementation object.
///
/// Because the pointer inside this type is known to be non-null, `Option<ComObject<T>>` should
/// always have the same size as a single pointer.
///
/// # Safety
///
/// The contained `ptr` field is an owned, reference-counted pointer to a _pinned_ `Pin<Box<T::Outer>>`.
/// Although this code does not currently use `Pin<T>`, it takes care not to expose any unsafe semantics
/// to safe code. However, code that calls unsafe functions on [`ComObject`] must, like all unsafe code,
/// understand and preserve invariants.
#[repr(transparent)]
pub struct ComObject<T: ComObjectInner> {
ptr: NonNull<T::Outer>,
}

impl<T: ComObjectInner> ComObject<T> {
/// Allocates a heap cell (box) and moves `value` into it. Returns a counted pointer to `value`.
pub fn new(value: T) -> Self {
unsafe {
let box_ = T::Outer::new_box(value);
Self { ptr: NonNull::new_unchecked(Box::into_raw(box_)) }
}
}

/// Gets a reference to the shared object stored in the box.
///
/// [`ComObject`] also implements [`Deref`], so you can often deref directly into the object.
/// For those situations where using the [`Deref`] impl is inconvenient, you can use
/// this method to explicitly get a reference to the contents.
#[inline(always)]
pub fn get(&self) -> &T {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a step back, how useful is ComObject outside the context of object construction? Since get_mut necessarily can't be used without exclusive ownership, it seems as if ComObject would only ever be used for object construction. If that is the case, I'm wondering whether anything can be simplified about this design. It is not overly complicated as it is, just something to consider.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can of course still use get to retrieve a synchronization primitive like RwLock stored within the implementation and go from there. So this could well be used during the life cycle of the object. Is that how you see it being used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's useful outside of just construction. A common pattern in DWriteCore is that we create a bunch of COM objects, and hand out interfaces to the client process, but we also need to keep strongly-typed references to the contents. Right now I'm doing that with the equivalent that is in the old com crate, and having ComObject in windows_core gets me that much closer to converting DWriteCore to use the Windows crates.

get_mut() itself is mainly useful during construction. You create a ComObject, then get access to its guts while you're still the only reference holder. You could get the same effect with RefCell, but since we already have the reference count in ComObject, why not use it? It's basically free.

self.get_box().get_impl()
}

/// Gets a reference to the shared object's heap box.
#[inline(always)]
pub fn get_box(&self) -> &T::Outer {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an implementation detail - does it need to be public?

unsafe { self.ptr.as_ref() }
}

// Note that we _do not_ provide a way to get a mutable reference to the outer box.
// It's ok to return `&mut T`, but not `&mut T::Outer`. That would allow someone to replace the
// contents of the entire object (box and reference count), which could lead to UB.
// This could maybe be solved by returning `Pin<&mut T::Outer>`, but that requires some
// additional thinking.

/// Gets a mutable reference to the object stored in the box, if the reference count
/// is exactly 1. If there are multiple references to this object then this returns `None`.
#[inline(always)]
pub fn get_mut(&mut self) -> Option<&mut T> {
sivadeilra marked this conversation as resolved.
Show resolved Hide resolved
if self.is_reference_count_one() {
// SAFETY: We must only return &mut T, *NOT* &mut T::Outer.
// Returning T::Outer would allow swapping the contents of the object, which would
// allow (incorrectly) modifying the reference count.
unsafe { Some(self.ptr.as_mut().get_impl_mut()) }
} else {
None
}
}

/// Returns `true` if this reference is the only reference to the `ComObject`.
#[inline(always)]
pub fn is_exclusive_reference(&self) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an implementation detail - does it need to be public?

self.get_box().is_reference_count_one()
}

/// If this object has only a single object reference (i.e. this [`ComObject`] is the only
/// reference to the heap allocation), then this method will extract the inner `T`
/// (and return it in an `Ok`) and then free the heap allocation.
///
/// If there is more than one reference to this object, then this returns `Err(self)`.
#[inline(always)]
pub fn take(self) -> Result<T, Self> {
if self.is_exclusive_reference() {
let outer_box: Box<T::Outer> = unsafe { core::mem::transmute(self) };
Ok(outer_box.into_inner())
} else {
Err(self)
}
}

/// Casts to a given interface type.
///
/// This always performs a `QueryInterface`, even if `T` is known to implement `I`.
/// If you know that `T` implements `I`, then use [`Self::as_interface`] or [`Self::to_interface`] because
/// those functions do not require a dynamic `QueryInterface` call.
#[inline(always)]
pub fn cast<I: Interface>(&self) -> windows_core::Result<I>
where
T::Outer: ComObjectInterface<IUnknown>,
{
let unknown = self.as_interface::<IUnknown>();
unknown.cast()
}

/// Gets a borrowed reference to an interface that is implemented by `T`.
///
/// The returned reference does not have an additional reference count.
/// You can AddRef it by calling [`Self::to_owned`].
#[inline(always)]
pub fn as_interface<I: Interface>(&self) -> InterfaceRef<'_, I>
where
T::Outer: ComObjectInterface<I>,
{
self.get_box().as_interface_ref()
}

/// Gets an owned (counted) reference to an interface that is implemented by this [`ComObject`].
#[inline(always)]
pub fn to_interface<I: Interface>(&self) -> I
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How useful is this in practice? Seems like it wouldn't be much harder just writing as_interface().to_owned().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a convenience. I think we'll need to use this stuff a bunch to see what's the most useful.

where
T::Outer: ComObjectInterface<I>,
{
self.as_interface::<I>().to_owned()
}

/// Converts `self` into an interface that it implements.
///
/// This does not need to adjust reference counts because `self` is consumed.
#[inline(always)]
pub fn into_interface<I: Interface>(self) -> I
sivadeilra marked this conversation as resolved.
Show resolved Hide resolved
where
T::Outer: ComObjectInterface<I>,
{
unsafe {
let raw = self.get_box().as_interface_ref().as_raw();
core::mem::forget(self);
I::from_raw(raw)
}
}
}

impl<T: ComObjectInner + Default> Default for ComObject<T> {
fn default() -> Self {
Self::new(T::default())
}
}

impl<T: ComObjectInner> Drop for ComObject<T> {
fn drop(&mut self) {
unsafe {
T::Outer::Release(self.ptr.as_ptr());
}
}
}

impl<T: ComObjectInner> Clone for ComObject<T> {
#[inline(always)]
fn clone(&self) -> Self {
unsafe {
self.ptr.as_ref().AddRef();
Self { ptr: self.ptr }
}
}
}

impl<T: ComObjectInner> AsRef<T> for ComObject<T>
where
IUnknown: From<T> + AsImpl<T>,
{
#[inline(always)]
fn as_ref(&self) -> &T {
self.get()
}
}

impl<T: ComObjectInner> Deref for ComObject<T> {
type Target = T::Outer;

#[inline(always)]
fn deref(&self) -> &Self::Target {
self.get_box()
}
}

// There is no DerefMut implementation because we cannot statically guarantee
// that the reference count is 1, which is a requirement for getting exclusive
// access to the contents of the object. Use get_mut() for dynamically-checked
// exclusive access.

impl<T: ComObjectInner> From<T> for ComObject<T> {
fn from(value: T) -> ComObject<T> {
ComObject::new(value)
}
}

// Delegate hashing, if implemented.
impl<T: ComObjectInner + core::hash::Hash> core::hash::Hash for ComObject<T> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.get().hash(state);
}
}

// If T is Send (or Sync) then the ComObject<T> is also Send (or Sync).
// Since the actual object storage is in the heap, the object is never moved.
unsafe impl<T: ComObjectInner + Send> Send for ComObject<T> {}
unsafe impl<T: ComObjectInner + Sync> Sync for ComObject<T> {}

impl<T: ComObjectInner + PartialEq> PartialEq for ComObject<T> {
fn eq(&self, other: &ComObject<T>) -> bool {
let inner_self: &T = self.get();
let other_self: &T = other.get();
inner_self == other_self
}
}

impl<T: ComObjectInner + Eq> Eq for ComObject<T> {}

impl<T: ComObjectInner + PartialOrd> PartialOrd for ComObject<T> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
let inner_self: &T = self.get();
let other_self: &T = other.get();
<T as PartialOrd>::partial_cmp(inner_self, other_self)
}
}

impl<T: ComObjectInner + Ord> Ord for ComObject<T> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
let inner_self: &T = self.get();
let other_self: &T = other.get();
<T as Ord>::cmp(inner_self, other_self)
}
}

impl<T: ComObjectInner + core::fmt::Debug> core::fmt::Debug for ComObject<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<T as core::fmt::Debug>::fmt(self.get(), f)
}
}

impl<T: ComObjectInner + core::fmt::Display> core::fmt::Display for ComObject<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<T as core::fmt::Display>::fmt(self.get(), f)
}
}

impl<T: ComObjectInner> Borrow<T> for ComObject<T> {
fn borrow(&self) -> &T {
self.get()
}
}
3 changes: 3 additions & 0 deletions crates/libs/core/src/imp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ macro_rules! define_interface {

#[doc(hidden)]
pub use define_interface;

#[doc(hidden)]
pub use std::boxed::Box;
5 changes: 5 additions & 0 deletions crates/libs/core/src/imp/weak_ref_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ impl WeakRefCount {
self.0.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count_or_pointer| bool::then_some(!is_weak_ref(count_or_pointer), count_or_pointer + 1)).map(|u| u as u32 + 1).unwrap_or_else(|pointer| unsafe { TearOff::decode(pointer).strong_count.add_ref() })
}

#[inline(always)]
pub fn is_one(&self) -> bool {
self.0.load(Ordering::Acquire) == 1
}

pub fn release(&self) -> u32 {
self.0.fetch_update(Ordering::Release, Ordering::Relaxed, |count_or_pointer| bool::then_some(!is_weak_ref(count_or_pointer), count_or_pointer - 1)).map(|u| u as u32 - 1).unwrap_or_else(|pointer| unsafe {
let tear_off = TearOff::decode(pointer);
Expand Down
Loading
Loading