From 2668e78b6d6de8d1e87068e36cc6da0a221aeb75 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Mon, 27 Nov 2023 16:06:55 -0800 Subject: [PATCH] Removed `num_traits::Num` requirement from Zeros. Had to figure out a way to store zeros in place --- dfdx-core/src/tensor/webgpu/allocate.rs | 16 ++++------------ dfdx-core/src/tensor_ops/utilities/device.rs | 4 ++-- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index c90d255b..52b22724 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -52,28 +52,20 @@ impl Webgpu { } } -impl> ZerosTensor for Webgpu { +impl ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); let data = unsafe { self.alloc_empty::(shape.num_elements()) }?; - data.copy_to_device( - &self.dev, - &self.queue, - &vec![E::from(false); shape.num_elements()], - ); + data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]); Ok(self.build_tensor(shape, strides, data)) } } -impl> ZeroFillStorage for Webgpu { +impl ZeroFillStorage for Webgpu { fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> { - storage.copy_to_device( - &self.dev, - &self.queue, - &vec![E::from(false); storage.size() as usize / std::mem::size_of::()], - ); + storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]); Ok(()) } diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 5abe4d9f..277be7a6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -132,9 +132,9 @@ impl Device for crate::tensor::Cuda {} impl Device for crate::tensor::Cuda {} #[cfg(all(feature = "webgpu", feature = "f16"))] -impl Device for crate::tensor::Cuda {} +impl Device for crate::tensor::Webgpu {} #[cfg(all(feature = "webgpu", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} +impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] impl Device for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")]