diff --git a/macros/src/attr/fn.rs b/macros/src/attr/fn.rs new file mode 100644 index 00000000..84000972 --- /dev/null +++ b/macros/src/attr/fn.rs @@ -0,0 +1,49 @@ +use syn::{parse_quote, Ident, Path, Result}; + +use super::{parse_assign_from_str, parse_assign_inflection, parse_assign_str, Inflection}; + +#[derive(Default)] +pub struct FnAttr { + crate_rename: Option, + pub args: Args, + pub export_to: Option, + pub rename: Option, + pub rename_all: Option, +} + +impl FnAttr { + pub fn crate_rename(&self) -> Path { + self.crate_rename + .clone() + .unwrap_or_else(|| parse_quote!(::ts_rs)) + } +} + +#[derive(Default)] +pub enum Args { + #[default] + Flattened, + Inlined, +} + +impl TryFrom for Args { + type Error = syn::Error; + + fn try_from(s: String) -> Result { + match s.as_str() { + "inlined" => Ok(Self::Inlined), + "flattened" => Ok(Self::Flattened), + x => syn_err!(r#"Expected "inlined" or "flattened", found "{x}""#), + } + } +} + +impl_parse! { + FnAttr(input, output) { + "crate" => output.crate_rename = Some(parse_assign_from_str(input)?), + "args" => output.args = parse_assign_str(input)?.try_into()?, + "export_to" => output.export_to = Some(parse_assign_str(input)?), + "rename" => output.rename = Some(parse_assign_str(input)?), + "rename_all" => output.rename_all = Some(parse_assign_inflection(input)?), + } +} diff --git a/macros/src/attr/mod.rs b/macros/src/attr/mod.rs index 196166f3..7e2bc38e 100644 --- a/macros/src/attr/mod.rs +++ b/macros/src/attr/mod.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; pub use field::*; pub use r#enum::*; +pub use r#fn::*; pub use r#struct::*; use syn::{ parse::{Parse, ParseStream}, @@ -12,6 +13,7 @@ pub use variant::*; mod r#enum; mod field; +mod r#fn; mod r#struct; mod variant; @@ -98,6 +100,19 @@ impl Inflection { Inflection::ScreamingKebab => Self::Kebab.apply(string).to_ascii_uppercase(), } } + + pub fn as_str(&self) -> &str { + match self { + Self::Lower => "lowercase", + Self::Upper => "UPPERCASE", + Self::Kebab => "kebab-case", + Self::Camel => "camelCase", + Self::Snake => "snake_case", + Self::Pascal => "PascalCase", + Self::ScreamingSnake => "SCREAMING_SNAKE_CASE", + Self::ScreamingKebab => "SCREAMING-KEBAB-CASE", + } + } } fn parse_assign_str(input: ParseStream) -> Result { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8ec15668..7192b333 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -3,13 +3,15 @@ use std::collections::{HashMap, HashSet}; +use attr::{FnAttr, Inflection}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse_quote, spanned::Spanned, ConstParam, GenericParam, Generics, Item, LifetimeParam, Path, - Result, Type, TypeArray, TypeParam, TypeParen, TypePath, TypeReference, TypeSlice, TypeTuple, - WhereClause, WherePredicate, + parse_quote, spanned::Spanned, ConstParam, GenericParam, Generics, Item, ItemFn, LifetimeParam, + Path, Result, Type, TypeArray, TypeParam, TypeParen, TypePath, TypeReference, TypeSlice, + TypeTuple, WhereClause, WherePredicate, }; +use types::ParsedFn; use crate::{deps::Dependencies, utils::format_generics}; @@ -443,3 +445,40 @@ fn entry(input: proc_macro::TokenStream) -> Result { Ok(ts.into_impl(ident, generics)) } + +#[proc_macro_attribute] +pub fn ts_rs_fn( + attr: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + entry_fn(attr.into(), input.into()).map_or_else(|e| e.into_compile_error().into(), Into::into) +} + +fn entry_fn(attr: TokenStream, input: TokenStream) -> Result { + let input = syn::parse2::(input)?; + let attr = if !attr.is_empty() { + syn::parse2::(attr)? + } else { + FnAttr::default() + }; + + let ident = format_ident!( + "{}Fn", + Inflection::Pascal.apply(&input.sig.ident.to_string()) + ); + + let ParsedFn { + args_struct, + derived_fn, + } = types::fn_def(&input, attr)?; + + let struct_impl = derived_fn.into_impl(ident.clone(), input.sig.generics.clone()); + Ok(quote!( + #input + + struct #ident; + #struct_impl + + #args_struct + )) +} diff --git a/macros/src/types/fn.rs b/macros/src/types/fn.rs new file mode 100644 index 00000000..7094b98a --- /dev/null +++ b/macros/src/types/fn.rs @@ -0,0 +1,120 @@ +use std::{collections::HashMap, ops::Not}; + +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Error, Field, FnArg, + ItemFn, PatType, Result, TypeReference, +}; + +use crate::{ + attr::{Args, FnAttr, Inflection}, + deps::Dependencies, + utils::{parse_docs, to_ts_ident}, + DerivedTS, +}; + +pub struct ParsedFn { + pub args_struct: Option, + pub derived_fn: DerivedTS, +} + +pub fn fn_def(input: &ItemFn, fn_attr: FnAttr) -> Result { + let mut dependencies = Dependencies::new(fn_attr.crate_rename()); + + let ident = &input.sig.ident; + let generics = &input.sig.generics; + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + let struct_ident = format_ident!("{}Args", Inflection::Pascal.apply(&ident.to_string())); + let fields = input + .sig + .inputs + .iter() + .map(|x| match x { + FnArg::Receiver(_) => Err(Error::new(x.span(), "self parameter is not allowed")), + FnArg::Typed(PatType { ty, attrs, pat, .. }) => { + dependencies.push(ty); + Ok(Field { + attrs: attrs.to_vec(), + vis: syn::Visibility::Inherited, + mutability: syn::FieldMutability::None, + ident: Some(syn::parse2(pat.to_token_stream())?), + colon_token: None, + ty: match ty.as_ref() { + syn::Type::Reference(TypeReference { elem, .. }) => { + parse_quote!(Box<#elem>) + } + x => x.clone(), + }, + }) + } + }) + .collect::>>()?; + + let crate_rename = fn_attr.crate_rename(); + let FnAttr { + rename_all, + rename, + args, + export_to, + .. + } = fn_attr; + let struct_attr = rename_all.map(|rename_all| { + let rename_all = rename_all.as_str(); + Some(quote!(#[ts(rename_all = #rename_all)])) + }); + + let args_struct = fields.is_empty().not().then_some(quote!( + #[derive(#crate_rename::TS)] + #struct_attr + struct #struct_ident #ty_generics #where_clause { + #fields + } + )); + + let docs = parse_docs(&input.attrs)?; + let ts_name = Inflection::Pascal.apply(&rename.clone().unwrap_or_else(|| to_ts_ident(ident))); + let is_async = input.sig.asyncness.is_some(); + let return_ty = match (is_async, input.sig.output.clone()) { + (false, syn::ReturnType::Default) => quote!("void"), + (true, syn::ReturnType::Default) => quote!("Promise"), + (false, syn::ReturnType::Type(_, ref ty)) => { + dependencies.push(ty); + quote!(<#ty as #crate_rename::TS>::name()) + } + (true, syn::ReturnType::Type(_, ref ty)) => { + dependencies.push(ty); + quote!(format!("Promise<{}>", <#ty as #crate_rename::TS>::name())) + } + }; + + let inline = match (&args_struct, args) { + (Some(_), Args::Inlined) => quote!(format!( + "(args: {}) => {}", + <#struct_ident as #crate_rename::TS>::inline(), + #return_ty, + )), + (Some(_), Args::Flattened) => quote!(format!("({}) => {}", + <#struct_ident as #crate_rename::TS>::inline_flattened().trim_matches(['{', '}', ' ']), + #return_ty + )), + (None, _) => quote!(format!("() => {}", #return_ty)), + }; + + Ok(ParsedFn { + args_struct, + derived_fn: DerivedTS { + crate_rename, + ts_name, + docs, + inline, + inline_flattened: None, + dependencies, + export: true, + export_to, + concrete: HashMap::default(), + bound: None, + }, + }) +} diff --git a/macros/src/types/mod.rs b/macros/src/types/mod.rs index caebefc0..81243bcc 100644 --- a/macros/src/types/mod.rs +++ b/macros/src/types/mod.rs @@ -7,6 +7,7 @@ use crate::{ }; mod r#enum; +mod r#fn; mod named; mod newtype; mod tuple; @@ -15,6 +16,7 @@ mod type_override; mod unit; pub(crate) use r#enum::r#enum_def; +pub(crate) use r#fn::*; pub(crate) fn struct_def(s: &ItemStruct) -> Result { let attr = StructAttr::from_attrs(&s.attrs)?; diff --git a/ts-rs/src/lib.rs b/ts-rs/src/lib.rs index 0954e697..b46326c1 100644 --- a/ts-rs/src/lib.rs +++ b/ts-rs/src/lib.rs @@ -128,7 +128,7 @@ use std::{ path::{Path, PathBuf}, }; -pub use ts_rs_macros::TS; +pub use ts_rs_macros::{ts_rs_fn, TS}; pub use crate::export::ExportError; use crate::typelist::TypeList; diff --git a/ts-rs/tests/fn.rs b/ts-rs/tests/fn.rs new file mode 100644 index 00000000..1c37fe4e --- /dev/null +++ b/ts-rs/tests/fn.rs @@ -0,0 +1,106 @@ +#![allow(dead_code, unused, clippy::disallowed_names)] +use ts_rs::{ts_rs_fn, TS}; + +#[ts_rs_fn(args = "inlined", export_to = "tests-out/fn/")] +fn my_void_function() {} + +#[ts_rs_fn(args = "inlined", export_to = "tests-out/fn/")] +fn my_non_void_function() -> String { + String::from("Hello world") +} + +#[ts_rs_fn(args = "inlined", export_to = "tests-out/fn/")] +fn my_void_function_with_inlined_args(str_arg: &str, int_arg: u32) {} + +#[ts_rs_fn(args = "flattened", export_to = "tests-out/fn/")] +fn my_void_function_with_flattened_args(str_arg: &str, int_arg: u32) {} + +#[ts_rs_fn(args = "inlined", export_to = "tests-out/fn/")] +fn my_non_void_function_with_inlined_args(str_arg: &str, int_arg: u32) -> String { + String::from("Hello world") +} + +#[ts_rs_fn(args = "flattened", export_to = "tests-out/fn/")] +fn my_non_void_function_with_flattened_args(str_arg: &str, int_arg: u32) -> String { + String::from("Hello world") +} + +#[derive(TS)] +#[ts(export, export_to = "tests-out/fn/")] +struct Foo { + foo: u32, +} + +#[ts_rs_fn(export_to = "tests-out/fn/")] +fn function_with_imported_return() -> Foo { + Foo { foo: 0 } +} + +#[ts_rs_fn(export_to = "tests-out/fn/")] +fn function_with_imported_flattened_args(foo: Foo) {} + +#[ts_rs_fn(args = "inlined", export_to = "tests-out/fn/")] +fn function_with_imported_inlined_args(foo: Foo) {} + +#[test] +fn void_fn() { + assert_eq!(MyVoidFunctionFn::inline(), "() => void") +} + +#[test] +fn non_void_fn() { + assert_eq!(MyNonVoidFunctionFn::inline(), "() => string") +} + +#[test] +fn void_fn_inlined_args() { + assert_eq!( + MyVoidFunctionWithInlinedArgsFn::inline(), + "(args: { str_arg: string, int_arg: number, }) => void" + ) +} + +#[test] +fn void_fn_flattened_args() { + assert_eq!( + MyVoidFunctionWithFlattenedArgsFn::inline(), + "(str_arg: string, int_arg: number,) => void" + ) +} + +#[test] +fn non_void_fn_inlined_args() { + assert_eq!( + MyNonVoidFunctionWithInlinedArgsFn::inline(), + "(args: { str_arg: string, int_arg: number, }) => string" + ) +} + +#[test] +fn non_void_fn_flattened_args() { + assert_eq!( + MyNonVoidFunctionWithFlattenedArgsFn::inline(), + "(str_arg: string, int_arg: number,) => string" + ) +} + +#[test] +fn fn_with_imported_return() { + assert_eq!(FunctionWithImportedReturnFn::inline(), "() => Foo") +} + +#[test] +fn fn_with_imported_inlined_args() { + assert_eq!( + FunctionWithImportedInlinedArgsFn::inline(), + "(args: { foo: Foo, }) => void" + ) +} + +#[test] +fn fn_with_imported_flattened_args() { + assert_eq!( + FunctionWithImportedFlattenedArgsFn::inline(), + "(foo: Foo,) => void" + ) +}