Skip to content

Commit

Permalink
metal lowbit kernels: optimized 2-bit, 3-bit and 4-bit shaders (#1422)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelcandales authored Dec 18, 2024
1 parent f52d3ab commit 2e032c6
Show file tree
Hide file tree
Showing 17 changed files with 599 additions and 188 deletions.
11 changes: 7 additions & 4 deletions torchao/experimental/kernels/mps/metal.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
- func: Vec4Type
file: common.metal

- func: int1mm
file: divbit.metal
file: int1mm.metal

- func: int2mm
file: divbit.metal
file: int2mm_opt.metal

- func: int3mm
file: int3mm.metal
file: int3mm_opt.metal

- func: int4mm
file: divbit.metal
file: int4mm_opt.metal

- func: int5mm
file: int5mm.metal
Expand Down
15 changes: 15 additions & 0 deletions torchao/experimental/kernels/mps/metal/common.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
template <typename T> struct Vec4Type {};

template <> struct Vec4Type<float> {
using type = float4;
};

template <> struct Vec4Type<half> {
using type = half4;
};

#if __METAL_VERSION__ >= 310
template <> struct Vec4Type<bfloat> {
using type = bfloat4;
};
#endif
109 changes: 0 additions & 109 deletions torchao/experimental/kernels/mps/metal/divbit.metal

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
using namespace metal;

/**
* 3-Bit Quantized Linear.
* 1-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
* @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8)
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
Expand All @@ -14,7 +14,7 @@ using namespace metal;
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int3pack_mm(
kernel void int1pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scales [[buffer(2)]],
Expand All @@ -28,7 +28,7 @@ kernel void int3pack_mm(
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 3 * K / 8;
constant uchar *B_ptr = B + n * K / 8;

float rc = 0.0;
uint k = 0;
Expand All @@ -45,19 +45,16 @@ kernel void int3pack_mm(
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[3 * (k / 8) + 0];
uchar b1 = B_ptr[3 * (k / 8) + 1];
uchar b2 = B_ptr[3 * (k / 8) + 2];
uchar b0 = B_ptr[(k / 8)];

uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);

uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);
uchar w_val0 = b0 & 0x01;
uchar w_val1 = (b0 & 0x02) >> 1;
uchar w_val2 = (b0 & 0x04) >> 2;
uchar w_val3 = (b0 & 0x08) >> 3;
uchar w_val4 = (b0 & 0x10) >> 4;
uchar w_val5 = (b0 & 0x20) >> 5;
uchar w_val6 = (b0 & 0x40) >> 6;
uchar w_val7 = (b0 & 0x80) >> 7;

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
Expand All @@ -72,10 +69,10 @@ kernel void int3pack_mm(
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \
template \
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int3pack_mm<DTYPE, GSIZE>( \
[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int1pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scales [[buffer(2)]], \
Expand All @@ -84,17 +81,17 @@ kernel void int3pack_mm<DTYPE, GSIZE>( \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT3MM(float, 32);
INSTANTIATE_INT3MM(half, 32);
INSTANTIATE_INT3MM(float, 64);
INSTANTIATE_INT3MM(half, 64);
INSTANTIATE_INT3MM(float, 128);
INSTANTIATE_INT3MM(half, 128);
INSTANTIATE_INT3MM(float, 256);
INSTANTIATE_INT3MM(half, 256);
INSTANTIATE_INT1MM(float, 32);
INSTANTIATE_INT1MM(half, 32);
INSTANTIATE_INT1MM(float, 64);
INSTANTIATE_INT1MM(half, 64);
INSTANTIATE_INT1MM(float, 128);
INSTANTIATE_INT1MM(half, 128);
INSTANTIATE_INT1MM(float, 256);
INSTANTIATE_INT1MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT3MM(bfloat, 32);
INSTANTIATE_INT3MM(bfloat, 64);
INSTANTIATE_INT3MM(bfloat, 128);
INSTANTIATE_INT3MM(bfloat, 256);
INSTANTIATE_INT1MM(bfloat, 32);
INSTANTIATE_INT1MM(bfloat, 64);
INSTANTIATE_INT1MM(bfloat, 128);
INSTANTIATE_INT1MM(bfloat, 256);
#endif
138 changes: 138 additions & 0 deletions torchao/experimental/kernels/mps/metal/int2mm_opt.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#include <metal_simdgroup>
#include <metal_stdlib>
using namespace metal;

/*
This code takes heavy inspiration from MLX:
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h
Specifically:
- Multiplying activation by inverse scaling factor to reduce compute
boundedness
- Handling zero point by accumulating act in separate sum term. Needed with
optimization done above. MLX MIT License:
https://github.com/ml-explore/mlx/blob/main/LICENSE
*/

/*
@brief This shader implements 2-bit matrix-vector multiplication where A
matrix is fp16, bfloat or float and B matrix is a 2-bit groupwise-quantized weight
matrix.
@param [in] A is activation matrix of size M x K.
@param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit
values, along K dim, packed together.
@param [in] scales_ptr is scales ptr corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
channels.
@param [in] zeros_ptr is zero points corresponding each
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
channels.
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
@param [out] output_data is output matrix of size M x N.
@param [in] sizes array contains values of M, K and N.
@param [in] thread_index is global thread id.
@param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31].
*/
template <typename T, unsigned group_size>
kernel void int2pack_mm(constant T *A [[buffer(0)]],
constant uchar *B [[buffer(1)]],
constant T *scales_ptr [[buffer(2)]],
constant T *zeros_ptr [[buffer(3)]],
device T *output_data [[buffer(4)]],
constant uint3 &sizes [[buffer(5)]], // M, K, N
uint3 thread_index [[thread_position_in_grid]],
uint tid_in_simdgroup [[thread_index_in_simdgroup]]) {
constexpr uint threads_per_channel = 32;
constexpr uint ks_per_thread = 4;
constexpr uint k_pack_factor = 4;
const uint K = sizes.y;
const uint N = sizes.z;
uint n = thread_index.x; // 0..N/4-1
uint m = thread_index.z; // 0..M
n = n / threads_per_channel;
n = n * 4;
// This is starting k for each thread. In the example above, for thread 1 this
// value will be 4.
uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread;
constexpr int k_jump = threads_per_channel * ks_per_thread;

using vecT = typename Vec4Type<T>::type;
constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K);
constant uchar *B_ptr = B + ((n * K) / k_pack_factor);

thread float4 result = float4(0.0);
// We multipy group of 4 channels with these scales.
// Because corresponding values from weight matrix are effectively left
// shifted. This is to avoid doing right shift on those values which ends up
// affecting performance. This is the trick applied in MLX kernels.
float4 act_div_scales = {1.f, 1 / 4.f, 1 / 16.f, 1 / 64.f};

for (; k < K; k += k_jump) {
// Find specific group to which channels handled by this thread
// belong.
uint k_block_index = k / group_size;
uint scales_group_offset = (k_block_index * N + n);

vecT scales =
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
// Adding zero point results in 10% perf penalty.
vecT zeros =
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
float4 zeros_float = float4(zeros);

float4 a_val = float4(A_ptr[k / 4]);
// We are gonna skip right-shifts of the weights and hence divide by corresponding factor.
float4 a_vec = a_val * act_div_scales;
float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3];

float4x4 b_mat;
ushort b_val0 = (B_ptr + (k + 0 * K) / k_pack_factor)[0];
ushort b_val1 = (B_ptr + (k + 1 * K) / k_pack_factor)[0];
ushort b_val2 = (B_ptr + (k + 2 * K) / k_pack_factor)[0];
ushort b_val3 = (B_ptr + (k + 3 * K) / k_pack_factor)[0];
b_mat[0] = scales[0] * float4(float(b_val0 & 0x03), float(b_val0 & 0x0c),
float(b_val0 & 0x30), float(b_val0 & 0xc0));
b_mat[1] = scales[1] * float4(float(b_val1 & 0x03), float(b_val1 & 0x0c),
float(b_val1 & 0x30), float(b_val1 & 0xc0));
b_mat[2] = scales[2] * float4(float(b_val2 & 0x03), float(b_val2 & 0x0c),
float(b_val2 & 0x30), float(b_val2 & 0xc0));
b_mat[3] = scales[3] * float4(float(b_val3 & 0x03), float(b_val3 & 0x0c),
float(b_val3 & 0x30), float(b_val3 & 0xc0));

result += a_vec * b_mat;
result += a_val_sum * zeros_float;
}
result += simd_shuffle_down(result, 1);
result += simd_shuffle_down(result, 2);
result += simd_shuffle_down(result, 4);
result += simd_shuffle_down(result, 8);
result += simd_shuffle_down(result, 16);
if (tid_in_simdgroup % threads_per_channel == 0) {
reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result);
}
}

#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \
template [[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \
int2pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \
constant DTYPE * scales_ptr [[buffer(2)]], \
constant DTYPE * zeros_ptr [[buffer(3)]], \
device DTYPE * output_data [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint3 thread_index [[thread_position_in_grid]], \
uint tid_in_simdgroup [[thread_index_in_simdgroup]])

INSTANTIATE_INT2MM(float, 32);
INSTANTIATE_INT2MM(half, 32);
INSTANTIATE_INT2MM(float, 64);
INSTANTIATE_INT2MM(half, 64);
INSTANTIATE_INT2MM(float, 128);
INSTANTIATE_INT2MM(half, 128);
INSTANTIATE_INT2MM(float, 256);
INSTANTIATE_INT2MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT2MM(bfloat, 32);
INSTANTIATE_INT2MM(bfloat, 64);
INSTANTIATE_INT2MM(bfloat, 128);
INSTANTIATE_INT2MM(bfloat, 256);
#endif
Loading

0 comments on commit 2e032c6

Please sign in to comment.