diff --git a/crates/libs/bindgen/src/rust/writer.rs b/crates/libs/bindgen/src/rust/writer.rs index 3da4429bcc..8e356eb847 100644 --- a/crates/libs/bindgen/src/rust/writer.rs +++ b/crates/libs/bindgen/src/rust/writer.rs @@ -697,6 +697,31 @@ impl Writer { self.GetResults() } } + #features + impl<#constraints> windows_core::AsyncOperation for #ident { + type Output = #return_type; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != #namespace AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&#namespace #handler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + } + #features + impl<#constraints> std::future::IntoFuture for #ident { + type Output = windows_core::Result<#return_type>; + type IntoFuture = windows_core::FutureWrapper<#ident>; + + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } + } } } } diff --git a/crates/libs/core/src/future.rs b/crates/libs/core/src/future.rs new file mode 100644 index 0000000000..e6e3af0577 --- /dev/null +++ b/crates/libs/core/src/future.rs @@ -0,0 +1,68 @@ +#![cfg(feature = "std")] + +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Poll, Waker}, +}; + +/// Wraps an `IAsyncOperation`, `IAsyncOperationWithProgress`, `IAsyncAction`, or `IAsyncActionWithProgress`. +/// Impls for this trait are generated automatically by windows-bindgen. +pub trait AsyncOperation { + /// The type produced when the operation finishes. + type Output; + /// Returns whether the operation is finished, in which case `self.get_results()` can be used to get the returned data. + /// Wraps `self.Status() != AsyncStatus::Started`. + fn is_complete(&self) -> crate::Result; + /// Register a callback that will be called once the operation is finished. + /// This can only be called once. + /// Wraps `self.SetCompleted(f)`. + fn set_completed(&self, f: impl Fn() + Send + 'static) -> crate::Result<()>; + /// Get the result value from a completed operation. + /// Wraps `self.GetResults()`. + fn get_results(&self) -> crate::Result; +} + +/// A wrapper around an `AsyncOperation` that implements `std::future::Future`. +/// This is used by generated `IntoFuture` impls. It shouldn't be necessary to use this type manually. +pub struct FutureWrapper { + inner: T, + waker: Option>>, +} + +impl FutureWrapper { + /// Creates a `FutureWrapper`, which implements `std::future::Future`. + pub fn new(inner: T) -> Self { + Self { + inner, + waker: None, + } + } +} + +impl Unpin for FutureWrapper {} + +impl Future for FutureWrapper { + type Output = crate::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + if self.inner.is_complete()? { + Poll::Ready(self.inner.get_results()) + } else { + if let Some(saved_waker) = &self.waker { + // Update the saved waker, in case the future has been transferred to a different executor. + // (e.g. if using `select`.) + let mut saved_waker = saved_waker.lock().unwrap(); + saved_waker.clone_from(cx.waker()); + } else { + let saved_waker = Arc::new(Mutex::new(cx.waker().clone())); + self.waker = Some(saved_waker.clone()); + self.inner.set_completed(move || { + saved_waker.lock().unwrap().wake_by_ref(); + })?; + } + Poll::Pending + } + } +} diff --git a/crates/libs/core/src/lib.rs b/crates/libs/core/src/lib.rs index d8a2b77f87..1a5e8ae03c 100644 --- a/crates/libs/core/src/lib.rs +++ b/crates/libs/core/src/lib.rs @@ -24,6 +24,7 @@ pub mod imp; mod as_impl; mod com_object; +mod future; mod guid; mod inspectable; mod interface; @@ -41,6 +42,7 @@ mod weak; pub use as_impl::*; pub use com_object::*; +pub use future::*; pub use guid::*; pub use inspectable::*; pub use interface::*; diff --git a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs index 64d785eda9..26b29068a6 100644 --- a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs +++ b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs @@ -1036,6 +1036,30 @@ impl DeleteSmsMessageOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for DeleteSmsMessageOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for DeleteSmsMessageOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct DeleteSmsMessagesOperation(windows_core::IUnknown); @@ -1122,6 +1146,30 @@ impl DeleteSmsMessagesOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for DeleteSmsMessagesOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for DeleteSmsMessagesOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct GetSmsDeviceOperation(windows_core::IUnknown); @@ -1211,6 +1259,30 @@ impl GetSmsDeviceOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for GetSmsDeviceOperation { + type Output = SmsDevice; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for GetSmsDeviceOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct GetSmsMessageOperation(windows_core::IUnknown); @@ -1299,6 +1371,30 @@ impl GetSmsMessageOperation { self.GetResults() } } +#[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for GetSmsMessageOperation { + type Output = ISmsMessage; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for GetSmsMessageOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] @@ -1407,6 +1503,30 @@ impl GetSmsMessagesOperation { self.GetResults() } } +#[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] +impl windows_core::AsyncOperation for GetSmsMessagesOperation { + type Output = super::super::Foundation::Collections::IVectorView; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] +impl std::future::IntoFuture for GetSmsMessagesOperation { + type Output = windows_core::Result>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] @@ -1493,6 +1613,30 @@ impl SendSmsMessageOperation { self.GetResults() } } +#[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for SendSmsMessageOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for SendSmsMessageOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct SmsAppMessage(windows_core::IUnknown); diff --git a/crates/libs/windows/src/Windows/Foundation/mod.rs b/crates/libs/windows/src/Windows/Foundation/mod.rs index a414916d61..11a548ce5a 100644 --- a/crates/libs/windows/src/Windows/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Foundation/mod.rs @@ -78,6 +78,28 @@ impl IAsyncAction { self.GetResults() } } +impl windows_core::AsyncOperation for IAsyncAction { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncAction { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncAction {} unsafe impl Sync for IAsyncAction {} impl windows_core::RuntimeType for IAsyncAction { @@ -183,6 +205,28 @@ impl IAsyncActionWithProgress windows_core::AsyncOperation for IAsyncActionWithProgress { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncActionWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncActionWithProgress { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncActionWithProgress {} unsafe impl Sync for IAsyncActionWithProgress {} impl windows_core::RuntimeType for IAsyncActionWithProgress { @@ -338,6 +382,28 @@ impl IAsyncOperation { self.GetResults() } } +impl windows_core::AsyncOperation for IAsyncOperation { + type Output = TResult; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncOperation {} unsafe impl Sync for IAsyncOperation {} impl windows_core::RuntimeType for IAsyncOperation { @@ -455,6 +521,28 @@ impl windows_core::AsyncOperation for IAsyncOperationWithProgress { + type Output = TResult; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncOperationWithProgress { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncOperationWithProgress {} unsafe impl Sync for IAsyncOperationWithProgress {} impl windows_core::RuntimeType for IAsyncOperationWithProgress { diff --git a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs index e82467d540..a5607a6c52 100644 --- a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs +++ b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs @@ -504,6 +504,28 @@ impl SignOutUserOperation { self.GetResults() } } +impl windows_core::AsyncOperation for SignOutUserOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for SignOutUserOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for SignOutUserOperation {} unsafe impl Sync for SignOutUserOperation {} #[repr(transparent)] @@ -587,6 +609,28 @@ impl UserAuthenticationOperation { self.GetResults() } } +impl windows_core::AsyncOperation for UserAuthenticationOperation { + type Output = UserIdentity; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for UserAuthenticationOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for UserAuthenticationOperation {} unsafe impl Sync for UserAuthenticationOperation {} #[repr(transparent)] diff --git a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs index 51059b48e8..329b6bb094 100644 --- a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs +++ b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs @@ -1325,6 +1325,28 @@ impl DataReaderLoadOperation { self.GetResults() } } +impl windows_core::AsyncOperation for DataReaderLoadOperation { + type Output = u32; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for DataReaderLoadOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for DataReaderLoadOperation {} unsafe impl Sync for DataReaderLoadOperation {} #[repr(transparent)] @@ -1593,6 +1615,28 @@ impl DataWriterStoreOperation { self.GetResults() } } +impl windows_core::AsyncOperation for DataWriterStoreOperation { + type Output = u32; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for DataWriterStoreOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for DataWriterStoreOperation {} unsafe impl Sync for DataWriterStoreOperation {} #[repr(transparent)] diff --git a/crates/samples/windows/ocr/Cargo.toml b/crates/samples/windows/ocr/Cargo.toml index 377ff0e03a..e75742439e 100644 --- a/crates/samples/windows/ocr/Cargo.toml +++ b/crates/samples/windows/ocr/Cargo.toml @@ -4,6 +4,9 @@ version = "0.0.0" edition = "2021" publish = false +[dependencies] +futures = "0.3.5" + [dependencies.windows] path = "../../../libs/windows" features = [ diff --git a/crates/samples/windows/ocr/src/main.rs b/crates/samples/windows/ocr/src/main.rs index 4dc85d8d06..7672ece94d 100644 --- a/crates/samples/windows/ocr/src/main.rs +++ b/crates/samples/windows/ocr/src/main.rs @@ -6,18 +6,22 @@ use windows::{ }; fn main() -> Result<()> { + futures::executor::block_on(main_async()) +} + +async fn main_async() -> Result<()> { let mut message = std::env::current_dir().unwrap(); message.push("message.png"); let file = - StorageFile::GetFileFromPathAsync(&HSTRING::from(message.to_str().unwrap()))?.get()?; - let stream = file.OpenAsync(FileAccessMode::Read)?.get()?; + StorageFile::GetFileFromPathAsync(&HSTRING::from(message.to_str().unwrap()))?.await?; + let stream = file.OpenAsync(FileAccessMode::Read)?.await?; - let decode = BitmapDecoder::CreateAsync(&stream)?.get()?; - let bitmap = decode.GetSoftwareBitmapAsync()?.get()?; + let decode = BitmapDecoder::CreateAsync(&stream)?.await?; + let bitmap = decode.GetSoftwareBitmapAsync()?.await?; let engine = OcrEngine::TryCreateFromUserProfileLanguages()?; - let result = engine.RecognizeAsync(&bitmap)?.get()?; + let result = engine.RecognizeAsync(&bitmap)?.await?; println!("{}", result.Text()?); Ok(()) diff --git a/crates/tests/winrt/Cargo.toml b/crates/tests/winrt/Cargo.toml index 10c6bf801f..fa09f0050f 100644 --- a/crates/tests/winrt/Cargo.toml +++ b/crates/tests/winrt/Cargo.toml @@ -29,4 +29,5 @@ features = [ ] [dev-dependencies] +futures = "0.3" helpers = { package = "test_helpers", path = "../helpers" } diff --git a/crates/tests/winrt/tests/async.rs b/crates/tests/winrt/tests/async.rs index 8ce9a224ae..44a843fd5c 100644 --- a/crates/tests/winrt/tests/async.rs +++ b/crates/tests/winrt/tests/async.rs @@ -23,3 +23,33 @@ fn async_get() -> windows::core::Result<()> { Ok(()) } + +async fn async_await() -> windows::core::Result<()> { + use windows::Storage::Streams::*; + + let stream = &InMemoryRandomAccessStream::new()?; + + let writer = DataWriter::CreateDataWriter(stream)?; + writer.WriteByte(1)?; + writer.WriteByte(2)?; + writer.WriteByte(3)?; + writer.StoreAsync()?.await?; + + stream.Seek(0)?; + let reader = DataReader::CreateDataReader(stream)?; + reader.LoadAsync(3)?.await?; + + let mut bytes: [u8; 3] = [0; 3]; + reader.ReadBytes(&mut bytes)?; + + assert!(bytes[0] == 1); + assert!(bytes[1] == 2); + assert!(bytes[2] == 3); + + Ok(()) +} + +#[test] +fn test_async_await() -> windows::core::Result<()> { + futures::executor::block_on(async_await()) +}