Skip to content

Commit

Permalink
Implement base Webgpu device (#893)
Browse files Browse the repository at this point in the history
* Implement Webgpu device.

All the tests pass, but it doesn't implement `Device` yet, because that
is a lot more work.

* Removed some of the more low level commands in favor of a wrapper struct

Also added tests for higher code coverage.

* AtomicPtr unsound fix

* Partial implementation of `Device<E>` for Webgpu

* Finished up with template tensor_ops

* Remove foolish Mutex

* Cargo fmt

* Adding f16 to webgpu

* Add Mutex back, since evidently it was causing issues.

Hopefully I can figure out a way to remove it again.

* Removed `num_traits::Num` requirement from Zeros.

Had to figure out a way to store zeros in place

* Fixing no-std stuff
  • Loading branch information
favilo authored Dec 3, 2023
1 parent 2ff7b60 commit dda8daa
Show file tree
Hide file tree
Showing 98 changed files with 2,362 additions and 1 deletion.
4 changes: 4 additions & 0 deletions dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", optional = true }
futures-lite = { version = "2.0.1", optional = true }
thingbuf = { version = "0.1.4", optional = true }

[dev-dependencies]
tempfile = "3.3.0"
Expand All @@ -59,6 +62,7 @@ fast-alloc = ["std"]

cuda = ["dep:cudarc", "dep:glob"]
cudnn = ["cuda", "cudarc?/cudnn"]
webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"]

f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"]

Expand Down
6 changes: 6 additions & 0 deletions dfdx-core/src/tensor/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ pub enum Error {

#[cfg(feature = "cudnn")]
CudnnError(cudarc::cudnn::CudnnError),

#[cfg(feature = "webgpu")]
WebgpuAdapterNotFound,

#[cfg(feature = "webgpu")]
WebgpuRequestDeviceError(wgpu::RequestDeviceError),
}

impl std::fmt::Display for Error {
Expand Down
9 changes: 8 additions & 1 deletion dfdx-core/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ mod gradients;
mod masks;
#[cfg(feature = "numpy")]
pub(crate) mod numpy;
#[cfg(feature = "webgpu")]
pub(crate) mod webgpu;
#[cfg(feature = "numpy")]
pub use numpy::NumpyDtype;
mod error;
Expand All @@ -162,7 +164,7 @@ pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};
pub use tensorlike::Tensorlike;

pub use cpu::Cpu;
#[cfg(not(feature = "cuda"))]
#[cfg(not(any(feature = "cuda", feature = "webgpu")))]
pub type AutoDevice = Cpu;

#[cfg(feature = "cuda")]
Expand All @@ -172,6 +174,11 @@ pub use cuda::Cuda;
#[cfg(feature = "cuda")]
pub type AutoDevice = Cuda;

#[cfg(feature = "webgpu")]
pub use webgpu::Webgpu;
#[cfg(feature = "webgpu")]
pub type AutoDevice = Webgpu;

pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec, TensorToArray};
pub use storage_traits::{Cache, RandomU64, Storage, Synchronize};
pub use storage_traits::{OnesTensor, SampleTensor, TriangleTensor, ZerosTensor};
Expand Down
221 changes: 221 additions & 0 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#![allow(clippy::needless_range_loop)]

use crate::{
shapes::*,
tensor::{masks::triangle_mask, storage_traits::*, unique_id, Cpu, Error, NoneTape, Tensor},
};

use super::{device::CachableBuffer, Buffer, Webgpu};

use core::marker::PhantomData;
use rand::Rng;
use std::{sync::Arc, vec::Vec};
use wgpu::COPY_BUFFER_ALIGNMENT;

pub(crate) fn round_to_buffer_alignment(size: u64) -> u64 {
(size + (COPY_BUFFER_ALIGNMENT - 1)) / COPY_BUFFER_ALIGNMENT * COPY_BUFFER_ALIGNMENT
}

impl Webgpu {
fn tensor_from_host_buf<S: Shape, E: Unit>(
&self,
shape: S,
buf: Vec<E>,
) -> Result<Tensor<S, E, Self>, Error> {
let buffer = unsafe { self.alloc_empty::<E>(buf.len()) }?;
buffer.copy_to_device::<E>(&self.dev, &self.queue, &buf);

Ok(self.build_tensor(shape, shape.strides(), buffer))
}

pub(crate) fn build_tensor<S: Shape, E: Unit>(
&self,
shape: S,
strides: S::Concrete,
buffer: Buffer,
) -> Tensor<S, E, Self> {
let data = CachableBuffer {
dev: self.dev.clone(),
queue: self.queue.clone(),
data: buffer,
cache: self.cache.clone(),
_phantom: PhantomData,
};
Tensor {
id: unique_id(),
data: Arc::new(data),
shape,
strides,
device: self.clone(),
tape: Default::default(),
}
}
}

impl<E: Unit + SafeZeros> ZerosTensor<E> for Webgpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let strides = shape.strides();
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]);

