From fb4e3d17f4bee06610ceb60d88e94ff590873dbe Mon Sep 17 00:00:00 2001 From: Jiyuan Zheng Date: Wed, 25 Sep 2024 11:57:33 +0800 Subject: [PATCH] Feat: add return type assertion (#43) --- Cargo.lock | 4 +- Cargo.toml | 1 - poc/guests/Cargo.lock | 103 +++++++++++- poc/guests/sum-balance-percent/src/main.rs | 4 +- poc/guests/sum-balance/src/main.rs | 4 +- poc/guests/total-supply/src/main.rs | 2 +- xcq-api/procedural/Cargo.toml | 2 + xcq-api/procedural/src/lib.rs | 7 +- xcq-api/procedural/src/program/expand/mod.rs | 156 +++++++++++++++++- xcq-api/procedural/src/program/mod.rs | 2 +- xcq-api/procedural/src/program/parse/call.rs | 22 ++- .../procedural/src/program/parse/helper.rs | 10 +- xcq-api/procedural/src/program/parse/mod.rs | 117 ++++++------- xcq-extension/Cargo.toml | 4 +- .../procedural/src/decl_extensions.rs | 55 ++++-- .../procedural/src/runtime_metadata.rs | 46 +++--- xcq-extension/src/lib.rs | 53 ++++-- xcq-extension/src/macros.rs | 33 ++-- xcq-extension/src/metadata.rs | 12 ++ xcq-primitives/src/metadata_ir.rs | 4 + 20 files changed, 483 insertions(+), 158 deletions(-) create mode 100644 xcq-extension/src/metadata.rs diff --git a/Cargo.lock b/Cargo.lock index 137a599..4d71f07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4458,10 +4458,12 @@ name = "xcq-api-procedural" version = "0.1.0" dependencies = [ "Inflector", + "parity-scale-codec", "proc-macro-crate", "proc-macro2", "quote", "syn 2.0.66", + "xcq-types", ] [[package]] @@ -4478,7 +4480,7 @@ dependencies = [ name = "xcq-extension" version = "0.1.0" dependencies = [ - "impl-trait-for-tuples", + "fortuples", "parity-scale-codec", "scale-info", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 89552c6..95e1f4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,6 @@ clap = { version = "4.5.4", features = ["derive"] } env_logger = { version = "0.11.3" } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -impl-trait-for-tuples = "0.2" fortuples = "0.9" # proc macros diff --git a/poc/guests/Cargo.lock b/poc/guests/Cargo.lock index 3479655..50daaa8 100644 --- a/poc/guests/Cargo.lock +++ b/poc/guests/Cargo.lock @@ -21,18 +21,58 @@ dependencies = [ "memchr", ] +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "byte-slice-cast" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + [[package]] name = "equivalent" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "fortuples" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87630a8087e9cac4b7edfb6ee5e250ddca9112b57b6b17d8f5107375a3a8eace" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "impl-trait-for-tuples" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "indexmap" version = "2.3.0" @@ -55,6 +95,30 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "parity-scale-codec" +version = "3.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "306800abfa29c7f16596b5970a588435e3d5b3149683d00c12b699cc19f895ee" +dependencies = [ + "arrayvec", + "byte-slice-cast", + "impl-trait-for-tuples", + "parity-scale-codec-derive", +] + +[[package]] +name = "parity-scale-codec-derive" +version = "3.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d830939c76d294956402033aee57a6da7b438f2294eb94864c37b0569053a42c" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "poc-guest-sum-balance" version = "0.1.0" @@ -105,7 +169,7 @@ dependencies = [ "polkavm-common", "proc-macro2", "quote", - "syn", + "syn 2.0.63", ] [[package]] @@ -113,7 +177,7 @@ name = "polkavm-derive-impl-macro" version = "0.10.0" dependencies = [ "polkavm-derive-impl", - "syn", + "syn 2.0.63", ] [[package]] @@ -172,6 +236,17 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.63" @@ -227,8 +302,30 @@ name = "xcq-api-procedural" version = "0.1.0" dependencies = [ "Inflector", + "parity-scale-codec", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.63", + "xcq-types", +] + +[[package]] +name = "xcq-types" +version = "0.1.0" +dependencies = [ + "cfg-if", + "fortuples", + "parity-scale-codec", + "xcq-types-derive", +] + +[[package]] +name = "xcq-types-derive" +version = "0.1.0" +dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 2.0.63", ] diff --git a/poc/guests/sum-balance-percent/src/main.rs b/poc/guests/sum-balance-percent/src/main.rs index 9037ab7..2bcf521 100644 --- a/poc/guests/sum-balance-percent/src/main.rs +++ b/poc/guests/sum-balance-percent/src/main.rs @@ -5,9 +5,9 @@ static GLOBAL: polkavm_derive::LeakingAllocator = polkavm_derive::LeakingAllocat use alloc::vec::Vec; #[xcq_api::program] mod sum_balance { - #[xcq::call_def] + #[xcq::call_def(extension_id = 0x92F353DB95824F9Du64, call_index = 1)] fn balance(asset: u32, who: [u8; 32]) -> u64 {} - #[xcq::call_def] + #[xcq::call_def(extension_id = 0x92F353DB95824F9Du64, call_index = 0)] fn total_supply(asset: u32) -> u64 {} #[xcq::entrypoint] diff --git a/poc/guests/sum-balance/src/main.rs b/poc/guests/sum-balance/src/main.rs index 7cd7810..997379a 100644 --- a/poc/guests/sum-balance/src/main.rs +++ b/poc/guests/sum-balance/src/main.rs @@ -3,11 +3,11 @@ #[global_allocator] static GLOBAL: polkavm_derive::LeakingAllocator = polkavm_derive::LeakingAllocator; use alloc::vec::Vec; +// An example instance of xcq program with specific arg types #[xcq_api::program] mod sum_balance { - #[xcq::call_def] + #[xcq::call_def(extension_id = 0x92F353DB95824F9Du64, call_index = 1)] fn balance(asset: u32, who: [u8; 32]) -> u64 {} - #[xcq::entrypoint] fn sum_balance(calls: Vec) -> u64 { let mut sum = 0; diff --git a/poc/guests/total-supply/src/main.rs b/poc/guests/total-supply/src/main.rs index 908d6a7..f327858 100644 --- a/poc/guests/total-supply/src/main.rs +++ b/poc/guests/total-supply/src/main.rs @@ -4,7 +4,7 @@ static GLOBAL: polkavm_derive::LeakingAllocator = polkavm_derive::LeakingAllocator; #[xcq_api::program] mod query_total_supply { - #[xcq::call_def] + #[xcq::call_def(extension_id = 10588899351449456541u64, call_index = 0)] fn total_supply(asset: u32) -> u64 {} #[xcq::entrypoint] diff --git a/xcq-api/procedural/Cargo.toml b/xcq-api/procedural/Cargo.toml index e65d246..064e941 100644 --- a/xcq-api/procedural/Cargo.toml +++ b/xcq-api/procedural/Cargo.toml @@ -12,3 +12,5 @@ syn = { workspace = true } proc-macro2 = { workspace = true } proc-macro-crate = { workspace = true } Inflector = { workspace = true } +xcq-types = { workspace = true } +parity-scale-codec = { workspace = true } diff --git a/xcq-api/procedural/src/lib.rs b/xcq-api/procedural/src/lib.rs index 81ccafb..df8b2c4 100644 --- a/xcq-api/procedural/src/lib.rs +++ b/xcq-api/procedural/src/lib.rs @@ -1,18 +1,21 @@ /// Declare the calls used in XCQ program +/// ```ignore /// #[xcq::program] /// mod query_fungibles { -/// #[xcq::call(extern_types = [AssetId, AccountId, Balance]])]] +/// #[xcq::call_def(extension_id = 123456, extern_types = [AssetId, AccountId, Balance])] /// fn balance(asset: AssetId, who: AccountId) -> Balance; /// -/// #[xcq(entrypoint)] +/// #[xcq::entrypoint] /// fn sum_balance(calls: Vec) -> u64 { /// let mut sum = 0; /// for call in calls { +/// // calculation requires a known balance type, we can use assert-type here /// sum += call.call(); /// } /// sum /// } /// } +/// ``` /// mod program; use proc_macro::TokenStream; diff --git a/xcq-api/procedural/src/program/expand/mod.rs b/xcq-api/procedural/src/program/expand/mod.rs index 7c2287f..d37c8ab 100644 --- a/xcq-api/procedural/src/program/expand/mod.rs +++ b/xcq-api/procedural/src/program/expand/mod.rs @@ -1,8 +1,9 @@ -use super::{Def, EntrypointDef}; +use super::{CallDef, Def, EntrypointDef}; use inflector::Inflector; +use parity_scale_codec::Encode; use proc_macro2::TokenStream as TokenStream2; -use quote::{format_ident, quote}; -use syn::{ItemFn, Result}; +use quote::{format_ident, quote, ToTokens}; +use syn::{ItemFn, PathArguments, Result}; pub fn expand(def: Def) -> Result { let preludes = generate_preludes(); // eprintln!("def{:?}", def.calls); @@ -11,8 +12,8 @@ pub fn expand(def: Def) -> Result { .iter() .map(|call_def| generate_call(&call_def.item_fn)) .collect::>>()?; - let entrypoint_def = &def.entrypoint.item_fn; - let main_fn = generate_main(&def.entrypoint)?; + let entrypoint_def = generate_entrypoint(&def.entrypoint)?; + let main_fn = generate_main(&def.calls, &def.entrypoint)?; Ok(quote! { #preludes #entrypoint_def @@ -21,11 +22,15 @@ pub fn expand(def: Def) -> Result { }) } -// At guest side, we only need call_ptr and size to perform call, -// the actual function signature is used at host side to construct the call data +// Generate a callable that holds the call data and a method to perform the call +// At compile time: extension_id and call_index are specified +// and they can be used to construct the runtime call data by front-end. +// At run time: we only forward call_data(including call_ptr and size) to host, +// and then we got the return bytes and convert to concrete numeric type fn generate_call(item: &ItemFn) -> Result { let camel_case_ident = syn::Ident::new(&item.sig.ident.to_string().to_pascal_case(), item.sig.ident.span()); let call_name = format_ident!("{}Call", camel_case_ident); + // This return_ty is a concrete unsigned integer type let return_ty = match &item.sig.output { syn::ReturnType::Type(_, return_ty) => return_ty, _ => { @@ -42,7 +47,8 @@ fn generate_call(item: &ItemFn) -> Result { pub call_size: u32, } impl #call_name { - pub fn call(&self) -> #return_ty { + pub fn call(&self) -> #return_ty { + // TODO: use xcq-types to represent the return type let res = unsafe { host_call(self.extension_id, self.call_ptr, self.call_size) }; @@ -58,6 +64,13 @@ fn generate_call(item: &ItemFn) -> Result { }; Ok(expand) } + +// Modify the calculation parts in the entrypoint function +// use type assertion to get the return type at runtime +fn generate_entrypoint(entrypoint: &EntrypointDef) -> Result { + Ok(entrypoint.item_fn.to_token_stream()) +} + fn pass_byte_to_host() -> TokenStream2 { // TODO check res type to determine the appropriate serializing method quote! { @@ -74,7 +87,108 @@ fn pass_byte_to_host() -> TokenStream2 { } } -fn generate_main(entrypoint: &EntrypointDef) -> Result { +fn generate_return_ty_assertion(call_def: &CallDef) -> Result { + let call_ty = &call_def.item_fn.sig.output; + // TODO: bytes representation is to be decided + let expected_ty_bytes = match call_ty { + syn::ReturnType::Type(_, return_ty) => match return_ty.as_ref() { + syn::Type::Path(path) => { + let last_segment = path + .path + .segments + .last() + .ok_or_else(|| syn::Error::new_spanned(path, "expected function return type to be a path"))?; + match last_segment.ident.to_string().as_str() { + "u8" => { + let encoded_ty_bytes = xcq_types::XcqType::Primitive(xcq_types::PrimitiveType::U8).encode(); + quote! { + &[#(#encoded_ty_bytes),*] + } + } + "u16" => { + let encoded_ty_bytes = xcq_types::XcqType::Primitive(xcq_types::PrimitiveType::U16).encode(); + quote! { + &[#(#encoded_ty_bytes),*] + } + } + "u32" => { + let encoded_ty_bytes = xcq_types::XcqType::Primitive(xcq_types::PrimitiveType::U32).encode(); + quote! { + &[#(#encoded_ty_bytes),*] + } + } + "u64" => { + let encoded_ty_bytes = xcq_types::XcqType::Primitive(xcq_types::PrimitiveType::U64).encode(); + quote! { + &[#(#encoded_ty_bytes),*] + } + } + "u128" => { + let encoded_ty_bytes = xcq_types::XcqType::Primitive(xcq_types::PrimitiveType::U128).encode(); + quote! { + &[#(#encoded_ty_bytes),*] + } + } + "Vec" => { + if let PathArguments::AngleBracketed(generic_args) = &last_segment.arguments { + if generic_args.args.len() == 1 { + match generic_args.args.first() { + Some(syn::GenericArgument::Type(syn::Type::Path(path))) + if path.path.is_ident("u8") => + { + let encoded_ty_bytes = xcq_types::XcqType::Sequence(Box::new( + xcq_types::XcqType::Primitive(xcq_types::PrimitiveType::U8), + )) + .encode(); + quote! { + &[#(#encoded_ty_bytes),*] + } + } + _ => quote! { &[0u8] }, + } + } else { + quote! { &[0u8] } + } + } else { + quote! {&[0u8]} + } + } + _ => quote! { &[0u8] }, + } + } + _ => { + return Err(syn::Error::new_spanned( + call_ty, + "expected function return type to be a path", + )) + } + }, + _ => { + return Err(syn::Error::new_spanned( + call_ty, + "expected function return type to be a path", + )) + } + }; + let extension_id = call_def.extension_id; + let call_index = call_def.call_index; + let item_fn_ident_string = &call_def.item_fn.sig.ident.to_string(); + let expanded = quote! { + if !assert_return_ty(#expected_ty_bytes, #extension_id, #call_index) { + panic!("function {} (extension {} call {}) return type mismatch", #item_fn_ident_string, #extension_id, #call_index); + } + }; + Ok(expanded) +} + +fn generate_main(call_defs: &[CallDef], entrypoint: &EntrypointDef) -> Result { + let assertions = call_defs + .iter() + .map(generate_return_ty_assertion) + .collect::>>()?; + let assert_program_types_match = quote! { + #(#assertions)* + }; // Construct call_data let mut get_call_data = TokenStream2::new(); for (arg_type_index, arg_type) in entrypoint.arg_types.iter().enumerate() { @@ -89,6 +203,7 @@ fn generate_main(entrypoint: &EntrypointDef) -> Result { ); get_call_data.extend({ quote! { + // TODO: extension_id can be eliminated since we have call_def indicating it let extension_id = unsafe {core::ptr::read_volatile((arg_ptr) as *const u64)}; // for multi calls, we assume the number of calls are given in the call data let call_num = unsafe {core::ptr::read_volatile((arg_ptr+8) as *const u8)}; @@ -139,6 +254,7 @@ fn generate_main(entrypoint: &EntrypointDef) -> Result { let main = quote! { #[polkavm_derive::polkavm_export] extern "C" fn main(mut arg_ptr: u32, size:u32) -> u64 { + #assert_program_types_match #get_call_data #call_entrypoint #pass_bytes_back @@ -167,6 +283,24 @@ fn generate_preludes() -> TokenStream2 { } }; + let host_return_ty_fn = quote! { + #[polkavm_derive::polkavm_import] + extern "C" { + fn return_ty(extension_id:u64, call_index:u32) -> u64; + } + }; + + let assert_return_ty_fn = quote! { + fn assert_return_ty(expected_ty_bytes: &[u8],extension_id:u64, call_index:u32) -> bool { + let return_ty = unsafe {return_ty(extension_id, call_index)}; + let ty_len = (return_ty >> 32) as u32; + let ty_ptr = (return_ty & 0xffffffff) as *const u8; + let ty_bytes = unsafe { + core::slice::from_raw_parts(ty_ptr, ty_len as usize) + }; + expected_ty_bytes == ty_bytes + } + }; quote! { #extern_crate @@ -174,5 +308,9 @@ fn generate_preludes() -> TokenStream2 { #panic_fn #host_call_fn + + #host_return_ty_fn + + #assert_return_ty_fn } } diff --git a/xcq-api/procedural/src/program/mod.rs b/xcq-api/procedural/src/program/mod.rs index 8c4fab9..f8ac0df 100644 --- a/xcq-api/procedural/src/program/mod.rs +++ b/xcq-api/procedural/src/program/mod.rs @@ -2,7 +2,7 @@ use proc_macro::TokenStream; use syn::{parse_macro_input, ItemMod}; mod expand; mod parse; -pub use parse::{Def, EntrypointDef}; +pub use parse::{CallDef, Def, EntrypointDef}; pub fn program(_attr: TokenStream, item: TokenStream) -> TokenStream { let item = parse_macro_input!(item as ItemMod); diff --git a/xcq-api/procedural/src/program/parse/call.rs b/xcq-api/procedural/src/program/parse/call.rs index 002d59e..0e9fd8a 100644 --- a/xcq-api/procedural/src/program/parse/call.rs +++ b/xcq-api/procedural/src/program/parse/call.rs @@ -7,24 +7,42 @@ use syn::{Item, ItemFn}; pub struct CallDef { pub index: usize, pub item_fn: ItemFn, + pub extension_id: u64, + pub call_index: u32, pub extern_types: Option, } impl CallDef { pub fn try_from( - _span: Span, + span: Span, index: usize, item: &mut Item, + extension_id: Option, + call_index: Option, extern_types: Option, ) -> syn::Result { + let extension_id = extension_id.ok_or_else(|| { + syn::Error::new( + span, + "Missing extension_id for xcq::call_def, expected #[xcq::call_def(extension_id = SOME_U64)]", + ) + })?; let item_fn = if let Item::Fn(item_fn) = item { item_fn } else { - return Err(syn::Error::new(item.span(), "Invalid xcq::call, expected item fn")); + return Err(syn::Error::new(item.span(), "Invalid xcq::call_def, expected item fn")); }; + let call_index = call_index.ok_or_else(|| { + syn::Error::new( + span, + "Missing call_index for xcq::call_def, expected #[xcq::call_def(call_index = SOME_U32)]", + ) + })?; Ok(Self { index, item_fn: item_fn.clone(), + extension_id, + call_index, extern_types, }) } diff --git a/xcq-api/procedural/src/program/parse/helper.rs b/xcq-api/procedural/src/program/parse/helper.rs index 2f22194..0744f34 100644 --- a/xcq-api/procedural/src/program/parse/helper.rs +++ b/xcq-api/procedural/src/program/parse/helper.rs @@ -1,12 +1,8 @@ -use quote::ToTokens; pub trait MutItemAttrs { fn mut_item_attrs(&mut self) -> Option<&mut Vec>; } -/// Take the first pallet attribute (e.g. attribute like `#[xcq..]`) and decode it to `Attr` -pub(crate) fn take_first_item_xcq_attr(item: &mut impl MutItemAttrs) -> syn::Result> -where - Attr: syn::parse::Parse, -{ +/// Take the first item attribute (e.g. attribute like `#[xcq..]`) and decode it to `Attr` +pub(crate) fn take_first_xcq_attr(item: &mut impl MutItemAttrs) -> syn::Result> { let Some(attrs) = item.mut_item_attrs() else { return Ok(None); }; @@ -21,7 +17,7 @@ where }; let xcq_attr = attrs.remove(index); - Ok(Some(syn::parse2(xcq_attr.into_token_stream())?)) + Ok(Some(xcq_attr)) } impl MutItemAttrs for syn::Item { fn mut_item_attrs(&mut self) -> Option<&mut Vec> { diff --git a/xcq-api/procedural/src/program/parse/mod.rs b/xcq-api/procedural/src/program/parse/mod.rs index 2e2129e..bac1f51 100644 --- a/xcq-api/procedural/src/program/parse/mod.rs +++ b/xcq-api/procedural/src/program/parse/mod.rs @@ -1,6 +1,7 @@ use syn::spanned::Spanned; -use syn::{Error, ItemMod, Result}; +use syn::{Error, ItemMod, LitInt, Result}; mod call; +pub use call::CallDef; mod entrypoint; pub use entrypoint::EntrypointDef; mod helper; @@ -17,7 +18,7 @@ impl Def { .content .as_mut() .ok_or_else(|| { - let msg = "Invalid pallet definition, expected mod to be inlined."; + let msg = "No content inside the XCQ program definition"; syn::Error::new(mod_span, msg) })? .1; @@ -26,33 +27,48 @@ impl Def { let mut entrypoint = None; for (index, item) in items.iter_mut().enumerate() { - let xcq_attr: Option = helper::take_first_item_xcq_attr(item)?; + let xcq_attr = helper::take_first_xcq_attr(item)?; - match xcq_attr { - Some(XcqAttr::CallDef(span, extern_types)) => { - calls.push(call::CallDef::try_from(span, index, item, extern_types)?); - } - Some(XcqAttr::Entrypoint(span)) => { - if entrypoint.is_some() { - return Err(Error::new(span, "Only one entrypoint function is allowed")); + if let Some(attr) = xcq_attr { + if let Some(last_segment) = attr.path().segments.last() { + if last_segment.ident == "call_def" { + let mut extern_types = None; + let mut extension_id = None; + let mut call_index = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("extension_id") { + let value = meta.value()?; + extension_id = Some(value.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("call_index") { + let value = meta.value()?; + call_index = Some(value.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("extern_types") { + let value = meta.value()?; + extern_types = Some(value.parse::()?); + } else { + return Err(Error::new(meta.path.span(), "Invalid attribute for `call_def`")); + } + Ok(()) + })?; + let call = + call::CallDef::try_from(attr.span(), index, item, extension_id, call_index, extern_types)?; + calls.push(call); + } else if last_segment.ident == "entrypoint" { + if entrypoint.is_some() { + return Err(Error::new(attr.span(), "Only one entrypoint function is allowed")); + } + entrypoint = Some(entrypoint::EntrypointDef::try_from(attr.span(), index, item)?); + } else { + return Err(Error::new( + item.span(), + "Invalid attribute, expected `#[xcq::call_def]` or `#[xcq::entrypoint]`", + )); } - let e = entrypoint::EntrypointDef::try_from(span, index, item)?; - entrypoint = Some(e); - } - None => { - return Err(Error::new( - item.span(), - "Invalid attribute, expected `#[xcq::call_def]` or `#[xcq::entrypoint]`", - )); } } } - let entrypoint = match entrypoint { - Some(entrypoint) => entrypoint, - None => { - return Err(Error::new(mod_span, "No entrypoint function found")); - } - }; + + let entrypoint = entrypoint.ok_or_else(|| Error::new(mod_span, "No entrypoint function found"))?; let def = Def { calls, entrypoint }; Ok(def) @@ -63,39 +79,10 @@ impl Def { mod keyword { syn::custom_keyword!(xcq); syn::custom_keyword!(call_def); + syn::custom_keyword!(extension_id); syn::custom_keyword!(extern_types); syn::custom_keyword!(entrypoint); } -enum XcqAttr { - CallDef(proc_macro2::Span, Option), - Entrypoint(proc_macro2::Span), -} - -// Custom parsing for xcq attribute -impl syn::parse::Parse for XcqAttr { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - input.parse::()?; - let content; - syn::bracketed!(content in input); - content.parse::()?; - content.parse::()?; - - let lookahead = content.lookahead1(); - if lookahead.peek(keyword::call_def) { - let span = content.parse::().expect("peeked").span(); - let extern_types = match content.is_empty() { - true => None, - false => Some(ExternTypesAttr::parse(&content)?), - }; - Ok(XcqAttr::CallDef(span, extern_types)) - } else if lookahead.peek(keyword::entrypoint) { - Ok(XcqAttr::Entrypoint(content.parse::()?.span())) - } else { - Err(lookahead.error()) - } - } -} - #[derive(Debug, Clone)] pub struct ExternTypesAttr { pub types: Vec, @@ -105,21 +92,11 @@ pub struct ExternTypesAttr { impl syn::parse::Parse for ExternTypesAttr { fn parse(input: syn::parse::ParseStream) -> syn::Result { let content; - syn::parenthesized!(content in input); - - let lookahead = content.lookahead1(); - if lookahead.peek(keyword::extern_types) { - let span = content.parse::().expect("peeked").span(); - content.parse::().expect("peeked"); - let list; - syn::bracketed!(list in content); - let types = list.parse_terminated(syn::Type::parse, syn::Token![,])?; - Ok(ExternTypesAttr { - types: types.into_iter().collect(), - span, - }) - } else { - Err(lookahead.error()) - } + syn::bracketed!(content in input); + let extern_types = content.parse_terminated(syn::Type::parse, syn::Token![,])?; + Ok(ExternTypesAttr { + types: extern_types.into_iter().collect(), + span: content.span(), + }) } } diff --git a/xcq-extension/Cargo.toml b/xcq-extension/Cargo.toml index b9e1714..f6633f0 100644 --- a/xcq-extension/Cargo.toml +++ b/xcq-extension/Cargo.toml @@ -11,11 +11,9 @@ version.workspace = true parity-scale-codec = { workspace = true } scale-info = { workspace = true } xcq-executor = { workspace = true } -impl-trait-for-tuples = { workspace = true } +fortuples = { workspace = true } tracing = { workspace = true } xcq-extension-procedural = { path = "procedural" } - -[dev-dependencies] xcq-primitives = { workspace = true } [features] diff --git a/xcq-extension/procedural/src/decl_extensions.rs b/xcq-extension/procedural/src/decl_extensions.rs index d19bf76..9cd9988 100644 --- a/xcq-extension/procedural/src/decl_extensions.rs +++ b/xcq-extension/procedural/src/decl_extensions.rs @@ -17,20 +17,27 @@ pub fn decl_extensions_impl(input: TokenStream) -> TokenStream { } pub fn decl_extension_inner(item_trait: &ItemTrait) -> Result { + let mut item_trait = item_trait.clone(); + // Generate a separate module for the extension let mod_name = generate_mod_name_for_trait(&item_trait.ident); - // Assume single config associated type. + // Add super trait ExtensionId and ExtensionMetadata to the trait's where clause + add_super_trait(&mut item_trait)?; + + // TODO: If trait has associated type, we assume it has a single associated type called `Config` let has_config = item_trait .items .iter() .any(|item| matches!(item, syn::TraitItem::Type(_))); + + // Extract methods from the trait let methods = methods(&item_trait.items)?; let call_enum_def = call_enum_def(&item_trait.ident, &methods)?; - let dispatchable_impl = dispatchable_impl(&item_trait.ident, &methods)?; - let extension_id_impl = extension_id_impl(&item_trait.ident, &item_trait.items)?; - - let runtime_metadata = crate::runtime_metadata::generate_decl_metadata(item_trait, has_config)?; + let call_data_dispatchable_impl = impl_dispatchable(&item_trait.ident, &methods)?; + let call_data_extension_id_impl = impl_extension_id(&item_trait.ident, &item_trait.items)?; + let call_data_metadata_impl = impl_metadata(&item_trait.ident)?; + let extension_runtime_metadata = crate::runtime_metadata::generate_decl_metadata(&item_trait, has_config)?; let expanded = quote! { #item_trait @@ -39,14 +46,25 @@ pub fn decl_extension_inner(item_trait: &ItemTrait) -> Result { pub mod #mod_name { pub use super::*; #call_enum_def - #dispatchable_impl - #extension_id_impl - #runtime_metadata + #call_data_dispatchable_impl + #call_data_extension_id_impl + #call_data_metadata_impl + #extension_runtime_metadata } pub use #mod_name::*; }; Ok(expanded) } + +fn add_super_trait(item_trait: &mut ItemTrait) -> Result<()> { + let xcq_extension = generate_crate_access("xcq-extension")?; + // item_trait.supertraits.push(parse_quote!(#xcq_extension::ExtensionId)); + item_trait + .supertraits + .push(parse_quote!(#xcq_extension::ExtensionMetadata)); + Ok(()) +} + #[derive(Clone)] struct Method { /// Function name @@ -87,7 +105,7 @@ fn call_enum_def(trait_ident: &Ident, methods: &[Method]) -> Result { )) } -fn dispatchable_impl(trait_ident: &Ident, methods: &[Method]) -> Result { +fn impl_dispatchable(trait_ident: &Ident, methods: &[Method]) -> Result { let xcq_extension = generate_crate_access("xcq-extension")?; let xcq_primitives = generate_crate_access("xcq-primitives")?; let mut pats = Vec::::new(); @@ -134,18 +152,33 @@ fn dispatchable_impl(trait_ident: &Ident, methods: &[Method]) -> Result Result { +fn impl_extension_id(trait_ident: &Ident, trait_items: &[TraitItem]) -> Result { let xcq_extension = generate_crate_access("xcq-extension")?; + let extension_id = calculate_hash(trait_ident, trait_items); Ok(quote! { - // TODO: check if we need a extension_id trait + // Defining an trait for extension_id is useful for generic usage impl #xcq_extension::ExtensionId for Call { const EXTENSION_ID: #xcq_extension::ExtensionIdTy = #extension_id; } + // This one is for easier access, since impl doesn't contribute to the extension_id calculation pub const EXTENSION_ID: #xcq_extension::ExtensionIdTy = #extension_id; }) } +// Delegate the metadata generation to the trait implementation +fn impl_metadata(trait_ident: &Ident) -> Result { + let xcq_extension = generate_crate_access("xcq-extension")?; + let xcq_primitives = generate_crate_access("xcq-primitives")?; + Ok(parse_quote! { + impl #xcq_extension::CallMetadata for Call { + fn metadata() -> #xcq_primitives::metadata_ir::ExtensionMetadataIR { + Impl::extension_metadata(::EXTENSION_ID) + } + } + }) +} + // helper functions fn methods(trait_items: &[TraitItem]) -> Result> { let mut methods = vec![]; diff --git a/xcq-extension/procedural/src/runtime_metadata.rs b/xcq-extension/procedural/src/runtime_metadata.rs index c3ee461..1c8a3c7 100644 --- a/xcq-extension/procedural/src/runtime_metadata.rs +++ b/xcq-extension/procedural/src/runtime_metadata.rs @@ -101,9 +101,9 @@ pub fn generate_decl_metadata(decl: &ItemTrait, has_config: bool) -> Result #xcq_primitives::metadata_ir::ExtensionMetadataIR + pub fn runtime_metadata #impl_generics () -> #xcq_primitives::metadata_ir::ExtensionMetadataIR #where_clause { #xcq_primitives::metadata_ir::ExtensionMetadataIR { @@ -137,6 +137,8 @@ pub fn generate_impl_metadata(impls: &[ItemImpl]) -> Result { let xcq_primitives = generate_crate_access("xcq-primitives")?; + let xcq_extension = generate_crate_access("xcq-extension")?; + // Get the name of the runtime for which the traits are implemented. let extension_impl_name = &impls .first() @@ -144,18 +146,20 @@ pub fn generate_impl_metadata(impls: &[ItemImpl]) -> Result { .self_ty; let mut metadata = Vec::new(); + let mut extension_ids = Vec::new(); for impl_ in impls { let mut trait_ = extract_impl_trait(impl_, RequireQualifiedTraitPath::Yes)?.clone(); - // Implementation traits are always references with a path `impl client::Core ...` + // Implementation traits are always references with a path `impl xcq_extension::ExtensionCore ...` // The trait name is the last segment of this path. - let trait_name_ident = &trait_ + let trait_name_ident = trait_ .segments .last() .as_ref() .expect("Trait path should always contain at least one item; qed") - .ident; + .ident + .clone(); // Convert associated types to generics let mut generic_params = HashSet::::new(); @@ -172,21 +176,12 @@ pub fn generate_impl_metadata(impls: &[ItemImpl]) -> Result { where_clause: None, }; - // Extract the generics from the trait to pass to the `runtime_metadata` given by `generate_decl_metadata` - // let generics = trait_ - // .segments - // .iter() - // .find_map(|segment| { - // if let syn::PathArguments::AngleBracketed(generics) = &segment.arguments { - // Some(generics.clone()) - // } else { - // None - // } - // }) - // .expect("Trait path should always contain at least one generic parameter; qed"); - - let mod_name = generate_mod_name_for_trait(trait_name_ident); + let mod_name = generate_mod_name_for_trait(&trait_name_ident); // Get absolute path to the `runtime_decl_for_` module by replacing the last segment. + let mut extension_trait_full_path = trait_.clone(); + if let Some(segment) = extension_trait_full_path.segments.last_mut() { + *segment = parse_quote!(EXTENSION_ID); + } if let Some(segment) = trait_.segments.last_mut() { *segment = parse_quote!(#mod_name); } @@ -196,9 +191,22 @@ pub fn generate_impl_metadata(impls: &[ItemImpl]) -> Result { #( #attrs )* #trait_::runtime_metadata::#generics() )); + extension_ids.push(quote!(#extension_trait_full_path)); } + let extension_metadata_impl = quote! { + impl #xcq_extension::ExtensionMetadata for #extension_impl_name { + fn extension_metadata(extension_id: #xcq_extension::ExtensionIdTy) -> #xcq_primitives::metadata_ir::ExtensionMetadataIR { + match extension_id { + #(#extension_ids => #metadata,)* + _ => panic!("Unknown extension id"), + } + } + } + }; + Ok(quote!( + #extension_metadata_impl impl #extension_impl_name { pub fn runtime_metadata() -> #xcq_primitives::metadata_ir::MetadataIR { #xcq_primitives::metadata_ir::MetadataIR { diff --git a/xcq-extension/src/lib.rs b/xcq-extension/src/lib.rs index 418a033..0efb67b 100644 --- a/xcq-extension/src/lib.rs +++ b/xcq-extension/src/lib.rs @@ -13,6 +13,8 @@ pub type XcqResult = Result; mod dispatchable; pub use dispatchable::{DispatchError, Dispatchable}; +mod metadata; +pub use metadata::{CallMetadata, ExtensionMetadata}; mod extension_id; pub use extension_id::{ExtensionId, ExtensionIdTy}; mod error; @@ -26,20 +28,22 @@ pub use perm_controller::{InvokeSource, PermController}; mod guest; pub use guest::{Guest, Input, Method}; -// alias trait -pub trait Extension: Dispatchable + ExtensionId + Decode {} -impl Extension for T where T: Dispatchable + ExtensionId + Decode {} +// Call data +pub trait CallData: Dispatchable + CallMetadata + ExtensionId + Decode {} +impl CallData for T where T: Dispatchable + CallMetadata + ExtensionId + Decode {} -pub trait ExtensionTuple { +pub trait CallDataTuple { fn dispatch(extension_id: ExtensionIdTy, data: &[u8]) -> Result, ExtensionError>; + // TODO: check if use metadata api + fn return_ty(extension_id: ExtensionIdTy, call_index: u32) -> Result, ExtensionError>; } -struct Context { +struct Context { invoke_source: InvokeSource, - _marker: PhantomData<(E, P)>, + _marker: PhantomData<(C, P)>, } -impl Context { +impl Context { pub fn new(invoke_source: InvokeSource) -> Self { Self { invoke_source, @@ -48,7 +52,7 @@ impl Context { } } -impl XcqExecutorContext for Context { +impl XcqExecutorContext for Context { fn register_host_functions(&mut self, linker: &mut Linker) { let invoke_source = self.invoke_source; linker @@ -69,7 +73,30 @@ impl XcqExecutorContext for Context if !P::is_allowed(extension_id, &call_bytes, invoke_source) { return Err(ExtensionError::PermissionError); } - let res_bytes = E::dispatch(extension_id, &call_bytes)?; + let res_bytes = C::dispatch(extension_id, &call_bytes)?; + tracing::debug!("(host call): res_bytes: {:?}", res_bytes); + let res_bytes_len = res_bytes.len(); + let res_ptr = caller.sbrk(0).ok_or(ExtensionError::PolkavmError)?; + if caller.sbrk(res_bytes_len as u32).is_none() { + return Err(ExtensionError::PolkavmError); + } + caller + .write_memory(res_ptr, &res_bytes) + .map_err(|_| ExtensionError::PolkavmError)?; + Ok(((res_bytes_len as u64) << 32) | (res_ptr as u64)) + }; + let result = func_with_result(); + tracing::trace!("(host call): result: {:?}", result); + result.unwrap_or(0) + }, + ) + .unwrap(); + linker + .func_wrap( + "return_ty", + move |mut caller: Caller<_>, extension_id: u64, call_index: u32| -> u64 { + let mut func_with_result = || -> Result { + let res_bytes = C::return_ty(extension_id, call_index)?; tracing::debug!("(host call): res_bytes: {:?}", res_bytes); let res_bytes_len = res_bytes.len(); let res_ptr = caller.sbrk(0).ok_or(ExtensionError::PolkavmError)?; @@ -90,13 +117,13 @@ impl XcqExecutorContext for Context } } -pub struct ExtensionsExecutor { - executor: XcqExecutor>, +pub struct ExtensionsExecutor { + executor: XcqExecutor>, } -impl ExtensionsExecutor { +impl ExtensionsExecutor { #[allow(dead_code)] pub fn new(source: InvokeSource) -> Self { - let context = Context::::new(source); + let context = Context::::new(source); let executor = XcqExecutor::new(Default::default(), context); Self { executor } } diff --git a/xcq-extension/src/macros.rs b/xcq-extension/src/macros.rs index 7753e5d..baa8df9 100644 --- a/xcq-extension/src/macros.rs +++ b/xcq-extension/src/macros.rs @@ -1,21 +1,32 @@ -use crate::Extension; +use crate::CallData; +use crate::CallDataTuple; use crate::ExtensionError; use crate::ExtensionIdTy; -use crate::ExtensionTuple; use crate::Vec; +use fortuples::fortuples; +use parity_scale_codec::Encode; // Use the macro to implement ExtensionTuple for tuples of different lengths -#[impl_trait_for_tuples::impl_for_tuples(10)] -#[tuple_types_custom_trait_bound(Extension)] -impl ExtensionTuple for Tuple { - fn dispatch(extension_id: ExtensionIdTy, mut call: &[u8]) -> Result, ExtensionError> { - for_tuples!( +fortuples! { + impl CallDataTuple for #Tuple where #(#Member: CallData),*{ + #[allow(unused_variables)] + #[allow(unused_mut)] + fn dispatch(extension_id: ExtensionIdTy, mut call: &[u8]) -> Result, ExtensionError> { #( - if extension_id == Tuple::EXTENSION_ID { - return Tuple::decode(&mut call).map_err(ExtensionError::DecodeError)?.dispatch().map_err(ExtensionError::DispatchError); + if extension_id == #Member::EXTENSION_ID { + return #Member::decode(&mut call).map_err(ExtensionError::DecodeError)?.dispatch().map_err(ExtensionError::DispatchError); } )* - ); - Err(ExtensionError::UnsupportedExtension) + Err(ExtensionError::UnsupportedExtension) + } + #[allow(unused_variables)] + fn return_ty(extension_id: ExtensionIdTy, call_index: u32) -> Result, ExtensionError> { + #( + if extension_id == #Member::EXTENSION_ID { + return Ok(#Member::metadata().methods[call_index as usize].output.type_info().encode()); + } + )* + Err(ExtensionError::UnsupportedExtension) + } } } diff --git a/xcq-extension/src/metadata.rs b/xcq-extension/src/metadata.rs new file mode 100644 index 0000000..278a194 --- /dev/null +++ b/xcq-extension/src/metadata.rs @@ -0,0 +1,12 @@ +use xcq_primitives::metadata_ir::ExtensionMetadataIR; + +use crate::extension_id; +// This trait is for CallData +pub trait CallMetadata { + fn metadata() -> ExtensionMetadataIR; +} + +// This trait is for runtime +pub trait ExtensionMetadata { + fn extension_metadata(extension_id: extension_id::ExtensionIdTy) -> ExtensionMetadataIR; +} diff --git a/xcq-primitives/src/metadata_ir.rs b/xcq-primitives/src/metadata_ir.rs index 5ba783d..659d259 100644 --- a/xcq-primitives/src/metadata_ir.rs +++ b/xcq-primitives/src/metadata_ir.rs @@ -30,3 +30,7 @@ pub struct ExtensionMetadataIR { pub struct MetadataIR { pub extensions: Vec, } + +pub trait RuntimeMetadata { + fn runtime_metadata() -> MetadataIR; +}