Skip to content

Commit

Permalink
add IntoPyObjectRef derive macro (#4672)
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu authored Nov 4, 2024
1 parent 55c9543 commit 00d84d8
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 69 deletions.
3 changes: 3 additions & 0 deletions guide/src/conversions/traits.md
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,9 @@ enum Enum<'a, 'py, K: Hash + Eq, V> { // enums are supported and convert using t
}
```

Additionally `IntoPyObject` can be derived for a reference to a struct or enum using the
`IntoPyObjectRef` derive macro. All the same rules from above apply as well.

#### manual implementation

If the derive macro is not suitable for your use case, `IntoPyObject` can be implemented manually as
Expand Down
94 changes: 66 additions & 28 deletions pyo3-macros-backend/src/intopyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,17 @@ impl FieldAttributes {
}
}

enum IntoPyObjectTypes {
Transparent(syn::Type),
Opaque {
target: TokenStream,
output: TokenStream,
error: TokenStream,
},
}

struct IntoPyObjectImpl {
target: TokenStream,
output: TokenStream,
error: TokenStream,
types: IntoPyObjectTypes,
body: TokenStream,
}

Expand Down Expand Up @@ -351,12 +358,10 @@ impl<'a> Container<'a> {
.unwrap_or_default();

IntoPyObjectImpl {
target: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Target},
output: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Output},
error: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Error},
types: IntoPyObjectTypes::Transparent(ty.clone()),
body: quote_spanned! { ty.span() =>
#unpack
<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::into_pyobject(arg0, py)
#pyo3_path::conversion::IntoPyObject::into_pyobject(arg0, py)
},
}
}
Expand Down Expand Up @@ -391,9 +396,11 @@ impl<'a> Container<'a> {
.collect::<TokenStream>();

IntoPyObjectImpl {
target: quote!(#pyo3_path::types::PyDict),
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
error: quote!(#pyo3_path::PyErr),
types: IntoPyObjectTypes::Opaque {
target: quote!(#pyo3_path::types::PyDict),
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
error: quote!(#pyo3_path::PyErr),
},
body: quote! {
#unpack
let dict = #pyo3_path::types::PyDict::new(py);
Expand All @@ -419,20 +426,21 @@ impl<'a> Container<'a> {
.iter()
.enumerate()
.map(|(i, f)| {
let ty = &f.field.ty;
let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
quote_spanned! { f.field.ty.span() =>
<#ty as #pyo3_path::conversion::IntoPyObject>::into_pyobject(#value, py)
#pyo3_path::conversion::IntoPyObject::into_pyobject(#value, py)
.map(#pyo3_path::BoundObject::into_any)
.map(#pyo3_path::BoundObject::into_bound)?,
}
})
.collect::<TokenStream>();

IntoPyObjectImpl {
target: quote!(#pyo3_path::types::PyTuple),
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
error: quote!(#pyo3_path::PyErr),
types: IntoPyObjectTypes::Opaque {
target: quote!(#pyo3_path::types::PyTuple),
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
error: quote!(#pyo3_path::PyErr),
},
body: quote! {
#unpack
#pyo3_path::types::PyTuple::new(py, [#setter])
Expand Down Expand Up @@ -502,9 +510,11 @@ impl<'a> Enum<'a> {
.collect::<TokenStream>();

IntoPyObjectImpl {
target: quote!(#pyo3_path::types::PyAny),
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
error: quote!(#pyo3_path::PyErr),
types: IntoPyObjectTypes::Opaque {
target: quote!(#pyo3_path::types::PyAny),
output: quote!(#pyo3_path::Bound<'py, Self::Target>),
error: quote!(#pyo3_path::PyErr),
},
body: quote! {
match self {
#variants
Expand All @@ -520,13 +530,16 @@ fn verify_and_get_lifetime(generics: &syn::Generics) -> Option<&syn::LifetimePar
lifetimes.find(|l| l.lifetime.ident == "py")
}

pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
pub fn build_derive_into_pyobject<const REF: bool>(tokens: &DeriveInput) -> Result<TokenStream> {
let options = ContainerOptions::from_attrs(&tokens.attrs)?;
let ctx = &Ctx::new(&options.krate, None);
let Ctx { pyo3_path, .. } = &ctx;

let (_, ty_generics, _) = tokens.generics.split_for_impl();
let mut trait_generics = tokens.generics.clone();
if REF {
trait_generics.params.push(parse_quote!('_a));
}
let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics) {
lt.clone()
} else {
Expand All @@ -538,17 +551,14 @@ pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
for param in trait_generics.type_params() {
let gen_ident = &param.ident;
where_clause
.predicates
.push(parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>))
where_clause.predicates.push(if REF {
parse_quote!(&'_a #gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
} else {
parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
})
}

let IntoPyObjectImpl {
target,
output,
error,
body,
} = match &tokens.data {
let IntoPyObjectImpl { types, body } = match &tokens.data {
syn::Data::Enum(en) => {
if options.transparent.is_some() {
bail_spanned!(tokens.span() => "`transparent` is not supported at top level for enums");
Expand All @@ -571,7 +581,35 @@ pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
),
};

let (target, output, error) = match types {
IntoPyObjectTypes::Transparent(ty) => {
if REF {
(
quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Target },
quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Output },
quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Error },
)
} else {
(
quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Target },
quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Output },
quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Error },
)
}
}
IntoPyObjectTypes::Opaque {
target,
output,
error,
} => (target, output, error),
};

let ident = &tokens.ident;
let ident = if REF {
quote! { &'_a #ident}
} else {
quote! { #ident }
};
Ok(quote!(
#[automatically_derived]
impl #impl_generics #pyo3_path::conversion::IntoPyObject<#lt_param> for #ident #ty_generics #where_clause {
Expand Down
13 changes: 12 additions & 1 deletion pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,18 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_derive(IntoPyObject, attributes(pyo3))]
pub fn derive_into_py_object(item: TokenStream) -> TokenStream {
let ast = parse_macro_input!(item as syn::DeriveInput);
let expanded = build_derive_into_pyobject(&ast).unwrap_or_compile_error();
let expanded = build_derive_into_pyobject::<false>(&ast).unwrap_or_compile_error();
quote!(
#expanded
)
.into()
}

#[proc_macro_derive(IntoPyObjectRef, attributes(pyo3))]
pub fn derive_into_py_object_ref(item: TokenStream) -> TokenStream {
let ast = parse_macro_input!(item as syn::DeriveInput);
let expanded =
pyo3_macros_backend::build_derive_into_pyobject::<true>(&ast).unwrap_or_compile_error();
quote!(
#expanded
)
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ mod version;
pub use crate::conversions::*;

#[cfg(feature = "macros")]
pub use pyo3_macros::{pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject};
pub use pyo3_macros::{
pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject, IntoPyObjectRef,
};

/// A proc macro used to expose Rust structs and fieldless enums as Python objects.
///
Expand Down
4 changes: 3 additions & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ pub use crate::pyclass_init::PyClassInitializer;
pub use crate::types::{PyAny, PyModule};

#[cfg(feature = "macros")]
pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject};
pub use pyo3_macros::{
pyclass, pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject, IntoPyObjectRef,
};

#[cfg(feature = "macros")]
pub use crate::wrap_pyfunction;
Expand Down
8 changes: 4 additions & 4 deletions src/tests/hygiene/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ macro_rules! macro_rules_hygiene {

macro_rules_hygiene!(MyClass1, MyClass2);

#[derive(crate::IntoPyObject)]
#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)]
#[pyo3(crate = "crate")]
struct IntoPyObject1(i32); // transparent newtype case

#[derive(crate::IntoPyObject)]
#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)]
#[pyo3(crate = "crate", transparent)]
struct IntoPyObject2<'a> {
inner: &'a str, // transparent newtype case
}

#[derive(crate::IntoPyObject)]
#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)]
#[pyo3(crate = "crate")]
struct IntoPyObject3<'py>(i32, crate::Bound<'py, crate::PyAny>); // tuple case

Expand All @@ -78,7 +78,7 @@ struct IntoPyObject4<'a, 'py> {
num: usize,
}

#[derive(crate::IntoPyObject)]
#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)]
#[pyo3(crate = "crate")]
enum IntoPyObject5<'a, 'py> {
TransparentTuple(i32),
Expand Down
Loading

0 comments on commit 00d84d8

Please sign in to comment.