diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index f82ce8ca6f..5e2b768a00 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -23,7 +23,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -wgpu-fusion = ["burn/wgpu", "burn/fusion"] +wgpu-fusion = ["burn/default", "burn/wgpu", "burn/fusion"] [dependencies] burn = { path = "../burn" } diff --git a/burn-compute/src/client.rs b/burn-compute/src/client.rs index 3832652aeb..936a6c32b2 100644 --- a/burn-compute/src/client.rs +++ b/burn-compute/src/client.rs @@ -6,15 +6,15 @@ use crate::{ use alloc::vec::Vec; use alloc::{boxed::Box, sync::Arc}; use burn_common::reader::Reader; +use burn_common::stub::RwLock; use core::marker::PhantomData; -use spin::Mutex; /// The ComputeClient is the entry point to require tasks from the ComputeServer. /// It should be obtained for a specific device via the Compute struct. #[derive(Debug)] pub struct ComputeClient { channel: Channel, - tuner: Arc>>, + tuner: Arc>>, _server: PhantomData, } @@ -38,7 +38,7 @@ where Channel: ComputeChannel, { /// Create a new client. - pub fn new(channel: Channel, tuner: Arc>>) -> Self { + pub fn new(channel: Channel, tuner: Arc>>) -> Self { Self { channel, tuner, @@ -72,12 +72,18 @@ where } /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks - pub fn execute_autotune( + pub fn autotune_execute( &self, autotune_operation_set: Box>, ) { self.tuner - .lock() + .write() + .unwrap() .execute_autotune(autotune_operation_set, self); } + + /// Get the fastest kernel for the given autotune key if it exists. + pub fn autotune_result(&self, key: &Server::AutotuneKey) -> Option { + self.tuner.read().unwrap().autotune_fastest(key) + } } diff --git a/burn-compute/src/tune/tune_cache.rs b/burn-compute/src/tune/tune_cache.rs index d2ec64bc51..7200942525 100644 --- a/burn-compute/src/tune/tune_cache.rs +++ b/burn-compute/src/tune/tune_cache.rs @@ -91,6 +91,25 @@ impl TuneCache { } } + pub(crate) fn find_fastest(&self, key: &K) -> Option { + let result = self.in_memory_cache.get(key); + + let val = match result { + Some(val) => val, + None => return None, + }; + + #[cfg(feature = "autotune-persistent-cache")] + if val.checksum_checked { + Some(val.fastest_index) + } else { + None + } + + #[cfg(not(feature = "autotune-persistent-cache"))] + Some(val.fastest_index) + } + pub(crate) fn try_cache( &mut self, autotune_operation_set: Box>, diff --git a/burn-compute/src/tune/tuner.rs b/burn-compute/src/tune/tuner.rs index 75e7c5bd57..4fafa5adf0 100644 --- a/burn-compute/src/tune/tuner.rs +++ b/burn-compute/src/tune/tuner.rs @@ -32,6 +32,10 @@ impl> Tuner { } } + pub(crate) fn autotune_fastest(&self, key: &S::AutotuneKey) -> Option { + self.tune_cache.find_fastest(key) + } + pub(crate) fn execute_autotune( &mut self, autotune_operation_set: Box>, diff --git a/burn-compute/tests/dummy/compute.rs b/burn-compute/tests/dummy/compute.rs index 08f0312e5e..48e79e094e 100644 --- a/burn-compute/tests/dummy/compute.rs +++ b/burn-compute/tests/dummy/compute.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use super::DummyServer; +use burn_common::stub::RwLock; use burn_compute::channel::MutexComputeChannel; use burn_compute::client::ComputeClient; use burn_compute::memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}; use burn_compute::storage::BytesStorage; use burn_compute::tune::Tuner; use burn_compute::Compute; -use spin::Mutex; /// The dummy device. #[derive(Clone, Debug, Hash, PartialEq, Eq)] @@ -25,7 +25,7 @@ pub fn init_client() -> ComputeClient { functions: Vec, vectorization: Vectorization, mappings_inplace: Vec, + workgroup_size: WorkgroupSize, _phase: PhantomData, } @@ -99,6 +100,7 @@ impl Default for ElemWiseKernelCodegen { functions: Vec::new(), vectorization: Vectorization::Scalar, mappings_inplace: Vec::new(), + workgroup_size: WorkgroupSize::default(), _phase: PhantomData, } } @@ -183,6 +185,7 @@ impl ElemWiseKernelCodegen { functions: self.functions, vectorization: self.vectorization, mappings_inplace: self.mappings_inplace, + workgroup_size: self.workgroup_size, _phase: PhantomData, } } @@ -231,6 +234,7 @@ impl ElemWiseKernelCodegen { vectorization: self.vectorization, functions: self.functions, mappings_inplace: self.mappings_inplace, + workgroup_size: self.workgroup_size, _phase: PhantomData, } } @@ -329,12 +333,18 @@ impl ElemWiseKernelCodegen { functions: self.functions, vectorization: self.vectorization, mappings_inplace: self.mappings_inplace, + workgroup_size: self.workgroup_size, _phase: PhantomData, } } } impl ElemWiseKernelCodegen { + pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self { + self.workgroup_size = workgroup_size; + self + } + /// Compile the kernel into a [compute shader](ComputeShader). pub fn compile(self) -> ComputeShader { let inputs = self.input_bindings; @@ -360,7 +370,7 @@ impl ElemWiseKernelCodegen { inputs, outputs, named, - workgroup_size: WorkgroupSize::default(), + workgroup_size: self.workgroup_size, body: Body::new(self.operations), num_workgroups: true, global_invocation_id: true, diff --git a/burn-wgpu/src/codegen/shader.rs b/burn-wgpu/src/codegen/shader.rs index d9eb5306ec..0029125482 100644 --- a/burn-wgpu/src/codegen/shader.rs +++ b/burn-wgpu/src/codegen/shader.rs @@ -63,18 +63,18 @@ pub struct Binding { pub size: Option, } -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub struct WorkgroupSize { - pub x: usize, - pub y: usize, - pub z: usize, + pub x: u32, + pub y: u32, + pub z: u32, } impl Default for WorkgroupSize { fn default() -> Self { Self { - x: WORKGROUP_DEFAULT, - y: WORKGROUP_DEFAULT, + x: WORKGROUP_DEFAULT as u32, + y: WORKGROUP_DEFAULT as u32, z: 1, } } diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs index f02a67aad9..16a4b16432 100644 --- a/burn-wgpu/src/compute/base.rs +++ b/burn-wgpu/src/compute/base.rs @@ -1,6 +1,7 @@ use super::WgpuServer; use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice}; use alloc::sync::Arc; +use burn_common::stub::RwLock; use burn_compute::{ channel::MutexComputeChannel, client::ComputeClient, @@ -8,7 +9,6 @@ use burn_compute::{ tune::Tuner, Compute, }; -use spin::Mutex; use wgpu::{AdapterInfo, DeviceDescriptor}; type MemoryManagement = SimpleMemoryManagement; @@ -69,7 +69,7 @@ async fn create_client(device: &WgpuDevice) -> ComputeClient WorkGroup; } +impl Kernel for Arc { + fn source(&self) -> SourceTemplate { + self.as_ref().source() + } + + fn id(&self) -> String { + self.as_ref().id() + } + + fn workgroup(&self) -> WorkGroup { + self.as_ref().workgroup() + } +} + +impl Kernel for Box { + fn source(&self) -> SourceTemplate { + self.as_ref().source() + } + + fn id(&self) -> String { + self.as_ref().id() + } + + fn workgroup(&self) -> WorkGroup { + self.as_ref().workgroup() + } +} + impl WgpuServer where MM: MemoryManagement, diff --git a/burn-wgpu/src/compute/tune_key.rs b/burn-wgpu/src/compute/tune_key.rs index a4b798a9d9..161da5c0f0 100644 --- a/burn-wgpu/src/compute/tune_key.rs +++ b/burn-wgpu/src/compute/tune_key.rs @@ -1,9 +1,10 @@ +use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; +use burn_compute::tune::AutotuneKey; use serde::{Deserialize, Serialize}; use std::fmt::Display; -use burn_compute::tune::AutotuneKey; - -use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; +#[cfg(any(feature = "fusion", test))] +use crate::fusion::FusionElemWiseAutotuneKey; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] /// Key for all autotune-enabled operations @@ -14,6 +15,9 @@ pub enum WgpuAutotuneKey { SumDim(ReduceAutotuneKey), /// Key for mean_dim operations MeanDim(ReduceAutotuneKey), + #[cfg(any(feature = "fusion", test))] + /// Key for fused element wise operations. + FusionElemWise(FusionElemWiseAutotuneKey), } impl Display for WgpuAutotuneKey { @@ -22,6 +26,8 @@ impl Display for WgpuAutotuneKey { WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + #[cfg(any(feature = "fusion", test))] + WgpuAutotuneKey::FusionElemWise(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), } } } diff --git a/burn-wgpu/src/fusion/elemwise/kernel.rs b/burn-wgpu/src/fusion/elemwise/kernel.rs index 8fb98c099d..17ab6095b1 100644 --- a/burn-wgpu/src/fusion/elemwise/kernel.rs +++ b/burn-wgpu/src/fusion/elemwise/kernel.rs @@ -6,7 +6,7 @@ use crate::{ source::DynKernelSource, WgpuFusionHandle, }, - kernel::{elemwise_workgroup, WORKGROUP_DEFAULT}, + kernel::elemwise_workgroup, }; use burn_fusion::TensorDescription; use std::sync::Arc; @@ -99,11 +99,19 @@ impl ElementWiseSource { inputs: &[&TensorDescription], outputs: &[&TensorDescription], ) -> SelectedKernel { + let workgroup_size_x = self.source_normal.shader.workgroup_size.x; + let workgroup_size_y = self.source_normal.shader.workgroup_size.y; + assert_eq!( + workgroup_size_x, workgroup_size_y, + "The grid must be a square" + ); + let workgroup_size = workgroup_size_x as usize; + match inplace_available(&self.mappings, handles_inputs) { true => { let reference_tensor = inputs[self.mappings[0].position_input]; let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape); - let workgroup = elemwise_workgroup(num_elems / self.factor, WORKGROUP_DEFAULT); + let workgroup = elemwise_workgroup(num_elems / self.factor, workgroup_size); let kernel = Box::new(DynamicKernel::new(self.source_inplace.clone(), workgroup)); let output_infos = self.inplace_output2input @@ -129,7 +137,7 @@ impl ElementWiseSource { false => { let reference_tensor = outputs[0]; let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape); - let workgroup = elemwise_workgroup(num_elems / self.factor, WORKGROUP_DEFAULT); + let workgroup = elemwise_workgroup(num_elems / self.factor, workgroup_size); let kernel = Box::new(DynamicKernel::new(self.source_normal.clone(), workgroup)); let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| { let elem = self.source_normal.shader.outputs[pos].item.elem(); diff --git a/burn-wgpu/src/fusion/elemwise/mod.rs b/burn-wgpu/src/fusion/elemwise/mod.rs index 0becc158f8..1048059937 100644 --- a/burn-wgpu/src/fusion/elemwise/mod.rs +++ b/burn-wgpu/src/fusion/elemwise/mod.rs @@ -1,6 +1,9 @@ mod builder; mod kernel; mod optimization; +mod tune; pub(crate) use builder::*; pub(crate) use optimization::*; + +pub use tune::FusionElemWiseAutotuneKey; diff --git a/burn-wgpu/src/fusion/elemwise/optimization.rs b/burn-wgpu/src/fusion/elemwise/optimization.rs index 91a7cb9e25..31ba973380 100644 --- a/burn-wgpu/src/fusion/elemwise/optimization.rs +++ b/burn-wgpu/src/fusion/elemwise/optimization.rs @@ -1,9 +1,14 @@ -use super::kernel::{ScalarElementWise, VecElementWise}; +use super::{ + kernel::{ScalarElementWise, VecElementWise}, + tune::ElementWiseAutotuneOperationSet, + FusionElemWiseAutotuneKey, +}; use crate::{ codegen::{ Elem, ElemWiseKernelCodegen, InplaceMapping, Input, Item, Operator, Output, - ReadingStrategy, Vectorization, Visibility, + ReadingStrategy, Vectorization, Visibility, WorkgroupSize, }, + compute::{compute_client, WgpuAutotuneKey, WgpuComputeClient}, fusion::{kernel::FusionKernelSet, source::DynKernelSource}, FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice, }; @@ -19,27 +24,28 @@ where F: FloatElement, I: IntElement, { - inputs: Vec<(TensorDescription, Elem)>, - outputs: Vec<(TensorDescription, Elem)>, - locals: Vec, - scalars: Scalars, - operators: Vec, - device: Device>, - phase: Phase, + pub(super) inputs: Vec<(TensorDescription, Elem)>, + pub(super) outputs: Vec<(TensorDescription, Elem)>, + pub(super) locals: Vec, + pub(super) scalars: Scalars, + pub(super) operators: Vec, + pub(super) device: Device>, + pub(super) phase: Phase, } #[derive(new, Clone, Serialize, Deserialize)] pub struct Scalars { - num_f32: usize, - num_u32: usize, - num_i32: usize, + pub(super) num_f32: usize, + pub(super) num_u32: usize, + pub(super) num_i32: usize, } pub struct CompilationPhase; #[derive(new)] pub struct ExecutionPhase { - kernel_set: FusionKernelSet, + pub(super) kernel_set_1: FusionKernelSet, + pub(super) kernel_set_2: FusionKernelSet, } #[derive(Serialize, Deserialize)] @@ -144,79 +150,20 @@ where }) .collect::>(); - let scalar = ScalarElementWise::new( - DynKernelSource::new( - IdGenerator::generate(), - ElemWiseKernelCodegen::new() - .inputs(&inputs) - .body(&self.operators) - .outputs(&outputs) - .compile(), - ), - DynKernelSource::new( - IdGenerator::generate(), - ElemWiseKernelCodegen::new() - .inplace(&mappings) - .inputs(&inputs) - .body(&self.operators) - .outputs(&outputs) - .compile(), - ), - mappings.clone(), - outputs.len(), + let kernel_set_1 = build_kernel_set( + &inputs, + &outputs, + &self.operators, + &mappings, + WorkgroupSize::default(), ); - - let vec2 = VecElementWise::new( - DynKernelSource::new( - IdGenerator::generate(), - ElemWiseKernelCodegen::new() - .vectorize(Vectorization::Vec2) - .inputs(&inputs) - .body(&self.operators) - .outputs(&outputs) - .compile(), - ), - DynKernelSource::new( - IdGenerator::generate(), - ElemWiseKernelCodegen::new() - .vectorize(Vectorization::Vec2) - .inplace(&mappings) - .inputs(&inputs) - .body(&self.operators) - .outputs(&outputs) - .compile(), - ), - mappings.clone(), - outputs.len(), - 2, + let kernel_set_2 = build_kernel_set( + &inputs, + &outputs, + &self.operators, + &mappings, + WorkgroupSize::new(16, 16, 1), ); - let vec4 = VecElementWise::new( - DynKernelSource::new( - IdGenerator::generate(), - ElemWiseKernelCodegen::new() - .vectorize(Vectorization::Vec4) - .inputs(&inputs) - .body(&self.operators) - .outputs(&outputs) - .compile(), - ), - DynKernelSource::new( - IdGenerator::generate(), - ElemWiseKernelCodegen::new() - .vectorize(Vectorization::Vec4) - .inplace(&mappings) - .inputs(&inputs) - .body(&self.operators) - .outputs(&outputs) - .compile(), - ), - mappings, - outputs.len(), - 4, - ); - - let kernel_set = - FusionKernelSet::new(vec![Box::new(scalar), Box::new(vec2), Box::new(vec4)]); ElementWise { inputs: self.inputs, @@ -225,7 +172,7 @@ where device: self.device, operators: self.operators, locals: self.locals, - phase: ExecutionPhase::new(kernel_set), + phase: ExecutionPhase::new(kernel_set_1, kernel_set_2), } } } @@ -237,20 +184,113 @@ where I: IntElement, { pub(crate) fn execute(&mut self, context: &mut Context<'_, Wgpu>) { - self.phase.kernel_set.execute( + let client = compute_client::(&self.device); + + let key = WgpuAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new( + self.operators.len(), + self.autotune_shape(context), + )); + + if let Some(index) = client.autotune_result(&key) { + self.run_kernel(context, client, index) + } else { + self.run_autotune(context, client, key) + } + } + + fn run_kernel( + &mut self, + context: &mut Context<'_, Wgpu>, + client: WgpuComputeClient, + fastest_set_index: usize, + ) { + let kernel_set = match fastest_set_index { + 0 => &self.phase.kernel_set_1, + 1 => &self.phase.kernel_set_2, + _ => panic!("Should be 0 or 1, got {fastest_set_index}"), + }; + + let kernel = kernel_set.select( &self.inputs.iter().map(|a| &a.0).collect::>(), &self.outputs.iter().map(|a| &a.0).collect::>(), self.scalars.num_f32, self.scalars.num_i32, context, self.device.clone(), - ) + client, + true, + ); + + kernel.execute(); + } + + fn run_autotune( + &mut self, + context: &mut Context<'_, Wgpu>, + client: WgpuComputeClient, + key: WgpuAutotuneKey, + ) { + let kernel_1 = self.phase.kernel_set_1.select( + &self.inputs.iter().map(|a| &a.0).collect::>(), + &self.outputs.iter().map(|a| &a.0).collect::>(), + self.scalars.num_f32, + self.scalars.num_i32, + context, + self.device.clone(), + client.clone(), + false, // Should not mutate the context. + ); + let kernel_2 = self.phase.kernel_set_1.select( + &self.inputs.iter().map(|a| &a.0).collect::>(), + &self.outputs.iter().map(|a| &a.0).collect::>(), + self.scalars.num_f32, + self.scalars.num_i32, + context, + self.device.clone(), + client.clone(), + false, // Should not mutate the context. + ); + let kernel_default = self.phase.kernel_set_1.select( + &self.inputs.iter().map(|a| &a.0).collect::>(), + &self.outputs.iter().map(|a| &a.0).collect::>(), + self.scalars.num_f32, + self.scalars.num_i32, + context, + self.device.clone(), + client.clone(), + true, // Can do whatever with the context. + ); + + client.autotune_execute(Box::new(ElementWiseAutotuneOperationSet::new( + key, + kernel_1.into(), + kernel_2.into(), + kernel_default.into(), + ))); } pub(crate) fn len(&self) -> usize { self.operators.len() } + /// The first output is chosen when possible, otherwise the first input is chosen. + pub(crate) fn autotune_shape<'a>( + &self, + context: &mut Context<'a, Wgpu>, + ) -> &'a [usize] { + if let Some(tensor) = self.outputs.first() { + let tensor = context.tensors.get(&tensor.0.id).unwrap(); + return &tensor.shape; + } + + if let Some(tensor) = self.inputs.first() { + let tensor = context.tensors.get(&tensor.0.id).unwrap(); + return &tensor.shape; + } + + &[] + } + pub(crate) fn from_state(device: &WgpuDevice, state: ElementWiseState) -> Self { // We don't save the compiled kernel structs since it's quick to compile and the output is // very large. @@ -280,6 +320,93 @@ where } } +fn build_kernel_set( + inputs: &[Input], + outputs: &[Output], + operators: &[Operator], + mappings: &[InplaceMapping], + workgroup_size: WorkgroupSize, +) -> FusionKernelSet { + let scalar = ScalarElementWise::new( + DynKernelSource::new( + IdGenerator::generate(), + ElemWiseKernelCodegen::new() + .inputs(inputs) + .body(operators) + .outputs(outputs) + .workgroup_size(workgroup_size) + .compile(), + ), + DynKernelSource::new( + IdGenerator::generate(), + ElemWiseKernelCodegen::new() + .inplace(mappings) + .inputs(inputs) + .body(operators) + .outputs(outputs) + .workgroup_size(workgroup_size) + .compile(), + ), + mappings.to_vec(), + outputs.len(), + ); + + let vec2 = VecElementWise::new( + DynKernelSource::new( + IdGenerator::generate(), + ElemWiseKernelCodegen::new() + .vectorize(Vectorization::Vec2) + .inputs(inputs) + .body(operators) + .outputs(outputs) + .workgroup_size(workgroup_size) + .compile(), + ), + DynKernelSource::new( + IdGenerator::generate(), + ElemWiseKernelCodegen::new() + .vectorize(Vectorization::Vec2) + .inplace(mappings) + .inputs(inputs) + .body(operators) + .outputs(outputs) + .workgroup_size(workgroup_size) + .compile(), + ), + mappings.to_vec(), + outputs.len(), + 2, + ); + let vec4 = VecElementWise::new( + DynKernelSource::new( + IdGenerator::generate(), + ElemWiseKernelCodegen::new() + .vectorize(Vectorization::Vec4) + .inputs(inputs) + .body(operators) + .outputs(outputs) + .workgroup_size(workgroup_size) + .compile(), + ), + DynKernelSource::new( + IdGenerator::generate(), + ElemWiseKernelCodegen::new() + .vectorize(Vectorization::Vec4) + .inplace(mappings) + .inputs(inputs) + .body(operators) + .outputs(outputs) + .workgroup_size(workgroup_size) + .compile(), + ), + mappings.to_vec(), + outputs.len(), + 4, + ); + + FusionKernelSet::new(vec![Box::new(scalar), Box::new(vec2), Box::new(vec4)]) +} + #[cfg(test)] mod tests { use super::*; diff --git a/burn-wgpu/src/fusion/elemwise/tune.rs b/burn-wgpu/src/fusion/elemwise/tune.rs new file mode 100644 index 0000000000..95b0ed6257 --- /dev/null +++ b/burn-wgpu/src/fusion/elemwise/tune.rs @@ -0,0 +1,59 @@ +use std::fmt::Display; + +use crate::{compute::WgpuAutotuneKey, fusion::kernel::AutotunableKernel, tune::anchor}; +use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; +use serde::{Deserialize, Serialize}; + +#[derive(new)] +pub struct ElementWiseAutotuneOperationSet { + key: WgpuAutotuneKey, + kernel_1: AutotunableKernel, + kernel_2: AutotunableKernel, + kernel_default: AutotunableKernel, +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] +/// Autotune key representative of a fused element wise kernel. +pub struct FusionElemWiseAutotuneKey { + anchored_num_operations: usize, + anchored_shape: Vec, +} + +impl Display for FusionElemWiseAutotuneKey { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "Fusion ElemWise - num_operations: {:?} shape: {:?}", + self.anchored_num_operations, self.anchored_shape + ) + .as_str(), + ) + } +} + +impl AutotuneOperationSet for ElementWiseAutotuneOperationSet { + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + let kernel_1: Box = self.kernel_1.clone(); + let kernel_2: Box = self.kernel_2.clone(); + + vec![kernel_1, kernel_2] + } + + fn fastest(self: Box, _: usize) -> Box { + Box::new(self.kernel_default) + } +} + +impl FusionElemWiseAutotuneKey { + /// Create a fused element wise autotune key. + pub fn new(num_operations: usize, shape: &[usize]) -> Self { + Self { + anchored_shape: shape.iter().map(|x| anchor(*x, Some(4096))).collect(), + anchored_num_operations: anchor(num_operations, None), + } + } +} diff --git a/burn-wgpu/src/fusion/kernel.rs b/burn-wgpu/src/fusion/kernel.rs index 8ab3660931..026d02d48a 100644 --- a/burn-wgpu/src/fusion/kernel.rs +++ b/burn-wgpu/src/fusion/kernel.rs @@ -1,40 +1,102 @@ -use crate::compute::{compute_client, Kernel}; +use crate::compute::{Kernel, WgpuComputeClient, WgpuHandle}; use crate::fusion::strides_dyn_rank; use crate::fusion::WgpuFusionHandle; use crate::{FloatElement, GraphicsApi, IntElement, Wgpu}; +use burn_compute::tune::AutotuneOperation; use burn_fusion::stream::Context; -use burn_fusion::TensorDescription; +use burn_fusion::{TensorDescription, TensorStatus}; use burn_tensor::Device; +use std::sync::Arc; /// Many kernels can be used for the same set of tensor operations fused into one. /// -/// This type makes it easy to group those potential kernels and execute the best one depending on -/// the context. +/// This type makes it easy to group those potential kernels and execute the best one depending on the context. #[derive(new)] pub struct FusionKernelSet { kernels: Vec>, } -/// The priority of a kernel. -pub enum Priority { - /// When a kernel can be executed in the specified context with its priority, higher is better. - Available(u8), - /// When a kernel can't be executed in the specified context. - Unavailable, +/// An instantiation of a [kernel](Kernel) that can be executed. +#[derive(new)] +pub struct ExecutableKernel { + kernel: Box, + handles: Vec, + client: WgpuComputeClient, +} + +/// An instantiation of a [kernel](Kernel) that can be autotuned. +/// +/// The main difference with an [executable kernel](ExecutableKernel) is that this kernel can be +/// cloned and executed multiple times to properly collect benchmarks. +/// +/// The clone function used is defined in the trait [AutotuneOperation] instead of [Clone]. +#[derive(new)] +pub struct AutotunableKernel { + kernel: Arc, + handles: Vec, + client: WgpuComputeClient, } +/// A selected kernel encapsulates a kernel that should be executed with the provided +/// [output info](OutputInfo). +/// +/// It isn't ready for execution yet but should provide all information necessary to +/// a [kernel set](FusionKernelSet) to create an [executable kernel](ExecutableKernel). #[derive(new)] pub struct SelectedKernel { kernel: Box, info: Vec, } +/// The priority of a kernel. +pub enum Priority { + /// When a kernel can be executed in the specified context with its priority, higher is better. + Available(u8), + /// When a kernel can't be executed in the specified context. + Unavailable, +} + // Information related to the output of this kernel. pub enum OutputInfo { Inplace { input_index: usize }, Array { size: usize }, } +impl ExecutableKernel { + /// Execute the kernel. + pub fn execute(self) { + self.client + .execute(self.kernel, &self.handles.iter().collect::>()) + } +} + +impl AutotuneOperation for AutotunableKernel { + fn execute(self: Box) { + self.client.execute( + Box::new(self.kernel), + &self.handles.iter().collect::>(), + ) + } + + fn clone(&self) -> Box { + Box::new(Self { + kernel: self.kernel.clone(), + handles: self.handles.iter().map(Clone::clone).collect(), + client: self.client.clone(), + }) + } +} + +impl From for AutotunableKernel { + fn from(value: ExecutableKernel) -> Self { + Self { + kernel: Arc::new(value.kernel), + handles: value.handles, + client: value.client, + } + } +} + pub trait FusionKernel: Send + Sync { /// Returns the priority of this kernel based on the input and output information. fn priority( @@ -53,8 +115,9 @@ pub trait FusionKernel: Send + Sync { } impl FusionKernelSet { - /// Execute the best kernel based on the given information. - pub fn execute( + /// Select the best kernel based on the given information. + #[allow(clippy::too_many_arguments)] + pub fn select( &self, inputs: &[&TensorDescription], outputs: &[&TensorDescription], @@ -62,11 +125,11 @@ impl FusionKernelSet { scalars_i32: usize, context: &mut Context<'_, Wgpu>, device: Device>, - ) { - let client = compute_client::(&device); - + client: WgpuComputeClient, + stateful: bool, + ) -> ExecutableKernel { let (handles_input, inputs_description_updated, outputs_description_updated) = - process_inputs_outputs(inputs, outputs, context); + process_inputs_outputs(inputs, outputs, context, stateful); let selected = self.select_kernel( &handles_input, @@ -151,8 +214,7 @@ impl FusionKernelSet { context.handles.register_handle(id, handle); } - // Execute the kernel. - client.execute(selected.kernel, &handles.iter().collect::>()); + ExecutableKernel::new(selected.kernel, handles, client) } fn select_kernel( @@ -198,10 +260,11 @@ fn register_info_tensor( } } -pub fn process_inputs_outputs<'a, G: GraphicsApi, F: FloatElement, I: IntElement>( +fn process_inputs_outputs<'a, G: GraphicsApi, F: FloatElement, I: IntElement>( inputs: &[&TensorDescription], outputs: &[&TensorDescription], context: &'a mut Context<'_, Wgpu>, + stateful: bool, ) -> ( Vec, Vec<&'a TensorDescription>, @@ -212,9 +275,14 @@ pub fn process_inputs_outputs<'a, G: GraphicsApi, F: FloatElement, I: IntElement let mut handles_input = Vec::new(); for tensor in inputs.iter() { - let status = &tensor.status; // Important to take the status of the relative graph and not - // the global graph, since the status of the global graph - // might be of a later operation on the same tensor id. + let status = if stateful { + &tensor.status // Important to take the status of the relative graph and not + // the global graph, since the status of the global graph + // might be of a later operation on the same tensor id. + } else { + &TensorStatus::ReadOnly + }; + let tensor = context.tensors.get(&tensor.id).unwrap(); let handle = context.handles.get_handle(&tensor.id, status); diff --git a/burn-wgpu/src/kernel/matmul/tune/base.rs b/burn-wgpu/src/kernel/matmul/tune/base.rs index 2699e233c4..0024386ed9 100644 --- a/burn-wgpu/src/kernel/matmul/tune/base.rs +++ b/burn-wgpu/src/kernel/matmul/tune/base.rs @@ -114,7 +114,7 @@ pub fn matmul_autotune( output.clone(), )); - client.execute_autotune(operation_set); + client.autotune_execute(operation_set); output } diff --git a/burn-wgpu/src/kernel/matmul/tune/key.rs b/burn-wgpu/src/kernel/matmul/tune/key.rs index 00fa2e446b..345ceeee6d 100644 --- a/burn-wgpu/src/kernel/matmul/tune/key.rs +++ b/burn-wgpu/src/kernel/matmul/tune/key.rs @@ -1,11 +1,8 @@ +use crate::tune::anchor; use burn_tensor::Shape; use core::fmt::Debug; use serde::{Deserialize, Serialize}; -use std::{ - cmp::{max, min}, - fmt::Display, - hash::Hash, -}; +use std::{cmp::max, fmt::Display, hash::Hash}; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] /// Autotune key representative of matmul versions @@ -68,16 +65,6 @@ impl MatmulAutotuneKey { } } -fn anchor(x: usize, max: Option) -> usize { - let exp = f32::ceil(f32::log2(x as f32)) as u32; - let power_of_2 = 2_u32.pow(exp) as usize; - if let Some(max) = max { - min(power_of_2, max) - } else { - power_of_2 - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs index 060380b212..3479a64a5d 100644 --- a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs +++ b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs @@ -104,7 +104,7 @@ pub fn mean_dim_autotune( reduce_dim, )); - client.execute_autotune(operation_set); + client.autotune_execute(operation_set); output } diff --git a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs index 61e50edf89..0ec4f06db9 100644 --- a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs +++ b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs @@ -104,7 +104,7 @@ pub fn sum_dim_autotune( reduce_dim, )); - client.execute_autotune(operation_set); + client.autotune_execute(operation_set); output } diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index acb03491e4..ec46ce4159 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -16,6 +16,7 @@ pub mod kernel; pub mod tensor; pub(crate) mod codegen; +pub(crate) mod tune; mod element; pub use element::{FloatElement, IntElement}; diff --git a/burn-wgpu/src/tune.rs b/burn-wgpu/src/tune.rs new file mode 100644 index 0000000000..b289f0cc9c --- /dev/null +++ b/burn-wgpu/src/tune.rs @@ -0,0 +1,16 @@ +//! Module with tune utilities. + +use std::cmp::min; + +/// Anchor a number to a power of 2. +/// +/// Useful when creating autotune keys. +pub fn anchor(x: usize, max: Option) -> usize { + let exp = f32::ceil(f32::log2(x as f32)) as u32; + let power_of_2 = 2_u32.pow(exp) as usize; + if let Some(max) = max { + min(power_of_2, max) + } else { + power_of_2 + } +}