From ccd334fa75ae56baf9013d836d5adeeb02d968ec Mon Sep 17 00:00:00 2001 From: sivadeilra Date: Sat, 25 May 2024 09:11:16 -0700 Subject: [PATCH] Dynamic casting to COM implementation (#3055) This provides a new feature for COM developers using the windows-rs crate. It allows for safe dynamic casting from IUnknown to an implementation object. It is based on Rust's Any trait. Any type that is marked with #[implement], except for those that contain non-static lifetimes, can be used with dynamic casting. Example: ```rust struct MyApp { ... } fn main() { let my_app = ComObject::new(MyApp { ... }); let iunknown: IUnknown = my_app.to_interface(); do_stuff(&iunknown); } fn do_stuff(unknown: &IUnknown) -> Result<()> { let my_app: ComObject = unknown.cast_object()?; my_app.internal_method(); Ok(()) } ``` Co-authored-by: Arlie Davis --- crates/libs/core/src/com_object.rs | 23 ++++ crates/libs/core/src/interface.rs | 128 +++++++++++++++++- crates/libs/core/src/unknown.rs | 15 ++ crates/libs/implement/src/lib.rs | 41 +++++- crates/tests/implement_core/src/com_object.rs | 45 +++++- 5 files changed, 245 insertions(+), 7 deletions(-) diff --git a/crates/libs/core/src/com_object.rs b/crates/libs/core/src/com_object.rs index 00d80e7b3b..ede313071e 100644 --- a/crates/libs/core/src/com_object.rs +++ b/crates/libs/core/src/com_object.rs @@ -1,5 +1,6 @@ use crate::imp::Box; use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef}; +use core::any::Any; use core::borrow::Borrow; use core::ops::Deref; use core::ptr::NonNull; @@ -196,6 +197,28 @@ impl ComObject { I::from_raw(raw) } } + + /// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer" + /// object, e.g. `MyApp_Impl`, not the inner `MyApp` object. + /// + /// `T` must be a type that has been annotated with `#[implement]`; this is checked at + /// compile-time by the generic constraints of this method. However, note that the + /// returned `&dyn Any` refers to the _outer_ implementation object that was generated by + /// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type. + /// + /// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust + /// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`. + /// + /// The returned value is an owned (counted) reference; this function calls `AddRef` on the + /// underlying COM object. If you do not need an owned reference, then you can use the + /// [`Interface::cast_object_ref`] method instead, and avoid the cost of `AddRef` / `Release`. + pub fn cast_from(interface: &I) -> crate::Result + where + I: Interface, + T::Outer: Any + 'static + IUnknownImpl, + { + interface.cast_object() + } } impl Default for ComObject { diff --git a/crates/libs/core/src/interface.rs b/crates/libs/core/src/interface.rs index c10991b9a7..93708fe9e3 100644 --- a/crates/libs/core/src/interface.rs +++ b/crates/libs/core/src/interface.rs @@ -1,7 +1,8 @@ use super::*; +use core::any::Any; use core::ffi::c_void; use core::marker::PhantomData; -use core::mem::{forget, transmute_copy}; +use core::mem::{forget, transmute_copy, MaybeUninit}; use core::ptr::NonNull; /// Provides low-level access to an interface vtable. @@ -97,7 +98,7 @@ pub unsafe trait Interface: Sized + Clone { // // This guards against implementations of COM interfaces which may store non-null values // in 'result' but still return E_NOINTERFACE. - let mut result = core::mem::MaybeUninit::>::zeroed(); + let mut result = MaybeUninit::>::zeroed(); self.query(&T::IID, result.as_mut_ptr() as _).ok()?; // If we get here, then query() has succeeded, but we still need to double-check @@ -110,6 +111,123 @@ pub unsafe trait Interface: Sized + Clone { } } + /// This casts the given COM interface to [`&dyn Any`]. + /// + /// Applications should generally _not_ call this method directly. Instead, use the + /// [`Interface::cast_object_ref`] or [`Interface::cast_object`] methods. + /// + /// `T` must be a type that has been annotated with `#[implement]`; this is checked at + /// compile-time by the generic constraints of this method. However, note that the + /// returned `&dyn Any` refers to the _outer_ implementation object that was generated by + /// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type. + /// + /// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust + /// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`. + /// + /// # Safety + /// + /// **IMPORTANT!!** This uses a non-standard protocol for QueryInterface! The `DYNAMIC_CAST_IID` + /// IID identifies this protocol, but there is no `IDynamicCast` interface. Instead, objects + /// that recognize `DYNAMIC_CAST_IID` simply store their `&dyn Any` directly at the interface + /// pointer that was passed to `QueryInterface. This means that the returned value has a + /// size that is twice as large (`size_of::<&dyn Any>() == 2 * size_of::<*const c_void>()`). + /// + /// This means that callers that use this protocol cannot simply pass `&mut ptr` for + /// an ordinary single-pointer-sized pointer. Only this method understands this protocol. + /// + /// Another part of this protocol is that the implementation of `QueryInterface` _does not_ + /// AddRef the object. The caller must guarantee the liveness of the COM object. In Rust, + /// this means tying the lifetime of the IUnknown* that we used for the QueryInterface + /// call to the lifetime of the returned `&dyn Any` value. + /// + /// This method preserves type safety and relies on these invariants: + /// + /// * All `QueryInterface` implementations that recognize `DYNAMIC_CAST_IID` are generated by + /// the `#[implement]` macro and respect the rules described here. + #[inline(always)] + fn cast_to_any(&self) -> Result<&dyn Any> + where + T: ComObjectInner, + T::Outer: Any + 'static + IUnknownImpl, + { + unsafe { + let mut any_ref_arg: MaybeUninit<&dyn Any> = MaybeUninit::zeroed(); + self.query(&DYNAMIC_CAST_IID, any_ref_arg.as_mut_ptr() as *mut *mut c_void).ok()?; + Ok(any_ref_arg.assume_init()) + } + } + + /// Returns `true` if the given COM interface refers to an implementation of `T`. + /// + /// `T` must be a type that has been annotated with `#[implement]`; this is checked at + /// compile-time by the generic constraints of this method. + /// + /// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust + /// object that contains non-static lifetimes, then this function will return `false`. + #[inline(always)] + fn is_object(&self) -> bool + where + T: ComObjectInner, + T::Outer: Any + 'static + IUnknownImpl, + { + if let Ok(any) = self.cast_to_any::() { + any.is::() + } else { + false + } + } + + /// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer" + /// object, e.g. `&MyApp_Impl`, not the inner `&MyApp` object. + /// + /// `T` must be a type that has been annotated with `#[implement]`; this is checked at + /// compile-time by the generic constraints of this method. However, note that the + /// returned `&dyn Any` refers to the _outer_ implementation object that was generated by + /// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type. + /// + /// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust + /// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`. + /// + /// The returned value is borrowed. If you need an owned (counted) reference, then use + /// [`Interface::cast_object`]. + #[inline(always)] + fn cast_object_ref(&self) -> Result<&T::Outer> + where + T: ComObjectInner, + T::Outer: Any + 'static + IUnknownImpl, + { + let any: &dyn Any = self.cast_to_any::()?; + if let Some(outer) = any.downcast_ref::() { + Ok(outer) + } else { + Err(imp::E_NOINTERFACE.into()) + } + } + + /// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer" + /// object, e.g. `MyApp_Impl`, not the inner `MyApp` object. + /// + /// `T` must be a type that has been annotated with `#[implement]`; this is checked at + /// compile-time by the generic constraints of this method. However, note that the + /// returned `&dyn Any` refers to the _outer_ implementation object that was generated by + /// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type. + /// + /// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust + /// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`. + /// + /// The returned value is an owned (counted) reference; this function calls `AddRef` on the + /// underlying COM object. If you do not need an owned reference, then you can use the + /// [`Interface::cast_object_ref`] method instead, and avoid the cost of `AddRef` / `Release`. + #[inline(always)] + fn cast_object(&self) -> Result> + where + T: ComObjectInner, + T::Outer: Any + 'static + IUnknownImpl, + { + let object_ref = self.cast_object_ref::()?; + Ok(object_ref.to_object()) + } + /// Attempts to create a [`Weak`] reference to this object. fn downgrade(&self) -> Result> { self.cast::().and_then(|source| Weak::downgrade(&source)) @@ -210,3 +328,9 @@ impl<'a, I: Interface> core::ops::Deref for InterfaceRef<'a, I> { unsafe { core::mem::transmute(self) } } } + +/// This IID identifies a special protocol, used by [`Interface::cast_to_any`]. This is _not_ +/// an ordinary COM interface; it uses special lifetime rules and a larger interface pointer. +/// See the comments on [`Interface::cast_to_any`]. +#[doc(hidden)] +pub const DYNAMIC_CAST_IID: GUID = GUID::from_u128(0xae49d5cb_143f_431c_874c_2729336e4eca); diff --git a/crates/libs/core/src/unknown.rs b/crates/libs/core/src/unknown.rs index 1fe088636f..86a08bb920 100644 --- a/crates/libs/core/src/unknown.rs +++ b/crates/libs/core/src/unknown.rs @@ -135,6 +135,21 @@ pub trait IUnknownImpl { { >::as_interface_ref(self).to_owned() } + + /// Creates a new owned reference to this object. + /// + /// # Safety + /// + /// This function can only be safely called by `_Impl` objects that are embedded in a + /// `ComObject`. Since we only allow safe Rust code to access these objects using a `ComObject` + /// or a `&_Impl` that points within a `ComObject`, this is safe. + fn to_object(&self) -> ComObject + where + Self::Impl: ComObjectInner; + + /// The distance from the start of `_Impl` to the `this` field within it, measured in + /// pointer-sized elements. The `this` field contains the `MyApp` instance. + const INNER_OFFSET_IN_POINTERS: usize; } impl IUnknown_Vtbl { diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 06c6536ab4..617054fe39 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -46,10 +46,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: let original_type2 = original_type.clone(); let original_type2 = syn::parse_macro_input!(original_type2 as syn::ItemStruct); let vis = &original_type2.vis; - let original_ident = original_type2.ident; + let original_ident = &original_type2.ident; let mut constraints = quote! {}; - if let Some(where_clause) = original_type2.generics.where_clause { + if let Some(where_clause) = &original_type2.generics.where_clause { where_clause.predicates.to_tokens(&mut constraints); } @@ -83,6 +83,25 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } }); + // Dynamic casting requires that the object not contain non-static lifetimes. + let enable_dyn_casting = original_type2.generics.lifetimes().count() == 0; + let dynamic_cast_query = if enable_dyn_casting { + quote! { + else if *iid == ::windows_core::DYNAMIC_CAST_IID { + // DYNAMIC_CAST_IID is special. We _do not_ increase the reference count for this pseudo-interface. + // Also, instead of returning an interface pointer, we simply write the `&dyn Any` directly to the + // 'interface' pointer. Since the size of `&dyn Any` is 2 pointers, not one, the caller must be + // prepared for this. This is not a normal QueryInterface call. + // + // See the `Interface::cast_to_any` method, which is the only caller that should use DYNAMIC_CAST_ID. + (interface as *mut *const dyn core::any::Any).write(self as &dyn ::core::any::Any as *const dyn ::core::any::Any); + return ::windows_core::HRESULT(0); + } + } + } else { + quote!() + }; + // The distance from the beginning of the generated type to the 'this' field, in units of pointers (not bytes). let offset_of_this_in_pointers = 1 + attributes.implement.len(); let offset_of_this_in_pointers_token = proc_macro2::Literal::usize_unsuffixed(offset_of_this_in_pointers); @@ -201,7 +220,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: || iid == &<::windows_core::IInspectable as ::windows_core::Interface>::IID || iid == &<::windows_core::imp::IAgileObject as ::windows_core::Interface>::IID { &self.identity as *const _ as *mut _ - } #(#queries)* else { + } + #(#queries)* + #dynamic_cast_query + else { ::core::ptr::null_mut() }; @@ -230,7 +252,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: unsafe fn Release(self_: *mut Self) -> u32 { let remaining = (*self_).count.release(); if remaining == 0 { - _ = ::windows_core::imp::Box::from_raw(self_ as *const Self as *mut Self); + _ = ::windows_core::imp::Box::from_raw(self_); } remaining } @@ -247,6 +269,17 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: &*((inner as *const Self::Impl as *const *const ::core::ffi::c_void) .sub(#offset_of_this_in_pointers_token) as *const Self) } + + fn to_object(&self) -> ::windows_core::ComObject { + self.count.add_ref(); + unsafe { + ::windows_core::ComObject::from_raw( + ::core::ptr::NonNull::new_unchecked(self as *const Self as *mut Self) + ) + } + } + + const INNER_OFFSET_IN_POINTERS: usize = #offset_of_this_in_pointers_token; } impl #generics #original_ident::#generics where #constraints { diff --git a/crates/tests/implement_core/src/com_object.rs b/crates/tests/implement_core/src/com_object.rs index a5d0eb2be9..de37c49bf9 100644 --- a/crates/tests/implement_core/src/com_object.rs +++ b/crates/tests/implement_core/src/com_object.rs @@ -4,7 +4,7 @@ use std::borrow::Borrow; use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use std::sync::Arc; use windows_core::{ - implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, InterfaceRef, + implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, Interface, InterfaceRef, }; #[interface("818f2fd1-d479-4398-b286-a93c4c7904d1")] @@ -19,8 +19,12 @@ unsafe trait IBar: IUnknown { fn say_hello(&self); } +const APP_SIGNATURE: [u8; 8] = *b"cafef00d"; + #[implement(IFoo, IBar)] struct MyApp { + // We use signature to verify field offsets for dynamic casts + signature: [u8; 8], x: u32, tombstone: Arc, } @@ -63,6 +67,7 @@ impl core::fmt::Display for MyApp { impl Default for MyApp { fn default() -> Self { Self { + signature: APP_SIGNATURE, x: 0, tombstone: Arc::new(Tombstone::default()), } @@ -109,6 +114,7 @@ impl MyApp { fn new(x: u32) -> ComObject { ComObject::new(Self { x, + signature: APP_SIGNATURE, tombstone: Arc::new(Tombstone::default()), }) } @@ -333,6 +339,43 @@ fn from_inner_ref() { unsafe { ibar.say_hello() }; } +#[test] +fn to_object() { + let app = MyApp::new(42); + let tombstone = app.tombstone.clone(); + let app_outer: &MyApp_Impl = &app; + + let second_app = app_outer.to_object(); + assert!(!tombstone.is_dead()); + assert_eq!(second_app.signature, APP_SIGNATURE); + + println!("x = {}", unsafe { second_app.get_x() }); + + drop(second_app); + assert!(!tombstone.is_dead()); + + drop(app); + assert!(tombstone.is_dead()); +} + +#[test] +fn dynamic_cast() { + let app = MyApp::new(42); + let unknown = app.to_interface::(); + + assert!(!unknown.is_object::()); + assert!(unknown.is_object::()); + + let dyn_app_ref: &MyApp_Impl = unknown.cast_object_ref::().unwrap(); + assert_eq!(dyn_app_ref.signature, APP_SIGNATURE); + + let dyn_app_owned: ComObject = unknown.cast_object().unwrap(); + assert_eq!(dyn_app_owned.signature, APP_SIGNATURE); + + let dyn_app_owned_2: ComObject = ComObject::cast_from(&unknown).unwrap(); + assert_eq!(dyn_app_owned_2.signature, APP_SIGNATURE); +} + // This tests that we can place a type that is not Send in a ComObject. // Compilation is sufficient to test. #[implement(IBar)]