Skip to content

Commit

Permalink
Introduce 6-bit quantization for Llama in torchchat
Browse files Browse the repository at this point in the history
Differential Revision: D63792020

Pull Request resolved: #1007
  • Loading branch information
ramreddymounica authored Oct 3, 2024
1 parent 8945fb3 commit 9ce7ebb
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <cassert>

Expand All @@ -36,7 +37,7 @@ TORCHAO_ALWAYS_INLINE void pack_uint_odd_bit_values(
int variant) {
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);
assert(unpacked_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
Expand Down Expand Up @@ -103,7 +104,7 @@ TORCHAO_ALWAYS_INLINE void unpack_uint_odd_bit_values(
int variant) {
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);
assert(unpacked_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
Expand Down Expand Up @@ -222,7 +223,7 @@ void pack_uint_values<2>(
constexpr int nbit = 2;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);
assert(unpacked_size % variant == 0);

uint8x8_t unpacked0_8x8;
uint8x8_t unpacked1_8x8;
Expand Down Expand Up @@ -287,7 +288,7 @@ void unpack_uint_values<2>(
constexpr int nbit = 2;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);
assert(unpacked_size % variant == 0);

uint8x8_t unpacked0_8x8;
uint8x8_t unpacked1_8x8;
Expand Down Expand Up @@ -394,7 +395,7 @@ void pack_uint_values<4>(
constexpr int nbit = 4;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);
assert(unpacked_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
Expand Down Expand Up @@ -435,7 +436,7 @@ void unpack_uint_values<4>(
constexpr int nbit = 4;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);
assert(unpacked_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
Expand Down Expand Up @@ -507,6 +508,98 @@ void unpack_uint_values<5>(
variant);
}

// Benchmark utility to compare variants of uint6 packing
template <>
void pack_uint_values<6>(
uint8_t* packed,
uint8_t* unpacked,
int packed_size,
int unpacked_size,
int variant) {
constexpr int nbit = 6;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(unpacked_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;

switch (variant) {
case 4:
for (int i = 0; i < unpacked_size; i += 4) {
torchao::bitpacking::internal::pack_4_uint6_values(
packed + ((i * nbit) / bitsPerByte), unpacked + i);
}
break;
case 32:
for (int i = 0; i < unpacked_size; i += 32) {
unpacked0 = vld1q_u8(unpacked + i);
unpacked1 = vld1q_u8(unpacked + 16 + i);
torchao::bitpacking::internal::vec_pack_32_uint6_values(
packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
unpacked0 = vld1q_u8(unpacked + i);
unpacked1 = vld1q_u8(unpacked + 16 + i);
unpacked2 = vld1q_u8(unpacked + 32 + i);
unpacked3 = vld1q_u8(unpacked + 48 + i);
torchao::bitpacking::internal::vec_pack_64_uint6_values(
packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1, unpacked2, unpacked3);
}
break;
}
}

// Benchmark utility to compare variants of uint6 unpacking
template <>
void unpack_uint_values<6>(
uint8_t* unpacked,
uint8_t* packed,
int unpacked_size,
int packed_size,
int variant) {
constexpr int nbit = 6;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(unpacked_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;

switch (variant) {
case 4:
for (int i = 0; i < unpacked_size; i += 4) {
torchao::bitpacking::internal::unpack_4_uint6_values(
unpacked + i, packed + ((i * nbit) / bitsPerByte));
}
break;
case 32:
for (int i = 0; i < unpacked_size; i += 32) {
torchao::bitpacking::internal::vec_unpack_32_uint6_values(
unpacked0, unpacked1, packed + ((i * nbit) / bitsPerByte));
vst1q_u8(unpacked + i, unpacked0);
vst1q_u8(unpacked + 16 + i, unpacked1);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_unpack_64_uint6_values(
unpacked0, unpacked1, unpacked2, unpacked3, packed + ((i * nbit) / bitsPerByte));
vst1q_u8(unpacked + i, unpacked0);
vst1q_u8(unpacked + 16 + i, unpacked1);
vst1q_u8(unpacked + 32 + i, unpacked2);
vst1q_u8(unpacked + 48 + i, unpacked3);
}
break;
}
}

} // namespace

template <int nbit>
Expand Down Expand Up @@ -557,6 +650,8 @@ BENCHMARK(benchmark_pack_uint_values<4>)->ArgsProduct({{128}, {2, 16, 32}});
BENCHMARK(benchmark_unpack_uint_values<4>)->ArgsProduct({{128}, {2, 16, 32}});
BENCHMARK(benchmark_pack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_pack_uint_values<6>)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint_values<6>)->ArgsProduct({{128}, {4, 32, 64}});

// Run the benchmark
BENCHMARK_MAIN();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
6);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
Expand All @@ -248,6 +250,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
6);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
Expand All @@ -258,6 +262,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
6);

// Run the benchmark
BENCHMARK_MAIN();
41 changes: 35 additions & 6 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
#include <cassert>

namespace torchao {
Expand Down Expand Up @@ -80,7 +81,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(

// Currently supported values
static_assert(nbit >= 1);
static_assert(nbit <= 5);
static_assert(nbit <= 6);

// Shift unpacked values to nonnegative range
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
Expand Down Expand Up @@ -138,6 +139,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
torchao::bitpacking::internal::pack_8_uint5_values(
packed + 15, buffer5 + 24);
break;
case 6:
torchao::bitpacking::internal::vec_pack_32_uint6_values(
packed, shifted0, shifted1);
break;
default:
assert(false);
}
Expand All @@ -153,7 +158,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(

// Currently supported values
static_assert(nbit >= 1);
static_assert(nbit <= 5);
static_assert(nbit <= 6);

uint8x16_t shifted0;
uint8x16_t shifted1;
Expand Down Expand Up @@ -208,6 +213,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
shifted0 = vld1q_u8(buffer5);
shifted1 = vld1q_u8(buffer5 + 16);
break;
case 6:
torchao::bitpacking::internal::vec_unpack_32_uint6_values(
shifted0, shifted1, packed);
break;
default:
assert(false);
}
Expand All @@ -230,7 +239,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(

// Currently supported values
static_assert(nbit >= 1);
static_assert(nbit <= 5);
static_assert(nbit <= 6);

// Shift unpacked values to nonnegative range
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
Expand Down Expand Up @@ -262,6 +271,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
torchao::bitpacking::internal::vec_pack_64_uint5_values(
packed, shifted0, shifted1, shifted2, shifted3);
break;
case 6:
torchao::bitpacking::internal::vec_pack_64_uint6_values(
packed, shifted0, shifted1, shifted2, shifted3);
break;
default:
assert(false);
}
Expand All @@ -279,7 +292,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(

// Currently supported values
static_assert(nbit >= 1);
static_assert(nbit <= 5);
static_assert(nbit <= 6);

uint8x16_t shifted0;
uint8x16_t shifted1;
Expand Down Expand Up @@ -309,6 +322,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
torchao::bitpacking::internal::vec_unpack_64_uint5_values(
shifted0, shifted1, shifted2, shifted3, packed);
break;
case 6:
torchao::bitpacking::internal::vec_unpack_64_uint6_values(
shifted0, shifted1, shifted2, shifted3, packed);
break;
default:
assert(false);
}
Expand Down Expand Up @@ -337,7 +354,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(

// Currently supported values
static_assert(nbit >= 1);
static_assert(nbit <= 5);
static_assert(nbit <= 6);

// Shift unpacked values to nonnegative range
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
Expand Down Expand Up @@ -403,6 +420,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
shifted6,
shifted7);
break;
case 6:
torchao::bitpacking::internal::vec_pack_64_uint6_values(
packed, shifted0, shifted1, shifted2, shifted3);
torchao::bitpacking::internal::vec_pack_64_uint6_values(
packed + 48, shifted4, shifted5, shifted6, shifted7);
break;
default:
assert(false);
}
Expand All @@ -424,7 +447,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(

// Currently supported values
static_assert(nbit >= 1);
static_assert(nbit <= 5);
static_assert(nbit <= 6);

uint8x16_t shifted0;
uint8x16_t shifted1;
Expand Down Expand Up @@ -488,6 +511,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
shifted7,
packed);
break;
case 6:
torchao::bitpacking::internal::vec_unpack_64_uint6_values(
shifted0, shifted1, shifted2, shifted3, packed);
torchao::bitpacking::internal::vec_unpack_64_uint6_values(
shifted4, shifted5, shifted6, shifted7, packed + 48);
break;
default:
assert(false);
}
Expand Down
Loading

0 comments on commit 9ce7ebb

Please sign in to comment.