diff --git a/runtime-sdk/src/modules/access/mod.rs b/runtime-sdk/src/modules/access/mod.rs new file mode 100644 index 0000000000..22c662ba7c --- /dev/null +++ b/runtime-sdk/src/modules/access/mod.rs @@ -0,0 +1,71 @@ +//! Method access control module. +use once_cell::unsync::Lazy; +use thiserror::Error; + +use crate::{ + context::Context, + module::{self, Module as _}, + modules, sdk_derive, + state::CurrentState, + types::transaction, +}; + +#[cfg(test)] +mod test; +pub mod types; + +/// Unique module name. +const MODULE_NAME: &str = "access"; + +/// Errors emitted by the access module. +#[derive(Error, Debug, oasis_runtime_sdk_macros::Error)] +pub enum Error { + #[error("caller is not authorized to call method")] + #[sdk_error(code = 1)] + NotAuthorized, +} + +/// Module configuration. +#[allow(clippy::declare_interior_mutable_const)] +pub trait Config: 'static { + /// To filter methods by caller address, add them to this mapping. + /// + /// If the mapping is empty, no method is filtered. + const METHOD_AUTHORIZATIONS: Lazy = Lazy::new(types::Authorization::new); +} + +/// The method access control module. +pub struct Module { + _cfg: std::marker::PhantomData, +} + +#[sdk_derive(Module)] +impl Module { + const NAME: &'static str = MODULE_NAME; + const VERSION: u32 = 1; + type Error = Error; + type Event = (); + type Parameters = (); + type Genesis = (); +} + +impl module::TransactionHandler for Module { + fn before_authorized_call_dispatch( + _ctx: &C, + call: &transaction::Call, + ) -> Result<(), modules::core::Error> { + let tx_caller_address = CurrentState::with_env(|env| env.tx_caller_address()); + #[allow(clippy::borrow_interior_mutable_const)] + if Cfg::METHOD_AUTHORIZATIONS.is_authorized(&call.method, &tx_caller_address) { + Ok(()) + } else { + Err(modules::core::Error::InvalidArgument( + Error::NotAuthorized.into(), + )) + } + } +} + +impl module::BlockHandler for Module {} + +impl module::InvariantHandler for Module {} diff --git a/runtime-sdk/src/modules/access/test.rs b/runtime-sdk/src/modules/access/test.rs new file mode 100644 index 0000000000..09e6a43622 --- /dev/null +++ b/runtime-sdk/src/modules/access/test.rs @@ -0,0 +1,272 @@ +//! Tests for the method access control module. +use std::collections::BTreeMap; + +use once_cell::unsync::Lazy; + +use crate::{ + context::Context, + crypto::signature::context as signature_context, + handler, + module::{self, Module}, + modules::{self, core}, + sdk_derive, + testing::{keys, mock}, + types::{ + token::{BaseUnits, Denomination}, + transaction, + }, + Runtime, Version, +}; + +use super::{ + types::{Authorization, MethodAuthorization}, + Error as AccessError, +}; + +struct TestConfig; + +impl core::Config for TestConfig {} + +impl modules::access::Config for TestConfig { + const METHOD_AUTHORIZATIONS: Lazy = Lazy::new(|| { + Authorization::with_filtered_methods([( + "test.FilteredMethod", + MethodAuthorization::allow_from([keys::alice::address()]), + )]) + }); +} + +/// Test runtime. +struct TestRuntime; + +impl Runtime for TestRuntime { + const VERSION: Version = Version::new(0, 0, 0); + + type Core = modules::core::Module; + type Accounts = modules::accounts::Module; + + type Modules = ( + modules::core::Module, + modules::accounts::Module, + modules::access::Module, + TestModule, + ); + + fn genesis_state() -> ::Genesis { + ( + core::Genesis { + parameters: core::Parameters { + max_batch_gas: 10_000_000, + min_gas_price: BTreeMap::from([(Denomination::NATIVE, 0)]), + ..Default::default() + }, + }, + modules::accounts::Genesis { + balances: BTreeMap::from([ + ( + keys::alice::address(), + BTreeMap::from([(Denomination::NATIVE, 1_000_000)]), + ), + ( + keys::bob::address(), + BTreeMap::from([(Denomination::NATIVE, 1_000_000)]), + ), + ]), + total_supplies: BTreeMap::from([(Denomination::NATIVE, 2_000_000)]), + ..Default::default() + }, + (), // Access module has no genesis. + (), // Test module has no genesis. + ) + } +} + +/// A module with multiple no-op methods; intended for testing routing. +struct TestModule; + +#[sdk_derive(Module)] +impl TestModule { + const NAME: &'static str = "test"; + type Error = core::Error; + type Event = (); + type Parameters = (); + type Genesis = (); + + #[handler(call = "test.FilteredMethod")] + fn filtered_method(_ctx: &C, fail: bool) -> Result { + Ok(42) + } + + #[handler(call = "test.AllowedMethod")] + fn allowed_method(ctx: &C, _args: ()) -> Result { + Ok(42) + } +} + +impl module::BlockHandler for TestModule {} +impl module::TransactionHandler for TestModule {} +impl module::InvariantHandler for TestModule {} + +fn dispatch_test( + ctx: &C, + signer: &mut mock::Signer, + meth: &str, + encrypted: bool, + should_fail: bool, +) { + let dispatch_result = signer.call_opts( + ctx, + meth, + (), + mock::CallOptions { + fee: transaction::Fee { + amount: BaseUnits::new(1_500, Denomination::NATIVE), + gas: 1_500, + ..Default::default() + }, + encrypted, + ..Default::default() + }, + ); + if should_fail { + let err = core::Error::InvalidArgument(AccessError::NotAuthorized.into()); + assert!( + matches!( + dispatch_result.result, + module::CallResult::Failed { module: _, code: _, message: m } if m == format!("{}", err), + ), + "method call should be blocked", + ); + } else { + assert!( + dispatch_result.result.is_success(), + "method call should succeed but failed with: {:?}", + dispatch_result.result, + ); + let unmarshalled: u64 = + cbor::from_value(dispatch_result.result.unwrap()).expect("result should be decodable"); + assert_eq!(unmarshalled, 42); + } +} + +#[test] +fn test_access_module() { + let _guard = signature_context::test_using_chain_context(); + signature_context::set_chain_context(Default::default(), "test"); + let mut mock = mock::Mock::default(); + let ctx = mock.create_ctx_for_runtime::(true); + + let mut alice = mock::Signer::new(0, keys::alice::sigspec()); + let mut bob = mock::Signer::new(0, keys::bob::sigspec()); + + TestRuntime::migrate(&ctx); + + let filtered = "test.FilteredMethod"; + let allowed = "test.AllowedMethod"; + + // Test plain calls. + + dispatch_test(&ctx, &mut alice, filtered, false, false); + dispatch_test(&ctx, &mut alice, allowed, false, false); + + dispatch_test(&ctx, &mut bob, filtered, false, true); + dispatch_test(&ctx, &mut bob, allowed, false, false); + + // Test encrypted calls. + + dispatch_test(&ctx, &mut alice, filtered, true, false); + dispatch_test(&ctx, &mut alice, allowed, true, false); + + dispatch_test(&ctx, &mut bob, filtered, true, true); + dispatch_test(&ctx, &mut bob, allowed, true, false); +} + +#[test] +fn test_method_authorization() { + let alice = keys::alice::address(); + let bob = keys::bob::address(); + let charlie = keys::charlie::address(); + + // An empty authorizer shouldn't let anybody through. + let empty = MethodAuthorization::allow_from([]); + assert_eq!(empty.is_authorized(&alice), false); + assert_eq!(empty.is_authorized(&bob), false); + assert_eq!(empty.is_authorized(&charlie), false); + + // An authorizer with some addresses should only let those through. + let for_alice = MethodAuthorization::allow_from([alice]); + assert_eq!(for_alice.is_authorized(&alice), true); + assert_eq!(for_alice.is_authorized(&bob), false); + assert_eq!(for_alice.is_authorized(&charlie), false); + + let for_bob = MethodAuthorization::allow_from([bob]); + assert_eq!(for_bob.is_authorized(&alice), false); + assert_eq!(for_bob.is_authorized(&bob), true); + assert_eq!(for_bob.is_authorized(&charlie), false); +} + +#[test] +fn test_authorization() { + let alice = keys::alice::address(); + let bob = keys::bob::address(); + let charlie = keys::charlie::address(); + let dave = keys::dave::address(); + + let authorization = Authorization::with_filtered_methods([ + ("test.Nobody", MethodAuthorization::allow_from([])), + ("test.Alice", MethodAuthorization::allow_from([alice])), + ("test.Bob", MethodAuthorization::allow_from([bob])), + ("test.Both", MethodAuthorization::allow_from([alice, bob])), + ( + "test.AliceAndCharlie", + MethodAuthorization::allow_from([alice, charlie]), + ), + ]); + + // Alice should be able to access some filtered methods and all unfiltered ones. + assert_eq!(authorization.is_authorized("test.Nobody", &alice), false); + assert_eq!(authorization.is_authorized("test.Alice", &alice), true); + assert_eq!(authorization.is_authorized("test.Bob", &alice), false); + assert_eq!(authorization.is_authorized("test.Both", &alice), true); + assert_eq!( + authorization.is_authorized("test.AliceAndCharlie", &alice), + true + ); + assert_eq!(authorization.is_authorized("test.Everybody", &alice), true); + + // Bob should be able to access some filtered methods and all unfiltered ones. + assert_eq!(authorization.is_authorized("test.Nobody", &bob), false); + assert_eq!(authorization.is_authorized("test.Alice", &bob), false); + assert_eq!(authorization.is_authorized("test.Bob", &bob), true); + assert_eq!(authorization.is_authorized("test.Both", &bob), true); + assert_eq!( + authorization.is_authorized("test.AliceAndCharlie", &bob), + false + ); + assert_eq!(authorization.is_authorized("test.Everybody", &bob), true); + + // Charlie should be able to access some filtered methods and all unfiltered ones. + assert_eq!(authorization.is_authorized("test.Nobody", &charlie), false); + assert_eq!(authorization.is_authorized("test.Alice", &charlie), false); + assert_eq!(authorization.is_authorized("test.Bob", &charlie), false); + assert_eq!(authorization.is_authorized("test.Both", &charlie), false); + assert_eq!( + authorization.is_authorized("test.AliceAndCharlie", &charlie), + true + ); + assert_eq!( + authorization.is_authorized("test.Everybody", &charlie), + true + ); + + // Dave is left out of everything, so should only be able to access unfiltered methods. + assert_eq!(authorization.is_authorized("test.Nobody", &dave), false); + assert_eq!(authorization.is_authorized("test.Alice", &dave), false); + assert_eq!(authorization.is_authorized("test.Bob", &dave), false); + assert_eq!(authorization.is_authorized("test.Both", &dave), false); + assert_eq!( + authorization.is_authorized("test.AliceAndCharlie", &dave), + false + ); + assert_eq!(authorization.is_authorized("test.Everybody", &dave), true); +} diff --git a/runtime-sdk/src/modules/access/types.rs b/runtime-sdk/src/modules/access/types.rs new file mode 100644 index 0000000000..8794c66f78 --- /dev/null +++ b/runtime-sdk/src/modules/access/types.rs @@ -0,0 +1,72 @@ +//! Method access control module types. +use std::collections::{BTreeMap, BTreeSet}; + +use crate::types::address::Address; + +/// A set of addresses that can be used to define access control for a particular method. +pub type Addresses = BTreeSet
; + +/// A specific kind of method authorization. +pub enum MethodAuthorization { + /// Only allow method calls from these addresses; + /// for other callers, the method call will fail. + AllowFrom(Addresses), +} + +impl MethodAuthorization { + /// Helper for creating a method authorization type that + /// only allows callers with the given addresses. + pub fn allow_from>(it: I) -> Self { + Self::AllowFrom(BTreeSet::from_iter(it)) + } + + pub(super) fn is_authorized(&self, address: &Address) -> bool { + match self { + Self::AllowFrom(addrs) => addrs.contains(address), + } + } +} + +/// A set of methods that are subject to access control. +pub type Methods = BTreeMap; + +/// A specific kind of access control. +pub enum Authorization { + /// Control a statically configured set of methods, each with a + /// statically configured set of addresses that are allowed to call it. + FilterOnly(Methods), +} + +impl Authorization { + /// Return a new access control configuration. + pub fn new() -> Self { + Self::FilterOnly(BTreeMap::new()) + } + + /// Helper for creating a static access control configuration. + pub fn with_filtered_methods(it: I) -> Self + where + S: AsRef, + I: IntoIterator, + { + Self::FilterOnly(BTreeMap::from_iter( + it.into_iter() + .map(|(name, authz)| (name.as_ref().to_string(), authz)), + )) + } + + pub(super) fn is_authorized(&self, method: &str, address: &Address) -> bool { + match self { + Self::FilterOnly(meths) => meths + .get(method) + .map(|authz| authz.is_authorized(address)) + .unwrap_or(true), + } + } +} + +impl Default for Authorization { + fn default() -> Self { + Self::FilterOnly(BTreeMap::default()) + } +} diff --git a/runtime-sdk/src/modules/mod.rs b/runtime-sdk/src/modules/mod.rs index 3624df4770..cb9d5346e4 100644 --- a/runtime-sdk/src/modules/mod.rs +++ b/runtime-sdk/src/modules/mod.rs @@ -1,5 +1,6 @@ //! Runtime modules included with the SDK. +pub mod access; pub mod accounts; pub mod consensus; pub mod consensus_accounts;