Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Arlie Davis committed May 15, 2024
1 parent ec29c03 commit f72fd9e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 48 deletions.
85 changes: 46 additions & 39 deletions crates/libs/core/src/com_object.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef};
use core::borrow::Borrow;
use core::ops::Deref;
use core::ptr::NonNull;
use std::borrow::Borrow;

/// Identifies types that can be placed in `ComObject`.
/// Identifies types that can be placed in [`ComObject`].
///
/// This trait links types that can be placed in `ComObject` with the types generated by the
/// `#[implement]` macro. The `#[implement]` macro generates implementations of this trait.
Expand All @@ -11,20 +12,24 @@ use std::borrow::Borrow;
///
/// This trait is an implementation detail of the Windows crates.
/// User code should not deal directly with this trait.
pub trait ComImpl {
///
/// This trait is sort of the reverse of [`IUnknownImpl`]. This trait allows user code to use
/// `ComObject<T>] instead of `ComObject<T_Impl>`.
pub trait ComObjectInner {
/// The generated `<foo>_Impl` type (aka the "boxed" type or "outer" type).
type Outer: IUnknownImpl<Impl = Self>;
}

/// Describes the COM interfaces that a specific ComObject implements.
/// This trait is implemented by ComObject implementation object (e.g. `MyApp_Impl`).
/// Describes the COM interfaces implemented by a specific COM object.
///
/// The `#[implement]` macro generates implementations of this trait.
/// The `#[implement]` macro generates implementations of this trait. Implementations are attached
/// to the "outer" types generated by `#[implemented]`, e.g. the `MyApp_Impl` type. Each
/// implementation knows how to locate the interface-specific field within `MyApp_Impl`.
///
/// This trait is an implementation detail of the Windows crates.
/// User code should not deal directly with this trait.
pub trait ComObjectInterface<I: Interface> {
/// Gets a borrowed interface on the ComObject.
/// Gets a borrowed interface that is implemented by `T`.
fn as_interface_ref(&self) -> InterfaceRef<'_, I>;
}

Expand All @@ -39,15 +44,17 @@ pub trait ComObjectInterface<I: Interface> {
///
/// # Safety
///
/// The contained `ptr` field is an owned, reference-counted pointer to a _pinned_ Pin<Box<T::Outer>>.
/// Although this code does not currently use `Pin<T>`,
/// The contained `ptr` field is an owned, reference-counted pointer to a _pinned_ `Pin<Box<T::Outer>>`.
/// Although this code does not currently use `Pin<T>`, it takes care not to expose any unsafe semantics
/// to safe code. However, code that calls unsafe functions on [`ComObject`] must, like all unsafe code,
/// understand and preserve invariants.
#[repr(transparent)]
pub struct ComObject<T: ComImpl> {
pub struct ComObject<T: ComObjectInner> {
ptr: NonNull<T::Outer>,
}

impl<T: ComImpl> ComObject<T> {
/// Allocates a heap cell (box) and moves `obj` into it. Returns a counted pointer to `obj`.
impl<T: ComObjectInner> ComObject<T> {
/// Allocates a heap cell (box) and moves `value` into it. Returns a counted pointer to `value`.
pub fn new(value: T) -> Self {
unsafe {
let box_ = T::Outer::new_box(value);
Expand All @@ -57,8 +64,8 @@ impl<T: ComImpl> ComObject<T> {

/// Gets a reference to the shared object stored in the box.
///
/// `ComObject` also implements `Deref`, so you can often deref directly into the object.
/// For those situations where using the `Deref` impl is inconvenient, you can use
/// [`ComObject`] also implements [`Deref`], so you can often deref directly into the object.
/// For those situations where using the [`Deref`] impl is inconvenient, you can use
/// this method to explicitly get a reference to the contents.
#[inline(always)]
pub fn get(&self) -> &T {
Expand All @@ -72,9 +79,9 @@ impl<T: ComImpl> ComObject<T> {
}

// Note that we _do not_ provide a way to get a mutable reference to the outer box.
// It's ok to return &mut T, but not &mut T::Outer. That would allow someone to replace the
// It's ok to return `&mut T`, but not `&mut T::Outer`. That would allow someone to replace the
// contents of the entire object (box and reference count), which could lead to UB.
// This could maybe be solved by returning Pin<&mut T::Outer>, but that requires some
// This could maybe be solved by returning `Pin<&mut T::Outer>`, but that requires some
// additional thinking.

/// Gets a mutable reference to the object stored in the box, if the reference count
Expand All @@ -97,13 +104,13 @@ impl<T: ComImpl> ComObject<T> {
self.get_box().is_reference_count_one()
}

/// If this object has only a single object reference (i.e. this `ComObject` is the only
/// If this object has only a single object reference (i.e. this [`ComObject`] is the only
/// reference to the heap allocation), then this method will extract the inner `T`
/// (and return it in an `Ok`) and then free the heap allocation.
///
/// If there is more than one reference to this object, then this returns `Err(self)`.
#[inline(always)]
pub fn try_take(self) -> Result<T, Self> {
pub fn take(self) -> Result<T, Self> {
if self.is_exclusive_reference() {
let outer_box: Box<T::Outer> = unsafe { core::mem::transmute(self) };
Ok(outer_box.into_inner())
Expand All @@ -115,7 +122,7 @@ impl<T: ComImpl> ComObject<T> {
/// Casts to a given interface type.
///
/// This always performs a `QueryInterface`, even if `T` is known to implement `I`.
/// If you know that `T` implements `I`, then use `as_interface` or `to_interface` because
/// If you know that `T` implements `I`, then use [`Self::as_interface`] or [`Self::to_interface`] because
/// those functions do not require a dynamic `QueryInterface` call.
#[inline(always)]
pub fn cast<I: Interface>(&self) -> windows_core::Result<I>
Expand All @@ -126,10 +133,10 @@ impl<T: ComImpl> ComObject<T> {
unknown.cast()
}

/// Gets a borrowed reference to an interface that is implemented by this ComObject.
/// Gets a borrowed reference to an interface that is implemented by `T`.
///
/// The returned reference does not have an additional reference count.
/// You can AddRef it by calling to_owned().
/// You can AddRef it by calling [`Self::to_owned`].
#[inline(always)]
pub fn as_interface<I: Interface>(&self) -> InterfaceRef<'_, I>
where
Expand All @@ -138,7 +145,7 @@ impl<T: ComImpl> ComObject<T> {
self.get_box().as_interface_ref()
}

/// Gets an owned (counted) reference to an interface that is implemented by this ComObject.
/// Gets an owned (counted) reference to an interface that is implemented by this [`ComObject`].
#[inline(always)]
pub fn to_interface<I: Interface>(&self) -> I
where
Expand All @@ -147,7 +154,7 @@ impl<T: ComImpl> ComObject<T> {
self.as_interface::<I>().to_owned()
}

/// Converts this `ComObject` into an interface that it implements.
/// Converts `self` into an interface that it implements.
///
/// This does not need to adjust reference counts because `self` is consumed.
#[inline(always)]
Expand All @@ -163,21 +170,21 @@ impl<T: ComImpl> ComObject<T> {
}
}

impl<T: ComImpl + Default> Default for ComObject<T> {
impl<T: ComObjectInner + Default> Default for ComObject<T> {
fn default() -> Self {
Self::new(T::default())
}
}

impl<T: ComImpl> Drop for ComObject<T> {
impl<T: ComObjectInner> Drop for ComObject<T> {
fn drop(&mut self) {
unsafe {
T::Outer::Release(self.ptr.as_ptr());
}
}
}

impl<T: ComImpl> Clone for ComObject<T> {
impl<T: ComObjectInner> Clone for ComObject<T> {
#[inline(always)]
fn clone(&self) -> Self {
unsafe {
Expand All @@ -187,7 +194,7 @@ impl<T: ComImpl> Clone for ComObject<T> {
}
}

impl<T: ComImpl> AsRef<T> for ComObject<T>
impl<T: ComObjectInner> AsRef<T> for ComObject<T>
where
IUnknown: From<T> + AsImpl<T>,
{
Expand All @@ -197,7 +204,7 @@ where
}
}

impl<T: ComImpl> core::ops::Deref for ComObject<T> {
impl<T: ComObjectInner> Deref for ComObject<T> {
type Target = T::Outer;

#[inline(always)]
Expand All @@ -211,63 +218,63 @@ impl<T: ComImpl> core::ops::Deref for ComObject<T> {
// access to the contents of the object. Use get_mut() for dynamically-checked
// exclusive access.

impl<T: ComImpl> From<T> for ComObject<T> {
impl<T: ComObjectInner> From<T> for ComObject<T> {
fn from(value: T) -> ComObject<T> {
ComObject::new(value)
}
}

// Delegate hashing, if implemented.
impl<T: ComImpl + core::hash::Hash> core::hash::Hash for ComObject<T> {
impl<T: ComObjectInner + core::hash::Hash> core::hash::Hash for ComObject<T> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.get().hash(state);
}
}

// If T is Send (or Sync) then the ComObject<T> is also Send (or Sync).
// Since the actual object storage is in the heap, the object is never moved.
unsafe impl<T: ComImpl + Send> Send for ComObject<T> {}
unsafe impl<T: ComImpl + Sync> Sync for ComObject<T> {}
unsafe impl<T: ComObjectInner + Send> Send for ComObject<T> {}
unsafe impl<T: ComObjectInner + Sync> Sync for ComObject<T> {}

impl<T: ComImpl + PartialEq> PartialEq for ComObject<T> {
impl<T: ComObjectInner + PartialEq> PartialEq for ComObject<T> {
fn eq(&self, other: &ComObject<T>) -> bool {
let inner_self: &T = self.get();
let other_self: &T = other.get();
inner_self == other_self
}
}

impl<T: ComImpl + Eq> Eq for ComObject<T> {}
impl<T: ComObjectInner + Eq> Eq for ComObject<T> {}

impl<T: ComImpl + PartialOrd> PartialOrd for ComObject<T> {
impl<T: ComObjectInner + PartialOrd> PartialOrd for ComObject<T> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
let inner_self: &T = self.get();
let other_self: &T = other.get();
<T as PartialOrd>::partial_cmp(inner_self, other_self)
}
}

impl<T: ComImpl + Ord> Ord for ComObject<T> {
impl<T: ComObjectInner + Ord> Ord for ComObject<T> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
let inner_self: &T = self.get();
let other_self: &T = other.get();
<T as Ord>::cmp(inner_self, other_self)
}
}

impl<T: ComImpl + core::fmt::Debug> core::fmt::Debug for ComObject<T> {
impl<T: ComObjectInner + core::fmt::Debug> core::fmt::Debug for ComObject<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<T as core::fmt::Debug>::fmt(self.get(), f)
}
}

impl<T: ComImpl + core::fmt::Display> core::fmt::Display for ComObject<T> {
impl<T: ComObjectInner + core::fmt::Display> core::fmt::Display for ComObject<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<T as core::fmt::Display>::fmt(self.get(), f)
}
}

impl<T: ComImpl> Borrow<T> for ComObject<T> {
impl<T: ComObjectInner> Borrow<T> for ComObject<T> {
fn borrow(&self) -> &T {
self.get()
}
Expand Down
1 change: 0 additions & 1 deletion crates/libs/core/src/imp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ mod sha1;
mod waiter;
mod weak_ref_count;

pub use crate::com_object::{ComImpl, ComObjectInterface};
pub use bindings::*;
pub use can_into::*;
pub use com_bindings::*;
Expand Down
6 changes: 3 additions & 3 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
}

impl #generics ::windows_core::imp::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints {
impl #generics ::windows_core::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints {
#[inline(always)]
fn as_interface_ref(&self) -> ::windows_core::InterfaceRef<'_, #interface_ident> {
unsafe {
Expand Down Expand Up @@ -149,7 +149,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
}

impl #generics ::windows_core::imp::ComImpl for #original_ident::#generics where #constraints {
impl #generics ::windows_core::ComObjectInner for #original_ident::#generics where #constraints {
type Outer = #impl_ident::#generics;
}

Expand Down Expand Up @@ -282,7 +282,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
}

impl #generics ::windows_core::imp::ComObjectInterface<::windows_core::IUnknown> for #impl_ident::#generics where #constraints {
impl #generics ::windows_core::ComObjectInterface<::windows_core::IUnknown> for #impl_ident::#generics where #constraints {
#[inline(always)]
fn as_interface_ref(&self) -> ::windows_core::InterfaceRef<'_, ::windows_core::IUnknown> {
unsafe {
Expand Down
40 changes: 35 additions & 5 deletions crates/tests/implement_core/src/com_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,24 +200,24 @@ fn get_mut() {
}

#[test]
fn try_take() {
fn take() {
let app: ComObject<MyApp> = MyApp::new(42);
let tombstone = app.tombstone.clone();
// refcount = 1

let app2 = app.clone();
// refcount = 2

let app2_rejected: ComObject<MyApp> = match app2.try_take() {
Ok(_unexpected) => panic!("try_take should have failed"),
let app2_rejected: ComObject<MyApp> = match app2.take() {
Ok(_unexpected) => panic!("take() should have failed"),
Err(e) => e,
};
// refcount = 1

drop(app2_rejected);
// refcount = 1

match app.try_take() {
match app.take() {
Ok(unwrapped_app) => {
// box destroyed
assert!(!tombstone.is_dead());
Expand All @@ -227,7 +227,7 @@ fn try_take() {
}

Err(_unexpected) => {
panic!("try_take should have succeeded");
panic!("take() should have succeeded");
}
}
}
Expand Down Expand Up @@ -332,3 +332,33 @@ fn from_inner_ref() {
let ibar: IBar = unsafe { ifoo.get_self_as_bar() };
unsafe { ibar.say_hello() };
}

// This tests that we can place a type that is not Send in a ComObject.
// Compilation is sufficient to test.
#[implement(IBar)]
struct UnsendableThing {
cell: core::cell::Cell<u32>,
}

impl IBar_Impl for UnsendableThing {
unsafe fn say_hello(&self) {
println!("{}", self.cell.get());
}
}

static_assertions::assert_not_impl_all!(UnsendableThing: Send, Sync);
static_assertions::assert_not_impl_all!(ComObject<UnsendableThing>: Send, Sync);

#[implement(IBar)]
struct SendableThing {
arc: std::sync::Arc<u32>,
}

impl IBar_Impl for SendableThing {
unsafe fn say_hello(&self) {
println!("{}", *self.arc);
}
}

static_assertions::assert_impl_all!(SendableThing: Send, Sync);
static_assertions::assert_impl_all!(ComObject<SendableThing>: Send, Sync);

0 comments on commit f72fd9e

Please sign in to comment.