Skip to content

Commit

Permalink
Cube: Matmul tiling (#1994)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Jul 9, 2024
1 parent c2b6318 commit 69be99b
Show file tree
Hide file tree
Showing 51 changed files with 4,121 additions and 262 deletions.
4 changes: 2 additions & 2 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn execute(&self, (lhs, rhs): Self::Args) {
lhs.clone().matmul(rhs.clone());
lhs.clone().transpose().matmul(rhs.clone());
}

fn prepare(&self) -> Self::Args {
Expand All @@ -56,7 +56,7 @@ fn bench<B: Backend>(
let m = 256;
let k = 1024;
let n = 256;
let shape_lhs = [batch_size, m, k].into();
let shape_lhs = [batch_size, k, m].into();
let shape_rhs = [batch_size, k, n].into();

let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, device.clone());
Expand Down
12 changes: 8 additions & 4 deletions crates/burn-cube-macros/src/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,14 @@ impl VariableAnalyzer {
self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
if let Some((_, expr)) = &expr.else_branch {
if let syn::Expr::Block(expr_block) = &**expr {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
} else {
// Unsupported: handled in codegen.
match &**expr {
syn::Expr::Block(expr_block) => {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
}
syn::Expr::If(expr) => {
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
}
_ => unreachable!(),
}
}
}
Expand Down
23 changes: 12 additions & 11 deletions crates/burn-cube-macros/src/codegen_function/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
/// if cond {...}
/// if cond {...} else {...}
/// if Comptime::get(...) {...} [else {...}]
/// if Comptime::get(...) {...} [else if Comptime::get(...) {...}]* [else {...}]
pub(crate) fn codegen_if(
expr_if: &syn::ExprIf,
loop_level: usize,
Expand All @@ -135,19 +136,19 @@ pub(crate) fn codegen_if(
let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_tracker);

if let Some((_, expr)) = &expr_if.else_branch {
if let syn::Expr::Block(expr_block) = &**expr {
let else_block = codegen_block(&expr_block.block, loop_level + 1, variable_tracker);
let else_block = match &**expr {
syn::Expr::Block(expr_block) => {
codegen_block(&expr_block.block, loop_level + 1, variable_tracker)
}

quote::quote! {
let _cond = #cond;
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level + 1, variable_tracker),
_ => unreachable!(),
};
quote::quote! {
{
let _cond = #cond;
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
}
} else {
syn::Error::new_spanned(
expr,
"Unsupported: only `else` block is allowed after an `if` statement.",
)
.into_compile_error()
}
} else {
quote::quote! {
Expand Down
20 changes: 20 additions & 0 deletions crates/burn-cube/src/compute/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ impl KernelBuilder {
self.context.scalar(index, elem)
}

/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
self.outputs.push(OutputInfo::Array { item });
let variable = self.context.output(self.num_output, item);
self.num_output += 1;

variable
}

/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
self.inputs.push(InputInfo::Array {
item,
visibility: Visibility::Read,
});
let variable = self.context.input(self.num_input, item);
self.num_input += 1;
variable
}

/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn output_array(&mut self, item: Item) -> ExpandElement {
self.outputs.push(OutputInfo::Array { item });
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-cube/src/frontend/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ impl CubeContext {
}
}

pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
}

/// Obtain the index-th input
pub fn input(&mut self, index: u16, item: Item) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
Expand Down
80 changes: 77 additions & 3 deletions crates/burn-cube/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ use crate::{
ir::{Item, Vectorization},
unexpanded, KernelSettings, Runtime,
};
use crate::{
frontend::{indexation::Index, CubeContext},
prelude::{assign, index, index_assign, Comptime},
};

use super::{
ArgSettings, CubePrimitive, ExpandElementTyped, Init, LaunchArg, LaunchArgExpand, TensorHandle,
UInt,
ArgSettings, CubePrimitive, ExpandElement, ExpandElementTyped, Init, LaunchArg,
LaunchArgExpand, TensorHandle, UInt,
};

/// A contiguous array of elements.
#[derive(new)]
pub struct Array<E> {
_val: PhantomData<E>,
}
Expand All @@ -22,6 +25,77 @@ impl<C: CubeType> CubeType for Array<C> {
type ExpandType = ExpandElementTyped<Array<C>>;
}

impl<T: CubePrimitive + Clone> Array<T> {
pub fn new<S: Index>(_size: S) -> Self {
Array { _val: PhantomData }
}

pub fn new_expand<S: Index>(
context: &mut CubeContext,
size: S,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
_ => panic!("Array need constant initialization value"),
};
context
.create_local_array(Item::new(T::as_elem()), size)
.into()
}

pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
Array { _val: PhantomData }
}

pub fn vectorized_expand<S: Index>(
context: &mut CubeContext,
size: S,
vectorization_factor: UInt,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context
.create_local_array(
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
size,
)
.into()
}

pub fn to_vectorized(self, _vectorization_factor: Comptime<UInt>) -> T {
unexpanded!()
}
}

impl<C: CubeType> ExpandElementTyped<Array<C>> {
pub fn to_vectorized_expand(
self,
context: &mut CubeContext,
vectorization_factor: UInt,
) -> ExpandElement {
let factor = vectorization_factor.val;
let var = self.expand.clone();
let mut new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8));
if vectorization_factor.val == 1 {
let element = index::expand(context, self.clone(), 0u32);
assign::expand(context, element, new_var.clone());
} else {
for i in 0..factor {
let element = index::expand(context, self.expand.clone(), i);
new_var = index_assign::expand(context, new_var, i, element);
}
}
new_var
}
}

