Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support richer COM interface hierarchies #1608

Merged
merged 5 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 46 additions & 21 deletions crates/libs/interface/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ macro_rules! expected_token {
/// }
/// ```
struct Interface {
pub visibility: syn::Visibility,
pub name: syn::Ident,
pub parent: Option<syn::Path>,
pub methods: Vec<InterfaceMethod>,
visibility: syn::Visibility,
name: syn::Ident,
parent: syn::Path,
methods: Vec<InterfaceMethod>,
docs: Vec<syn::Attribute>,
}

Expand All @@ -67,8 +67,8 @@ impl Interface {
let vis = &self.visibility;
let name = &self.name;
let docs = &self.docs;
let parent = self.parent.as_ref().map(|p| quote!(#p)).unwrap_or_else(|| quote!(::windows::core::IUnknown));
let vtable_name = quote::format_ident!("{}_Vtbl", name);
let parent = self.parent();
let vtable_name = quote::format_ident!("{}Vtbl", name);
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
let guid = guid.to_tokens()?;
let implementation = self.gen_implementation();
let com_trait = self.get_com_trait();
Expand Down Expand Up @@ -143,9 +143,11 @@ impl Interface {
}
})
.collect::<Vec<_>>();
let parent = self.parent_trait_constraint();

quote! {
#[allow(non_camel_case_types)]
#vis trait #name: Sized {
#vis trait #name: #parent Sized {
#(#methods)*
}
}
Expand All @@ -154,8 +156,8 @@ impl Interface {
/// Generates the vtable for a COM interface
fn gen_vtable(&self, vtable_name: &syn::Ident) -> proc_macro2::TokenStream {
let name = &self.name;
// TODO
let parent_vtable = quote!(::windows::core::IUnknownVtbl);
let parent_vtable = self.parent_vtable();
let parent_vtable_generics = if self.parent_is_iunknown() { quote!(Identity, OFFSET) } else { quote!(Identity, Impl, OFFSET) };
let vtable_entries = self
.methods
.iter()
Expand Down Expand Up @@ -207,15 +209,14 @@ impl Interface {
#[repr(C)]
#[doc(hidden)]
pub struct #vtable_name {
// TODO: handle non-IUnknown parents
pub base: ::windows::core::IUnknownVtbl,
pub base: #parent_vtable,
#(#vtable_entries)*
}

impl #vtable_name {
pub const fn new<Identity: ::windows::core::IUnknownImpl, Impl: #trait_name, const OFFSET: isize>() -> Self {
#(#functions)*
Self { base: #parent_vtable::new::<Identity, OFFSET>(), #(#entries),* }
Self { base: #parent_vtable::new::<#parent_vtable_generics>(), #(#entries),* }
}

pub fn matches(iid: &windows::core::GUID) -> bool {
Expand All @@ -232,8 +233,7 @@ impl Interface {
quote! {
impl ::core::convert::From<#name> for ::windows::core::IUnknown {
fn from(value: #name) -> Self {
// TODO: handle when direct parent is not IUnknown
value.0
unsafe { ::core::mem::transmute(value) }
}
}
impl ::core::convert::From<&#name> for ::windows::core::IUnknown {
Expand All @@ -248,8 +248,7 @@ impl Interface {
}
impl<'a> ::windows::core::IntoParam<'a, ::windows::core::IUnknown> for &'a #name {
fn into_param(self) -> ::windows::core::Param<'a, ::windows::core::IUnknown> {
// TODO: handle when direct parent is not IUnknown
::windows::core::Param::Borrowed(&self.0)
::windows::core::Param::Borrowed(unsafe { ::core::mem::transmute(self) })
}
}
impl ::core::clone::Clone for #name {
Expand All @@ -270,6 +269,35 @@ impl Interface {
}
}
}

fn parent(&self) -> proc_macro2::TokenStream {
let p = &self.parent;
quote!(#p)
}

fn parent_vtable(&self) -> proc_macro2::TokenStream {
let i = self.parent_ident();
let i = quote::format_ident!("{}Vtbl", i);
quote!(#i)
}

fn parent_is_iunknown(&self) -> bool {
self.parent.is_ident("IUnknown")
}

fn parent_ident(&self) -> &syn::Ident {
&self.parent.segments.last().as_ref().expect("segements should never be empty").ident
}

/// Gets the parent trait constrait which is nothing if the parent is IUnknown
fn parent_trait_constraint(&self) -> proc_macro2::TokenStream {
let i = self.parent_ident();
if i == "IUnknown" {
return quote!();
}
let i = quote::format_ident!("{}_Impl", i);
quote!(#i +)
}
}

impl Parse for Interface {
Expand All @@ -289,11 +317,8 @@ impl Parse for Interface {
let _ = input.parse::<syn::Token![unsafe]>()?;
let _ = input.parse::<syn::Token![trait]>()?;
let name = input.parse::<syn::Ident>()?;
let mut parent = None;
if name != "IUnknown" {
let _ = input.parse::<syn::Token![:]>().map_err(|_| syn::Error::new(name.span(), format!("Interfaces must inherit from another interface like so: `interface {}: IParentInterface`", name)))?;
parent = Some(input.parse::<syn::Path>()?);
}
let _ = input.parse::<syn::Token![:]>().map_err(|_| syn::Error::new(name.span(), format!("Interfaces must inherit from another interface like so: `interface {}: IParentInterface`", name)))?;
let parent = input.parse::<syn::Path>()?;
let content;
syn::braced!(content in input);
let mut methods = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/libs/windows/src/core/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub unsafe trait Interface: Sized {
#[doc(hidden)]
unsafe fn assume_vtable<T: Interface>(&self) -> &T::Vtable {
let this: RawPtr = core::mem::transmute_copy(self);
&(*(*(this as *mut *mut _) as *mut _))
&**(this as *mut *mut T::Vtable)
}

#[doc(hidden)]
Expand Down
166 changes: 106 additions & 60 deletions crates/tests/nightly_interface/tests/com.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,99 @@ pub unsafe trait ICustomUri: IUnknown {
// etc
}

#[implement(ICustomUri)]
struct CustomUri;

impl ICustomUri_Impl for CustomUri {
unsafe fn GetPropertyBSTR(&self, property: Uri_PROPERTY, value: *mut BSTR, flags: u32) -> HRESULT {
assert!(flags == 0);
assert!(property == Uri_PROPERTY_DOMAIN);
*value = "property".into();
S_OK
}
unsafe fn GetPropertyLength(&self) -> HRESULT {
todo!()
/// A custom declaration of implementation of `IPersist`
#[interface("0000010c-0000-0000-C000-000000000046")]
pub unsafe trait ICustomPersist: IUnknown {
unsafe fn GetClassID(&self, clsid: *mut GUID) -> HRESULT;
}

/// A custom declaration of implementation of `IPersistMemory`
#[interface("BD1AE5E0-A6AE-11CE-BD37-504200C10000")]
pub unsafe trait ICustomPersistMemory: ICustomPersist {
unsafe fn IsDirty(&self) -> HRESULT;
unsafe fn Load(&self, input: *const core::ffi::c_void, size: u32) -> HRESULT;
unsafe fn Save(&self, output: *mut core::ffi::c_void, clear_dirty: BOOL, size: u32) -> HRESULT;
unsafe fn GetSizeMax(&self, len: *mut u32) -> HRESULT;
unsafe fn InitNew(&self) -> HRESULT;
}

/// A custom in-memory store
#[implement(ICustomPersistMemory, ICustomPersist)]
#[derive(Default)]
struct Persist(std::sync::RwLock<PersistState>);

impl Persist {
fn new() -> Self {
Self(std::sync::RwLock::new(PersistState::default()))
}
unsafe fn GetPropertyDWORD(&self, property: Uri_PROPERTY, value: *mut u32, flags: u32) -> HRESULT {
assert!(flags == 0);
assert!(property == Uri_PROPERTY_PORT);
*value = 123;
}

#[derive(Default)]
struct PersistState {
memory: [u8; 10],
dirty: bool,
}

impl ICustomPersist_Impl for Persist {
unsafe fn GetClassID(&self, clsid: *mut GUID) -> HRESULT {
*clsid = "117fb826-2155-483a-b50d-bc99a2c7cca3".into();
S_OK
}
unsafe fn HasProperty(&self) {
todo!()
}

impl ICustomPersistMemory_Impl for Persist {
unsafe fn IsDirty(&self) -> HRESULT {
let reader = self.0.read().unwrap();
if reader.dirty {
S_OK
} else {
S_FALSE
}
}
unsafe fn GetAbsoluteUri(&self) -> HRESULT {
todo!()

unsafe fn Load(&self, input: *const core::ffi::c_void, size: u32) -> HRESULT {
let mut writer = self.0.write().unwrap();
if size <= writer.memory.len() as _ {
std::ptr::copy(input, writer.memory.as_mut_ptr() as _, size as _);
writer.dirty = true;
S_OK
} else {
E_OUTOFMEMORY
}
}
unsafe fn GetAuthority(&self) -> HRESULT {
todo!()

unsafe fn Save(&self, output: *mut core::ffi::c_void, clear_dirty: BOOL, size: u32) -> HRESULT {
let mut writer = self.0.write().unwrap();
if size <= writer.memory.len() as _ {
std::ptr::copy(writer.memory.as_mut_ptr() as _, output, size as _);
if clear_dirty.as_bool() {
writer.dirty = false;
}
S_OK
} else {
E_OUTOFMEMORY
}
}
unsafe fn GetDisplayUri(&self) -> i32 {
todo!()

unsafe fn GetSizeMax(&self, len: *mut u32) -> HRESULT {
let reader = self.0.read().unwrap();
*len = reader.memory.len() as _;
S_OK
}
unsafe fn GetDomain(&self, value: *mut BSTR) -> HRESULT {
*value = "kennykerr.ca".into();

unsafe fn InitNew(&self) -> HRESULT {
let mut writer = self.0.write().unwrap();
writer.memory = Default::default();
writer.dirty = false;
S_OK
}
}

#[test]
fn test_custom_interface() -> windows::core::Result<()> {
unsafe {
// Use the OS implementation through the OS interface
// Use the OS implementation of Uri through the custom `ICustomUri` interface
let a: IUri = CreateUri("http://kennykerr.ca", Default::default(), 0)?;
let domain = a.GetDomain()?;
assert_eq!(domain, "kennykerr.ca");
let mut property = BSTR::new();
a.GetPropertyBSTR(Uri_PROPERTY_DOMAIN, &mut property, 0)?;
assert_eq!(property, "kennykerr.ca");
let mut property = 0;
a.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0)?;
assert_eq!(property, 80);

// Call the OS implementation through the custom interface
let b: ICustomUri = a.cast()?;
let mut domain = BSTR::new();
b.GetDomain(&mut domain).ok()?;
Expand All @@ -79,30 +120,35 @@ fn test_custom_interface() -> windows::core::Result<()> {
a.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0)?;
assert_eq!(property, 80);

// Use the custom implementation through the OS interface
let c: ICustomUri = CustomUri.into();
// This works because `ICustomUri` and `IUri` share the same guid
let c: IUri = c.cast()?;
let domain = c.GetDomain()?;
assert_eq!(domain, "kennykerr.ca");
let mut property = BSTR::new();
c.GetPropertyBSTR(Uri_PROPERTY_DOMAIN, &mut property, 0)?;
assert_eq!(property, "property");
let mut property = 0;
c.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0)?;
assert_eq!(property, 123);
// Use the custom implementation of `Persist` through the OS `IPersistMemory` interface
let p: ICustomPersistMemory = Persist::new().into();
// This works because `ICustomPersistMemory` and `IPersistMemory` share the same guid
let p: IPersistMemory = p.cast()?;
assert_eq!(p.GetClassID()?, "117fb826-2155-483a-b50d-bc99a2c7cca3".into());
// TODO: can't test IsDirty until this is fixed: https://github.com/microsoft/win32metadata/issues/838
assert_eq!(p.GetSizeMax()?, 10);
p.Load(&[0xAAu8, 0xBB, 0xCC])?;
let mut memory = [0x00u8, 0x00, 0x00, 0x00];
p.Save(&mut memory, true)?;
assert_eq!(memory, [0xAAu8, 0xBB, 0xCC, 0x00]);

// Call the custom implementation through the custom interface
let d: ICustomUri = c.cast()?;
let mut domain = BSTR::new();
d.GetDomain(&mut domain).ok()?;
assert_eq!(domain, "kennykerr.ca");
let mut property = BSTR::new();
d.GetPropertyBSTR(Uri_PROPERTY_DOMAIN, &mut property, 0).ok()?;
assert_eq!(property, "property");
let mut property = 0;
d.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0).ok()?;
assert_eq!(property, 123);
// Use the custom implementation of `Persist` through the custom interface of `ICustomPersist`
let p: ICustomPersistMemory = p.cast()?;
let mut size = 0;
p.GetSizeMax(&mut size).ok()?;
assert_eq!(size, 10);
assert_eq!(p.IsDirty(), S_FALSE);
p.Load(&[0xAAu8, 0xBB, 0xCC] as *const _ as *const _, 3).ok()?;
assert_eq!(p.IsDirty(), S_OK);
let mut memory = [0x00u8, 0x00, 0x00, 0x00];
p.Save(&mut memory as *mut _ as *mut _, true.into(), 4).ok()?;
assert_eq!(p.IsDirty(), S_FALSE);
assert_eq!(memory, [0xAAu8, 0xBB, 0xCC, 0x00]);

let p: ICustomPersist = p.cast()?;
let mut b = GUID::default();
p.GetClassID(&mut b).ok()?;
assert_eq!(b, "117fb826-2155-483a-b50d-bc99a2c7cca3".into());

Ok(())
}
Expand Down