Skip to content

Commit

Permalink
Feat/fusion/wgpu autotune (tracel-ai#1188)
Browse files Browse the repository at this point in the history
* wip

* WIP

* Update

* USe read write lock

* Refactor autotune

* Cleanup

* Add docs
  • Loading branch information
nathanielsimard authored Jan 30, 2024
1 parent 8b4038d commit b7486b0
Show file tree
Hide file tree
Showing 23 changed files with 504 additions and 162 deletions.
2 changes: 1 addition & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
wgpu-fusion = ["burn/wgpu", "burn/fusion"]
wgpu-fusion = ["burn/default", "burn/wgpu", "burn/fusion"]

[dependencies]
burn = { path = "../burn" }
Expand Down
16 changes: 11 additions & 5 deletions burn-compute/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use crate::{
use alloc::vec::Vec;
use alloc::{boxed::Box, sync::Arc};
use burn_common::reader::Reader;
use burn_common::stub::RwLock;
use core::marker::PhantomData;
use spin::Mutex;

/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
#[derive(Debug)]
pub struct ComputeClient<Server: ComputeServer, Channel> {
channel: Channel,
tuner: Arc<Mutex<Tuner<Server, Channel>>>,
tuner: Arc<RwLock<Tuner<Server, Channel>>>,
_server: PhantomData<Server>,
}

Expand All @@ -38,7 +38,7 @@ where
Channel: ComputeChannel<Server>,
{
/// Create a new client.
pub fn new(channel: Channel, tuner: Arc<Mutex<Tuner<Server, Channel>>>) -> Self {
pub fn new(channel: Channel, tuner: Arc<RwLock<Tuner<Server, Channel>>>) -> Self {
Self {
channel,
tuner,
Expand Down Expand Up @@ -72,12 +72,18 @@ where
}

/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
pub fn execute_autotune(
pub fn autotune_execute(
&self,
autotune_operation_set: Box<dyn AutotuneOperationSet<Server::AutotuneKey>>,
) {
self.tuner
.lock()
.write()
.unwrap()
.execute_autotune(autotune_operation_set, self);
}

/// Get the fastest kernel for the given autotune key if it exists.
pub fn autotune_result(&self, key: &Server::AutotuneKey) -> Option<usize> {
self.tuner.read().unwrap().autotune_fastest(key)
}
}
19 changes: 19 additions & 0 deletions burn-compute/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,25 @@ impl<K: AutotuneKey> TuneCache<K> {
}
}

pub(crate) fn find_fastest(&self, key: &K) -> Option<usize> {
let result = self.in_memory_cache.get(key);

let val = match result {
Some(val) => val,
None => return None,
};

#[cfg(feature = "autotune-persistent-cache")]
if val.checksum_checked {
Some(val.fastest_index)
} else {
None
}

#[cfg(not(feature = "autotune-persistent-cache"))]
Some(val.fastest_index)
}

pub(crate) fn try_cache(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
Expand Down
4 changes: 4 additions & 0 deletions burn-compute/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
}
}

pub(crate) fn autotune_fastest(&self, key: &S::AutotuneKey) -> Option<usize> {
self.tune_cache.find_fastest(key)
}

pub(crate) fn execute_autotune(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
Expand Down
4 changes: 2 additions & 2 deletions burn-compute/tests/dummy/compute.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::sync::Arc;

use super::DummyServer;
use burn_common::stub::RwLock;
use burn_compute::channel::MutexComputeChannel;
use burn_compute::client::ComputeClient;
use burn_compute::memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy};
use burn_compute::storage::BytesStorage;
use burn_compute::tune::Tuner;
use burn_compute::Compute;
use spin::Mutex;

/// The dummy device.
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
Expand All @@ -25,7 +25,7 @@ pub fn init_client() -> ComputeClient<DummyServer, MutexComputeChannel<DummyServ
SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never);
let server = DummyServer::new(memory_management);
let channel = MutexComputeChannel::new(server);
let tuner = Arc::new(Mutex::new(Tuner::new(TUNER_DEVICE_ID)));
let tuner = Arc::new(RwLock::new(Tuner::new(TUNER_DEVICE_ID)));
ComputeClient::new(channel, tuner)
}

Expand Down
22 changes: 11 additions & 11 deletions burn-compute/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn autotune_basic_addition_execution() {

let addition_autotune_kernel =
dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles);
client.execute_autotune(Box::new(addition_autotune_kernel));
client.autotune_execute(Box::new(addition_autotune_kernel));

let obtained_resource = client.read(&out);

Expand All @@ -78,7 +78,7 @@ fn autotune_basic_multiplication_execution() {

let multiplication_autotune_kernel =
dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles);
client.execute_autotune(Box::new(multiplication_autotune_kernel));
client.autotune_execute(Box::new(multiplication_autotune_kernel));

let obtained_resource = client.read(&out);

Expand Down Expand Up @@ -115,8 +115,8 @@ fn autotune_cache_same_key_return_a_cache_hit() {
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
let cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
client.execute_autotune(Box::new(cache_test_autotune_kernel_1));
client.execute_autotune(Box::new(cache_test_autotune_kernel_2));
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));

