From e4211da395c1b69aba03ffc4cadb6b8a44bdcd4f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 11 Nov 2024 11:26:48 -0800 Subject: [PATCH] metal lowbit kernels: split scales and zero points (#1202) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1202 Reviewed By: malfet Differential Revision: D65232787 --- .../kernels/mps/metal/divbit.metal | 25 ++++++----- .../kernels/mps/metal/int3mm.metal | 21 ++++++---- .../kernels/mps/metal/int5mm.metal | 21 ++++++---- .../kernels/mps/metal/int6mm.metal | 21 ++++++---- .../kernels/mps/metal/int7mm.metal | 21 ++++++---- torchao/experimental/kernels/mps/src/lowbit.h | 16 ++++--- .../kernels/mps/test/test_lowbit.mm | 42 ++++++++++++------- torchao/experimental/ops/mps/register.mm | 37 ++++++++++------ .../experimental/ops/mps/test/test_lowbit.py | 27 ++++++------ 9 files changed, 134 insertions(+), 97 deletions(-) diff --git a/torchao/experimental/kernels/mps/metal/divbit.metal b/torchao/experimental/kernels/mps/metal/divbit.metal index 68f5f7dc03..5c8b146643 100644 --- a/torchao/experimental/kernels/mps/metal/divbit.metal +++ b/torchao/experimental/kernels/mps/metal/divbit.metal @@ -6,7 +6,8 @@ using namespace metal; * * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8) - * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -16,9 +17,10 @@ template kernel void divbit_mm( constant T * A [[buffer(0)]], constant uchar * B [[buffer(1)]], - constant T * scalesAndZeros [[buffer(2)]], - device T * outputData [[buffer(3)]], - constant uint3 & sizes [[buffer(4)]], // M, K, N + constant T * scales [[buffer(2)]], + constant T * zeros [[buffer(3)]], + device T * outputData [[buffer(4)]], + constant uint3 & sizes [[buffer(5)]], // M, K, N uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; @@ -35,29 +37,30 @@ kernel void divbit_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const T scale = scalesAndZeros[(kb * N + n) * 2 + 0]; - const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift); + const float scale = float(scales[kb * N + n]); + const float zero = float(zeros[kb * N + n]); for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { const auto a_val = float(A_ptr[k]); uint8_t b_val = B_ptr[(n * K + k) / values_per_byte]; uint8_t shift = nbit * (k % values_per_byte); uint8_t mask = minimask << shift; b_val = (b_val & mask) >> shift; - rc += a_val * float(scale * T(b_val) + zero); + rc += a_val * (scale * float(b_val) + zero); } } outputData[m * N + n] = T(rc); } -#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ +#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ template \ [[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ kernel void divbit_mm( \ constant DTYPE * A [[buffer(0)]], \ constant uchar * B [[buffer(1)]], \ - constant DTYPE * scalesAndZeros [[buffer(2)]], \ - device DTYPE * outputData [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ + constant DTYPE * scales [[buffer(2)]], \ + constant DTYPE * zeros [[buffer(3)]], \ + device DTYPE * outputData [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ uint2 thread_index [[thread_position_in_grid]]) INSTANTIATE_DIVBIT_MM(1, float, 32); diff --git a/torchao/experimental/kernels/mps/metal/int3mm.metal b/torchao/experimental/kernels/mps/metal/int3mm.metal index 8fd68cd768..4a44345b83 100644 --- a/torchao/experimental/kernels/mps/metal/int3mm.metal +++ b/torchao/experimental/kernels/mps/metal/int3mm.metal @@ -6,7 +6,8 @@ using namespace metal; * * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) - * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -16,9 +17,10 @@ template kernel void int3pack_mm( constant T * A [[buffer(0)]], constant uchar * B [[buffer(1)]], - constant T * scalesAndZeros [[buffer(2)]], - device T * outputData [[buffer(3)]], - constant uint3 & sizes [[buffer(4)]], // M, K, N + constant T * scales [[buffer(2)]], + constant T * zeros [[buffer(3)]], + device T * outputData [[buffer(4)]], + constant uint3 & sizes [[buffer(5)]], // M, K, N uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; @@ -31,8 +33,8 @@ kernel void int3pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); - const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); + const float scale = float(scales[kb * N + n]); + const float zero = float(zeros[kb * N + n]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -76,9 +78,10 @@ template \ kernel void int3pack_mm( \ constant DTYPE * A [[buffer(0)]], \ constant uchar * B [[buffer(1)]], \ - constant DTYPE * scalesAndZeros [[buffer(2)]], \ - device DTYPE * outputData [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ + constant DTYPE * scales [[buffer(2)]], \ + constant DTYPE * zeros [[buffer(3)]], \ + device DTYPE * outputData [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ uint2 thread_index [[thread_position_in_grid]]) INSTANTIATE_INT3MM(float, 32); diff --git a/torchao/experimental/kernels/mps/metal/int5mm.metal b/torchao/experimental/kernels/mps/metal/int5mm.metal index 84aba20725..d854b7f90e 100644 --- a/torchao/experimental/kernels/mps/metal/int5mm.metal +++ b/torchao/experimental/kernels/mps/metal/int5mm.metal @@ -6,7 +6,8 @@ using namespace metal; * * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8) - * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -16,9 +17,10 @@ template kernel void int5pack_mm( constant T * A [[buffer(0)]], constant uchar * B [[buffer(1)]], - constant T * scalesAndZeros [[buffer(2)]], - device T * outputData [[buffer(3)]], - constant uint3 & sizes [[buffer(4)]], // M, K, N + constant T * scales [[buffer(2)]], + constant T * zeros [[buffer(3)]], + device T * outputData [[buffer(4)]], + constant uint3 & sizes [[buffer(5)]], // M, K, N uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; @@ -31,8 +33,8 @@ kernel void int5pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); - const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16); + const float scale = float(scales[kb * N + n]); + const float zero = float(zeros[kb * N + n]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -78,9 +80,10 @@ template \ kernel void int5pack_mm( \ constant DTYPE * A [[buffer(0)]], \ constant uchar * B [[buffer(1)]], \ - constant DTYPE * scalesAndZeros [[buffer(2)]], \ - device DTYPE * outputData [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ + constant DTYPE * scales [[buffer(2)]], \ + constant DTYPE * zeros [[buffer(3)]], \ + device DTYPE * outputData [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ uint2 thread_index [[thread_position_in_grid]]) INSTANTIATE_INT5MM(float, 32); diff --git a/torchao/experimental/kernels/mps/metal/int6mm.metal b/torchao/experimental/kernels/mps/metal/int6mm.metal index 7b99b749e7..a43f5c0e0a 100644 --- a/torchao/experimental/kernels/mps/metal/int6mm.metal +++ b/torchao/experimental/kernels/mps/metal/int6mm.metal @@ -6,7 +6,8 @@ using namespace metal; * * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8) - * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -16,9 +17,10 @@ template kernel void int6pack_mm( constant T * A [[buffer(0)]], constant uchar * B [[buffer(1)]], - constant T * scalesAndZeros [[buffer(2)]], - device T * outputData [[buffer(3)]], - constant uint3 & sizes [[buffer(4)]], // M, K, N + constant T * scales [[buffer(2)]], + constant T * zeros [[buffer(3)]], + device T * outputData [[buffer(4)]], + constant uint3 & sizes [[buffer(5)]], // M, K, N uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; @@ -31,8 +33,8 @@ kernel void int6pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); - const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32); + const float scale = float(scales[kb * N + n]); + const float zero = float(zeros[kb * N + n]); for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -63,9 +65,10 @@ template \ kernel void int6pack_mm( \ constant DTYPE * A [[buffer(0)]], \ constant uchar * B [[buffer(1)]], \ - constant DTYPE * scalesAndZeros [[buffer(2)]], \ - device DTYPE * outputData [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ + constant DTYPE * scales [[buffer(2)]], \ + constant DTYPE * zeros [[buffer(3)]], \ + device DTYPE * outputData [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ uint2 thread_index [[thread_position_in_grid]]) INSTANTIATE_INT6MM(float, 32); diff --git a/torchao/experimental/kernels/mps/metal/int7mm.metal b/torchao/experimental/kernels/mps/metal/int7mm.metal index bcd03f50f7..57c74402d9 100644 --- a/torchao/experimental/kernels/mps/metal/int7mm.metal +++ b/torchao/experimental/kernels/mps/metal/int7mm.metal @@ -6,7 +6,8 @@ using namespace metal; * * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8) - * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 + * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -16,9 +17,10 @@ template kernel void int7pack_mm( constant T * A [[buffer(0)]], constant uchar * B [[buffer(1)]], - constant T * scalesAndZeros [[buffer(2)]], - device T * outputData [[buffer(3)]], - constant uint3 & sizes [[buffer(4)]], // M, K, N + constant T * scales [[buffer(2)]], + constant T * zeros [[buffer(3)]], + device T * outputData [[buffer(4)]], + constant uint3 & sizes [[buffer(5)]], // M, K, N uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; @@ -31,8 +33,8 @@ kernel void int7pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); - const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(64); + const float scale = float(scales[kb * N + n]); + const float zero = float(zeros[kb * N + n]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -80,9 +82,10 @@ template \ kernel void int7pack_mm( \ constant DTYPE * A [[buffer(0)]], \ constant uchar * B [[buffer(1)]], \ - constant DTYPE * scalesAndZeros [[buffer(2)]], \ - device DTYPE * outputData [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ + constant DTYPE * scales [[buffer(2)]], \ + constant DTYPE * zeros [[buffer(3)]], \ + device DTYPE * outputData [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ uint2 thread_index [[thread_position_in_grid]]) INSTANTIATE_INT7MM(float, 32); diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index 8aed682afc..d10d00c284 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -88,7 +88,8 @@ using DispatchFn = inline void linear_lowbit_quant_weights_mps_impl( id a_buf, id b_buf, - id sz_buf, + id s_buf, + id z_buf, id out_buf, int32_t M, int32_t K, @@ -111,11 +112,12 @@ inline void linear_lowbit_quant_weights_mps_impl( [computeEncoder setComputePipelineState:cpl]; [computeEncoder setBuffer:a_buf offset:0 atIndex:0]; [computeEncoder setBuffer:b_buf offset:0 atIndex:1]; - [computeEncoder setBuffer:sz_buf offset:0 atIndex:2]; - [computeEncoder setBuffer:out_buf offset:0 atIndex:3]; + [computeEncoder setBuffer:s_buf offset:0 atIndex:2]; + [computeEncoder setBuffer:z_buf offset:0 atIndex:3]; + [computeEncoder setBuffer:out_buf offset:0 atIndex:4]; [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() - atIndex:4]; + atIndex:5]; dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); finalize_block(mpsStream); } @@ -128,7 +130,8 @@ void linear_lowbit_quant_weights_mps( id a_buf, id b_buf, int64_t qGroupSize, - id sz_buf, + id s_buf, + id z_buf, id out_buf, int32_t M, int32_t K, @@ -143,7 +146,8 @@ void linear_lowbit_quant_weights_mps( return linear_lowbit_quant_weights_mps_impl( a_buf, b_buf, - sz_buf, + s_buf, + z_buf, out_buf, M, K, diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index e9e4eb9025..398af237ae 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -45,14 +45,12 @@ void reference_linear_lowbit_quant_weights_cpu( const T* a_ptr, const uint8_t* w_ptr, int64_t group_size, - const T* sz_ptr, + const T* s_ptr, + const T* z_ptr, T* out_ptr, int32_t M, int32_t K, - int32_t N, - int64_t nbit) { - uint8_t zero_shift = 1 << (nbit - 1); - + int32_t N) { for (int32_t m = 0; m < M; m++) { for (int32_t n = 0; n < N; n++) { const int32_t k_block = (K + group_size - 1) / group_size; @@ -61,9 +59,8 @@ void reference_linear_lowbit_quant_weights_cpu( float rc = 0.0; int32_t k = 0; for (int32_t kb = 0; kb < k_block; kb++) { - const float scale = float(sz_ptr[(kb * N + n) * 2 + 0]); - const float zero = - float(sz_ptr[(kb * N + n) * 2 + 1]) - scale * float(zero_shift); + const float scale = float(s_ptr[kb * N + n]); + const float zero = float(z_ptr[kb * N + n]); for (int32_t idx = 0; idx < group_size && k < K; idx++, k++) { const auto a_val = float(A_ptr[k]); uint8_t w_val = w_ptr[n * K + k]; @@ -88,7 +85,8 @@ void init() { T* a_ptr = reinterpret_cast([buf_A contents]); uint8_t* w_ptr = reinterpret_cast([buf_W contents]); T* c_ptr = reinterpret_cast([buf_C contents]); - T* s_ptr = reinterpret_cast([buf_SZ contents]); + T* s_ptr = reinterpret_cast([buf_S contents]); + T* z_ptr = reinterpret_cast([buf_Z contents]); std::random_device rd; std::mt19937 generator(rd()); std::uniform_int_distribution<> int_distrib(0, (1 << nbit) - 1); @@ -102,8 +100,8 @@ void init() { } int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize; for (int idx = 0; idx < N * ceil_K_group_size; ++idx) { - s_ptr[2 * idx] = (idx + 1.0) / N; - s_ptr[2 * idx + 1] = 0; + s_ptr[idx] = (idx + 1.0) / N; + z_ptr[idx] = int_distrib(generator); } for (int idx = 0; idx < M * N; ++idx) { c_ptr[idx] = -1.0; @@ -118,19 +116,29 @@ void pack() { void linear() { LowBitQuantWeights::linear( - buf_A, buf_B, qGroupSize, buf_SZ, buf_C, M, K, N, type_string()); + buf_A, + buf_B, + qGroupSize, + buf_S, + buf_Z, + buf_C, + M, + K, + N, + type_string()); } bool validate(float atol_lim = 5e-3, float rtol_lim = 5e-3) const { T* a_ptr = reinterpret_cast([buf_A contents]); uint8_t* w_ptr = reinterpret_cast([buf_W contents]); T* c_ptr = reinterpret_cast([buf_C contents]); - T* sz_ptr = reinterpret_cast([buf_SZ contents]); + T* s_ptr = reinterpret_cast([buf_S contents]); + T* z_ptr = reinterpret_cast([buf_Z contents]); char* e_ptr_f = new char[M * N * sizeof(T)]; // expected T* e_ptr = reinterpret_cast(e_ptr_f); reference_linear_lowbit_quant_weights_cpu( - a_ptr, w_ptr, qGroupSize, sz_ptr, e_ptr, M, K, N, nbit); + a_ptr, w_ptr, qGroupSize, s_ptr, z_ptr, e_ptr, M, K, N); for (int m = 0; m < M; m++) { for (int n = 0; n < N; n++) { @@ -159,7 +167,8 @@ void allocBuffers(id device) { buf_W = allocSharedBuffer(device, N * K); buf_B = allocSharedBuffer(device, N * nbit * K / 8); buf_C = allocSharedBuffer(device, M * N * elem_size); - buf_SZ = allocSharedBuffer(device, N * ceil_K_group_size * 2 * elem_size); + buf_S = allocSharedBuffer(device, N * ceil_K_group_size * elem_size); + buf_Z = allocSharedBuffer(device, N * ceil_K_group_size * elem_size); } public: @@ -169,7 +178,8 @@ void allocBuffers(id device) { id buf_W; // NxK elements id buf_B; // NxK elements (packed) id buf_C; // MxN elements - id buf_SZ; // (K/group_size)xNx2 elements + id buf_S; // (K/group_size)xN elements + id buf_Z; // (K/group_size)xN elements }; } // namespace torchao::kernels::mps::lowbit diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/register.mm index dc2bd0f700..11e17b7481 100644 --- a/torchao/experimental/ops/mps/register.mm +++ b/torchao/experimental/ops/mps/register.mm @@ -21,7 +21,8 @@ Tensor linear_mps_kernel( const Tensor& A, const Tensor& B, int64_t group_size, - const Tensor& SZ) { + const Tensor& S, + const Tensor& Z) { auto M = A.size(0); auto N = B.size(0); auto K = A.size(1); @@ -31,7 +32,9 @@ Tensor linear_mps_kernel( TORCH_CHECK( B.is_mps(), __func__, "B is on ", B.device(), " but expected on mps"); TORCH_CHECK( - SZ.is_mps(), __func__, "SZ is on ", SZ.device(), " but expected on mps"); + S.is_mps(), __func__, "S is on ", S.device(), " but expected on mps"); + TORCH_CHECK( + Z.is_mps(), __func__, "Z is on ", Z.device(), " but expected on mps"); TORCH_CHECK( A.dtype() == at::kBFloat16 || A.dtype() == at::kHalf || @@ -60,11 +63,18 @@ Tensor linear_mps_kernel( group_size); TORCH_CHECK( - SZ.dim() == 3 && SZ.size(1) == N && SZ.size(2) == 2, + S.dim() == 2 && S.size(1) == N, + __func__, + ": expect S to be 2d tensor with shape [:, ", + N, + "]"); + + TORCH_CHECK( + Z.dim() == 2 && Z.size(1) == N, __func__, - ": expect SZ to be 3d tensor with sizes [:, ", + ": expect Z to be 2d tensor with shape [:, ", N, - ", 2]"); + "]"); auto C = at::empty({M, N}, A.options()); @@ -72,7 +82,8 @@ Tensor linear_mps_kernel( getMTLBufferStorage(A), getMTLBufferStorage(B), group_size, - getMTLBufferStorage(SZ), + getMTLBufferStorage(S), + getMTLBufferStorage(Z), getMTLBufferStorage(C), M, K, @@ -109,19 +120,19 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { m.def("_pack_weight_6bit(Tensor W) -> Tensor"); m.def("_pack_weight_7bit(Tensor W) -> Tensor"); m.def( - "_linear_fp_act_1bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_1bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( - "_linear_fp_act_2bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_2bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( - "_linear_fp_act_3bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_3bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( - "_linear_fp_act_4bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_4bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( - "_linear_fp_act_5bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_5bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( - "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( - "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor SZ) -> Tensor"); + "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index c663451cd9..f2d9d9c175 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -52,40 +52,37 @@ def _init_tensors(self, group_size, M, K, N, nbit, device="mps"): W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device) S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01 Z = torch.randint( - -max_abs, - max_abs, + 0, + 2 * max_abs, (ceil_K_group_size, N), dtype=torch.float32, device=device, ) - SZ = torch.stack((S, Z), dim=2) - return A, W, SZ + return A, W, S, Z - def _reference_linear_lowbit_quant_weights(self, A, W, group_size, SZ, nbit): + def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit): # A is (M, K) # W is (N, K) - # SZ is (K // group_size, N, 2) + # S is (K // group_size, N) + # Z is (K // group_size, N) N = W.shape[0] K = W.shape[1] - max_abs = 1 << (nbit - 1) - W = W.to(torch.float32) - max_abs - scales = ( - SZ[:, :, 0].t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] - ) - zeros = SZ[:, :, 1].t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + W = W.to(torch.float32) + scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] W = scales * W + zeros return torch.mm(A, W.t()) @parameterized(cases) def test_linear(self, nbit, M=1, K=32, N=32, group_size=32): print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}") - A, W, SZ = self._init_tensors(group_size, M, K, N, nbit=nbit) + A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit) packing_op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") linear_op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") B = packing_op(W.cpu()).to("mps") - result = linear_op(A, B, group_size, SZ).cpu() + result = linear_op(A, B, group_size, S, Z).cpu() expected = self._reference_linear_lowbit_quant_weights( - A.cpu(), W.cpu(), group_size, SZ.cpu(), nbit=nbit + A.cpu(), W.cpu(), group_size, S.cpu(), Z.cpu(), nbit=nbit ) torch.testing.assert_close(result, expected, rtol=0.001, atol=0.001)