From 00d84d888bd01831f3419dfa3f086758b6d825ca Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Mon, 4 Nov 2024 08:10:44 +0100 Subject: [PATCH] add `IntoPyObjectRef` derive macro (#4672) --- guide/src/conversions/traits.md | 3 + pyo3-macros-backend/src/intopyobject.rs | 94 +++++++++++++++------- pyo3-macros/src/lib.rs | 13 ++- src/lib.rs | 4 +- src/prelude.rs | 4 +- src/tests/hygiene/misc.rs | 8 +- tests/test_frompy_intopy_roundtrip.rs | 101 +++++++++++++++++++++--- tests/ui/invalid_intopy_derive.rs | 42 +++++----- 8 files changed, 200 insertions(+), 69 deletions(-) diff --git a/guide/src/conversions/traits.md b/guide/src/conversions/traits.md index 6cc809e0d03..a0e6ec6db0e 100644 --- a/guide/src/conversions/traits.md +++ b/guide/src/conversions/traits.md @@ -559,6 +559,9 @@ enum Enum<'a, 'py, K: Hash + Eq, V> { // enums are supported and convert using t } ``` +Additionally `IntoPyObject` can be derived for a reference to a struct or enum using the +`IntoPyObjectRef` derive macro. All the same rules from above apply as well. + #### manual implementation If the derive macro is not suitable for your use case, `IntoPyObject` can be implemented manually as diff --git a/pyo3-macros-backend/src/intopyobject.rs b/pyo3-macros-backend/src/intopyobject.rs index 3b4b2d376bb..4a46c07418f 100644 --- a/pyo3-macros-backend/src/intopyobject.rs +++ b/pyo3-macros-backend/src/intopyobject.rs @@ -164,10 +164,17 @@ impl FieldAttributes { } } +enum IntoPyObjectTypes { + Transparent(syn::Type), + Opaque { + target: TokenStream, + output: TokenStream, + error: TokenStream, + }, +} + struct IntoPyObjectImpl { - target: TokenStream, - output: TokenStream, - error: TokenStream, + types: IntoPyObjectTypes, body: TokenStream, } @@ -351,12 +358,10 @@ impl<'a> Container<'a> { .unwrap_or_default(); IntoPyObjectImpl { - target: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Target}, - output: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Output}, - error: quote! {<#ty as #pyo3_path::conversion::IntoPyObject<'py>>::Error}, + types: IntoPyObjectTypes::Transparent(ty.clone()), body: quote_spanned! { ty.span() => #unpack - <#ty as #pyo3_path::conversion::IntoPyObject<'py>>::into_pyobject(arg0, py) + #pyo3_path::conversion::IntoPyObject::into_pyobject(arg0, py) }, } } @@ -391,9 +396,11 @@ impl<'a> Container<'a> { .collect::(); IntoPyObjectImpl { - target: quote!(#pyo3_path::types::PyDict), - output: quote!(#pyo3_path::Bound<'py, Self::Target>), - error: quote!(#pyo3_path::PyErr), + types: IntoPyObjectTypes::Opaque { + target: quote!(#pyo3_path::types::PyDict), + output: quote!(#pyo3_path::Bound<'py, Self::Target>), + error: quote!(#pyo3_path::PyErr), + }, body: quote! { #unpack let dict = #pyo3_path::types::PyDict::new(py); @@ -419,10 +426,9 @@ impl<'a> Container<'a> { .iter() .enumerate() .map(|(i, f)| { - let ty = &f.field.ty; let value = Ident::new(&format!("arg{i}"), f.field.ty.span()); quote_spanned! { f.field.ty.span() => - <#ty as #pyo3_path::conversion::IntoPyObject>::into_pyobject(#value, py) + #pyo3_path::conversion::IntoPyObject::into_pyobject(#value, py) .map(#pyo3_path::BoundObject::into_any) .map(#pyo3_path::BoundObject::into_bound)?, } @@ -430,9 +436,11 @@ impl<'a> Container<'a> { .collect::(); IntoPyObjectImpl { - target: quote!(#pyo3_path::types::PyTuple), - output: quote!(#pyo3_path::Bound<'py, Self::Target>), - error: quote!(#pyo3_path::PyErr), + types: IntoPyObjectTypes::Opaque { + target: quote!(#pyo3_path::types::PyTuple), + output: quote!(#pyo3_path::Bound<'py, Self::Target>), + error: quote!(#pyo3_path::PyErr), + }, body: quote! { #unpack #pyo3_path::types::PyTuple::new(py, [#setter]) @@ -502,9 +510,11 @@ impl<'a> Enum<'a> { .collect::(); IntoPyObjectImpl { - target: quote!(#pyo3_path::types::PyAny), - output: quote!(#pyo3_path::Bound<'py, Self::Target>), - error: quote!(#pyo3_path::PyErr), + types: IntoPyObjectTypes::Opaque { + target: quote!(#pyo3_path::types::PyAny), + output: quote!(#pyo3_path::Bound<'py, Self::Target>), + error: quote!(#pyo3_path::PyErr), + }, body: quote! { match self { #variants @@ -520,13 +530,16 @@ fn verify_and_get_lifetime(generics: &syn::Generics) -> Option<&syn::LifetimePar lifetimes.find(|l| l.lifetime.ident == "py") } -pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result { +pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result { let options = ContainerOptions::from_attrs(&tokens.attrs)?; let ctx = &Ctx::new(&options.krate, None); let Ctx { pyo3_path, .. } = &ctx; let (_, ty_generics, _) = tokens.generics.split_for_impl(); let mut trait_generics = tokens.generics.clone(); + if REF { + trait_generics.params.push(parse_quote!('_a)); + } let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics) { lt.clone() } else { @@ -538,17 +551,14 @@ pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result { let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where)); for param in trait_generics.type_params() { let gen_ident = ¶m.ident; - where_clause - .predicates - .push(parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)) + where_clause.predicates.push(if REF { + parse_quote!(&'_a #gen_ident: #pyo3_path::conversion::IntoPyObject<'py>) + } else { + parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>) + }) } - let IntoPyObjectImpl { - target, - output, - error, - body, - } = match &tokens.data { + let IntoPyObjectImpl { types, body } = match &tokens.data { syn::Data::Enum(en) => { if options.transparent.is_some() { bail_spanned!(tokens.span() => "`transparent` is not supported at top level for enums"); @@ -571,7 +581,35 @@ pub fn build_derive_into_pyobject(tokens: &DeriveInput) -> Result { ), }; + let (target, output, error) = match types { + IntoPyObjectTypes::Transparent(ty) => { + if REF { + ( + quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Target }, + quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Output }, + quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Error }, + ) + } else { + ( + quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Target }, + quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Output }, + quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Error }, + ) + } + } + IntoPyObjectTypes::Opaque { + target, + output, + error, + } => (target, output, error), + }; + let ident = &tokens.ident; + let ident = if REF { + quote! { &'_a #ident} + } else { + quote! { #ident } + }; Ok(quote!( #[automatically_derived] impl #impl_generics #pyo3_path::conversion::IntoPyObject<#lt_param> for #ident #ty_generics #where_clause { diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 7c43c55dcd7..2621bea4c6e 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -156,7 +156,18 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_derive(IntoPyObject, attributes(pyo3))] pub fn derive_into_py_object(item: TokenStream) -> TokenStream { let ast = parse_macro_input!(item as syn::DeriveInput); - let expanded = build_derive_into_pyobject(&ast).unwrap_or_compile_error(); + let expanded = build_derive_into_pyobject::(&ast).unwrap_or_compile_error(); + quote!( + #expanded + ) + .into() +} + +#[proc_macro_derive(IntoPyObjectRef, attributes(pyo3))] +pub fn derive_into_py_object_ref(item: TokenStream) -> TokenStream { + let ast = parse_macro_input!(item as syn::DeriveInput); + let expanded = + pyo3_macros_backend::build_derive_into_pyobject::(&ast).unwrap_or_compile_error(); quote!( #expanded ) diff --git a/src/lib.rs b/src/lib.rs index 25c88143609..c71e12b8649 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -457,7 +457,9 @@ mod version; pub use crate::conversions::*; #[cfg(feature = "macros")] -pub use pyo3_macros::{pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject}; +pub use pyo3_macros::{ + pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject, IntoPyObjectRef, +}; /// A proc macro used to expose Rust structs and fieldless enums as Python objects. /// diff --git a/src/prelude.rs b/src/prelude.rs index 54f5a9f6beb..d4f649f552a 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -19,7 +19,9 @@ pub use crate::pyclass_init::PyClassInitializer; pub use crate::types::{PyAny, PyModule}; #[cfg(feature = "macros")] -pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject}; +pub use pyo3_macros::{ + pyclass, pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject, IntoPyObjectRef, +}; #[cfg(feature = "macros")] pub use crate::wrap_pyfunction; diff --git a/src/tests/hygiene/misc.rs b/src/tests/hygiene/misc.rs index 1790c65961d..cecc8991f4a 100644 --- a/src/tests/hygiene/misc.rs +++ b/src/tests/hygiene/misc.rs @@ -57,17 +57,17 @@ macro_rules! macro_rules_hygiene { macro_rules_hygiene!(MyClass1, MyClass2); -#[derive(crate::IntoPyObject)] +#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)] #[pyo3(crate = "crate")] struct IntoPyObject1(i32); // transparent newtype case -#[derive(crate::IntoPyObject)] +#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)] #[pyo3(crate = "crate", transparent)] struct IntoPyObject2<'a> { inner: &'a str, // transparent newtype case } -#[derive(crate::IntoPyObject)] +#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)] #[pyo3(crate = "crate")] struct IntoPyObject3<'py>(i32, crate::Bound<'py, crate::PyAny>); // tuple case @@ -78,7 +78,7 @@ struct IntoPyObject4<'a, 'py> { num: usize, } -#[derive(crate::IntoPyObject)] +#[derive(crate::IntoPyObject, crate::IntoPyObjectRef)] #[pyo3(crate = "crate")] enum IntoPyObject5<'a, 'py> { TransparentTuple(i32), diff --git a/tests/test_frompy_intopy_roundtrip.rs b/tests/test_frompy_intopy_roundtrip.rs index 6b3718693d7..b17320fa43b 100644 --- a/tests/test_frompy_intopy_roundtrip.rs +++ b/tests/test_frompy_intopy_roundtrip.rs @@ -1,7 +1,7 @@ #![cfg(feature = "macros")] use pyo3::types::{PyDict, PyString}; -use pyo3::{prelude::*, IntoPyObject}; +use pyo3::{prelude::*, IntoPyObject, IntoPyObjectRef}; use std::collections::HashMap; use std::hash::Hash; @@ -9,7 +9,7 @@ use std::hash::Hash; #[path = "../src/tests/common.rs"] mod common; -#[derive(Debug, Clone, IntoPyObject, FromPyObject)] +#[derive(Debug, Clone, IntoPyObject, IntoPyObjectRef, FromPyObject)] pub struct A<'py> { #[pyo3(item)] s: String, @@ -27,6 +27,16 @@ fn test_named_fields_struct() { t: PyString::new(py, "World"), p: 42i32.into_pyobject(py).unwrap().into_any(), }; + let pya = (&a).into_pyobject(py).unwrap(); + let new_a = pya.extract::>().unwrap(); + + assert_eq!(a.s, new_a.s); + assert_eq!(a.t.to_cow().unwrap(), new_a.t.to_cow().unwrap()); + assert_eq!( + a.p.extract::().unwrap(), + new_a.p.extract::().unwrap() + ); + let pya = a.clone().into_pyobject(py).unwrap(); let new_a = pya.extract::>().unwrap(); @@ -39,7 +49,7 @@ fn test_named_fields_struct() { }); } -#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[derive(Debug, Clone, PartialEq, IntoPyObject, IntoPyObjectRef, FromPyObject)] #[pyo3(transparent)] pub struct B { test: String, @@ -51,13 +61,17 @@ fn test_transparent_named_field_struct() { let b = B { test: "test".into(), }; + let pyb = (&b).into_pyobject(py).unwrap(); + let new_b = pyb.extract::().unwrap(); + assert_eq!(b, new_b); + let pyb = b.clone().into_pyobject(py).unwrap(); let new_b = pyb.extract::().unwrap(); assert_eq!(b, new_b); }); } -#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[derive(Debug, Clone, PartialEq, IntoPyObject, IntoPyObjectRef, FromPyObject)] #[pyo3(transparent)] pub struct D { test: T, @@ -66,6 +80,18 @@ pub struct D { #[test] fn test_generic_transparent_named_field_struct() { Python::with_gil(|py| { + let d = D { + test: String::from("test"), + }; + let pyd = (&d).into_pyobject(py).unwrap(); + let new_d = pyd.extract::>().unwrap(); + assert_eq!(d, new_d); + + let d = D { test: 1usize }; + let pyd = (&d).into_pyobject(py).unwrap(); + let new_d = pyd.extract::>().unwrap(); + assert_eq!(d, new_d); + let d = D { test: String::from("test"), }; @@ -80,7 +106,7 @@ fn test_generic_transparent_named_field_struct() { }); } -#[derive(Debug, IntoPyObject, FromPyObject)] +#[derive(Debug, IntoPyObject, IntoPyObjectRef, FromPyObject)] pub struct GenericWithBound(HashMap); #[test] @@ -89,10 +115,12 @@ fn test_generic_with_bound() { let mut hash_map = HashMap::::new(); hash_map.insert("1".into(), 1); hash_map.insert("2".into(), 2); - let map = GenericWithBound(hash_map).into_pyobject(py).unwrap(); - assert_eq!(map.len(), 2); + let map = GenericWithBound(hash_map); + let py_map = (&map).into_pyobject(py).unwrap(); + assert_eq!(py_map.len(), 2); assert_eq!( - map.get_item("1") + py_map + .get_item("1") .unwrap() .unwrap() .extract::() @@ -100,44 +128,75 @@ fn test_generic_with_bound() { 1 ); assert_eq!( - map.get_item("2") + py_map + .get_item("2") .unwrap() .unwrap() .extract::() .unwrap(), 2 ); - assert!(map.get_item("3").unwrap().is_none()); + assert!(py_map.get_item("3").unwrap().is_none()); + + let py_map = map.into_pyobject(py).unwrap(); + assert_eq!(py_map.len(), 2); + assert_eq!( + py_map + .get_item("1") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); + assert_eq!( + py_map + .get_item("2") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 2 + ); + assert!(py_map.get_item("3").unwrap().is_none()); }); } -#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[derive(Debug, Clone, PartialEq, IntoPyObject, IntoPyObjectRef, FromPyObject)] pub struct Tuple(String, usize); #[test] fn test_tuple_struct() { Python::with_gil(|py| { let tup = Tuple(String::from("test"), 1); + let tuple = (&tup).into_pyobject(py).unwrap(); + let new_tup = tuple.extract::().unwrap(); + assert_eq!(tup, new_tup); + let tuple = tup.clone().into_pyobject(py).unwrap(); let new_tup = tuple.extract::().unwrap(); assert_eq!(tup, new_tup); }); } -#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[derive(Debug, Clone, PartialEq, IntoPyObject, IntoPyObjectRef, FromPyObject)] pub struct TransparentTuple(String); #[test] fn test_transparent_tuple_struct() { Python::with_gil(|py| { let tup = TransparentTuple(String::from("test")); + let tuple = (&tup).into_pyobject(py).unwrap(); + let new_tup = tuple.extract::().unwrap(); + assert_eq!(tup, new_tup); + let tuple = tup.clone().into_pyobject(py).unwrap(); let new_tup = tuple.extract::().unwrap(); assert_eq!(tup, new_tup); }); } -#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[derive(Debug, Clone, PartialEq, IntoPyObject, IntoPyObjectRef, FromPyObject)] pub enum Foo { TupleVar(usize, String), StructVar { @@ -156,10 +215,20 @@ pub enum Foo { fn test_enum() { Python::with_gil(|py| { let tuple_var = Foo::TupleVar(1, "test".into()); + let foo = (&tuple_var).into_pyobject(py).unwrap(); + assert_eq!(tuple_var, foo.extract::().unwrap()); + let foo = tuple_var.clone().into_pyobject(py).unwrap(); assert_eq!(tuple_var, foo.extract::().unwrap()); let struct_var = Foo::StructVar { test: 'b' }; + let foo = (&struct_var) + .into_pyobject(py) + .unwrap() + .downcast_into::() + .unwrap(); + assert_eq!(struct_var, foo.extract::().unwrap()); + let foo = struct_var .clone() .into_pyobject(py) @@ -170,10 +239,16 @@ fn test_enum() { assert_eq!(struct_var, foo.extract::().unwrap()); let transparent_tuple = Foo::TransparentTuple(1); + let foo = (&transparent_tuple).into_pyobject(py).unwrap(); + assert_eq!(transparent_tuple, foo.extract::().unwrap()); + let foo = transparent_tuple.clone().into_pyobject(py).unwrap(); assert_eq!(transparent_tuple, foo.extract::().unwrap()); let transparent_struct_var = Foo::TransparentStructVar { a: None }; + let foo = (&transparent_struct_var).into_pyobject(py).unwrap(); + assert_eq!(transparent_struct_var, foo.extract::().unwrap()); + let foo = transparent_struct_var.clone().into_pyobject(py).unwrap(); assert_eq!(transparent_struct_var, foo.extract::().unwrap()); }); diff --git a/tests/ui/invalid_intopy_derive.rs b/tests/ui/invalid_intopy_derive.rs index 310309992d4..c65d44ff1bc 100644 --- a/tests/ui/invalid_intopy_derive.rs +++ b/tests/ui/invalid_intopy_derive.rs @@ -1,67 +1,67 @@ -use pyo3::IntoPyObject; +use pyo3::{IntoPyObject, IntoPyObjectRef}; -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] struct Foo(); -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] struct Foo2 {} -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EmptyEnum {} -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EnumWithEmptyTupleVar { EmptyTuple(), Valid(String), } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EnumWithEmptyStructVar { EmptyStruct {}, Valid(String), } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] #[pyo3(transparent)] struct EmptyTransparentTup(); -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] #[pyo3(transparent)] struct EmptyTransparentStruct {} -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EnumWithTransparentEmptyTupleVar { #[pyo3(transparent)] EmptyTuple(), Valid(String), } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EnumWithTransparentEmptyStructVar { #[pyo3(transparent)] EmptyStruct {}, Valid(String), } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] #[pyo3(transparent)] struct TransparentTupTooManyFields(String, String); -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] #[pyo3(transparent)] struct TransparentStructTooManyFields { foo: String, bar: String, } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EnumWithTransparentTupleTooMany { #[pyo3(transparent)] EmptyTuple(String, String), Valid(String), } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum EnumWithTransparentStructTooMany { #[pyo3(transparent)] EmptyStruct { @@ -71,35 +71,35 @@ enum EnumWithTransparentStructTooMany { Valid(String), } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] #[pyo3(unknown = "should not work")] struct UnknownContainerAttr { a: String, } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] union Union { a: usize, } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] enum UnitEnum { Unit, } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] struct TupleAttribute(#[pyo3(attribute)] String, usize); -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] struct TupleItem(#[pyo3(item)] String, usize); -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] struct StructAttribute { #[pyo3(attribute)] foo: String, } -#[derive(IntoPyObject)] +#[derive(IntoPyObject, IntoPyObjectRef)] #[pyo3(transparent)] struct StructTransparentItem { #[pyo3(item)]