-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce lowbit quantized linear MPS kernels
Summary: The following is the directory structure of the submitted code under torchao ``` experimental/ ├── kernels/ │ └── mps/ │ ├── metal/ │ │ └── (metal shaders) │ ├── src/ │ │ └── (tensor agnostic mps kernel implementations) │ └── test/ │ │ └── (directly test mps kernel implementations) └── ops/ └── mps/ ├── register.mm ├── setup.py └── test/ └── (test torch custom ops) ``` Differential Revision: D63342895
- Loading branch information
1 parent
ceec750
commit 9ad395e
Showing
20 changed files
with
2,037 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
// dispatchThreads:MTLSizeMake(N, M, 1) | ||
|
||
template<unsigned nbit, typename T, unsigned groupSize> | ||
kernel void divbit_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B; | ||
|
||
constexpr uint8_t zero_shift = 1 << (nbit - 1); | ||
constexpr uint8_t inv_nbit = 8 / nbit; | ||
constexpr uint8_t minimask = (1 << nbit) - 1; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0]; | ||
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift); | ||
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { | ||
const auto a_val = float(A_ptr[k]); | ||
uint8_t b_val = B_ptr[(n * K + k) / inv_nbit]; | ||
uint8_t shift = nbit * (k % inv_nbit); | ||
uint8_t mask = minimask << shift; | ||
b_val = (b_val & mask) >> shift; | ||
rc += a_val * float(scale * T(b_val) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void divbit_mm<NBIT, DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_DIVBIT_MM(1, float, 32); | ||
INSTANTIATE_DIVBIT_MM(1, half, 32); | ||
INSTANTIATE_DIVBIT_MM(1, float, 64); | ||
INSTANTIATE_DIVBIT_MM(1, half, 64); | ||
INSTANTIATE_DIVBIT_MM(1, float, 128); | ||
INSTANTIATE_DIVBIT_MM(1, half, 128); | ||
INSTANTIATE_DIVBIT_MM(1, float, 256); | ||
INSTANTIATE_DIVBIT_MM(1, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 256); | ||
#endif | ||
|
||
INSTANTIATE_DIVBIT_MM(2, float, 32); | ||
INSTANTIATE_DIVBIT_MM(2, half, 32); | ||
INSTANTIATE_DIVBIT_MM(2, float, 64); | ||
INSTANTIATE_DIVBIT_MM(2, half, 64); | ||
INSTANTIATE_DIVBIT_MM(2, float, 128); | ||
INSTANTIATE_DIVBIT_MM(2, half, 128); | ||
INSTANTIATE_DIVBIT_MM(2, float, 256); | ||
INSTANTIATE_DIVBIT_MM(2, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 256); | ||
#endif | ||
|
||
INSTANTIATE_DIVBIT_MM(4, float, 32); | ||
INSTANTIATE_DIVBIT_MM(4, half, 32); | ||
INSTANTIATE_DIVBIT_MM(4, float, 64); | ||
INSTANTIATE_DIVBIT_MM(4, half, 64); | ||
INSTANTIATE_DIVBIT_MM(4, float, 128); | ||
INSTANTIATE_DIVBIT_MM(4, half, 128); | ||
INSTANTIATE_DIVBIT_MM(4, float, 256); | ||
INSTANTIATE_DIVBIT_MM(4, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 256); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
// dispatchThreads:MTLSizeMake(N, M, 1) | ||
|
||
template<typename T, unsigned groupSize> | ||
kernel void int1pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale; | ||
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { | ||
const auto a_val = float(A_ptr[k]); | ||
uint8_t b_val = B_ptr[(n * K + k) / 8]; | ||
uint8_t shift = (k % 8); | ||
uint8_t mask = 1 << shift; | ||
b_val = (b_val & mask) >> shift; | ||
rc += a_val * (scale * float(b_val) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int1pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT1MM(float, 32); | ||
INSTANTIATE_INT1MM(half, 32); | ||
INSTANTIATE_INT1MM(float, 64); | ||
INSTANTIATE_INT1MM(half, 64); | ||
INSTANTIATE_INT1MM(float, 128); | ||
INSTANTIATE_INT1MM(half, 128); | ||
INSTANTIATE_INT1MM(float, 256); | ||
INSTANTIATE_INT1MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT1MM(bfloat, 32); | ||
INSTANTIATE_INT1MM(bfloat, 64); | ||
INSTANTIATE_INT1MM(bfloat, 128); | ||
INSTANTIATE_INT1MM(bfloat, 256); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
// dispatchThreads:MTLSizeMake(N, M, 1) | ||
|
||
template<typename T, unsigned groupSize> | ||
kernel void int2pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = scalesAndZeros[(kb * N + n) * 2 + 0]; | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(2); | ||
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { | ||
const auto a_val = float(A_ptr[k]); | ||
uint8_t b_val = B_ptr[(n * K + k) / 4]; | ||
uint8_t shift = 2 * (k % 4); | ||
uint8_t mask = 3 << shift; | ||
b_val = (b_val & mask) >> shift; | ||
rc += a_val * (scale * float(b_val) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int2pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT2MM(float, 32); | ||
INSTANTIATE_INT2MM(half, 32); | ||
INSTANTIATE_INT2MM(float, 64); | ||
INSTANTIATE_INT2MM(half, 64); | ||
INSTANTIATE_INT2MM(float, 128); | ||
INSTANTIATE_INT2MM(half, 128); | ||
INSTANTIATE_INT2MM(float, 256); | ||
INSTANTIATE_INT2MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT2MM(bfloat, 32); | ||
INSTANTIATE_INT2MM(bfloat, 64); | ||
INSTANTIATE_INT2MM(bfloat, 128); | ||
INSTANTIATE_INT2MM(bfloat, 256); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
// dispatchThreads:MTLSizeMake(N, M, 1) | ||
|
||
template<typename T, unsigned groupSize> | ||
kernel void int3pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B + n * 3 * K / 8; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); | ||
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { | ||
const auto a_val0 = float(A_ptr[k + 0]); | ||
const auto a_val1 = float(A_ptr[k + 1]); | ||
const auto a_val2 = float(A_ptr[k + 2]); | ||
const auto a_val3 = float(A_ptr[k + 3]); | ||
const auto a_val4 = float(A_ptr[k + 4]); | ||
const auto a_val5 = float(A_ptr[k + 5]); | ||
const auto a_val6 = float(A_ptr[k + 6]); | ||
const auto a_val7 = float(A_ptr[k + 7]); | ||
|
||
uchar b0 = B_ptr[3 * (k / 8) + 0]; | ||
uchar b1 = B_ptr[3 * (k / 8) + 1]; | ||
uchar b2 = B_ptr[3 * (k / 8) + 2]; | ||
|
||
uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3); | ||
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2); | ||
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4); | ||
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6); | ||
|
||
uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3); | ||
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2); | ||
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4); | ||
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6); | ||
|
||
rc += a_val0 * (scale * float(w_val0) + zero); | ||
rc += a_val1 * (scale * float(w_val1) + zero); | ||
rc += a_val2 * (scale * float(w_val2) + zero); | ||
rc += a_val3 * (scale * float(w_val3) + zero); | ||
rc += a_val4 * (scale * float(w_val4) + zero); | ||
rc += a_val5 * (scale * float(w_val5) + zero); | ||
rc += a_val6 * (scale * float(w_val6) + zero); | ||
rc += a_val7 * (scale * float(w_val7) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int3pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT3MM(float, 32); | ||
INSTANTIATE_INT3MM(half, 32); | ||
INSTANTIATE_INT3MM(float, 64); | ||
INSTANTIATE_INT3MM(half, 64); | ||
INSTANTIATE_INT3MM(float, 128); | ||
INSTANTIATE_INT3MM(half, 128); | ||
INSTANTIATE_INT3MM(float, 256); | ||
INSTANTIATE_INT3MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT3MM(bfloat, 32); | ||
INSTANTIATE_INT3MM(bfloat, 64); | ||
INSTANTIATE_INT3MM(bfloat, 128); | ||
INSTANTIATE_INT3MM(bfloat, 256); | ||
#endif |
Oops, something went wrong.