Skip to content

Commit

Permalink
fix powf bugs (#1207)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Jan 31, 2024
1 parent f1d98bc commit e03facc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::binary_ops_tensor(
tensor,
exponent,
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|lhs, rhs| rhs.f_pow(lhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
)
}
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tests/ops/powf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ mod tests {
fn should_support_powf_ops() {
let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);
let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default());
let data_actual = tensor.powf(tensor_pow).into_data();
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 3125.0]]);
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);
data_expected.assert_approx_eq(&data_actual, 3);
}

Expand Down

0 comments on commit e03facc

Please sign in to comment.