impl<C: CubeType> CubeType for &Array<C> {
type ExpandType = ExpandElementTyped<Array<C>>;
}
impl<C: CubeType> Init for ExpandElementTyped<Array<C>> {
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
// The type can't be deeply cloned/copied.
Expand Down
27 changes: 27 additions & 0 deletions crates/burn-cube/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub trait Float:
+ Erf
+ Recip
+ core::ops::Index<UInt, Output = Self>
+ core::ops::IndexMut<UInt, Output = Self>
{
fn new(val: f32) -> Self;
fn new_expand(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
Expand All @@ -35,6 +36,11 @@ pub trait Float:
val: f32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
fn vectorized_empty(vectorization: UInt) -> Self;
fn vectorized_empty_expand(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
}

macro_rules! impl_float {
Expand Down Expand Up @@ -101,6 +107,21 @@ macro_rules! impl_float {
new_var
}
}

fn vectorized_empty(vectorization: UInt) -> Self {
Self::vectorized(0., vectorization)
}

fn vectorized_empty_expand(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::new_expand(context, 0.)
} else {
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8))
}
}
}

impl core::ops::Index<UInt> for $type {
Expand All @@ -111,6 +132,12 @@ macro_rules! impl_float {
}
}

impl core::ops::IndexMut<UInt> for $type {
fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
unexpanded!()
}
}

impl LaunchArgExpand for $type {
fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
Expand Down
22 changes: 21 additions & 1 deletion crates/burn-cube/src/frontend/element/shared_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
ir::Item,
};

use super::{ExpandElement, Init};
use super::{ExpandElement, Init, UInt};

#[derive(Clone, Copy)]
pub struct SharedMemory<T: CubeType> {
Expand Down Expand Up @@ -49,4 +49,24 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
};
context.create_shared(Item::new(T::as_elem()), size)
}

pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
SharedMemory { _val: PhantomData }
}

pub fn vectorized_expand<S: Index>(
context: &mut CubeContext,
size: S,
vectorization_factor: UInt,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context.create_shared(
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
size,
)
}
}
12 changes: 6 additions & 6 deletions crates/burn-cube/src/frontend/operation/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ pub mod sub {
impl core::ops::Sub for $type {
type Output = Self;

fn sub(self, _rhs: Self) -> Self::Output {
unexpanded!()
fn sub(self, rhs: Self) -> Self::Output {
(self.val - rhs.val).into()
}
}
};
Expand Down Expand Up @@ -83,8 +83,8 @@ pub mod mul {
impl core::ops::Mul for $type {
type Output = Self;

fn mul(self, _rhs: Self) -> Self::Output {
unexpanded!()
fn mul(self, rhs: Self) -> Self::Output {
(self.val * rhs.val).into()
}
}
};
Expand Down Expand Up @@ -115,8 +115,8 @@ pub mod div {
impl core::ops::Div for $type {
type Output = Self;

fn div(self, _rhs: Self) -> Self::Output {
unexpanded!()
fn div(self, rhs: Self) -> Self::Output {
(self.val / rhs.val).into()
}
}
};
Expand Down
10 changes: 0 additions & 10 deletions crates/burn-cube/tests/error/if_else_if.rs

This file was deleted.

7 changes: 0 additions & 7 deletions crates/burn-cube/tests/error/if_else_if.stderr

This file was deleted.

Loading

0 comments on commit 69be99b

Please sign in to comment.