Skip to content

Commit

Permalink
[Breaking] feat: add support for batch matrix-matrix product in cuBLA…
Browse files Browse the repository at this point in the history
…SLt (#186)

* feat: add support for batch matrix-matrix product

* fix tests
  • Loading branch information
OlivierDehaene authored Nov 6, 2023
1 parent 902bd49 commit ef7c675
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/cublaslt/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ pub fn create_matrix_layout(
}
}

/// Sets the value of the specified attribute belonging to a previously created matrix layout
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutsetattribute)
pub unsafe fn set_matrix_layout_attribute(
matrix_layout: sys::cublasLtMatrixLayout_t,
attr: sys::cublasLtMatrixLayoutAttribute_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatrixLayoutSetAttribute(matrix_layout, attr, buf, buf_size).result()
}

/// Destroys a matrix layout previously created with [create_matrix_layout(...)]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutdestroy)
///
Expand Down
82 changes: 82 additions & 0 deletions src/cublaslt/safe.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Safe abstractions around [crate::cublaslt::result] for doing matmul.
use super::{result, result::CublasError, sys};
use crate::cublaslt::result::set_matrix_layout_attribute;
use crate::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream};
use crate::driver::{CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError};
use core::ffi::c_int;
use core::mem;
use std::sync::Arc;

Expand Down Expand Up @@ -108,6 +110,11 @@ pub struct MatmulConfig {
pub ldb: i64,
pub beta: f32,
pub ldc: i64,
pub stride_a: Option<i64>,
pub stride_b: Option<i64>,
pub stride_c: Option<i64>,
pub stride_bias: Option<i64>,
pub batch_size: Option<c_int>,
}

/// Matrix matrix multiplication with elements of type `T`.
Expand Down Expand Up @@ -146,8 +153,58 @@ pub trait Matmul<T>: MatmulShared {

// Creates matrix layouts
let a_layout = result::create_matrix_layout(Self::matrix_type(), a_rows, a_cols, cfg.lda)?;
if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) {
// Set batch size
set_matrix_layout_attribute(
a_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&batch_size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
a_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride_a) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}

let b_layout = result::create_matrix_layout(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?;
if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) {
// Set batch size
set_matrix_layout_attribute(
b_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&batch_size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
b_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride_b) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}

let c_layout = result::create_matrix_layout(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
// Set batch size
set_matrix_layout_attribute(
c_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&batch_size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
c_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride_c) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}

// Matmul description
let matmul_desc =
Expand Down Expand Up @@ -189,6 +246,16 @@ pub trait Matmul<T>: MatmulShared {
bias.device_ptr() as *const CUdeviceptr as *const _,
mem::size_of::<CUdeviceptr>(),
)?;

if let Some(stride_bias) = cfg.stride_bias {
// Set bias batch stride
result::set_matmul_desc_attribute(
matmul_desc,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE,
(&stride_bias) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}
epilogue
} else if let Some(act) = act {
// Only Act
Expand Down Expand Up @@ -388,6 +455,11 @@ mod tests {
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
Expand Down Expand Up @@ -503,6 +575,11 @@ mod tests {
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
Expand Down Expand Up @@ -552,6 +629,11 @@ mod tests {
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
Expand Down

0 comments on commit ef7c675

Please sign in to comment.