Skip to content

Commit

Permalink
Merge pull request #800 from Sympatron/main
Browse files Browse the repository at this point in the history
Fix trait bounds issue #799
  • Loading branch information
Urhengulas authored Mar 6, 2024
2 parents 0593ac7 + 63e275b commit 8d488c4
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## [Unreleased]

- [#800]: `defmt-macros`: Fix generic trait bounds in Format derive macro

[#800]: https://github.com/knurling-rs/defmt/pull/800

## [v0.3.6] - 2024-02-05

- [#804]: `CI`: Remove mdbook strategy
Expand Down
29 changes: 29 additions & 0 deletions defmt/tests/derive-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
fn main() {
let baz: Baz<Qux> = Default::default();
defmt::info!("{}", baz);
}

trait Foo {
type Bar;
}
#[derive(defmt::Format, Default)]
struct Baz<T: Foo> {
field: T::Bar,
field2: Quux<T>,
}
#[derive(defmt::Format, Default)]
struct Qux;
impl Foo for Qux {
type Bar = Qux;
}
#[allow(dead_code)]
#[derive(defmt::Format, Default)]
enum Quux<T: Foo> {
#[default]
None,
Variant1(T),
Variant2 {
f: T::Bar,
},
Variant3(T::Bar),
}
1 change: 1 addition & 0 deletions defmt/tests/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ fn ui() {
t.compile_fail("tests/ui/*.rs");

t.pass("tests/basic_usage.rs");
t.pass("tests/derive-bounds.rs");
}
}
2 changes: 1 addition & 1 deletion macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ defmt-parser = { version = "=0.3.4", path = "../parser", features = ["unstable"]
proc-macro-error = "1"
proc-macro2 = "1"
quote = "1"
syn = { version = "2", features = ["full"] }
syn = { version = "2", features = ["full", "extra-traits"] }

[dev-dependencies]
maplit = "1"
Expand Down
8 changes: 6 additions & 2 deletions macros/src/derives/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ pub(crate) fn expand(input: TokenStream) -> TokenStream {
Data::Union(_) => abort_call_site!("`#[derive(Format)]` does not support unions"),
};

let codegen::EncodeData { format_tag, stmts } = match encode_data {
let codegen::EncodeData {
format_tag,
stmts,
where_predicates,
} = match encode_data {
Ok(data) => data,
Err(e) => return e.into_compile_error().into(),
};
Expand All @@ -24,7 +28,7 @@ pub(crate) fn expand(input: TokenStream) -> TokenStream {
impl_generics,
type_generics,
where_clause,
} = codegen::Generics::codegen(&mut input.generics);
} = codegen::Generics::codegen(&mut input.generics, where_predicates);

quote!(
impl #impl_generics defmt::Format for #ident #type_generics #where_clause {
Expand Down
28 changes: 14 additions & 14 deletions macros/src/derives/format/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_quote, DataStruct, GenericParam, Ident, ImplGenerics, TypeGenerics, WhereClause};
use syn::{DataStruct, Ident, ImplGenerics, TypeGenerics, WhereClause, WherePredicate};

pub(crate) use enum_data::encode as encode_enum_data;

Expand All @@ -12,14 +12,15 @@ mod fields;
pub(crate) struct EncodeData {
pub(crate) format_tag: TokenStream2,
pub(crate) stmts: Vec<TokenStream2>,
pub(crate) where_predicates: Vec<WherePredicate>,
}

pub(crate) fn encode_struct_data(ident: &Ident, data: &DataStruct) -> syn::Result<EncodeData> {
let mut format_string = ident.to_string();
let mut stmts = vec![];
let mut field_patterns = vec![];

let encode_fields_stmts =
let (encode_fields_stmts, where_predicates) =
fields::codegen(&data.fields, &mut format_string, &mut field_patterns)?;

stmts.push(quote!(match self {
Expand All @@ -29,7 +30,11 @@ pub(crate) fn encode_struct_data(ident: &Ident, data: &DataStruct) -> syn::Resul
}));

let format_tag = construct::interned_string(&format_string, "derived", false);
Ok(EncodeData { format_tag, stmts })
Ok(EncodeData {
format_tag,
stmts,
where_predicates,
})
}

pub(crate) struct Generics<'a> {
Expand All @@ -39,20 +44,15 @@ pub(crate) struct Generics<'a> {
}

impl<'a> Generics<'a> {
pub(crate) fn codegen(generics: &'a mut syn::Generics) -> Self {
pub(crate) fn codegen(
generics: &'a mut syn::Generics,
where_predicates: Vec<WherePredicate>,
) -> Self {
let mut where_clause = generics.make_where_clause().clone();
let (impl_generics, type_generics, _) = generics.split_for_impl();

// Extend where-clause with `Format` bounds for type parameters.
for param in &generics.params {
if let GenericParam::Type(ty) = param {
let ident = &ty.ident;

where_clause
.predicates
.push(parse_quote!(#ident: defmt::Format));
}
}
// Extend where-clause with `Format` bounds for all field types.
where_clause.predicates.extend(where_predicates);

Self {
impl_generics,
Expand Down
12 changes: 10 additions & 2 deletions macros/src/derives/format/codegen/enum_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ pub(crate) fn encode(ident: &Ident, data: &DataEnum) -> syn::Result<EncodeData>
return Ok(EncodeData {
stmts: vec![quote!(match *self {})],
format_tag: construct::interned_string("!", "derived", false),
where_predicates: vec![],
});
}

let mut format_string = String::new();
let mut where_predicates = vec![];

let mut match_arms = vec![];
let mut is_first_variant = true;
Expand All @@ -32,8 +34,9 @@ pub(crate) fn encode(ident: &Ident, data: &DataEnum) -> syn::Result<EncodeData>
format_string.push_str(&variant_ident.to_string());

let mut field_patterns = vec![];
let encode_fields_stmts =
let (encode_fields_stmts, encode_field_where_predicates) =
super::fields::codegen(&variant.fields, &mut format_string, &mut field_patterns)?;
where_predicates.extend(encode_field_where_predicates.into_iter());
let pattern = quote!( { #(#field_patterns),* } );

let encode_discriminant_stmt = discriminant_encoder.encode(index);
Expand All @@ -50,8 +53,13 @@ pub(crate) fn encode(ident: &Ident, data: &DataEnum) -> syn::Result<EncodeData>
let stmts = vec![quote!(match self {
#(#match_arms)*
})];
where_predicates.dedup_by(|a, b| a == b);

Ok(EncodeData { format_tag, stmts })
Ok(EncodeData {
format_tag,
stmts,
where_predicates,
})
}

enum DiscriminantEncoder {
Expand Down
37 changes: 25 additions & 12 deletions macros/src/derives/format/codegen/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ use std::fmt::Write as _;

use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{Field, Fields, Index, Type};
use syn::{parse_quote, Field, Fields, Index, Type, WherePredicate};

use crate::consts;

pub(crate) fn codegen(
fields: &Fields,
format_string: &mut String,
patterns: &mut Vec<TokenStream2>,
) -> syn::Result<Vec<TokenStream2>> {
) -> syn::Result<(Vec<TokenStream2>, Vec<WherePredicate>)> {
let (fields, fields_are_named) = match fields {
Fields::Named(named) => (&named.named, true),
Fields::Unit => return Ok(vec![]),
Fields::Unit => return Ok((vec![], vec![])),
Fields::Unnamed(unnamed) => (&unnamed.unnamed, false),
};

if fields.is_empty() {
return Ok(vec![]);
return Ok((vec![], vec![]));
}

if fields_are_named {
Expand All @@ -28,6 +28,7 @@ pub(crate) fn codegen(
}

let mut stmts = vec![];
let mut where_predicates = vec![];
let mut is_first = true;
for (index, field) in fields.iter().enumerate() {
if is_first {
Expand All @@ -37,21 +38,33 @@ pub(crate) fn codegen(
}

let format_opt = get_defmt_format_option(field)?;
let ty = as_native_type(&field.ty).unwrap_or_else(|| consts::TYPE_FORMAT.to_string());
// Find out if the field type is natively supported by defmt. `ty` will be None if not.
let ty = as_native_type(&field.ty);
// `field_ty` will be the field's type if it is not natively supported by defmt
let field_ty = if ty.is_none() { Some(&field.ty) } else { None };
// Get the field format specifier. Either the native specifier or '?'.
let ty = ty.unwrap_or_else(|| consts::TYPE_FORMAT.to_string());
let ident = field
.ident
.clone()
.unwrap_or_else(|| format_ident!("arg{}", index));

if let Some(FormatOption::Debug2Format) = format_opt {
stmts.push(quote!(defmt::export::fmt(&defmt::Debug2Format(&#ident))));
// Find the required trait bounds for the field and add the formatting statement depending on the field type and the formatting options
let bound: Option<syn::Path> = if let Some(FormatOption::Debug2Format) = format_opt {
stmts.push(quote!(::defmt::export::fmt(&defmt::Debug2Format(&#ident))));
field_ty.map(|_| parse_quote!(::core::fmt::Debug))
} else if let Some(FormatOption::Display2Format) = format_opt {
stmts.push(quote!(defmt::export::fmt(&defmt::Display2Format(&#ident))));
stmts.push(quote!(::defmt::export::fmt(&defmt::Display2Format(&#ident))));
field_ty.map(|_| parse_quote!(::core::fmt::Display))
} else if ty == consts::TYPE_FORMAT {
stmts.push(quote!(defmt::export::fmt(#ident)));
stmts.push(quote!(::defmt::export::fmt(#ident)));
field_ty.map(|_| parse_quote!(::defmt::Format))
} else {
let method = format_ident!("{}", ty);
stmts.push(quote!(defmt::export::#method(#ident)));
stmts.push(quote!(::defmt::export::#method(#ident)));
field_ty.map(|_| parse_quote!(::defmt::Format))
};
if let Some(bound) = bound {
where_predicates.push(parse_quote!(#field_ty: #bound));
}

if field.ident.is_some() {
Expand All @@ -74,7 +87,7 @@ pub(crate) fn codegen(
format_string.push(')');
}

Ok(stmts)
Ok((stmts, where_predicates))
}

#[derive(Copy, Clone, Eq, PartialEq, Debug)]
Expand Down

0 comments on commit 8d488c4

Please sign in to comment.