Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stylus constructors #184

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions stylus-proc/src/macros/public/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,38 +103,46 @@ impl<E: FnExtension> From<&mut syn::ImplItemFn> for PublicFn<E> {
consume_attr::<attrs::Selector>(&mut node.attrs, "selector").map(|s| s.value.value());
let fallback = consume_flag(&mut node.attrs, "fallback");
let receive = consume_flag(&mut node.attrs, "receive");
let constructor = consume_flag(&mut node.attrs, "constructor");

let kind = match (fallback, receive) {
(true, false) => {
// Fallback functions may have two signatures, either
// with input calldata and output bytes, or no input and output.
// node.sig.
let has_inputs = node.sig.inputs.len() > 1;
if has_inputs {
FnKind::FallbackWithArgs
} else {
FnKind::FallbackNoArgs
}
}
(false, true) => FnKind::Receive,
(false, false) => FnKind::Function,
(true, true) => {
emit_error!(node.span(), "function cannot be both fallback and receive");
FnKind::Function
let kind = if fallback {
gligneul marked this conversation as resolved.
Show resolved Hide resolved
// Fallback functions may have two signatures, either
// with input calldata and output bytes, or no input and output.
FnKind::Fallback {
with_args: node.sig.inputs.len() > 1,
}
} else if receive {
FnKind::Receive
} else if constructor {
FnKind::Constructor
} else {
FnKind::Function
};

// name for generated rust, and solidity abi
let name = node.sig.ident.clone();

if matches!(kind, FnKind::Function) && (name == "receive" || name == "fallback") {
let num_specials = (fallback as i8) + (constructor as i8) + (receive as i8);
if num_specials > 1 {
emit_error!(
node.span(),
"receive and/or fallback functions can only be defined using the #[receive] or "
.to_string()
+ "#[fallback] attribute instead of names",
"function can be only one of fallback, receive or constructor"
);
}
if num_specials > 0 && selector_override.is_some() {
emit_error!(
node.span(),
"fallback, receive, and constructor can't have custom selector"
);
}

// name for generated rust, and solidity abi
let name = node.sig.ident.clone();
for special_name in ["receive", "fallback", "constructor"] {
if matches!(kind, FnKind::Function) && name == special_name {
emit_error!(
node.span(),
format!("{special_name} function can only be defined using the #[{special_name}] attribute")
);
}
}

let sol_name = syn_solidity::SolIdent::new(
&selector_override.unwrap_or(name.to_string().to_case(Case::Camel)),
Expand All @@ -154,7 +162,7 @@ impl<E: FnExtension> From<&mut syn::ImplItemFn> for PublicFn<E> {
args.next();
}
let inputs = match kind {
FnKind::Function => args.map(PublicFnArg::from).collect(),
FnKind::Function | FnKind::Constructor => args.map(PublicFnArg::from).collect(),
_ => Vec::new(),
};
let input_span = node.sig.inputs.span();
Expand Down Expand Up @@ -205,7 +213,7 @@ impl<E: FnArgExtension> From<&syn::FnArg> for PublicFnArg<E> {
mod tests {
use syn::parse_quote;

use super::types::PublicImpl;
use super::types::{FnKind, PublicImpl};

#[test]
fn test_public_consumes_inherit() {
Expand Down Expand Up @@ -235,4 +243,21 @@ mod tests {
};
assert_eq!(attrs, &vec![parse_quote! { #[other] }]);
}

#[test]
fn test_public_consumes_constructor() {
let mut impl_item = parse_quote! {
#[derive(Debug)]
impl Contract {
#[constructor]
fn func(&mut self, val: U256) {}
}
};
let public = PublicImpl::from(&mut impl_item);
assert!(matches!(public.funcs[0].kind, FnKind::Constructor));
let syn::ImplItem::Fn(syn::ImplItemFn { attrs, .. }) = &impl_item.items[0] else {
unreachable!();
};
assert!(attrs.is_empty());
}
}
1 change: 0 additions & 1 deletion stylus-proc/src/macros/public/overrides.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright 2022-2024, Offchain Labs, Inc.
// use crate::consts::{ALLOW_OVERRIDE_FN, ASSERT_OVERRIDES_FN};
// For licensing, see https://github.com/OffchainLabs/stylus-sdk-rs/blob/main/licenses/COPYRIGHT.md

//! Ensure that public functions follow safe override rules.
Expand Down
176 changes: 95 additions & 81 deletions stylus-proc/src/macros/public/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright 2022-2024, Offchain Labs, Inc.
// For licensing, see https://github.com/OffchainLabs/stylus-sdk-rs/blob/main/licenses/COPYRIGHT.md

use proc_macro2::{Span, TokenStream};
use proc_macro_error::emit_error;
use quote::{quote, ToTokens};
Expand All @@ -18,6 +19,34 @@ use crate::{

use super::Extension;

/// Generate the code to call the special function (fallback, receive, or constructor) from the
/// public impl block. Emits an error if there are multiple implementations.
macro_rules! call_special {
($self:expr, $kind:pat, $kind_name:literal, $func:expr) => {{
let specials: Vec<syn::Stmt> = $self
.funcs
.iter()
.filter(|&func| matches!(func.kind, $kind))
.map($func)
.collect();
if specials.is_empty() {
None
} else {
if specials.len() > 1 {
emit_error!(
concat!("multiple ", $kind_name),
concat!(
"contract can only have one #[",
$kind_name,
"] method defined"
)
);
}
specials.first().cloned()
}
}};
}

pub struct PublicImpl<E: InterfaceExtension = Extension> {
pub self_ty: syn::Type,
pub generic_params: Punctuated<syn::GenericParam, Token![,]>,
Expand Down Expand Up @@ -47,30 +76,22 @@ impl PublicImpl {
.collect::<Vec<_>>();
let inheritance_routes = self.inheritance_routes();

let call_fallback = self.call_fallback();
let call_fallback = call_special!(
self,
FnKind::Fallback { .. },
"fallback",
PublicFn::call_fallback
);
let inheritance_fallback = self.inheritance_fallback();

let (fallback, fallback_purity) = call_fallback.unwrap_or_else(|| {
let fallback = call_fallback.unwrap_or_else(|| {
// If there is no fallback function specified, we rely on any inherited fallback.
(
parse_quote!({
#(#inheritance_fallback)*
None
}),
Purity::Payable, // Let the inherited fallback deal with purity.
)
parse_quote!({
#(#inheritance_fallback)*
None
})
});

let fallback_deny: Option<syn::ExprIf> = match fallback_purity {
Purity::Payable => None,
_ => Some(parse_quote! {
if let Err(err) = stylus_sdk::abi::internal::deny_value("fallback") {
return Some(Err(err));
}
}),
};

let call_receive = self.call_receive();
let call_receive = call_special!(self, FnKind::Receive, "receive", PublicFn::call_receive);
let inheritance_receive = self.inheritance_receive();
let receive = call_receive.unwrap_or_else(|| {
parse_quote!({
Expand All @@ -79,6 +100,14 @@ impl PublicImpl {
})
});

let call_constructor = call_special!(
self,
FnKind::Constructor,
"constructor",
PublicFn::call_constructor
);
let constructor = call_constructor.unwrap_or_else(|| parse_quote!({ None }));

parse_quote! {
impl<S, #generic_params> #Router<S> for #self_ty
where
Expand Down Expand Up @@ -112,14 +141,18 @@ impl PublicImpl {

#[inline(always)]
fn fallback(storage: &mut S, input: &[u8]) -> Option<stylus_sdk::ArbResult> {
#fallback_deny
#fallback
}

#[inline(always)]
fn receive(storage: &mut S) -> Option<()> {
#receive
}

#[inline(always)]
fn constructor(storage: &mut S, input: &[u8]) -> Option<stylus_sdk::ArbResult> {
#constructor
}
}
}
}
Expand All @@ -134,37 +167,6 @@ impl PublicImpl {
})
}

fn call_fallback(&self) -> Option<(syn::Stmt, Purity)> {
let mut fallback_purity = Purity::View;
let fallbacks: Vec<syn::Stmt> = self
.funcs
.iter()
.filter(|&func| {
if matches!(func.kind, FnKind::FallbackWithArgs)
|| matches!(func.kind, FnKind::FallbackNoArgs)
{
fallback_purity = func.purity;
return true;
}
false
})
.map(PublicFn::call_fallback)
.collect();
if fallbacks.is_empty() {
return None;
}
if fallbacks.len() > 1 {
emit_error!(
"multiple fallbacks",
"contract can only have one #[fallback] method defined"
);
}
fallbacks
.first()
.cloned()
.map(|func| (func, fallback_purity))
}

fn inheritance_fallback(&self) -> impl Iterator<Item = syn::ExprIf> + '_ {
self.inheritance.iter().map(|ty| {
parse_quote! {
Expand All @@ -175,25 +177,6 @@ impl PublicImpl {
})
}

fn call_receive(&self) -> Option<syn::Stmt> {
let receives: Vec<syn::Stmt> = self
.funcs
.iter()
.filter(|&func| matches!(func.kind, FnKind::Receive))
.map(PublicFn::call_receive)
.collect();
if receives.is_empty() {
return None;
}
if receives.len() > 1 {
emit_error!(
"multiple receives",
"contract can only have one #[receive] method defined"
);
}
receives.first().cloned()
}

fn inheritance_receive(&self) -> impl Iterator<Item = syn::ExprIf> + '_ {
self.inheritance.iter().map(|ty| {
parse_quote! {
Expand All @@ -208,9 +191,9 @@ impl PublicImpl {
#[derive(Debug)]
pub enum FnKind {
Function,
FallbackWithArgs,
FallbackNoArgs,
Fallback { with_args: bool },
Receive,
Constructor,
}

pub struct PublicFn<E: FnExtension> {
Expand Down Expand Up @@ -323,24 +306,30 @@ impl<E: FnExtension> PublicFn<E> {
} else {
let name = self.name.to_string();
Some(parse_quote! {
if let Err(err) = internal::deny_value(#name) {
if let Err(err) = stylus_sdk::abi::internal::deny_value(#name) {
return Some(Err(err));
}
})
}
}

fn call_fallback(&self) -> syn::Stmt {
let deny_value = self.deny_value();
let name = &self.name;
let storage_arg = self.storage_arg();
if matches!(self.kind, FnKind::FallbackNoArgs) {
return parse_quote! {
let call: syn::Stmt = if matches!(self.kind, FnKind::Fallback { with_args: false }) {
parse_quote! {
return Some(Self::#name(#storage_arg));
};
}
parse_quote! {
return Some(Self::#name(#storage_arg input));
}
}
} else {
parse_quote! {
return Some(Self::#name(#storage_arg input));
}
};
parse_quote!({
#deny_value
#call
})
}

fn call_receive(&self) -> syn::Stmt {
Expand All @@ -350,6 +339,31 @@ impl<E: FnExtension> PublicFn<E> {
return Some(Self::#name(#storage_arg));
}
}

fn call_constructor(&self) -> syn::Stmt {
let deny_value = self.deny_value();
let name = &self.name;
let decode_inputs = self.decode_inputs();
let storage_arg = self.storage_arg();
let expand_args = self.expand_args();
let encode_output = self.encode_output();
parse_quote!({
use stylus_sdk::abi::{internal, internal::EncodableReturnType};
#deny_value
if let Err(e) = internal::constructor_guard() {
return Some(Err(e));
}
let args = match <#decode_inputs as #SolType>::abi_decode_params(input, true) {
Ok(args) => args,
Err(err) => {
internal::failed_to_decode_arguments(err);
return Some(Err(Vec::new()));
}
};
let result = Self::#name(#storage_arg #(#expand_args, )* );
Some(#encode_output)
})
}
}

pub struct PublicFnArg<E: FnArgExtension> {
Expand Down
Loading
Loading