Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for batch matrix-matrix product in cuBLASLt #186

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/cublaslt/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@
}
}

/// Sets the value of the specified attribute belonging to a previously created matrix layout
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutsetattribute)
pub unsafe fn set_matrix_layout_attribute(

Check failure on line 67 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

unsafe function's docs miss `# Safety` section
matrix_layout: sys::cublasLtMatrixLayout_t,
attr: sys::cublasLtMatrixLayoutAttribute_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatrixLayoutSetAttribute(matrix_layout, attr, buf, buf_size).result()
}

/// Destroys a matrix layout previously created with [create_matrix_layout(...)]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutdestroy)
///
Expand Down Expand Up @@ -90,7 +102,7 @@
/// Sets the value of the specified attribute belonging to a previously created matrix multiply
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldescsetattribute)
pub unsafe fn set_matmul_desc_attribute(

Check failure on line 105 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

unsafe function's docs miss `# Safety` section
matmul_desc: sys::cublasLtMatmulDesc_t,
attr: sys::cublasLtMatmulDescAttributes_t,
buf: *const c_void,
Expand Down Expand Up @@ -124,7 +136,7 @@
/// Sets the value of the specified attribute belonging to a previously create matrix multiply
/// preferences descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencesetattribute)
pub unsafe fn set_matmul_pref_attribute(

Check failure on line 139 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

unsafe function's docs miss `# Safety` section
matmul_pref: sys::cublasLtMatmulPreference_t,
attr: sys::cublasLtMatmulPreferenceAttributes_t,
buf: *const c_void,
Expand Down Expand Up @@ -163,13 +175,13 @@

unsafe {
sys::cublasLtMatmulAlgoGetHeuristic(
handle,

Check failure on line 178 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
matmul_desc,

Check failure on line 179 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
a_layout,

Check failure on line 180 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
b_layout,

Check failure on line 181 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
c_layout,

Check failure on line 182 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
d_layout,

Check failure on line 183 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
matmul_pref,

Check failure on line 184 in src/cublaslt/result.rs

View workflow job for this annotation

GitHub Actions / clippy

this public function might dereference a raw pointer but is not marked `unsafe`
1, // only select the fastest algo
matmul_heuristic.as_mut_ptr(),
&mut algo_count,
Expand Down
82 changes: 82 additions & 0 deletions src/cublaslt/safe.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Safe abstractions around [crate::cublaslt::result] for doing matmul.

use super::{result, result::CublasError, sys};
use crate::cublaslt::result::set_matrix_layout_attribute;
use crate::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream};
use crate::driver::{CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError};
use core::ffi::c_int;
use core::mem;
use std::sync::Arc;

Expand Down Expand Up @@ -108,6 +110,11 @@ pub struct MatmulConfig {
pub ldb: i64,
pub beta: f32,
pub ldc: i64,
pub stride_a: Option<i64>,
pub stride_b: Option<i64>,
pub stride_c: Option<i64>,
pub stride_bias: Option<i64>,
pub batch_size: Option<c_int>,
}

/// Matrix matrix multiplication with elements of type `T`.
Expand Down Expand Up @@ -146,8 +153,58 @@ pub trait Matmul<T>: MatmulShared {

// Creates matrix layouts
let a_layout = result::create_matrix_layout(Self::matrix_type(), a_rows, a_cols, cfg.lda)?;
if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) {
// Set batch size
set_matrix_layout_attribute(
a_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&batch_size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
a_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride_a) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}

let b_layout = result::create_matrix_layout(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?;
if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) {
// Set batch size
set_matrix_layout_attribute(
b_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&batch_size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
b_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride_b) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}

let c_layout = result::create_matrix_layout(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
// Set batch size
set_matrix_layout_attribute(
c_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&batch_size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
c_layout,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride_c) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}

// Matmul description
let matmul_desc =
Expand Down Expand Up @@ -189,6 +246,16 @@ pub trait Matmul<T>: MatmulShared {
bias.device_ptr() as *const CUdeviceptr as *const _,
mem::size_of::<CUdeviceptr>(),
)?;

if let Some(stride_bias) = cfg.stride_bias {
// Set bias batch stride
result::set_matmul_desc_attribute(
matmul_desc,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE,
(&stride_bias) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}
epilogue
} else if let Some(act) = act {
// Only Act
Expand Down Expand Up @@ -388,6 +455,11 @@ mod tests {
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
Expand Down Expand Up @@ -503,6 +575,11 @@ mod tests {
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
Expand Down Expand Up @@ -552,6 +629,11 @@ mod tests {
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
Expand Down
Loading