Skip to content

Commit

Permalink
Replace explicit features and paths on generated code (#886)
Browse files Browse the repository at this point in the history
* Disallow feature detection on generated code

- Move feature detection to before generating code.
- Only emits the `SaveSafeTensors` and `LoadSafeTensors` derivations and the `#[serialize]` attr if the `safetensors` feature is enabled.

* Only use `::dfdx` as a path entrypoint on code generation

- Re-export the `self` crate as `dfdx`.
  - See [related question](https://users.rust-lang.org/t/how-to-express-crate-path-in-procedural-macros/91274/10).
  - Does the same for `dfdx-core`.
- Change the generated code so they always have the `::dfdx` path entrypoint.
  - User apps no longer need to depend on `dfdx-core` or `dfdx-derives` directly.
- Re-export `safetensors`.
  - User apps no longer need to depend on `safetensors` directly (as in, this is no longer required by the `SaveSafeTensors` and `LoadSafeTensors` derivations).

* Add safetensors feature to dfdx-derives

* fix  typo
  • Loading branch information
swfsql authored Dec 2, 2023
1 parent 4476b5e commit a522c7a
Show file tree
Hide file tree
Showing 24 changed files with 237 additions and 182 deletions.
1 change: 1 addition & 0 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
extern crate alloc;
#[cfg(all(feature = "no-std", not(feature = "std")))]
extern crate no_std_compat as std;
extern crate self as dfdx_core;

pub mod data;
pub mod dtypes;
Expand Down
3 changes: 2 additions & 1 deletion dfdx-derives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ syn = { version = "2", features = ["extra-traits"] }
dfdx-core = { path = "../dfdx-core" }

[features]
nightly = ["dfdx-core/nightly"]
nightly = ["dfdx-core/nightly"]
safetensors = ["dfdx-core/safetensors"]
261 changes: 147 additions & 114 deletions dfdx-derives/src/lib.rs

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion dfdx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ cudnn = ["dfdx-core/cudnn"]
f16 = ["dfdx-core/f16"]

numpy = ["dfdx-core/numpy"]
safetensors = ["dep:safetensors", "dfdx-core/safetensors"]
safetensors = [
"dep:safetensors",
"dfdx-core/safetensors",
"dfdx-derives/safetensors",
]

test-f16 = ["f16", "dfdx-core/f16"]
test-amp-f16 = ["f16", "dfdx-core/f16"]
Expand Down
9 changes: 9 additions & 0 deletions dfdx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,20 @@

#![cfg_attr(feature = "nightly", feature(generic_const_exprs))]

Check warning on line 253 in dfdx/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the feature `generic_const_exprs` is incomplete and may not be safe to use and/or cause compiler crashes

Check warning on line 253 in dfdx/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the feature `generic_const_exprs` is incomplete and may not be safe to use and/or cause compiler crashes

Check warning on line 253 in dfdx/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the feature `generic_const_exprs` is incomplete and may not be safe to use and/or cause compiler crashes

Check warning on line 253 in dfdx/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the feature `generic_const_exprs` is incomplete and may not be safe to use and/or cause compiler crashes

extern crate self as dfdx;

pub mod feature_flags;
pub mod nn;

pub use dfdx_core::*;

#[cfg(feature = "safetensors")]
pub use safetensors;

pub use dfdx_derives::{CustomModule, ResetParams, Sequential, UpdateParams, ZeroGrads};
#[cfg(feature = "safetensors")]
pub use dfdx_derives::{LoadSafeTensors, SaveSafeTensors};

pub mod prelude {
pub use crate::nn::*;
pub use dfdx_core::prelude::*;
Expand Down
9 changes: 4 additions & 5 deletions dfdx/src/nn/layers/add_into.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::prelude::*;

/// Add inputs together into a single tensor. `T` should be a tuple
//// where every element of the tuple has the same output type
/// where every element of the tuple has the same output type
///
/// This provides a utility for networks where multiple inputs are needed
///
Expand All @@ -19,13 +19,12 @@ use crate::prelude::*;
/// let b: Tensor<Rank1<3>, f32, _> = dev.zeros();
/// let _: Tensor<Rank1<5>, f32, _> = model.forward((a, b));
/// ```
#[derive(
Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
#[repr(transparent)]
pub struct AddInto<T>(
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub T,
);

Expand Down
15 changes: 8 additions & 7 deletions dfdx/src/nn/layers/batch_norm1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,29 +55,30 @@ impl<C: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for BatchNorm1DConfig<C
}

/// See [BatchNorm1DConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct BatchNorm1D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
/// Scale for affine transform. Defaults to 1.0
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub scale: Tensor<(C,), Elem, Dev>,
/// Bias for affine transform. Defaults to 0.0
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(C,), Elem, Dev>,
/// Spatial mean that is updated during training. Defaults to 0.0
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_mean: Tensor<(C,), Elem, Dev>,
/// Spatial variance that is updated during training. Defaults to 1.0
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_var: Tensor<(C,), Elem, Dev>,
/// Added to variance before taking sqrt for numerical stability. Defaults to 1e-5
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub epsilon: f64,
/// Controls exponential moving average of running stats. Defaults to 0.1
///
/// `running_stat * (1.0 - momentum) + stat * momentum`.
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub momentum: f64,
}

Expand Down
15 changes: 8 additions & 7 deletions dfdx/src/nn/layers/batch_norm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ impl<C: Dim, E: Dtype, D: Device<E>> crate::nn::BuildOnDevice<E, D> for BatchNor
}