let obtained_resource = client.read(&out_2);

Expand Down Expand Up @@ -155,8 +155,8 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
let cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
client.execute_autotune(Box::new(cache_test_autotune_kernel_1));
client.execute_autotune(Box::new(cache_test_autotune_kernel_2));
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));

// read the resource which should update the cache on disk
let obtained_resource = client.read(&out_2);
Expand Down Expand Up @@ -192,7 +192,7 @@ fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {

let cache_test_autotune_kernel =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles);
client.execute_autotune(Box::new(cache_test_autotune_kernel));
client.autotune_execute(Box::new(cache_test_autotune_kernel));
// ensure that the autotune operations are run and cached
let _obtained_resource = client.read(&out);

Expand Down Expand Up @@ -227,8 +227,8 @@ fn autotune_cache_different_keys_return_a_cache_miss() {
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
let cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
client.execute_autotune(Box::new(cache_test_autotune_kernel_1));
client.execute_autotune(Box::new(cache_test_autotune_kernel_2));
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));

let obtained_resource = client.read(&out_2);

Expand All @@ -253,7 +253,7 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
let handles_1 = vec![lhs_1, rhs_1, out_1];
let cache_test_autotune_kernel_1 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
client.execute_autotune(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.sync();

// we use a second compute client in order to have freshly initialized autotune cache
Expand All @@ -272,7 +272,7 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
let mut cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
cache_test_autotune_kernel_2.generate_random_checksum = true;
client.execute_autotune(Box::new(cache_test_autotune_kernel_2));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
client.sync();

let obtained_resource = client.read(&out_2);
Expand Down
2 changes: 1 addition & 1 deletion burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-wgpu"
version.workspace = true

[features]
default = ["autotune", "std"]
default = ["autotune", "std", "burn-compute/default"]
std = []
autotune = []
fusion = ["burn-fusion"]
Expand Down
12 changes: 11 additions & 1 deletion burn-wgpu/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub struct ElemWiseKernelCodegen<Phase = InputPhase> {
functions: Vec<Function>,
vectorization: Vectorization,
mappings_inplace: Vec<InplaceMapping>,
workgroup_size: WorkgroupSize,
_phase: PhantomData<Phase>,
}

Expand Down Expand Up @@ -99,6 +100,7 @@ impl Default for ElemWiseKernelCodegen<InputPhase> {
functions: Vec::new(),
vectorization: Vectorization::Scalar,
mappings_inplace: Vec::new(),
workgroup_size: WorkgroupSize::default(),
_phase: PhantomData,
}
}
Expand Down Expand Up @@ -183,6 +185,7 @@ impl ElemWiseKernelCodegen<InputPhase> {
functions: self.functions,
vectorization: self.vectorization,
mappings_inplace: self.mappings_inplace,
workgroup_size: self.workgroup_size,
_phase: PhantomData,
}
}
Expand Down Expand Up @@ -231,6 +234,7 @@ impl ElemWiseKernelCodegen<BodyPhase> {
vectorization: self.vectorization,
functions: self.functions,
mappings_inplace: self.mappings_inplace,
workgroup_size: self.workgroup_size,
_phase: PhantomData,
}
}
Expand Down Expand Up @@ -329,12 +333,18 @@ impl ElemWiseKernelCodegen<OutputPhase> {
functions: self.functions,
vectorization: self.vectorization,
mappings_inplace: self.mappings_inplace,
workgroup_size: self.workgroup_size,
_phase: PhantomData,
}
}
}

