Skip to content

Commit

Permalink
fix test for f64
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Nov 6, 2023
1 parent 1f2b10d commit ade4e15
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions dfdx/src/nn/layers/on.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ mod tests {
use super::*;
use crate::tests::*;

#[input_wrapper]
pub struct MyWrapper<A, B> {
pub a: A,
pub b: B,
}

#[input_wrapper]
pub struct Split1<Forward, Skip> {
pub forward: Forward,
Expand Down Expand Up @@ -73,8 +67,20 @@ mod tests {
fn test_residual_add_backward() {
let dev: TestDevice = Default::default();

let model =
dev.build_module::<TestDtype>(<ResidualAdd1<LinearConstConfig<2, 2>>>::default());
let model = dev.build_module::<f32>(<ResidualAdd1<LinearConstConfig<2, 2>>>::default());
let model = DeviceResidualAdd1::<LinearConstConfig<2, 2>, TestDtype, TestDevice> {
t: On {
t: Linear {
weight: model.t.t.weight.to_dtype::<TestDtype>(),
bias: model.t.t.bias.to_dtype::<TestDtype>(),
},
_n: Default::default(),
},
add: Default::default(),
input_to_tuple: Default::default(),
input_to_wrapper: Default::default(),
split: Default::default(),
};

let x: Tensor<Rank2<4, 2>, f32, _> = dev.sample_normal();
let x = x.to_dtype::<TestDtype>();
Expand Down Expand Up @@ -115,8 +121,20 @@ mod tests {
fn test_residual_add_backward2() {
let dev: TestDevice = Default::default();

let model =
dev.build_module::<TestDtype>(<ResidualAdd2<LinearConstConfig<2, 2>>>::default());
let model = dev.build_module::<f32>(<ResidualAdd2<LinearConstConfig<2, 2>>>::default());
let model = DeviceResidualAdd2::<LinearConstConfig<2, 2>, TestDtype, TestDevice> {
t: On {
t: Linear {
weight: model.t.t.weight.to_dtype::<TestDtype>(),
bias: model.t.t.bias.to_dtype::<TestDtype>(),
},
_n: Default::default(),
},
add: Default::default(),
input_to_tuple: Default::default(),
input_to_wrapper: Default::default(),
split: Default::default(),
};

let x: Tensor<Rank2<4, 2>, f32, _> = dev.sample_normal();
let x = x.to_dtype::<TestDtype>();
Expand Down

0 comments on commit ade4e15

Please sign in to comment.