Skip to content

Commit

Permalink
metal lowbit kernels: split scales and zero points (pytorch#1202)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1202

Differential Revision: D65232787
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Oct 30, 2024
1 parent 4f1fc4c commit fe06a21
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 96 deletions.
25 changes: 14 additions & 11 deletions torchao/experimental/kernels/mps/metal/divbit.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using namespace metal;
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -16,9 +17,10 @@ template<typename T, unsigned nbit, unsigned groupSize>
kernel void divbit_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
constant T * scales [[buffer(2)]],
constant T * zeros [[buffer(3)]],
device T * outputData [[buffer(4)]],
constant uint3 & sizes [[buffer(5)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
Expand All @@ -35,29 +37,30 @@ kernel void divbit_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0];
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift);
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte];
uint8_t shift = nbit * (k % values_per_byte);
uint8_t mask = minimask << shift;
b_val = (b_val & mask) >> shift;
rc += a_val * float(scale * T(b_val) + zero);
rc += a_val * (scale * float(b_val) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
template \
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
constant DTYPE * scales [[buffer(2)]], \
constant DTYPE * zeros [[buffer(3)]], \
device DTYPE * outputData [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_DIVBIT_MM(1, float, 32);
Expand Down
21 changes: 12 additions & 9 deletions torchao/experimental/kernels/mps/metal/int3mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using namespace metal;
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
kernel void int3pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
constant T * scales [[buffer(2)]],
constant T * zeros [[buffer(3)]],
device T * outputData [[buffer(4)]],
constant uint3 & sizes [[buffer(5)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
Expand All @@ -31,8 +33,8 @@ kernel void int3pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4);
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand Down Expand Up @@ -76,9 +78,10 @@ template \
kernel void int3pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
constant DTYPE * scales [[buffer(2)]], \
constant DTYPE * zeros [[buffer(3)]], \
device DTYPE * outputData [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT3MM(float, 32);
Expand Down
21 changes: 12 additions & 9 deletions torchao/experimental/kernels/mps/metal/int5mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using namespace metal;
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
kernel void int5pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
constant T * scales [[buffer(2)]],
constant T * zeros [[buffer(3)]],
device T * outputData [[buffer(4)]],
constant uint3 & sizes [[buffer(5)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
Expand All @@ -31,8 +33,8 @@ kernel void int5pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16);
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand Down Expand Up @@ -78,9 +80,10 @@ template \
kernel void int5pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
constant DTYPE * scales [[buffer(2)]], \
constant DTYPE * zeros [[buffer(3)]], \
device DTYPE * outputData [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT5MM(float, 32);
Expand Down
21 changes: 12 additions & 9 deletions torchao/experimental/kernels/mps/metal/int6mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using namespace metal;
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
kernel void int6pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
constant T * scales [[buffer(2)]],
constant T * zeros [[buffer(3)]],
device T * outputData [[buffer(4)]],
constant uint3 & sizes [[buffer(5)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
Expand All @@ -31,8 +33,8 @@ kernel void int6pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32);
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand Down Expand Up @@ -63,9 +65,10 @@ template \
kernel void int6pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
constant DTYPE * scales [[buffer(2)]], \
constant DTYPE * zeros [[buffer(3)]], \
device DTYPE * outputData [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT6MM(float, 32);
Expand Down
21 changes: 12 additions & 9 deletions torchao/experimental/kernels/mps/metal/int7mm.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using namespace metal;
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
Expand All @@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
kernel void int7pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
constant T * scales [[buffer(2)]],
constant T * zeros [[buffer(3)]],
device T * outputData [[buffer(4)]],
constant uint3 & sizes [[buffer(5)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
Expand All @@ -31,8 +33,8 @@ kernel void int7pack_mm(
float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(64);
const float scale = float(scales[kb * N + n]);
const float zero = float(zeros[kb * N + n]);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
Expand Down Expand Up @@ -80,9 +82,10 @@ template \
kernel void int7pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
constant DTYPE * scales [[buffer(2)]], \
constant DTYPE * zeros [[buffer(3)]], \
device DTYPE * outputData [[buffer(4)]], \
constant uint3 & sizes [[buffer(5)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT7MM(float, 32);
Expand Down
16 changes: 10 additions & 6 deletions torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ using DispatchFn =
inline void linear_lowbit_quant_weights_mps_impl(
id<MTLBuffer> a_buf,
id<MTLBuffer> b_buf,
id<MTLBuffer> sz_buf,
id<MTLBuffer> s_buf,
id<MTLBuffer> z_buf,
id<MTLBuffer> out_buf,
int32_t M,
int32_t K,
Expand All @@ -111,11 +112,12 @@ inline void linear_lowbit_quant_weights_mps_impl(
[computeEncoder setComputePipelineState:cpl];
[computeEncoder setBuffer:a_buf offset:0 atIndex:0];
[computeEncoder setBuffer:b_buf offset:0 atIndex:1];
[computeEncoder setBuffer:sz_buf offset:0 atIndex:2];
[computeEncoder setBuffer:out_buf offset:0 atIndex:3];
[computeEncoder setBuffer:s_buf offset:0 atIndex:2];
[computeEncoder setBuffer:z_buf offset:0 atIndex:3];
[computeEncoder setBuffer:out_buf offset:0 atIndex:4];
[computeEncoder setBytes:sizes.data()
length:sizeof(uint32_t) * sizes.size()
atIndex:4];
atIndex:5];
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
finalize_block(mpsStream);
}
Expand All @@ -128,7 +130,8 @@ void linear_lowbit_quant_weights_mps(
id<MTLBuffer> a_buf,
id<MTLBuffer> b_buf,
int64_t qGroupSize,
id<MTLBuffer> sz_buf,
id<MTLBuffer> s_buf,
id<MTLBuffer> z_buf,
id<MTLBuffer> out_buf,
int32_t M,
int32_t K,
Expand All @@ -143,7 +146,8 @@ void linear_lowbit_quant_weights_mps(
return linear_lowbit_quant_weights_mps_impl(
a_buf,
b_buf,
sz_buf,
s_buf,
z_buf,
out_buf,
M,
K,
Expand Down
Loading

0 comments on commit fe06a21

Please sign in to comment.