impl ElemWiseKernelCodegen<CompilationPhase> {
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
self.workgroup_size = workgroup_size;
self
}

/// Compile the kernel into a [compute shader](ComputeShader).
pub fn compile(self) -> ComputeShader {
let inputs = self.input_bindings;
Expand All @@ -360,7 +370,7 @@ impl ElemWiseKernelCodegen<CompilationPhase> {
inputs,
outputs,
named,
workgroup_size: WorkgroupSize::default(),
workgroup_size: self.workgroup_size,
body: Body::new(self.operations),
num_workgroups: true,
global_invocation_id: true,
Expand Down
12 changes: 6 additions & 6 deletions burn-wgpu/src/codegen/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,18 @@ pub struct Binding {
pub size: Option<usize>,
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct WorkgroupSize {
pub x: usize,
pub y: usize,
pub z: usize,
pub x: u32,
pub y: u32,
pub z: u32,
}

impl Default for WorkgroupSize {
fn default() -> Self {
Self {
x: WORKGROUP_DEFAULT,
y: WORKGROUP_DEFAULT,
x: WORKGROUP_DEFAULT as u32,
y: WORKGROUP_DEFAULT as u32,
z: 1,
}
}
Expand Down
4 changes: 2 additions & 2 deletions burn-wgpu/src/compute/base.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::WgpuServer;
use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice};
use alloc::sync::Arc;
use burn_common::stub::RwLock;
use burn_compute::{
channel::MutexComputeChannel,
client::ComputeClient,
memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
tune::Tuner,
Compute,
};
use spin::Mutex;
use wgpu::{AdapterInfo, DeviceDescriptor};

type MemoryManagement = SimpleMemoryManagement<WgpuStorage>;
Expand Down Expand Up @@ -69,7 +69,7 @@ async fn create_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Ser
let channel = Channel::new(server);

let tuner_device_id = tuner_device_id(info);
ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new(&tuner_device_id))))
ComputeClient::new(channel, Arc::new(RwLock::new(Tuner::new(&tuner_device_id))))
}

/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
Expand Down
28 changes: 28 additions & 0 deletions burn-wgpu/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ pub trait Kernel: 'static + Send + Sync {
fn workgroup(&self) -> WorkGroup;
}

impl Kernel for Arc<dyn Kernel> {
fn source(&self) -> SourceTemplate {
self.as_ref().source()
}

fn id(&self) -> String {
self.as_ref().id()
}

fn workgroup(&self) -> WorkGroup {
self.as_ref().workgroup()
}
}

impl Kernel for Box<dyn Kernel> {
fn source(&self) -> SourceTemplate {
self.as_ref().source()
}

fn id(&self) -> String {
self.as_ref().id()
}

fn workgroup(&self) -> WorkGroup {
self.as_ref().workgroup()
}
}

impl<MM> WgpuServer<MM>
where
MM: MemoryManagement<WgpuStorage>,
Expand Down
12 changes: 9 additions & 3 deletions burn-wgpu/src/compute/tune_key.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};
use burn_compute::tune::AutotuneKey;
use serde::{Deserialize, Serialize};
use std::fmt::Display;

use burn_compute::tune::AutotuneKey;

use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};
#[cfg(any(feature = "fusion", test))]
use crate::fusion::FusionElemWiseAutotuneKey;

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
/// Key for all autotune-enabled operations
Expand All @@ -14,6 +15,9 @@ pub enum WgpuAutotuneKey {
SumDim(ReduceAutotuneKey),
/// Key for mean_dim operations
MeanDim(ReduceAutotuneKey),
#[cfg(any(feature = "fusion", test))]
/// Key for fused element wise operations.
FusionElemWise(FusionElemWiseAutotuneKey),
}

impl Display for WgpuAutotuneKey {
Expand All @@ -22,6 +26,8 @@ impl Display for WgpuAutotuneKey {
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
#[cfg(any(feature = "fusion", test))]
WgpuAutotuneKey::FusionElemWise(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
}
}
}
Expand Down
Loading

0 comments on commit b7486b0

Please sign in to comment.