Skip to content

Commit

Permalink
maxwell compat
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 27, 2024
1 parent 196c8e0 commit fa6f597
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2142,7 +2142,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
using BlockReduceT = cub::BlockReduce<T, THREADS>;

// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE && __CUDACC__
using TReduction = T;
#else
using TReduction = float;
#endif

using BlockReduceT = cub::BlockReduce<TReduction, THREADS>;

// One block per row.
// Threads load column values in a striped arrangement.
Expand All @@ -2152,27 +2161,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
// We then do a blockwise reduction to determine the row's absmax.

__shared__ typename BlockReduceT::TempStorage temp_storage;
__shared__ T smem_row_absmax;
__shared__ TReduction smem_row_absmax;

const int row_id = blockIdx.x;
const T* row_data = A + (row_id * cols);

// Threads will read the row values in a striped access pattern and find a local absmax.
T row_local_absmax = -FLT_MIN;
TReduction row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const T absval = fabsf(__ldcs(&(row_data[i])));
const TReduction absval = fabsf(__ldcs(&(row_data[i])));

// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax);
row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}

// Reduce thread-local absmax across the block.
const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = smem_row_absmax = row_absmax;
Expand Down

0 comments on commit fa6f597

Please sign in to comment.