Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom ctx override for derive macro #106

Merged
merged 10 commits into from
Oct 22, 2024
72 changes: 72 additions & 0 deletions scroll_derive/examples/derive_custom_ctx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use scroll_derive::{Pread, Pwrite, SizeWith};

#[derive(Debug, PartialEq)]
struct CustomCtx {
buf: Vec<u8>,
}
impl CustomCtx {
fn len() -> usize {
3 + 2
}
}
impl<'a> TryFromCtx<'a, usize> for CustomCtx {
type Error = scroll::Error;

fn try_from_ctx(from: &'a [u8], ctx: usize) -> Result<(Self, usize), Self::Error> {
let offset = &mut 0;
let buf = from.gread_with::<&[u8]>(offset, ctx)?.to_owned();
Ok((Self { buf }, *offset))
}
}
impl<'a> TryIntoCtx<usize> for &'a CustomCtx {
type Error = scroll::Error;
fn try_into_ctx(self, dst: &mut [u8], ctx: usize) -> Result<usize, Self::Error> {
let offset = &mut 0;
for i in 0..(ctx.min(self.buf.len())) {
dst.gwrite(self.buf[i], offset)?;
}
Ok(*offset)
}
}
impl SizeWith<usize> for CustomCtx {
fn size_with(ctx: &usize) -> usize {
*ctx
}
}

#[derive(Debug, PartialEq, Pread, Pwrite, SizeWith)]
#[repr(C)]
struct Data {
id: u32,
timestamp: f64,
#[scroll(ctx = BE)]
arr: [u16; 2],
#[scroll(ctx = CustomCtx::len())]
Easyoakland marked this conversation as resolved.
Show resolved Hide resolved
custom_ctx: CustomCtx,
}

use scroll::{
ctx::{SizeWith, TryFromCtx, TryIntoCtx},
Pread, Pwrite, BE, LE,
};

fn main() {
let bytes = [
0xefu8, 0xbe, 0xad, 0xde, 0, 0, 0, 0, 0, 0, 224, 63, 0xad, 0xde, 0xef, 0xbe, 0xaa, 0xbb,
0xcc, 0xdd, 0xee,
];
let data: Data = bytes.pread_with(0, LE).unwrap();
println!("data: {data:?}");
assert_eq!(data.id, 0xdeadbeefu32);
assert_eq!(data.arr, [0xadde, 0xefbe]);
let mut bytes2 = vec![0; ::std::mem::size_of::<Data>()];
bytes2.pwrite_with(data, 0, LE).unwrap();
let data: Data = bytes.pread_with(0, LE).unwrap();
let data2: Data = bytes2.pread_with(0, LE).unwrap();
assert_eq!(data, data2);

/*
let data: Data = bytes.cread_with(0, LE);
assert_eq!(data, data2);
*/
}
116 changes: 93 additions & 23 deletions scroll_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

extern crate proc_macro;
use proc_macro2;
use quote::quote;
use quote::{quote, ToTokens};

use proc_macro::TokenStream;

fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream {
fn impl_field(
ident: &proc_macro2::TokenStream,
ty: &syn::Type,
custom_ctx: Option<&proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(&default_ctx);
match *ty {
syn::Type::Array(ref array) => match array.len {
syn::Expr::Lit(syn::ExprLit {
Expand All @@ -15,20 +21,63 @@ fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::
}) => {
let size = int.base10_parse::<usize>().unwrap();
quote! {
#ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, ctx)?; __tmp }
#ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, #ctx)?; __tmp }
}
}
_ => panic!("Pread derive with bad array constexpr"),
},
syn::Type::Group(ref group) => impl_field(ident, &group.elem),
syn::Type::Group(ref group) => impl_field(ident, &group.elem, custom_ctx),
_ => {
quote! {
#ident: src.gread_with::<#ty>(offset, ctx)?
#ident: src.gread_with::<#ty>(offset, #ctx)?
}
}
}
}

/// Retrieve the field attribute with given ident e.g:
/// ```ignore
/// #[attr_ident(..)]
/// field: T,
/// ```
fn get_attr<'a>(attr_ident: &str, field: &'a syn::Field) -> Option<&'a syn::Attribute> {
field
.attrs
.iter()
.filter(|attr| attr.path().is_ident(attr_ident))
.next()
}

/// Gets the `TokenStream` for the custom ctx set in the `ctx` attribute. e.g. `expr` in the following
/// ```ignore
/// #[scroll(ctx = expr)]
/// field: T,
/// ```
fn custom_ctx(field: &syn::Field) -> Option<proc_macro2::TokenStream> {
get_attr("scroll", field).and_then(|x| {
// parsed #[scroll..]
// `expr` is `None` if the `ctx` key is not used.
let mut expr = None;
let res = x.parse_nested_meta(|meta| {
// parsed #[scroll(..)]
if meta.path.is_ident("ctx") {
// parsed #[scroll(ctx..)]
let value = meta.value()?; // parsed #[scroll(ctx = ..)]
expr = Some(value.parse::<syn::Expr>()?); // parsed #[scroll(ctx = expr)]
Easyoakland marked this conversation as resolved.
Show resolved Hide resolved
return Ok(());
}
Err(meta.error(match meta.path.get_ident() {
Some(ident) => format!("unrecognized attribute: {ident}"),
None => "unrecognized and invalid attribute".to_owned(),
}))
});
match res {
Ok(_) => expr.map(|x| x.into_token_stream()),
Err(e) => Some(e.into_compile_error()),
}
})
}