/// See [BatchNorm2DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct BatchNorm2D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub scale: Tensor<(C,), Elem, Dev>,
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(C,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_mean: Tensor<(C,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_var: Tensor<(C,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub epsilon: f64,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub momentum: f64,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/bias1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ impl<I: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias1DConfig<I> {
}

/// See [Bias1DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Bias1D<I: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(I,), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/bias2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ impl<C: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias2DConfig<C> {
}

/// See [Bias2DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Bias2D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(C,), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ where
}

/// The module built with [Conv1DConfig]. See [Conv1DConfig] for usage.
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Debug, Clone, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Conv1D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
InChan: std::ops::Div<Groups>,
Expand All @@ -94,7 +95,7 @@ where
Dev: Device<Elem>,
{
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ where
}

/// The module built with [Conv2DConfig]. See [Conv2DConfig] for usage.
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Debug, Clone, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Conv2D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
InChan: std::ops::Div<Groups>,
Expand All @@ -115,7 +116,7 @@ where
Dev: Device<Elem>,
{
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/conv_trans2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ where
}

/// See [ConvTrans2DConfig].
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Debug, Clone, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct ConvTrans2D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
OutChan: std::ops::Div<Groups>,
Expand All @@ -93,7 +94,7 @@ where
Dev: Device<Elem>,
{
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ impl<V: Dim, M: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for EmbeddingCo
}

/// See [EmbeddingConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Embedding<Vocab: Dim, Model: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub weight: Tensor<(Vocab, Model), Elem, Dev>,
}

Expand Down
9 changes: 4 additions & 5 deletions dfdx/src/nn/layers/generalized_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ use crate::prelude::*;
/// let y = model.forward(x);
/// assert_eq!(y.array(), [4.0, 1.0, 0.0, 2.0, 6.0]);
/// ```
#[derive(
Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct GeneralizedAdd<T, U> {
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub t: T,
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub u: U,
}

Expand Down
9 changes: 4 additions & 5 deletions dfdx/src/nn/layers/generalized_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@ use crate::prelude::*;
/// let y = model.forward(x);
/// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 8.0]);
/// ```
#[derive(
Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct GeneralizedMul<T, U> {
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub t: T,
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub u: U,
}

Expand Down
9 changes: 5 additions & 4 deletions dfdx/src/nn/layers/layer_norm1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ impl<M: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for LayerNorm1DConfig<M
}

/// See [LayerNorm1DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct LayerNorm1D<M: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub gamma: Tensor<(M,), Elem, Dev>,
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub beta: Tensor<(M,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub epsilon: f64,
}

Expand Down
7 changes: 4 additions & 3 deletions dfdx/src/nn/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ impl<I: Dim, O: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for LinearConfi
}

/// See [LinearConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Linear<I: Dim, O: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub weight: Tensor<(O, I), Elem, Dev>,
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(O,), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ impl<I: Dim, O: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for MatMulConfi
}

/// See [MatMulConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct MatMul<I: Dim, O: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub weight: Tensor<(O, I), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/prelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ impl<E: Dtype, D: Device<E>> BuildOnDevice<E, D> for PReLUConfig {
}

/// See [PReLUConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct PReLU<Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub a: Tensor<(), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/prelu1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ impl<C: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for PReLU1DConfig<C> {
}

/// See [PReLU1DConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct PReLU1D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub a: Tensor<(C,), Elem, Dev>,
}

Expand Down
7 changes: 3 additions & 4 deletions dfdx/src/nn/layers/residual_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ use crate::prelude::*;
/// let y = model.forward(x);
/// assert_eq!(y.array(), [-2.0, -1.0, 0.0, 2.0, 4.0]);
/// ```
#[derive(
Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, SaveSafeTensors, LoadSafeTensors,
)]
#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
#[repr(transparent)]
pub struct ResidualAdd<T>(
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub T,
);

Expand Down
7 changes: 3 additions & 4 deletions dfdx/src/nn/layers/residual_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ use crate::prelude::*;
/// let y = model.forward(x);
/// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 4.0]);
/// ```
#[derive(
Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, SaveSafeTensors, LoadSafeTensors,
)]
#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
#[repr(transparent)]
pub struct ResidualMul<T>(
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub T,
);

Expand Down
Loading

0 comments on commit a522c7a

Please sign in to comment.