From 17b16400a36dc6a00258d7f7fbadbb03433dc6d0 Mon Sep 17 00:00:00 2001 From: Kenny Kerr Date: Fri, 13 Dec 2024 23:37:48 -0600 Subject: [PATCH] Make `Ref` work with more than just interface types (#3394) --- crates/libs/core/src/ref.rs | 38 +++++++++---------- crates/libs/core/src/type.rs | 27 ++++++++++++- crates/tests/winrt/ref_params/src/interop.cpp | 5 +++ crates/tests/winrt/ref_params/tests/test.rs | 1 + crates/tests/winrt/reference/src/lib.rs | 2 +- 5 files changed, 51 insertions(+), 22 deletions(-) diff --git a/crates/libs/core/src/ref.rs b/crates/libs/core/src/ref.rs index 889567bb32..537df22d02 100644 --- a/crates/libs/core/src/ref.rs +++ b/crates/libs/core/src/ref.rs @@ -1,5 +1,4 @@ use super::*; -use core::ffi::c_void; use core::marker::PhantomData; use core::mem::transmute; @@ -7,18 +6,24 @@ use core::mem::transmute; #[repr(transparent)] pub struct Ref<'a, T: Type>(T::Abi, PhantomData<&'a T>); -impl, Abi = *mut c_void>> Ref<'_, T> { +impl> Ref<'_, T> { /// Returns `true` if the argument is null. pub fn is_null(&self) -> bool { - self.0.is_null() + T::is_null(&self.0) } - /// Converts the argument to a [Result<&T>] reference. + /// Converts the argument to a [`Result<&T>`] reference. pub fn ok(&self) -> Result<&T> { - if self.0.is_null() { - Err(Error::from_hresult(imp::E_POINTER)) + self.as_ref() + .ok_or_else(|| Error::from_hresult(imp::E_POINTER)) + } + + /// Converts the argument to a [`Option<&T>`] reference. + pub fn as_ref(&self) -> Option<&T> { + if self.is_null() { + None } else { - unsafe { Ok(self.assume_init()) } + unsafe { Some(self.assume_init_ref()) } } } @@ -27,25 +32,18 @@ impl, Abi = *mut c_void>> Ref<'_, T> { /// # Panics /// /// Panics if the argument is null. + #[track_caller] pub fn unwrap(&self) -> &T { - if self.0.is_null() { - panic!("called `Ref::unwrap` on a null value") - } else { - unsafe { self.assume_init() } - } + self.as_ref().expect("called `Ref::unwrap` on a null value") } - /// Converts the argument to an [Option] by cloning the reference. + /// Converts the argument to an [`Option`] by cloning the reference. pub fn cloned(&self) -> Option { - if self.0.is_null() { - None - } else { - unsafe { Some(self.assume_init().clone()) } - } + self.as_ref().cloned() } - unsafe fn assume_init(&self) -> &T { - unsafe { transmute::<&*mut c_void, &T>(&self.0) } + unsafe fn assume_init_ref(&self) -> &T { + unsafe { T::assume_init_ref(&self.0) } } } diff --git a/crates/libs/core/src/type.rs b/crates/libs/core/src/type.rs index 87f8e6f900..a64f191d0a 100644 --- a/crates/libs/core/src/type.rs +++ b/crates/libs/core/src/type.rs @@ -19,7 +19,8 @@ pub trait Type::TypeKind>: TypeKind + Sized + C type Abi; type Default; - /// # Safety + fn is_null(abi: &Self::Abi) -> bool; + unsafe fn assume_init_ref(abi: &Self::Abi) -> &Self; unsafe fn from_abi(abi: Self::Abi) -> Result; fn from_default(default: &Self::Default) -> Result; } @@ -31,6 +32,14 @@ where type Abi = *mut core::ffi::c_void; type Default = Option; + fn is_null(abi: &Self::Abi) -> bool { + abi.is_null() + } + + unsafe fn assume_init_ref(abi: &Self::Abi) -> &Self { + unsafe { core::mem::transmute::<&*mut core::ffi::c_void, &T>(abi) } + } + unsafe fn from_abi(abi: Self::Abi) -> Result { unsafe { if !abi.is_null() { @@ -53,6 +62,14 @@ where type Abi = core::mem::MaybeUninit; type Default = Self; + fn is_null(_: &Self::Abi) -> bool { + false + } + + unsafe fn assume_init_ref(abi: &Self::Abi) -> &Self { + unsafe { abi.assume_init_ref() } + } + unsafe fn from_abi(abi: Self::Abi) -> Result { unsafe { Ok(abi.assume_init()) } } @@ -69,6 +86,14 @@ where type Abi = Self; type Default = Self; + fn is_null(_: &Self::Abi) -> bool { + false + } + + unsafe fn assume_init_ref(abi: &Self::Abi) -> &Self { + abi + } + unsafe fn from_abi(abi: Self::Abi) -> Result { Ok(abi) } diff --git a/crates/tests/winrt/ref_params/src/interop.cpp b/crates/tests/winrt/ref_params/src/interop.cpp index c1f9236592..6045b5bc20 100644 --- a/crates/tests/winrt/ref_params/src/interop.cpp +++ b/crates/tests/winrt/ref_params/src/interop.cpp @@ -15,6 +15,11 @@ struct Test : implements int32_t Input(ITest const& input) { + if (!input) + { + throw hresult_error(E_POINTER); + } + return input.Current(); } diff --git a/crates/tests/winrt/ref_params/tests/test.rs b/crates/tests/winrt/ref_params/tests/test.rs index 6cd60f6f6e..26a6dd0356 100644 --- a/crates/tests/winrt/ref_params/tests/test.rs +++ b/crates/tests/winrt/ref_params/tests/test.rs @@ -32,6 +32,7 @@ fn test_interface(test: &ITest) -> Result<()> { assert_eq!(test.Input(&one_two_three)?, 123); assert_eq!(test.Input(&four_five_six)?, 456); + assert_eq!(test.Input(None).unwrap_err().code(), HRESULT(-2147467261)); // E_POINTER let mut seven_eight_nine = None; test.Output(789, &mut seven_eight_nine)?; diff --git a/crates/tests/winrt/reference/src/lib.rs b/crates/tests/winrt/reference/src/lib.rs index 5aed6f4004..f4aad9219a 100644 --- a/crates/tests/winrt/reference/src/lib.rs +++ b/crates/tests/winrt/reference/src/lib.rs @@ -7,7 +7,7 @@ unsafe extern "system" fn DllGetActivationFactory( name: Ref, factory: OutRef, ) -> HRESULT { - if *name == "test_reference.Reference" { + if name.unwrap() == "test_reference.Reference" { factory.write(Some(ReferenceFactory.into())).into() } else { _ = factory.write(None);