fn impl_struct(
name: &syn::Ident,
fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
Expand All @@ -43,7 +92,9 @@ fn impl_struct(
quote! {#t}
});
let ty = &f.ty;
impl_field(ident, ty)
// parse the `expr` out of #[scroll(ctx = expr)]
let custom_ctx = custom_ctx(f);
impl_field(ident, ty, custom_ctx.as_ref())
})
.collect();

Expand Down Expand Up @@ -104,14 +155,20 @@ fn impl_try_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(Pread)]
#[proc_macro_derive(Pread, attributes(scroll))]
pub fn derive_pread(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_try_from_ctx(&ast);
gen.into()
}

fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream {
fn impl_pwrite_field(
ident: &proc_macro2::TokenStream,
ty: &syn::Type,
custom_ctx: Option<&proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(&default_ctx);
match ty {
syn::Type::Array(ref array) => match array.len {
syn::Expr::Lit(syn::ExprLit {
Expand All @@ -121,24 +178,24 @@ fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_m
let size = int.base10_parse::<usize>().unwrap();
quote! {
for i in 0..#size {
dst.gwrite_with(&self.#ident[i], offset, ctx)?;
dst.gwrite_with(&self.#ident[i], offset, #ctx)?;
}
}
}
_ => panic!("Pwrite derive with bad array constexpr"),
},
syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem),
syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem, custom_ctx),
syn::Type::Reference(reference) => match *reference.elem {
syn::Type::Slice(_) => quote! {
dst.gwrite_with(self.#ident, offset, ())?
},
_ => quote! {
dst.gwrite_with(self.#ident, offset, ctx)?
dst.gwrite_with(self.#ident, offset, #ctx)?
},
},
_ => {
quote! {
dst.gwrite_with(&self.#ident, offset, ctx)?
dst.gwrite_with(&self.#ident, offset, #ctx)?
}
}
}
Expand All @@ -158,7 +215,8 @@ fn impl_try_into_ctx(
quote! {#t}
});
let ty = &f.ty;
impl_pwrite_field(ident, ty)
let custom_ctx = custom_ctx(f);
impl_pwrite_field(ident, ty, custom_ctx.as_ref())
})
.collect();

Expand Down Expand Up @@ -249,7 +307,7 @@ fn impl_pwrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(Pwrite)]
#[proc_macro_derive(Pwrite, attributes(scroll))]
pub fn derive_pwrite(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_pwrite(&ast);
Expand All @@ -265,6 +323,10 @@ fn size_with(
.iter()
.map(|f| {
let ty = &f.ty;
let custom_ctx = custom_ctx(f).map(|x| quote! {&#x});
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let elem = &array.elem;
Expand All @@ -275,15 +337,15 @@ fn size_with(
}) => {
let size = int.base10_parse::<usize>().unwrap();
quote! {
(#size * <#elem>::size_with(ctx))
(#size * <#elem>::size_with(#ctx))
}
}
_ => panic!("Pread derive with bad array constexpr"),
}
}
_ => {
quote! {
<#ty>::size_with(ctx)
<#ty>::size_with(#ctx)
}
}
}
Expand Down Expand Up @@ -341,7 +403,7 @@ fn impl_size_with(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(SizeWith)]
#[proc_macro_derive(SizeWith, attributes(scroll))]
pub fn derive_sizewith(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_size_with(&ast);
Expand All @@ -356,6 +418,10 @@ fn impl_cread_struct(
let items: Vec<_> = fields.iter().enumerate().map(|(i, f)| {
let ident = &f.ident.as_ref().map(|i|quote!{#i}).unwrap_or({let t = proc_macro2::Literal::usize_unsuffixed(i); quote!{#t}});
let ty = &f.ty;
let custom_ctx = custom_ctx(f);
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let arrty = &array.elem;
Expand All @@ -367,7 +433,7 @@ fn impl_cread_struct(
#ident: {
let mut __tmp: #ty = [0u8.into(); #size];
for i in 0..__tmp.len() {
__tmp[i] = src.cread_with(*offset, ctx);
__tmp[i] = src.cread_with(*offset, #ctx);
*offset += #incr;
}
__tmp
Expand All @@ -380,7 +446,7 @@ fn impl_cread_struct(
_ => {
let size = quote! { ::scroll::export::mem::size_of::<#ty>() };
quote! {
#ident: { let res = src.cread_with::<#ty>(*offset, ctx); *offset += #size; res }
#ident: { let res = src.cread_with::<#ty>(*offset, #ctx); *offset += #size; res }
}
}
}
Expand Down Expand Up @@ -440,7 +506,7 @@ fn impl_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(IOread)]
#[proc_macro_derive(IOread, attributes(scroll))]
pub fn derive_ioread(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_from_ctx(&ast);
Expand All @@ -462,20 +528,24 @@ fn impl_into_ctx(
});
let ty = &f.ty;
let size = quote! { ::scroll::export::mem::size_of::<#ty>() };
let custom_ctx = custom_ctx(f);
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let arrty = &array.elem;
quote! {
let size = ::scroll::export::mem::size_of::<#arrty>();
for i in 0..self.#ident.len() {
dst.cwrite_with(self.#ident[i], *offset, ctx);
dst.cwrite_with(self.#ident[i], *offset, #ctx);
*offset += size;
}
}
}
_ => {
quote! {
dst.cwrite_with(self.#ident, *offset, ctx);
dst.cwrite_with(self.#ident, *offset, #ctx);
*offset += #size;
}
}
Expand Down Expand Up @@ -544,7 +614,7 @@ fn impl_iowrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(IOwrite)]
#[proc_macro_derive(IOwrite, attributes(scroll))]
pub fn derive_iowrite(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_iowrite(&ast);
Expand Down
Loading
Loading