diff --git a/dfdx-derives/Cargo.toml b/dfdx-derives/Cargo.toml index 9941edbd..c142e70f 100644 --- a/dfdx-derives/Cargo.toml +++ b/dfdx-derives/Cargo.toml @@ -13,6 +13,7 @@ proc-macro2 = "1" quote = "1" syn = { version = "2", features = ["extra-traits"] } dfdx-core = { path = "../dfdx-core" } +heck = "0.4.1" [features] nightly = ["dfdx-core/nightly"] diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 4eca0d82..6d0edd11 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -957,3 +957,457 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre } }) } + +/// Generates a module containing helpful structs and implementations for a input wrapper. +/// +/// ## Example +/// +/// The following definition: +/// ```ignore +/// #[input_wrapper] +/// pub struct MyWrapper { +/// pub a: A, +/// pub b: B, +/// } +/// ``` +/// +/// Generates the following module: +/// ```ignore +/// pub mod my_wrapper { +/// // structs for the fields +/// pub struct a; +/// pub struct b; +/// // note: if MyWrapper was a tuple-struct, +/// // the fields would be named _0, _1 and so on +/// +/// // structs to help in tuple conversions (Module impls omitted) +/// pub struct FromTuple; +/// pub struct IntoTuple; +/// +/// // access for the `a` field +/// impl, A, B> Module> for On +/// { +/// type Output = MyWrapper<>::Output, B>; +/// fn try_forward(&self, x: MyWrapper) -> Result {/* (...) */} +/// fn try_forward_mut(&mut self, x: MyWrapper) -> Result {/* (...) */} +/// } +/// +/// // access for the `b` field +/// impl, A, B> Module> for On +/// { +/// type Output = MyWrapper>::Output>; +/// fn try_forward(&self, x: MyWrapper) -> Result {/* ... */} +/// fn try_forward_mut(&mut self, x: MyWrapper) -> Result {/* ... */} +/// } +/// } +/// ``` +/// To better visualize the generated code and items, it's recommended to expand it with Rust-Analyzer, +/// or to generate the project's documentation. +/// +/// Those helpers can then be used as modules: +/// ```ignore +/// #[derive(Default, Clone, Sequential)] +/// pub struct Arch { +/// // (...) +/// +/// // assuming Input is of type (X, Y), converts the input into MyWrapper +/// pub input_to_wrapper: my_wrapper::FromTuple, +/// +/// // apply module T on the field `a`, while also mapping the input into: +/// // MyWrapper<>::Output, Y> +/// pub t: On, +/// +/// // converts the input into a tuple: +/// // (>::Output, Y) +/// pub input_to_tuple: split1::IntoTuple, +/// +/// // (...) +/// } +/// ``` +#[proc_macro_attribute] +pub fn input_wrapper( + _attr: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let wrapper = parse_macro_input!(input as DeriveInput); + + // - TODO: any bounds on the struct definition probably should be copied into other impls. + // - NOTE: check on how to deal with the `Self` (as this won't refer to the struct on the On Module impl). + + // eg. MyWrapper + let wrapper_ident = wrapper.ident.clone(); + let wrapper_vis = wrapper.vis.clone(); + let wrapper_lowercase = format!("{}", heck::AsSnakeCase(wrapper_ident.to_string())); + // eg. my_wrapper + // TODO: allow renaming + let wrapper_lowercase_ident = syn::Ident::new(&wrapper_lowercase, wrapper_ident.span()); + + // get wrapper field info + // eg. [(pub, Some(my_field), MyFieldType, field span)] + let mut wrapper_fields = vec![]; + match &wrapper.data { + Data::Struct(ref obj) => match obj.fields { + Fields::Named(ref fields) => { + let fields = fields.named.iter().map(|f| { + let ty = &f.ty; + assert_ne!( + quote!(#ty).to_string(), + "M", + "A generic type named `M` is not allowed because this is used internally" + ); + (&f.vis, &f.ident, &f.ty, f.span()) + }); + wrapper_fields.extend(fields) + } + Fields::Unnamed(ref fields) => { + let fields = fields.unnamed.iter().map(|f: &syn::Field| { + let ty = &f.ty; + assert_ne!( + quote!(#ty).to_string(), + "M", + "A generic type named `M` is not allowed because this is used internally" + ); + (&f.vis, &None, &f.ty, f.span()) + }); + wrapper_fields.extend(fields) + } + // no fields + Fields::Unit => {} + }, + Data::Enum(_) => unimplemented!("Input wrapper cannot be derived for enums."), + Data::Union(_) => unimplemented!("Input wrapper cannot be derived for unions."), + }; + + // wrapper fields as structs + let mut wrapper_field_structs_quote = vec![]; + let mut are_fields_named = false; + for (i, (_vis, field, _ty, span)) in wrapper_fields.iter().enumerate() { + let (doc, field) = if let Some(field) = field { + are_fields_named = true; + let doc = format!( + "Indicates the [`{}::{}`] field. \nThis field is the `{}` value (`0`-based index).", + wrapper_ident, + field, + i + ); + (doc, field.clone()) + } else { + let doc = format!( + "Indicates the `{}`-th value from [`{}`] (0-based index).", + i, wrapper_ident + ); + let field = syn::Ident::new(&format!("_{}", i), *span); + (doc, field) + }; + wrapper_field_structs_quote.push(quote! { + #[doc = #doc] + #[allow(non_camel_case_types)] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct #field; + }); + } + + let imports = if are_fields_named { + quote! { + use super::*; + } + } else { + quote! { + use super::*; + // TODO: import tuple stuff + use crate::prelude; + } + }; + + let wrapper_generics = wrapper.generics.clone(); + let wrapper_generics_params = wrapper_generics.params.iter().collect::>(); + // eg. MyWrapper -> [A, B] + let wrapper_generics_param_idents = { + wrapper_generics_params + .iter() + .map(|p| { + use syn::GenericParam::*; + match p { + Lifetime(l) => &l.lifetime.ident, + Type(t) => &t.ident, + Const(c) => &c.ident, + } + }) + .collect::>() + }; + + // eg. MyWrapper -> [A, B] + let wrapper_generic_names = wrapper_generics_param_idents.iter().collect::>(); + + // eg. MyWrapper { field1: A, field2: bool} -> [A, bool] + let field_ty_names = wrapper_fields + .iter() + .map(|(_, _, ty, _)| ty) + .collect::>(); + + // create structs to represent tuple conversions + let tuple_conversion_structs = { + let field_ty_names = field_ty_names + .iter() + .map(|ty| quote! {#ty}.to_string()) + .collect::>() + .join(", "); + let wrapper_generic_names = wrapper_generic_names + .iter() + .map(|ident| ident.to_string()) + .collect::>() + .join(", "); + let doc1 = format!( + "Indicates a conversion from a ({}) tuple into a `{}<{}>`.", + &field_ty_names, wrapper_ident, &wrapper_generic_names + ); + let doc2 = format!( + "Indicates a conversion from a `{}<{}>` into a ({}) tuple.", + wrapper_ident, &wrapper_generic_names, &field_ty_names + ); + quote! { + #[doc = #doc1] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::CustomModule)] + pub struct FromTuple; + #[doc = #doc2] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::prelude::CustomModule)] + pub struct IntoTuple; + } + }; + + // impl From<> conversions + let tuple_conversions = { + let doc1 = format!("Conversion of a tuple into a [`{}`].", wrapper_ident,); + let doc2 = format!("Conversion of a [`{}`] into a tuple.", wrapper_ident,); + + let mut field_from_tuple = vec![]; + let mut field_to_tuple = vec![]; + for (i, (_, ident, _, _span)) in wrapper_fields.iter().enumerate() { + let i = syn::Index::from(i); + if let Some(ident) = ident { + field_from_tuple.push(quote! {#ident: __x.#i}); + field_to_tuple.push(quote! {__x.#ident}); + } else { + field_from_tuple.push(quote! {__x.#i}); + field_to_tuple.push(quote! {__x.#i}); + }; + } + + let (from_tuple, to_tuple) = if are_fields_named { + ( + quote! { + #wrapper_ident { + #(#field_from_tuple), * + } + }, + quote! { (#(#field_to_tuple), *) }, + ) + } else { + ( + quote! { + #wrapper_ident ( + #(#field_from_tuple), * + ) + }, + quote! {(#(#field_to_tuple), *)}, + ) + }; + + quote! { + #[doc = #doc1] + impl<#(#wrapper_generic_names), *> From<(#(#field_ty_names), *)> for #wrapper_ident<#(#wrapper_generic_names), *> { + fn from(__x: (#(#field_ty_names), *)) -> Self { + #from_tuple + } + } + #[doc = #doc2] + impl<#(#wrapper_generic_names), *> From<#wrapper_ident<#(#wrapper_generic_names), *>> for (#(#field_ty_names), *) { + fn from(__x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Self { + #to_tuple + } + } + } + }; + + // impl Module for conversions into and from tuples + let module_conversions = { + let doc1 = format!("Module to convert a tuple into a [`{}`].", wrapper_ident,); + let doc2 = format!("Module to convert a [`{}`] into a tuple.", wrapper_ident,); + quote! { + #[doc = #doc1] + impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<(#(#field_ty_names), *)> for FromTuple { + type Output = #wrapper_ident<#(#wrapper_generic_names), *>; + fn try_forward(&self, __x: (#(#field_ty_names), *)) -> Result { + Ok(__x.into()) + } + } + #[doc = #doc2] + impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { + type Output = (#(#field_ty_names), *); + fn try_forward(&self, __x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + Ok(__x.into()) + } + } + } + }; + + // assertion + for generic_ident in wrapper_generic_names.iter() { + let count = wrapper_fields + .iter() + .map(|(_, _, field_ty, _)| is_ident_container(generic_ident, field_ty)) + .filter(|contains| *contains) + .count(); + if count > 1 { + panic!("the generic {generic_ident} should be used in at most one field"); + } + } + + // field access modules + let mut field_access_modules = vec![]; + for (i, (_vis, ident, ty, span)) in wrapper_fields.iter().enumerate() { + let (doc, on_acccess, forward) = if let Some(ident) = ident { + let doc = format!( + "Module that access [`{}::{}`] and then applies Module `M` on it.", + wrapper_ident, ident + ); + let on_access = ident.clone(); + let forward = syn::Ident::new(&format!("__x{i}"), ident.span()); + (doc, on_access, forward) + } else { + let doc = format!( + "Module that access the `{}`-th value from [`{}`] and then applies Module `M` on it.", + i, + wrapper_ident, + ); + let on_access = syn::Ident::new(&format!("_{}", i), *span); + let forward = syn::Ident::new(&format!("__x{i}"), *span); + (doc, on_access, forward) + }; + + let mut contains_ident = false; + let output_generics = wrapper_generic_names.iter().map(|ty_ident| { + // + if is_ident_container(ty_ident, ty) { + if contains_ident { + panic!( + "the field {ident:?} at index {i} should contain at most one generic type" + ); + } + contains_ident = true; + quote!(>::Output) + } else { + quote!(#ty_ident) + } + }); + + let mut field_extraction_idents = vec![]; + let mut field_extraction = vec![]; + let mut field_construction = vec![]; + for (i, (_, _ident, _, span)) in wrapper_fields.iter().enumerate() { + let ii = syn::Index::from(i); + if let Some(_ident) = _ident { + let xident = syn::Ident::new(&format!("__x{i}"), _ident.span()); + field_extraction_idents.push(xident.clone()); + field_extraction.push(quote! {let #xident = __x.#_ident;}); + field_construction.push(quote! {#_ident: #xident,}); + } else { + let xident = syn::Ident::new(&format!("__x{i}"), *span); + field_extraction_idents.push(xident.clone()); + field_extraction.push(quote! {let #xident = __x.#ii;}); + field_construction.push(quote! {#xident,}); + }; + } + let field_replacement = if are_fields_named { + quote! { + #wrapper_ident { + #(#field_construction)* + } + } + } else { + quote! { + #wrapper_ident ( + #(#field_construction)* + ) + } + }; + + let field_access_module = quote! { + #[doc = #doc] + impl, #(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for ::dfdx::prelude::On<#on_acccess, M> { + type Output = #wrapper_ident<#(#output_generics), *>; + fn try_forward(&self, __x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + #(#field_extraction)* + let #forward = self.t.try_forward(#forward)?; + let __x = #field_replacement; + Ok(__x) + } + fn try_forward_mut(&mut self, __x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + #(#field_extraction)* + let #forward = self.t.try_forward_mut(#forward)?; + let __x = #field_replacement; + Ok(__x) + } + } + }; + field_access_modules.push(field_access_module); + } + + // all of the generated content + let _mod = quote! { + #wrapper_vis mod #wrapper_lowercase_ident { + #imports + + #(#wrapper_field_structs_quote)* + + #tuple_conversion_structs + + #tuple_conversions + + #module_conversions + + #(#field_access_modules)* + } + }; + let output = quote!( + #wrapper + + #[doc = "Automatically generated by `input_wrapper`. The containing items are visible on your project's documentation."] + #_mod + ); + proc_macro::TokenStream::from(output) +} + +/// Checks whether `ty` contains any ident that matches the `ident`. +fn is_ident_container(ident: &syn::Ident, ty: &syn::Type) -> bool { + match ty { + syn::Type::Array(_) => todo!("input_wrapper is_ident_container for array"), + syn::Type::BareFn(_) => todo!("input_wrapper is_ident_container for bare fn"), + syn::Type::Group(_) => todo!("input_wrapper is_ident_container for group"), + syn::Type::ImplTrait(_) => todo!("input_wrapper is_ident_container for impl trait"), + syn::Type::Infer(_) => todo!("input_wrapper is_ident_container for infer"), + syn::Type::Macro(_) => todo!("input_wrapper is_ident_container for macro"), + syn::Type::Never(_) => todo!("input_wrapper is_ident_container for never"), + syn::Type::Paren(_) => todo!("input_wrapper is_ident_container for paren"), + syn::Type::Path(ty) => { + let mut is = false; + if let Some(qself) = &ty.qself { + is |= is_ident_container(ident, &qself.ty); + } + if let Some(segment) = &ty.path.segments.last() { + is |= &segment.ident == ident; + } + is + } + syn::Type::Ptr(_) => todo!("input_wrapper is_ident_container for ptr"), + syn::Type::Reference(_) => todo!("input_wrapper is_ident_container for reference"), + syn::Type::Slice(_) => todo!("input_wrapper is_ident_container for slice"), + syn::Type::TraitObject(_) => todo!("input_wrapper is_ident_container for trait object"), + syn::Type::Tuple(_) => todo!("input_wrapper is_ident_container for tuple"), + syn::Type::Verbatim(_) => todo!("input_wrapper is_ident_container for verbatim"), + other => unimplemented!( + "input_wrapper is_ident_container not implemented for {}", + quote!(#other).to_string() + ), + } +} diff --git a/dfdx/src/nn/layers/id.rs b/dfdx/src/nn/layers/id.rs new file mode 100644 index 00000000..b209b693 --- /dev/null +++ b/dfdx/src/nn/layers/id.rs @@ -0,0 +1,19 @@ +use crate::prelude::*; + +/// Forwards the input to the output. +#[derive(Default, Debug, Clone, Copy, CustomModule)] +pub struct Id; + +impl Module for Id { + type Output = Input; + fn try_forward(&self, x: Input) -> Result { + Ok(x) + } +} + +pub type Id1 = (Id,); +pub type Id2 = (Id, Id); +pub type Id3 = (Id, Id, Id); +pub type Id4 = (Id, Id, Id, Id); +pub type Id5 = (Id, Id, Id, Id, Id); +pub type Id6 = (Id, Id, Id, Id, Id, Id); diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 828b1e97..e0dd49c5 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -1,3 +1,5 @@ +pub mod ops; + mod abs; mod add_into; mod batch_norm1d; @@ -19,6 +21,7 @@ mod flatten2d; mod gelu; mod generalized_add; mod generalized_mul; +pub mod id; mod layer_norm1d; mod leaky_relu; mod linear; @@ -26,6 +29,7 @@ mod ln; mod log_softmax; mod matmul; mod multi_head_attention; +mod on; #[cfg(feature = "nightly")] mod pool_2d_avg; #[cfg(feature = "nightly")] @@ -72,6 +76,7 @@ pub use flatten2d::Flatten2D; pub use gelu::{AccurateGeLU, FastGeLU}; pub use generalized_add::GeneralizedAdd; pub use generalized_mul::GeneralizedMul; +pub use id::Id; pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig}; pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; @@ -79,6 +84,7 @@ pub use ln::Ln; pub use log_softmax::LogSoftmax; pub use matmul::{MatMul, MatMulConfig, MatMulConstConfig}; pub use multi_head_attention::{MultiHeadAttention, MultiHeadAttentionConfig}; +pub use on::On; #[cfg(feature = "nightly")] pub use pool_2d_avg::{AvgPool2D, AvgPool2DConst}; #[cfg(feature = "nightly")] diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs new file mode 100644 index 00000000..6396d70c --- /dev/null +++ b/dfdx/src/nn/layers/on.rs @@ -0,0 +1,150 @@ +use crate::prelude::*; +use std::marker::PhantomData; + +// TODO: try making a Call module, whih allows calling an arbitrary method on the input. + +/// Applies module `T` into an input field from a wrapper. +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] +#[repr(transparent)] +pub struct On { + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub t: T, + + pub _n: PhantomData, +} + +impl, N: Clone + std::fmt::Debug, T: BuildOnDevice> BuildOnDevice + for On +{ + type Built = On; + fn try_build_on_device(&self, device: &D) -> Result { + let t = self.t.try_build_on_device(device)?; + Ok(On { t, _n: PhantomData }) + } +} + +// TODO: define On access for standard tuples, +// so that it's possible to access them with something like: +// On +pub mod tuple {} + +// cargo 'test' '--package' 'dfdx' '--lib' '--' 'nn::layers::on::tests' '--nocapture' +// test based on nn/layers/residual_add.rs +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + #[input_wrapper] + pub struct Split1 { + pub forward: Forward, + pub skip: Skip, + } + + #[derive(Default, Clone, Sequential)] + pub struct ResidualAdd1 { + // input is Input + pub split: SplitInto<(Id, Id)>, + + // input is (Input, Input) + pub input_to_wrapper: split1::FromTuple, + + // input is Split1 + pub t: On, + + // input is Split1 + pub input_to_tuple: split1::IntoTuple, + + // input is (T::Output, Input) + pub add: ops::Add, + // input is T::Output = Input + } + + #[test] + fn test_input_wrapper_struct() { + let dev: TestDevice = Default::default(); + + let model = dev.build_module::(>>::default()); + let model = DeviceResidualAdd1::, TestDtype, TestDevice> { + t: On { + t: Linear { + weight: model.t.t.weight.to_dtype::(), + bias: model.t.t.bias.to_dtype::(), + }, + _n: Default::default(), + }, + add: Default::default(), + input_to_tuple: Default::default(), + input_to_wrapper: Default::default(), + split: Default::default(), + }; + + let x: Tensor, f32, _> = dev.sample_normal(); + let x = x.to_dtype::(); + let y = model.forward(x.leaky_trace()); + + #[rustfmt::skip] + assert_close_to_literal!(y, [[0.25372928, -2.4258814],[1.7892148, -2.6242268],[1.5131638, 0.23407778],[3.4201493, 1.597525]]); + + let g = y.mean().backward(); + assert_close_to_literal!(g.get(&model.t.t.weight), [[0.475242, -0.075136]; 2]); + assert_close_to_literal!(g.get(&model.t.t.bias), [0.5; 2]); + assert_close_to_literal!(g.get(&x), [[0.18806472, 0.21419683]; 4]); + } + + #[input_wrapper] + pub struct Split2(Forward, Skip); + + #[derive(Default, Clone, Sequential)] + pub struct ResidualAdd2 { + // input is Input + pub split: SplitInto<(Id, Id)>, + + // input is (Input, Input) + pub input_to_wrapper: split2::FromTuple, + + // input is Split2 + pub t: On, + + // input is Split2 + pub input_to_tuple: split2::IntoTuple, + + // input is (T::Output, Input) + pub add: ops::Add, + // input is T::Output = Input + } + + #[test] + fn test_input_wrapper_tuple_struct() { + let dev: TestDevice = Default::default(); + + let model = dev.build_module::(>>::default()); + let model = DeviceResidualAdd2::, TestDtype, TestDevice> { + t: On { + t: Linear { + weight: model.t.t.weight.to_dtype::(), + bias: model.t.t.bias.to_dtype::(), + }, + _n: Default::default(), + }, + add: Default::default(), + input_to_tuple: Default::default(), + input_to_wrapper: Default::default(), + split: Default::default(), + }; + + let x: Tensor, f32, _> = dev.sample_normal(); + let x = x.to_dtype::(); + let y = model.forward(x.leaky_trace()); + + #[rustfmt::skip] + assert_close_to_literal!(y, [[0.25372928, -2.4258814],[1.7892148, -2.6242268],[1.5131638, 0.23407778],[3.4201493, 1.597525]]); + + let g = y.mean().backward(); + assert_close_to_literal!(g.get(&model.t.t.weight), [[0.475242, -0.075136]; 2]); + assert_close_to_literal!(g.get(&model.t.t.bias), [0.5; 2]); + assert_close_to_literal!(g.get(&x), [[0.18806472, 0.21419683]; 4]); + } +} diff --git a/dfdx/src/nn/layers/ops/add.rs b/dfdx/src/nn/layers/ops/add.rs new file mode 100644 index 00000000..fcf40142 --- /dev/null +++ b/dfdx/src/nn/layers/ops/add.rs @@ -0,0 +1,20 @@ +use crate::prelude::*; + +/// Calls [crate::tensor_ops::add()] +#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)] +pub struct Add; + +// TODO: macro for more tuples +// TODO: lower the requirement, as long as one of the values can be broadcast into the other one +// TODO: check if this works for constants + +impl Module<(Input, Input)> for Add +where + Input: TryAdd, +{ + type Output = ::Output; + + fn try_forward(&self, x: (Input, Input)) -> Result { + x.0.try_add(x.1) + } +} diff --git a/dfdx/src/nn/layers/ops/mod.rs b/dfdx/src/nn/layers/ops/mod.rs new file mode 100644 index 00000000..6ecd1204 --- /dev/null +++ b/dfdx/src/nn/layers/ops/mod.rs @@ -0,0 +1,3 @@ +mod add; + +pub use add::Add;