Skip to content

Commit

Permalink
add try_get_inner and try_get_inner_mut functions and implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
qalisander committed Jun 27, 2024
1 parent b622dfc commit 921c86a
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 62 deletions.
4 changes: 2 additions & 2 deletions examples/erc20/src/erc20.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use alloy_sol_types::sol;
use core::marker::PhantomData;
use stylus_sdk::{evm, msg, prelude::*};

pub trait Erc20Params {
pub trait Erc20Params: 'static {
/// Immutable token name
const NAME: &'static str;

Expand All @@ -27,7 +27,7 @@ pub trait Erc20Params {

sol_storage! {
/// Erc20 implements all ERC-20 methods.
pub struct Erc20<T> {
pub struct Erc20<T: Erc20Params> {
/// Maps users to balances
mapping(address => uint256) balances;
/// Maps users to a mapping of each spender's allowance
Expand Down
13 changes: 1 addition & 12 deletions stylus-proc/src/methods/entrypoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,7 @@ pub fn entrypoint(attr: TokenStream, input: TokenStream) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

output.extend(quote!{
unsafe impl #impl_generics stylus_sdk::storage::TopLevelStorage for #name #ty_generics #where_clause {
fn get_storage<S: 'static>(&mut self) -> &mut S {
use stylus_sdk::storage::InnerStorage;
unsafe {
self.try_get_storage().unwrap_or_else(|| {
panic!(
"storage for type doesn't exist - type name is {}",
core::any::type_name::<S>()
)})
}
}
}
unsafe impl #impl_generics stylus_sdk::storage::TopLevelStorage for #name #ty_generics #where_clause {}

fn entrypoint(input: alloc::vec::Vec<u8>) -> stylus_sdk::ArbResult {
use stylus_sdk::{abi::Router, alloy_primitives::U256, console, hex, storage::StorageType};
Expand Down
15 changes: 2 additions & 13 deletions stylus-proc/src/methods/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub fn external(_attr: TokenStream, input: TokenStream) -> TokenStream {
let storage = if needed_purity == Pure {
quote!()
} else if has_self {
quote! { core::borrow::BorrowMut::borrow_mut(storage), }
quote! { storage.inner_mut(), }
} else {
quote! { storage, }
};
Expand Down Expand Up @@ -241,13 +241,6 @@ pub fn external(_attr: TokenStream, input: TokenStream) -> TokenStream {
}
});

// ensure we can actually borrow the things we inherit
let borrow_clauses = inherits.iter().map(|ty| {
quote! {
S: core::borrow::BorrowMut<#ty>
}
});

let self_ty = &input.self_ty;
let generic_params = &input.generics.params;
let where_clauses = input
Expand All @@ -263,13 +256,9 @@ pub fn external(_attr: TokenStream, input: TokenStream) -> TokenStream {

impl<S, #generic_params> stylus_sdk::abi::Router<S> for #self_ty
where
S: stylus_sdk::storage::TopLevelStorage + core::borrow::BorrowMut<Self>,
#(#borrow_clauses,)*
S: stylus_sdk::storage::TopLevelStorage,
#where_clauses
{
// TODO: this should be configurable
type Storage = Self;

#[inline(always)]
fn route(storage: &mut S, selector: u32, input: &[u8]) -> Option<stylus_sdk::ArbResult> {
use stylus_sdk::{function_selector, alloy_sol_types::SolType};
Expand Down
50 changes: 33 additions & 17 deletions stylus-proc/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ pub fn solidity_storage(_attr: TokenStream, input: TokenStream) -> TokenStream {
error!(&field, "Type not supported for EVM state storage");
};

let accessor = match field.ident.as_ref() {
Some(accessor) => accessor.into_token_stream(),
None => Index::from(field_index).into_token_stream(),
};
inner_storage_accessors.push(accessor.clone());

// implement borrows (TODO: use drain_filter when stable)
let attrs = mem::take(&mut field.attrs);
for attr in attrs {
Expand All @@ -37,13 +43,7 @@ pub fn solidity_storage(_attr: TokenStream, input: TokenStream) -> TokenStream {
error!(attr.tokens, "borrow attribute does not take parameters");
}
let ty = &field.ty;
let accessor = match field.ident.as_ref() {
Some(accessor) => accessor.into_token_stream(),
None => Index::from(field_index).into_token_stream(),
};

inner_storage_accessors.push(accessor.clone());


borrows.extend(quote! {
impl core::borrow::Borrow<#ty> for #name {
fn borrow(&self) -> &#ty {
Expand Down Expand Up @@ -120,24 +120,40 @@ pub fn solidity_storage(_attr: TokenStream, input: TokenStream) -> TokenStream {
});
}

let inner_storage_calls = inner_storage_accessors.into_iter().map(|accessor|{
let inner_storage_calls = inner_storage_accessors.iter().map(|accessor|{
quote! {
.or(self.#accessor.try_get_storage())
.or(self.#accessor.try_get_inner())
}
});
});

let storage_impl = quote! {
#[allow(clippy::transmute_ptr_to_ptr)]
unsafe impl #impl_generics stylus_sdk::storage::InnerStorage for #name #ty_generics #where_clause {
unsafe fn try_get_storage<S: 'static>(&mut self) -> Option<&mut S> {
let inner_mut_storage_calls = inner_storage_accessors.iter().map(|accessor|{
quote! {
.or(self.#accessor.try_get_inner_mut())
}
});

let storage_level_impl = quote! {
unsafe impl #impl_generics stylus_sdk::storage::StorageLevel for #name #ty_generics #where_clause {

unsafe fn try_get_inner<S: stylus_sdk::storage::StorageLevel + 'static>(&self) -> Option<&S> {
use core::any::TypeId;
use stylus_sdk::storage::InnerStorage;
use stylus_sdk::storage::StorageLevel;
if TypeId::of::<S>() == TypeId::of::<Self>() {
Some(unsafe { core::mem::transmute::<_, _>(self) })
Some(unsafe { &*(self as *const Self as *const S) })
} else {
None #(#inner_storage_calls)*
}
}

unsafe fn try_get_inner_mut<S: stylus_sdk::storage::StorageLevel + 'static>(&mut self) -> Option<&mut S> {
use core::any::TypeId;
use stylus_sdk::storage::StorageLevel;
if TypeId::of::<S>() == TypeId::of::<Self>() {
Some(unsafe { &mut *(self as *mut Self as *mut S) })
} else {
None #(#inner_mut_storage_calls)*
}
}
}
};

Expand Down Expand Up @@ -189,7 +205,7 @@ pub fn solidity_storage(_attr: TokenStream, input: TokenStream) -> TokenStream {

#borrows

#storage_impl
#storage_level_impl
};
expanded.into()
}
Expand Down
5 changes: 1 addition & 4 deletions stylus-sdk/src/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@ pub mod internal;
/// Composition with other routers is possible via `#[inherit]`.
pub trait Router<S>
where
S: TopLevelStorage + BorrowMut<Self::Storage>,
S: TopLevelStorage
{
/// The type the [`TopLevelStorage`] borrows into. Usually just `Self`.
type Storage;

/// Tries to find and execute a method for the given selector, returning `None` if none is found.
/// Routes add via `#[inherit]` will only execute if no match is found among `Self`.
/// This means that it is possible to override a method by redefining it in `Self`.
Expand Down
4 changes: 3 additions & 1 deletion stylus-sdk/src/storage/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2023, Offchain Labs, Inc.
// For licensing, see https://github.com/OffchainLabs/stylus-sdk-rs/blob/stylus/licenses/COPYRIGHT.md

use super::{Erase, StorageGuard, StorageGuardMut, StorageType};
use super::{Erase, StorageGuard, StorageGuardMut, StorageLevel, StorageType};
use alloy_primitives::U256;
use core::marker::PhantomData;

Expand All @@ -11,6 +11,8 @@ pub struct StorageArray<S: StorageType, const N: usize> {
marker: PhantomData<S>,
}

unsafe impl<S: StorageType, const N: usize> StorageLevel for StorageArray<S, N> {}

impl<S: StorageType, const N: usize> StorageType for StorageArray<S, N> {
type Wraps<'a> = StorageGuard<'a, StorageArray<S, N>> where Self: 'a;
type WrapsMut<'a> = StorageGuardMut<'a, StorageArray<S, N>> where Self: 'a;
Expand Down
6 changes: 5 additions & 1 deletion stylus-sdk/src/storage/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2022-2023, Offchain Labs, Inc.
// For licensing, see https://github.com/OffchainLabs/stylus-sdk-rs/blob/stylus/licenses/COPYRIGHT.md

use super::{Erase, GlobalStorage, Storage, StorageB8, StorageGuard, StorageGuardMut, StorageType};
use super::{Erase, GlobalStorage, Storage, StorageArray, StorageB8, StorageGuard, StorageGuardMut, StorageLevel, StorageType};
use crate::crypto;
use alloc::{
string::{String, ToString},
Expand All @@ -16,6 +16,8 @@ pub struct StorageBytes {
base: OnceCell<U256>,
}

unsafe impl StorageLevel for StorageBytes {}

impl StorageType for StorageBytes {
type Wraps<'a> = StorageGuard<'a, StorageBytes> where Self: 'a;
type WrapsMut<'a> = StorageGuardMut<'a, StorageBytes> where Self: 'a;
Expand Down Expand Up @@ -264,6 +266,8 @@ impl<'a> Extend<&'a u8> for StorageBytes {
/// Accessor for storage-backed bytes
pub struct StorageString(pub StorageBytes);

unsafe impl StorageLevel for StorageString {}

impl StorageType for StorageString {
type Wraps<'a> = StorageGuard<'a, StorageString> where Self: 'a;
type WrapsMut<'a> = StorageGuardMut<'a, StorageString> where Self: 'a;
Expand Down
4 changes: 3 additions & 1 deletion stylus-sdk/src/storage/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::crypto;

use super::{Erase, SimpleStorageType, StorageGuard, StorageGuardMut, StorageType};
use super::{Erase, SimpleStorageType, StorageBytes, StorageGuard, StorageGuardMut, StorageLevel, StorageType};
use alloc::{string::String, vec::Vec};
use alloy_primitives::{Address, FixedBytes, Signed, Uint, B256, U160, U256};
use core::marker::PhantomData;
Expand All @@ -14,6 +14,8 @@ pub struct StorageMap<K: StorageKey, V: StorageType> {
marker: PhantomData<(K, V)>,
}

unsafe impl<K: StorageKey, S: StorageType> StorageLevel for StorageMap<K, S> {}

impl<K, V> StorageType for StorageMap<K, V>
where
K: StorageKey,
Expand Down
18 changes: 17 additions & 1 deletion stylus-sdk/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub use bytes::{StorageBytes, StorageString};
pub use map::{StorageKey, StorageMap};
pub use traits::{
Erase, GlobalStorage, SimpleStorageType, StorageGuard, StorageGuardMut, StorageType,
TopLevelStorage, InnerStorage,
TopLevelStorage, StorageLevel,
};
pub use vec::StorageVec;

Expand Down Expand Up @@ -160,6 +160,8 @@ impl<const B: usize, const L: usize> StorageUint<B, L> {
}
}

unsafe impl<const B: usize, const L: usize> StorageLevel for StorageUint<B, L> {}

impl<const B: usize, const L: usize> StorageType for StorageUint<B, L> {
type Wraps<'a> = Uint<B, L>;
type WrapsMut<'a> = StorageGuardMut<'a, Self>;
Expand Down Expand Up @@ -235,6 +237,8 @@ impl<const B: usize, const L: usize> StorageSigned<B, L> {
}
}

unsafe impl<const B: usize, const L: usize> StorageLevel for StorageSigned<B, L> {}

impl<const B: usize, const L: usize> StorageType for StorageSigned<B, L> {
type Wraps<'a> = Signed<B, L>;
type WrapsMut<'a> = StorageGuardMut<'a, Self>;
Expand Down Expand Up @@ -306,6 +310,8 @@ impl<const N: usize> StorageFixedBytes<N> {
}
}

unsafe impl<const N: usize> StorageLevel for StorageFixedBytes<N> {}

impl<const N: usize> StorageType for StorageFixedBytes<N>
where
ByteCount<N>: SupportedFixedBytes,
Expand Down Expand Up @@ -386,6 +392,8 @@ impl StorageBool {
}
}

unsafe impl StorageLevel for StorageBool {}

impl StorageType for StorageBool {
type Wraps<'a> = bool;
type WrapsMut<'a> = StorageGuardMut<'a, Self>;
Expand Down Expand Up @@ -459,6 +467,8 @@ impl StorageAddress {
}
}

unsafe impl StorageLevel for StorageAddress {}

impl StorageType for StorageAddress {
type Wraps<'a> = Address;
type WrapsMut<'a> = StorageGuardMut<'a, Self>;
Expand Down Expand Up @@ -531,6 +541,8 @@ impl StorageBlockNumber {
}
}

unsafe impl StorageLevel for StorageBlockNumber {}

impl StorageType for StorageBlockNumber {
type Wraps<'a> = BlockNumber;
type WrapsMut<'a> = StorageGuardMut<'a, Self>;
Expand Down Expand Up @@ -603,6 +615,8 @@ impl StorageBlockHash {
}
}

unsafe impl StorageLevel for StorageBlockHash {}

impl StorageType for StorageBlockHash {
type Wraps<'a> = BlockHash;
type WrapsMut<'a> = StorageGuardMut<'a, Self>;
Expand Down Expand Up @@ -647,6 +661,8 @@ impl From<StorageBlockHash> for BlockHash {
}
}

unsafe impl<T> StorageLevel for PhantomData<T> {}

/// We implement `StorageType` for `PhantomData` so that storage types can be generic.
impl<T> StorageType for PhantomData<T> {
type Wraps<'a> = Self where Self: 'a;
Expand Down
56 changes: 49 additions & 7 deletions stylus-sdk/src/storage/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,60 @@ where
/// # Safety
///
/// The type must be top-level to prevent storage aliasing.
pub unsafe trait TopLevelStorage {
pub unsafe trait TopLevelStorage : StorageLevel {

/// Retrieve immutable reference to inner storage of type [`S`] or panic.
fn inner<S: StorageLevel + 'static>(&self) -> &S {
unsafe {
self.try_get_inner().unwrap_or_else(|| {
panic!(
"type does not exist inside TopLevelStorage - type is {}",
core::any::type_name::<S>()
)})
}
}

/// Retrieve mutable reference to inner storage of type [`S`]
fn get_storage<S: 'static>(&mut self) -> &mut S {
panic!("arbitrary storage access is not implemented")
/// Retrieve mutable reference to inner storage of type [`S`] or panic.
fn inner_mut<S: StorageLevel + 'static>(&mut self) -> &mut S {
unsafe {
self.try_get_inner_mut().unwrap_or_else(|| {
panic!(
"type does not exist inside TopLevelStorage - type is {}",
core::any::type_name::<S>()
)})
}
}
}

pub unsafe trait InnerStorage {
/// Trait for all-level storage types, usually implemented by proc macros.
///
/// # Safety
///
/// For simple types like (StorageMap, StorageBool, ..) should have default implementation.
/// Since it is pointless to querry for a type that can exists in many contracts.
pub unsafe trait StorageLevel {

/// Try etrieve mutable reference to inner storage of type [`S`]
unsafe fn try_get_storage<S: 'static>(&mut self) -> Option<&mut S>;
/// Try to retrieve immutable reference to inner laying type [`S`].
/// [`Option::None`] result usually means we should panic.
///
/// # Safety
///
/// To be able to retrieve type that contain current type (parrent) you should
/// call [`TopLevelStorage::inner`].
unsafe fn try_get_inner<S: StorageLevel + 'static>(&self) -> Option<&S>{
None
}

/// Try to retrieve mutable reference to inner laying type [`S`].
/// [`Option::None`] result usually means we should panic.
///
/// # Safety
///
/// To be able to retrieve type that contain current type (parrent) you should
/// call [`TopLevelStorage::inner_mut`].
unsafe fn try_get_inner_mut<S: StorageLevel + 'static>(&mut self) -> Option<&mut S>{
None
}
}

/// Binds a storage accessor to a lifetime to prevent aliasing.
Expand Down
Loading

0 comments on commit 921c86a

Please sign in to comment.