Ok(self.build_tensor(shape, strides, data))
}
}

impl<E: Unit + SafeZeros> ZeroFillStorage<E> for Webgpu {
fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> {
storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]);

Ok(())
}
}

impl<E: Unit> OnesTensor<E> for Webgpu {
fn try_ones_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let buf = vec![E::ONE; shape.num_elements()];
self.tensor_from_host_buf(shape, buf)
}
}

impl<E: Unit> TriangleTensor<E> for Webgpu
where
Cpu: TriangleTensor<E>,
{
fn try_upper_tri_like<S: HasShape>(
&self,
src: &S,
val: E,
diagonal: impl Into<Option<isize>>,
) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let mut data = vec![val; shape.num_elements()];
let offset = diagonal.into().unwrap_or(0);
triangle_mask(&mut data, &shape, true, offset);
self.tensor_from_host_buf(shape, data)
}

fn try_lower_tri_like<S: HasShape>(
&self,
src: &S,
val: E,
diagonal: impl Into<Option<isize>>,
) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let mut data = vec![val; shape.num_elements()];
let offset = diagonal.into().unwrap_or(0);
triangle_mask(&mut data, &shape, false, offset);
self.tensor_from_host_buf(shape, data)
}
}

impl<E: Unit> OneFillStorage<E> for Webgpu {
fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> {
let len = storage.size() as usize / std::mem::size_of::<E>();
let buf = vec![E::ONE; len];
storage
.data
.copy_to_device::<E>(&self.dev, &self.queue, &buf);

Ok(())
}
}

impl<E: Unit> SampleTensor<E> for Webgpu
where
Cpu: SampleTensor<E>,
{
fn try_sample_like<S: HasShape, D: rand::prelude::Distribution<E>>(
&self,
src: &S,
distr: D,
) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let mut buf = Vec::with_capacity(shape.num_elements());
{
#[cfg(not(feature = "no-std"))]
let mut rng = self.cpu.rng.lock().unwrap();
#[cfg(feature = "no-std")]
let mut rng = self.cpu.rng.lock();
buf.resize_with(shape.num_elements(), || rng.sample(&distr));
}
self.tensor_from_host_buf::<S::Shape, E>(shape, buf)
}

fn try_fill_with_distr<D: rand::prelude::Distribution<E>>(
&self,
storage: &mut Self::Vec,
distr: D,
) -> Result<(), Error> {
let len = storage.size() as usize / std::mem::size_of::<E>();
let mut buf = Vec::with_capacity(len);
{
#[cfg(not(feature = "no-std"))]
let mut rng = self.cpu.rng.lock().unwrap();
#[cfg(feature = "no-std")]
let mut rng = self.cpu.rng.lock();
buf.resize_with(len, || rng.sample(&distr));
}
unsafe {
std::ptr::copy_nonoverlapping(
buf.as_ptr(),
storage.data.slice(..).get_mapped_range_mut().as_mut_ptr() as *mut E,
len,
)
};
Ok(())
}
}

impl<E: Unit> CopySlice<E> for Webgpu {
fn copy_from<S: Shape, T>(dst: &mut Tensor<S, E, Self, T>, src: &[E]) {
assert_eq!(
dst.data.size() as usize,
src.len() * std::mem::size_of::<E>(),
"Slices must have same number of elements as *physical* Storage<E> of tensors."
);
dst.data
.data
.copy_to_device(&dst.device.dev, &dst.device.queue, src);
}

fn copy_into<S: Shape, T>(src: &Tensor<S, E, Self, T>, dst: &mut [E]) {
assert_eq!(
src.data.size() as usize,
dst.len() * std::mem::size_of::<E>(),
"Slices must have same number of elements as *physical* Storage<E> of tensors."
);
src.data
.data
.copy_to_host(&src.device.dev, &src.device.queue, dst);
}
}

impl<E: Unit> TensorFromVec<E> for Webgpu {
fn try_tensor_from_vec<S: Shape>(
&self,
src: Vec<E>,
shape: S,
) -> Result<Tensor<S, E, Self>, Error> {
let num_elements = shape.num_elements();

if src.len() != num_elements {
Err(Error::WrongNumElements)
} else {
self.tensor_from_host_buf(shape, src)
}
}
}

impl<S: Shape, E: Unit> TensorToArray<S, E> for Webgpu
where
Cpu: TensorToArray<S, E> + Storage<E>,
{
type Array = <Cpu as TensorToArray<S, E>>::Array;
fn tensor_to_array<T>(&self, tensor: &Tensor<S, E, Self, T>) -> Self::Array {
let buf = tensor.as_vec();
let cpu_tensor = self.cpu.tensor_from_vec(buf, tensor.shape);
self.cpu.tensor_to_array::<NoneTape>(&cpu_tensor)
}
}
Loading

0 comments on commit dda8daa

Please sign in to comment.