Skip to content

Commit

Permalink
runtime-sdk-macros: Add support for internal call handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
kostko committed Sep 25, 2023
1 parent c35d047 commit 335072a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
40 changes: 37 additions & 3 deletions runtime-sdk-macros/src/module_derive/method_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,29 @@ impl super::Deriver for DeriveMethodHandler {
};

let dispatch_call_impl = {
let (handler_names, handler_idents) = filter_by_kind(handlers, HandlerKind::Call);
let (handler_names, handler_fns): (Vec<_>, Vec<_>) = handlers
.iter()
.filter_map(|h| h.handler.as_ref())
.filter(|h| h.attrs.kind == HandlerKind::Call)
.map(|h| {
(h.attrs.rpc_name.clone(), {
let ident = &h.ident;

if h.attrs.is_internal {
quote! {
|ctx, body| {
if !ctx.is_internal() {
return Err(sdk::modules::core::Error::Forbidden.into());
}
Self::#ident(ctx, body)
}
}
} else {
quote! { Self::#ident }
}
})
})
.unzip();

if handler_names.is_empty() {
quote! {}
Expand All @@ -113,7 +135,7 @@ impl super::Deriver for DeriveMethodHandler {
) -> DispatchResult<cbor::Value, CallResult> {
match method {
#(
#handler_names => module::dispatch_call(ctx, body, Self::#handler_idents),
#handler_names => module::dispatch_call(ctx, body, #handler_fns),
)*
_ => DispatchResult::Unhandled(body),
}
Expand Down Expand Up @@ -347,6 +369,8 @@ struct MethodHandlerAttr {
allow_private_km: bool,
/// Whether this handler is tagged as allowing interactive calls. Only applies to call handlers.
allow_interactive: bool,
/// Whether this handler is tagged as internal.
is_internal: bool,
}
impl syn::parse::Parse for MethodHandlerAttr {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Expand All @@ -365,6 +389,7 @@ impl syn::parse::Parse for MethodHandlerAttr {
let mut is_expensive = false;
let mut allow_private_km = false;
let mut allow_interactive = false;
let mut is_internal = false;
while input.peek(syn::token::Comma) {
let _: syn::token::Comma = input.parse()?;
let tag: syn::Ident = input.parse()?;
Expand Down Expand Up @@ -393,10 +418,18 @@ impl syn::parse::Parse for MethodHandlerAttr {
));
}
allow_interactive = true;
} else if tag == "internal" {
if kind != HandlerKind::Call {
return Err(syn::Error::new(
tag.span(),
"`internal` tag is only allowed on `call` handlers",
));
}
is_internal = true;
} else {
return Err(syn::Error::new(
tag.span(),
"invalid handler tag; supported: `expensive`, `allow_private_km`, `allow_interactive`",
"invalid handler tag; supported: `expensive`, `allow_private_km`, `allow_interactive`, `internal`",
));
}
}
Expand All @@ -410,6 +443,7 @@ impl syn::parse::Parse for MethodHandlerAttr {
is_expensive,
allow_private_km,
allow_interactive,
is_internal,
})
}
}
Expand Down
27 changes: 27 additions & 0 deletions runtime-sdk-macros/src/module_derive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ mod tests {
fn my_call(foo2: Bar2) -> Baz2 {}
#[handler(call = "my_module.MyOtherCall")]
fn my_other_call(foo3: Bar3) -> Baz3 {}
#[handler(call = "my_module.MyInternalCall", internal)]
fn my_internal_call(foo4: Bar4) -> Baz4 {}
}
);

Expand All @@ -184,6 +186,7 @@ mod tests {
Self::prefetch_for_my_call(&mut add_prefix, body, auth_info),
),
"my_module.MyOtherCall" => module::DispatchResult::Handled(Ok(())),
"my_module.MyInternalCall" => module::DispatchResult::Handled(Ok(())),
_ => module::DispatchResult::Unhandled(body),
}
}
Expand All @@ -197,6 +200,12 @@ mod tests {
"my_module.MyOtherCall" => {
module::dispatch_call(ctx, body, Self::my_other_call)
}
"my_module.MyInternalCall" => module::dispatch_call(ctx, body, |ctx, body| {
if !ctx.is_internal() {
return Err(sdk::modules::core::Error::Forbidden.into());
}
Self::my_internal_call(ctx, body)
}),
_ => DispatchResult::Unhandled(body),
}
}
Expand All @@ -223,6 +232,10 @@ mod tests {
kind: core_types::MethodHandlerKind::Call,
name: "my_module.MyOtherCall".to_string(),
},
core_types::MethodHandlerInfo {
kind: core_types::MethodHandlerKind::Call,
name: "my_module.MyInternalCall".to_string(),
},
]
}
}
Expand All @@ -237,6 +250,8 @@ mod tests {
fn my_call(foo2: Bar2) -> Baz2 {}
#[handler(call = "my_module.MyOtherCall")]
fn my_other_call(foo3: Bar3) -> Baz3 {}
#[handler(call = "my_module.MyInternalCall", internal)]
fn my_internal_call(foo4: Bar4) -> Baz4 {}
}
};
)
Expand Down Expand Up @@ -537,6 +552,18 @@ mod tests {
super::derive_module(input);
}

#[test]
#[should_panic(expected = "only allowed on `call` handlers")]
fn generate_method_handler_malformed_internal_noncall() {
let input: syn::ItemImpl = syn::parse_quote!(
impl<C: Cfg> MyModule<C> {
#[handler(query = "foo", internal)]
fn my_method_call() -> () {}
}
);
super::derive_module(input);
}

#[test]
#[should_panic]
fn generate_method_handler_malformed_multiple_metas() {
Expand Down

0 comments on commit 335072a

Please sign in to comment.