diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 53cf8065..31e61643 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -110,6 +110,7 @@ extern crate alloc; #[cfg(all(feature = "no-std", not(feature = "std")))] extern crate no_std_compat as std; +extern crate self as dfdx_core; pub mod data; pub mod dtypes; diff --git a/dfdx-derives/Cargo.toml b/dfdx-derives/Cargo.toml index 366ae08d..9941edbd 100644 --- a/dfdx-derives/Cargo.toml +++ b/dfdx-derives/Cargo.toml @@ -15,4 +15,5 @@ syn = { version = "2", features = ["extra-traits"] } dfdx-core = { path = "../dfdx-core" } [features] -nightly = ["dfdx-core/nightly"] \ No newline at end of file +nightly = ["dfdx-core/nightly"] +safetensors = ["dfdx-core/safetensors"] diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 13c03c75..4eca0d82 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -8,15 +8,15 @@ macro_rules! has_attr { }; } -/// Allows you to implement [dfdx_core::nn_traits::Module], while automatically implementing the following: -/// 1. [dfdx_core::nn_traits::BuildOnDevice] -/// 2. [dfdx_core::nn_traits::ResetParams] -/// 3. [dfdx_core::nn_traits::UpdateParams] -/// 4. [dfdx_core::nn_traits::ZeroGrads] -/// 5. [dfdx_core::nn_traits::SaveSafeTensors] -/// 6. [dfdx_core::nn_traits::LoadSafeTensors] +/// Allows you to implement [dfdx::nn_traits::Module], while automatically implementing the following: +/// 1. [dfdx::nn_traits::BuildOnDevice] +/// 2. [dfdx::nn_traits::ResetParams] +/// 3. [dfdx::nn_traits::UpdateParams] +/// 4. [dfdx::nn_traits::ZeroGrads] +/// 5. [dfdx::nn_traits::SaveSafeTensors] +/// 6. [dfdx::nn_traits::LoadSafeTensors] /// -/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx_core::nn_traits::BuildOnDevice]. +/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx::nn_traits::BuildOnDevice]. /// /// You can control the name of the built struct with the `#[built()]` attribute on the struct. /// @@ -25,9 +25,8 @@ macro_rules! has_attr { /// Here we have a unit struct that just calls a method on Tensor in the forward: /// /// ```ignore -/// # use dfdx::*; -/// # use dfdx_core::prelude::*; -/// #[derive(Default, Debug, Clone, Copy, CustomModule)] +/// # use dfdx::prelude::*; +/// #[derive(Default, Debug, Clone, Copy, dfdx::CustomModule)] /// pub struct Abs; /// impl, T: Tape> Module> for Abs { /// type Output = Tensor; @@ -40,9 +39,8 @@ macro_rules! has_attr { /// # Using CustomModule on structs with non-parameter fields /// /// ```ignore -/// # use dfdx::*; -/// # use dfdx_core::prelude::*; -/// #[derive(Default, Debug, Clone, Copy, CustomModule)] +/// # use dfdx::prelude::*; +/// #[derive(Default, Debug, Clone, Copy, dfdx::CustomModule)] /// pub struct Reshape(pub S); /// /// impl, T: Tape> Module> @@ -63,9 +61,8 @@ macro_rules! has_attr { /// 3. We must annotate the sub module with `#[module]` /// /// ```ignore -/// # use dfdx::*; -/// # use dfdx_core::prelude::*; -/// #[derive(Debug, Clone, CustomModule)] +/// # use dfdx::prelude::*; +/// #[derive(Debug, Clone, dfdx::CustomModule)] /// #[built(ResidualMatMul)] /// pub struct ResidualMatMulConfig(#[module] pub matmul: MatMulConfig); /// @@ -109,8 +106,13 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream has_fields_to_build = true; where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice)); - quote_spanned!(f.span()=> #[module] #[serialize] #vis #name: <#ty as dfdx_core::nn_traits::BuildOnDevice>::Built,) + .push(parse_quote!(#ty: ::dfdx::nn_traits::BuildOnDevice)); + let safetensors_serialize_attr = if cfg!(feature = "safetensors") { + quote!(#[serialize]) + } else { + quote!() + }; + quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis #name: <#ty as ::dfdx::nn_traits::BuildOnDevice>::Built,) } else { quote_spanned!(f.span()=> #vis #name: #ty,) } @@ -125,8 +127,13 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream has_fields_to_build = true; where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice)); - quote_spanned!(f.span()=> #[module] #[serialize] #vis <#ty as dfdx_core::nn_traits::BuildOnDevice>::Built,) + .push(parse_quote!(#ty: ::dfdx::nn_traits::BuildOnDevice)); + let safetensors_serialize_attr = if cfg!(feature = "safetensors") { + quote!(#[serialize]) + } else { + quote!() + }; + quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis <#ty as ::dfdx::nn_traits::BuildOnDevice>::Built,) } else { quote_spanned!(f.span()=> #vis #ty,) } @@ -143,10 +150,10 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream let built_name = if has_fields_to_build { built_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); built_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); input .attrs .iter() @@ -162,8 +169,13 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream let (built_impl, _, built_where) = built_generics.split_for_impl(); let def = if has_fields_to_build { + let safetensors_derive = if cfg!(feature = "safetensors") { + quote!(::dfdx::SaveSafeTensors, ::dfdx::LoadSafeTensors) + } else { + quote!() + }; quote! { - #[derive(Clone, Debug, dfdx_derives::ResetParams, dfdx_derives::UpdateParams, dfdx_derives::ZeroGrads, dfdx_derives::SaveSafeTensors, dfdx_derives::LoadSafeTensors)] + #[derive(Clone, Debug, ::dfdx::ResetParams, ::dfdx::UpdateParams, ::dfdx::ZeroGrads, #safetensors_derive)] pub struct #built_name #built_impl #built_where #fields } } else { @@ -173,54 +185,60 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream if !has_fields_to_build { build_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); build_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); } let (build_impl, _, _) = build_generics.split_for_impl(); let (built_impl, built_ty, built_where) = built_generics.split_for_impl(); - quote! { - #[cfg(feature = "safetensors")] - impl #built_impl dfdx_core::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where { - fn write_safetensors( - &self, - location: &str, - tensors: &mut Vec<(String, ::safetensors::Dtype, Vec, Vec)>, - ) {} - } + let safetensors_impls = if cfg!(feature = "safetensors") { + quote! { + impl #built_impl ::dfdx::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where { + fn write_safetensors( + &self, + location: &str, + tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + ) {} + } - #[cfg(feature = "safetensors")] - impl #built_impl dfdx_core::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where { - fn read_safetensors<'a>( - &mut self, - location: &str, - tensors: &::safetensors::SafeTensors<'a>, - ) -> Result<(), ::safetensors::SafeTensorError> { - Ok(()) + impl #built_impl ::dfdx::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where { + fn read_safetensors<'a>( + &mut self, + location: &str, + tensors: &::dfdx::safetensors::SafeTensors<'a>, + ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { + Ok(()) + } } } + } else { + quote! {} + }; - impl #build_impl dfdx_core::nn_traits::ResetParams for #builder_name #built_ty #built_where { - fn try_reset_params(&mut self) -> Result<(), dfdx_core::tensor::Error> { + quote! { + #safetensors_impls + + impl #build_impl ::dfdx::nn_traits::ResetParams for #builder_name #built_ty #built_where { + fn try_reset_params(&mut self) -> Result<(), ::dfdx::tensor::Error> { Ok(()) } } - impl #build_impl dfdx_core::nn_traits::UpdateParams for #builder_name #built_ty #built_where { - fn try_update_params>( + impl #build_impl ::dfdx::nn_traits::UpdateParams for #builder_name #built_ty #built_where { + fn try_update_params>( &mut self, optimizer: &mut Optim, - gradients: &dfdx_core::tensor::Gradients, - missing_tensors: &mut Vec, - ) -> Result<(), dfdx_core::tensor::Error> { + gradients: &::dfdx::tensor::Gradients, + missing_tensors: &mut Vec<::dfdx::tensor::UniqueId>, + ) -> Result<(), ::dfdx::tensor::Error> { Ok(()) } } - impl #build_impl dfdx_core::nn_traits::ZeroGrads for #builder_name #built_ty #built_where { - fn try_zero_grads(&self, grads: &mut dfdx_core::tensor::Gradients) -> Result<(), dfdx_core::tensor::Error> { + impl #build_impl ::dfdx::nn_traits::ZeroGrads for #builder_name #built_ty #built_where { + fn try_zero_grads(&self, grads: &mut ::dfdx::tensor::Gradients) -> Result<(), ::dfdx::tensor::Error> { Ok(()) } } @@ -236,10 +254,10 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream if !has_fields_to_build { build_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); build_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); } let (build_impl, _, _) = build_generics.split_for_impl(); let (_, built_ty, built_where) = built_generics.split_for_impl(); @@ -256,9 +274,9 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream } }); quote! { - impl #build_impl dfdx_core::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { + impl #build_impl ::dfdx::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { type Built = #built_name #built_ty; - fn try_build_on_device(&self, device: &Dev) -> Result { + fn try_build_on_device(&self, device: &Dev) -> Result { let built = #built_name { #(#recurse)* }; Ok(built) } @@ -275,9 +293,9 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream } }); quote! { - impl #build_impl dfdx_core::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { + impl #build_impl ::dfdx::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { type Built = #built_name #built_ty; - fn try_build_on_device(&self, device: &Dev) -> Result { + fn try_build_on_device(&self, device: &Dev) -> Result { let built = #built_name(#(#recurse)*); Ok(built) } @@ -286,9 +304,9 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream } Fields::Unit => { quote! { - impl #build_impl dfdx_core::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { + impl #build_impl ::dfdx::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { type Built = #built_name #built_ty; - fn try_build_on_device(&self, device: &Dev) -> Result { + fn try_build_on_device(&self, device: &Dev) -> Result { Ok(#built_name) } } @@ -308,7 +326,7 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream /// Implements all of the dfdx_nn traits automatically on your type. Assumes all fields on your type /// are modules (i.e. they also implement all the dfdx_nn traits). /// -/// [dfdx_core::nn_traits::Module] is implemented as calling each of the fields on the type in definition order. +/// [dfdx::nn_traits::Module] is implemented as calling each of the fields on the type in definition order. /// /// # Example usage /// Here we define a simple feedforward network with 3 layers. @@ -320,9 +338,8 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream /// 4. act2 /// 5. linear3 /// ```ignore -/// # use dfdx_core::prelude::*; -/// # use dfdx::*; -/// #[derive(Debug, Clone, Sequential)] +/// # use dfdx::prelude::*; +/// #[derive(Debug, Clone, dfdx::Sequential)] /// #[built(Mlp)] /// struct MlpConfig { /// // Linear with compile time input size & runtime known output size @@ -350,10 +367,10 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut built_generics = input.generics.clone(); built_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); built_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); // get the generics for the impl. `Input` must be added only to the impl_generics. // NOTE: without cloning, `Input` will appear in both impl & ty generics. @@ -372,8 +389,13 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let vis = &f.vis; where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice)); - quote_spanned!(f.span()=> #[module] #[serialize] #vis #name: <#ty as dfdx_core::nn_traits::BuildOnDevice>::Built,) + .push(parse_quote!(#ty: ::dfdx::nn_traits::BuildOnDevice)); + let safetensors_serialize_attr = if cfg!(feature = "safetensors") { + quote!(#[serialize]) + } else { + quote!() + }; + quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis #name: <#ty as ::dfdx::nn_traits::BuildOnDevice>::Built,) }); quote! { #(#fields)* } } @@ -383,8 +405,13 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let vis = &f.vis; where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice)); - quote_spanned!(f.span()=> #[module] #[serialize] #vis <#ty as dfdx_core::nn_traits::BuildOnDevice>::Built,) + .push(parse_quote!(#ty: ::dfdx::nn_traits::BuildOnDevice)); + let safetensors_serialize_attr = if cfg!(feature = "safetensors") { + quote!(#[serialize]) + } else { + quote!() + }; + quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis <#ty as ::dfdx::nn_traits::BuildOnDevice>::Built,) }); quote! { #(#fields)* } } @@ -397,8 +424,14 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let (built_impl, _, built_where) = built_generics.split_for_impl(); + let safetensors_derive = if cfg!(feature = "safetensors") { + quote!(::dfdx::SaveSafeTensors, ::dfdx::LoadSafeTensors) + } else { + quote!() + }; + quote! { - #[derive(Clone, Debug, dfdx_derives::ResetParams, dfdx_derives::UpdateParams, dfdx_derives::ZeroGrads, dfdx_derives::SaveSafeTensors, dfdx_derives::LoadSafeTensors)] + #[derive(Clone, Debug, ::dfdx::ResetParams, ::dfdx::UpdateParams, ::dfdx::ZeroGrads, #safetensors_derive)] pub struct #built_name #built_impl #built_where { #fields } @@ -417,9 +450,9 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { quote_spanned! {f.span()=> #name: self.#name.try_build_on_device(device)?, } }); quote! { - impl #built_impl dfdx_core::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { + impl #built_impl ::dfdx::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { type Built = #built_name #built_ty; - fn try_build_on_device(&self, device: &Dev) -> Result { + fn try_build_on_device(&self, device: &Dev) -> Result { let built = #built_name { #(#recurse)* }; @@ -434,9 +467,9 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { quote_spanned! {f.span()=> self.#index.try_build_on_device(device)?, } }); quote! { - impl #built_impl dfdx_core::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { + impl #built_impl ::dfdx::nn_traits::BuildOnDevice for #builder_name #builder_ty #built_where { type Built = #built_name #built_ty; - fn try_build_on_device(&self, device: &Dev) -> Result { + fn try_build_on_device(&self, device: &Dev) -> Result { #built_name( #(#recurse)* ) @@ -461,11 +494,11 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ty = &f.ty; where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::BuildOnDevice)); where_clause .predicates - .push(parse_quote!(<#ty as dfdx_core::nn_traits::BuildOnDevice>::Built: dfdx_core::nn_traits::Module<#last_ty>)); - last_ty = parse_quote!(<<#ty as dfdx_core::nn_traits::BuildOnDevice>::Built as dfdx_core::nn_traits::Module<#last_ty>>::Output); + .push(parse_quote!(<#ty as ::dfdx::nn_traits::BuildOnDevice>::Built: ::dfdx::nn_traits::Module<#last_ty>)); + last_ty = parse_quote!(<<#ty as ::dfdx::nn_traits::BuildOnDevice>::Built as ::dfdx::nn_traits::Module<#last_ty>>::Output); }); } Fields::Unnamed(ref fields) => { @@ -473,11 +506,11 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ty = &f.ty; where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::BuildOnDevice)); where_clause .predicates - .push(parse_quote!(<#ty as dfdx_core::nn_traits::BuildOnDevice>::Built: dfdx_core::nn_traits::Module<#last_ty>)); - last_ty = parse_quote!(<<#ty as dfdx_core::nn_traits::BuildOnDevice>::Built as dfdx_core::nn_traits::Module<#last_ty>>::Output); + .push(parse_quote!(<#ty as ::dfdx::nn_traits::BuildOnDevice>::Built: ::dfdx::nn_traits::Module<#last_ty>)); + last_ty = parse_quote!(<<#ty as ::dfdx::nn_traits::BuildOnDevice>::Built as ::dfdx::nn_traits::Module<#last_ty>>::Output); }); } Fields::Unit => {} @@ -535,7 +568,7 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let (module_impl, _, module_where) = module_generics.split_for_impl(); quote! { - impl #module_impl dfdx_core::nn_traits::Module for #built_name #built_ty #module_where { + impl #module_impl ::dfdx::nn_traits::Module for #built_name #built_ty #module_where { type Output = #output_ty; fn try_forward(&self, x: Input) -> Result { #src @@ -568,7 +601,7 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream { ) { custom_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); } if !custom_generics.params.iter().any( @@ -576,7 +609,7 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream { ) { custom_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); } let where_clause = input.generics.make_where_clause(); @@ -589,7 +622,7 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream { if has_attr!(f, "module") { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::ResetParams)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::ResetParams)); quote_spanned!(f.span()=>self.#name.try_reset_params()?;) } else { Default::default() @@ -604,7 +637,7 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream { if has_attr!(f, "module") { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::ResetParams)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::ResetParams)); quote_spanned!(f.span()=>self.#index.try_reset_params()?;) } else { Default::default() @@ -622,8 +655,8 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let (_, ty_generics, where_clause) = input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { - impl #impl_generics dfdx_core::nn_traits::ResetParams for #name #ty_generics #where_clause { - fn try_reset_params(&mut self) -> Result<(), dfdx_core::tensor::Error> { + impl #impl_generics ::dfdx::nn_traits::ResetParams for #name #ty_generics #where_clause { + fn try_reset_params(&mut self) -> Result<(), ::dfdx::tensor::Error> { #resets Ok(()) } @@ -643,7 +676,7 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream ) { custom_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); } if !custom_generics.params.iter().any( @@ -651,7 +684,7 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream ) { custom_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); } let where_clause = input.generics.make_where_clause(); @@ -664,7 +697,7 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream if has_attr!(f, "module") { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::UpdateParams)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::UpdateParams)); quote_spanned!(f.span()=>self.#name.try_update_params(optimizer, gradients, missing_tensors)?;) } else if has_attr!(f, "param") { quote_spanned!(f.span()=>optimizer.update_tensor(&mut self.#name, gradients, missing_tensors)?;) @@ -681,7 +714,7 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream if has_attr!(f, "module") { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::UpdateParams)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::UpdateParams)); quote_spanned!(f.span()=>self.#index.try_update_params(optimizer, gradients, missing_tensors)?;) } else if has_attr!(f, "param") { quote_spanned!(f.span()=>optimizer.update_tensor(&mut self.#index, gradients, missing_tensors)?;) @@ -701,13 +734,13 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream let (_, ty_generics, where_clause) = input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { - impl #impl_generics dfdx_core::nn_traits::UpdateParams for #struct_name #ty_generics #where_clause { - fn try_update_params<_Model, Optim: dfdx_core::nn_traits::Optimizer<_Model, Elem, Dev>>( + impl #impl_generics ::dfdx::nn_traits::UpdateParams for #struct_name #ty_generics #where_clause { + fn try_update_params<_Model, Optim: ::dfdx::nn_traits::Optimizer<_Model, Elem, Dev>>( &mut self, optimizer: &mut Optim, - gradients: &dfdx_core::tensor::Gradients, - missing_tensors: &mut Vec, - ) -> Result<(), dfdx_core::tensor::Error> { + gradients: &::dfdx::tensor::Gradients, + missing_tensors: &mut Vec<::dfdx::tensor::UniqueId>, + ) -> Result<(), ::dfdx::tensor::Error> { #updates Ok(()) } @@ -727,7 +760,7 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { ) { custom_generics .params - .push(parse_quote!(Elem: dfdx_core::prelude::Dtype)); + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); } if !custom_generics.params.iter().any( @@ -735,7 +768,7 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { ) { custom_generics .params - .push(parse_quote!(Dev: dfdx_core::prelude::Device)); + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); } let where_clause = input.generics.make_where_clause(); @@ -749,7 +782,7 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::ZeroGrads)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::ZeroGrads)); quote_spanned!(f.span()=>self.#name.try_zero_grads(grads)?;) } else if has_attr!(f, "param") { @@ -768,7 +801,7 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::ZeroGrads)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::ZeroGrads)); quote_spanned!(f.span()=>self.#index.try_zero_grads(grads)?;) } else if has_attr!(f, "param") { @@ -789,8 +822,8 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let (_, ty_generics, where_clause) = input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { - impl #impl_generics dfdx_core::nn_traits::ZeroGrads for #name #ty_generics #where_clause { - fn try_zero_grads(&self, grads: &mut dfdx_core::prelude::Gradients) -> Result<(), dfdx_core::tensor::Error> { + impl #impl_generics ::dfdx::nn_traits::ZeroGrads for #name #ty_generics #where_clause { + fn try_zero_grads(&self, grads: &mut ::dfdx::prelude::Gradients) -> Result<(), ::dfdx::tensor::Error> { #zero_grads Ok(()) } @@ -816,7 +849,7 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::SaveSafeTensors)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); quote_spanned!(f.span()=>self.#name.write_safetensors(&format!("{location}{}", #name_str), tensors);) } else { Default::default() @@ -832,7 +865,7 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::SaveSafeTensors)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); quote_spanned!(f.span()=>self.#index.write_safetensors(&format!("{location}{}", #index), tensors);) } else { Default::default() @@ -849,12 +882,12 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { - #[cfg(feature = "safetensors")] - impl #impl_generics dfdx_core::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause { + // note: SaveSafeTensors definition is already gated by the safetensors feature + impl #impl_generics ::dfdx::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause { fn write_safetensors( &self, location: &str, - tensors: &mut Vec<(String, ::safetensors::Dtype, Vec, Vec)>, + tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, ) { #save_fields } @@ -879,7 +912,7 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre if has_attr!(f, "serialize") { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::LoadSafeTensors)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); quote_spanned!(f.span()=>self.#name.read_safetensors(&format!("{location}{}", #name_str), tensors)?;) } else { Default::default() @@ -894,7 +927,7 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre if has_attr!(f, "serialize") { where_clause .predicates - .push(parse_quote!(#ty: dfdx_core::nn_traits::LoadSafeTensors)); + .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); quote_spanned!(f.span()=>self.#index.read_safetensors(&format!("{location}{}", #index), tensors)?;) } else { Default::default() @@ -911,13 +944,13 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { - #[cfg(feature = "safetensors")] - impl #impl_generics dfdx_core::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause { + // note: LoadSafeTensors definition is already gated by the safetensors feature + impl #impl_generics ::dfdx::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause { fn read_safetensors<'a>( &mut self, location: &str, - tensors: &::safetensors::SafeTensors<'a>, - ) -> Result<(), ::safetensors::SafeTensorError> { + tensors: &::dfdx::safetensors::SafeTensors<'a>, + ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { #load_fields Ok(()) } diff --git a/dfdx/Cargo.toml b/dfdx/Cargo.toml index 223c63e4..a79e81c4 100644 --- a/dfdx/Cargo.toml +++ b/dfdx/Cargo.toml @@ -53,7 +53,11 @@ cudnn = ["dfdx-core/cudnn"] f16 = ["dfdx-core/f16"] numpy = ["dfdx-core/numpy"] -safetensors = ["dep:safetensors", "dfdx-core/safetensors"] +safetensors = [ + "dep:safetensors", + "dfdx-core/safetensors", + "dfdx-derives/safetensors", +] test-f16 = ["f16", "dfdx-core/f16"] test-amp-f16 = ["f16", "dfdx-core/f16"] diff --git a/dfdx/src/lib.rs b/dfdx/src/lib.rs index cf4be4a0..b07b77df 100644 --- a/dfdx/src/lib.rs +++ b/dfdx/src/lib.rs @@ -252,11 +252,20 @@ #![cfg_attr(feature = "nightly", feature(generic_const_exprs))] +extern crate self as dfdx; + pub mod feature_flags; pub mod nn; pub use dfdx_core::*; +#[cfg(feature = "safetensors")] +pub use safetensors; + +pub use dfdx_derives::{CustomModule, ResetParams, Sequential, UpdateParams, ZeroGrads}; +#[cfg(feature = "safetensors")] +pub use dfdx_derives::{LoadSafeTensors, SaveSafeTensors}; + pub mod prelude { pub use crate::nn::*; pub use dfdx_core::prelude::*; diff --git a/dfdx/src/nn/layers/add_into.rs b/dfdx/src/nn/layers/add_into.rs index 98310339..5b57dd56 100644 --- a/dfdx/src/nn/layers/add_into.rs +++ b/dfdx/src/nn/layers/add_into.rs @@ -1,7 +1,7 @@ use crate::prelude::*; /// Add inputs together into a single tensor. `T` should be a tuple -//// where every element of the tuple has the same output type +/// where every element of the tuple has the same output type /// /// This provides a utility for networks where multiple inputs are needed /// @@ -19,13 +19,12 @@ use crate::prelude::*; /// let b: Tensor, f32, _> = dev.zeros(); /// let _: Tensor, f32, _> = model.forward((a, b)); /// ``` -#[derive( - Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, -)] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct AddInto( #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub T, ); diff --git a/dfdx/src/nn/layers/batch_norm1d.rs b/dfdx/src/nn/layers/batch_norm1d.rs index e6186365..ce8ff93c 100644 --- a/dfdx/src/nn/layers/batch_norm1d.rs +++ b/dfdx/src/nn/layers/batch_norm1d.rs @@ -55,29 +55,30 @@ impl> BuildOnDevice for BatchNorm1DConfig> { /// Scale for affine transform. Defaults to 1.0 #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub scale: Tensor<(C,), Elem, Dev>, /// Bias for affine transform. Defaults to 0.0 #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub bias: Tensor<(C,), Elem, Dev>, /// Spatial mean that is updated during training. Defaults to 0.0 - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub running_mean: Tensor<(C,), Elem, Dev>, /// Spatial variance that is updated during training. Defaults to 1.0 - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub running_var: Tensor<(C,), Elem, Dev>, /// Added to variance before taking sqrt for numerical stability. Defaults to 1e-5 - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub epsilon: f64, /// Controls exponential moving average of running stats. Defaults to 0.1 /// /// `running_stat * (1.0 - momentum) + stat * momentum`. - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub momentum: f64, } diff --git a/dfdx/src/nn/layers/batch_norm2d.rs b/dfdx/src/nn/layers/batch_norm2d.rs index 772923d7..bd8fc013 100644 --- a/dfdx/src/nn/layers/batch_norm2d.rs +++ b/dfdx/src/nn/layers/batch_norm2d.rs @@ -57,21 +57,22 @@ impl> crate::nn::BuildOnDevice for BatchNor } /// See [BatchNorm2DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct BatchNorm2D> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub scale: Tensor<(C,), Elem, Dev>, #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub bias: Tensor<(C,), Elem, Dev>, - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub running_mean: Tensor<(C,), Elem, Dev>, - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub running_var: Tensor<(C,), Elem, Dev>, - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub epsilon: f64, - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub momentum: f64, } diff --git a/dfdx/src/nn/layers/bias1d.rs b/dfdx/src/nn/layers/bias1d.rs index eb68a67b..c904df1a 100644 --- a/dfdx/src/nn/layers/bias1d.rs +++ b/dfdx/src/nn/layers/bias1d.rs @@ -36,10 +36,11 @@ impl> BuildOnDevice for Bias1DConfig { } /// See [Bias1DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Bias1D> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub bias: Tensor<(I,), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/bias2d.rs b/dfdx/src/nn/layers/bias2d.rs index 1b6ac42a..397e0707 100644 --- a/dfdx/src/nn/layers/bias2d.rs +++ b/dfdx/src/nn/layers/bias2d.rs @@ -36,10 +36,11 @@ impl> BuildOnDevice for Bias2DConfig { } /// See [Bias2DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Bias2D> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub bias: Tensor<(C,), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/conv1d.rs b/dfdx/src/nn/layers/conv1d.rs index 793b326a..5241b912 100644 --- a/dfdx/src/nn/layers/conv1d.rs +++ b/dfdx/src/nn/layers/conv1d.rs @@ -78,7 +78,8 @@ where } /// The module built with [Conv1DConfig]. See [Conv1DConfig] for usage. -#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Debug, Clone, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Conv1D where InChan: std::ops::Div, @@ -94,7 +95,7 @@ where Dev: Device, { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] #[allow(clippy::type_complexity)] pub weight: Tensor< ( diff --git a/dfdx/src/nn/layers/conv2d.rs b/dfdx/src/nn/layers/conv2d.rs index 9e65f41f..c88ea821 100644 --- a/dfdx/src/nn/layers/conv2d.rs +++ b/dfdx/src/nn/layers/conv2d.rs @@ -99,7 +99,8 @@ where } /// The module built with [Conv2DConfig]. See [Conv2DConfig] for usage. -#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Debug, Clone, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Conv2D where InChan: std::ops::Div, @@ -115,7 +116,7 @@ where Dev: Device, { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] #[allow(clippy::type_complexity)] pub weight: Tensor< ( diff --git a/dfdx/src/nn/layers/conv_trans2d.rs b/dfdx/src/nn/layers/conv_trans2d.rs index b4730a92..b7683676 100644 --- a/dfdx/src/nn/layers/conv_trans2d.rs +++ b/dfdx/src/nn/layers/conv_trans2d.rs @@ -77,7 +77,8 @@ where } /// See [ConvTrans2DConfig]. -#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Debug, Clone, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct ConvTrans2D where OutChan: std::ops::Div, @@ -93,7 +94,7 @@ where Dev: Device, { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] #[allow(clippy::type_complexity)] pub weight: Tensor< ( diff --git a/dfdx/src/nn/layers/embedding.rs b/dfdx/src/nn/layers/embedding.rs index 2979c9d7..6c7971e9 100644 --- a/dfdx/src/nn/layers/embedding.rs +++ b/dfdx/src/nn/layers/embedding.rs @@ -51,10 +51,11 @@ impl> BuildOnDevice for EmbeddingCo } /// See [EmbeddingConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Embedding> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub weight: Tensor<(Vocab, Model), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/generalized_add.rs b/dfdx/src/nn/layers/generalized_add.rs index fb60b9c3..3dc4708b 100644 --- a/dfdx/src/nn/layers/generalized_add.rs +++ b/dfdx/src/nn/layers/generalized_add.rs @@ -18,15 +18,14 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [4.0, 1.0, 0.0, 2.0, 6.0]); /// ``` -#[derive( - Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, -)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct GeneralizedAdd { #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub t: T, #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub u: U, } diff --git a/dfdx/src/nn/layers/generalized_mul.rs b/dfdx/src/nn/layers/generalized_mul.rs index b6b5b5ba..64562024 100644 --- a/dfdx/src/nn/layers/generalized_mul.rs +++ b/dfdx/src/nn/layers/generalized_mul.rs @@ -17,15 +17,14 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 8.0]); /// ``` -#[derive( - Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, -)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct GeneralizedMul { #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub t: T, #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub u: U, } diff --git a/dfdx/src/nn/layers/layer_norm1d.rs b/dfdx/src/nn/layers/layer_norm1d.rs index a1381534..363db245 100644 --- a/dfdx/src/nn/layers/layer_norm1d.rs +++ b/dfdx/src/nn/layers/layer_norm1d.rs @@ -38,15 +38,16 @@ impl> BuildOnDevice for LayerNorm1DConfig> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub gamma: Tensor<(M,), Elem, Dev>, #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub beta: Tensor<(M,), Elem, Dev>, - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub epsilon: f64, } diff --git a/dfdx/src/nn/layers/linear.rs b/dfdx/src/nn/layers/linear.rs index ce63b63a..2d8f2e08 100644 --- a/dfdx/src/nn/layers/linear.rs +++ b/dfdx/src/nn/layers/linear.rs @@ -47,13 +47,14 @@ impl> BuildOnDevice for LinearConfi } /// See [LinearConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Linear> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub weight: Tensor<(O, I), Elem, Dev>, #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub bias: Tensor<(O,), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/matmul.rs b/dfdx/src/nn/layers/matmul.rs index 86a3fa52..a2e301e9 100644 --- a/dfdx/src/nn/layers/matmul.rs +++ b/dfdx/src/nn/layers/matmul.rs @@ -36,10 +36,11 @@ impl> BuildOnDevice for MatMulConfi } /// See [MatMulConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct MatMul> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub weight: Tensor<(O, I), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/prelu.rs b/dfdx/src/nn/layers/prelu.rs index 6af12151..9eb8f508 100644 --- a/dfdx/src/nn/layers/prelu.rs +++ b/dfdx/src/nn/layers/prelu.rs @@ -19,10 +19,11 @@ impl> BuildOnDevice for PReLUConfig { } /// See [PReLUConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct PReLU> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub a: Tensor<(), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/prelu1d.rs b/dfdx/src/nn/layers/prelu1d.rs index b1a405c1..fa0a35b9 100644 --- a/dfdx/src/nn/layers/prelu1d.rs +++ b/dfdx/src/nn/layers/prelu1d.rs @@ -25,10 +25,11 @@ impl> BuildOnDevice for PReLU1DConfig { } /// See [PReLU1DConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct PReLU1D> { #[param] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub a: Tensor<(C,), Elem, Dev>, } diff --git a/dfdx/src/nn/layers/residual_add.rs b/dfdx/src/nn/layers/residual_add.rs index 5da249a9..77c7ca97 100644 --- a/dfdx/src/nn/layers/residual_add.rs +++ b/dfdx/src/nn/layers/residual_add.rs @@ -17,13 +17,12 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [-2.0, -1.0, 0.0, 2.0, 4.0]); /// ``` -#[derive( - Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, SaveSafeTensors, LoadSafeTensors, -)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct ResidualAdd( #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub T, ); diff --git a/dfdx/src/nn/layers/residual_mul.rs b/dfdx/src/nn/layers/residual_mul.rs index c55787ff..7a9c9c9d 100644 --- a/dfdx/src/nn/layers/residual_mul.rs +++ b/dfdx/src/nn/layers/residual_mul.rs @@ -16,13 +16,12 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 4.0]); /// ``` -#[derive( - Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, SaveSafeTensors, LoadSafeTensors, -)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct ResidualMul( #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub T, ); diff --git a/dfdx/src/nn/layers/split_into.rs b/dfdx/src/nn/layers/split_into.rs index de905a0d..440cba5c 100644 --- a/dfdx/src/nn/layers/split_into.rs +++ b/dfdx/src/nn/layers/split_into.rs @@ -21,13 +21,12 @@ use crate::prelude::*; /// let model = dev.build_module::(Model::default()); /// let _: (Tensor, f32, _>, Tensor, f32, _>) = model.forward(dev.zeros::>()); /// ``` -#[derive( - Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, -)] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct SplitInto( #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub T, );