From 3475aed3beaf4b137b6e937bcd821693f3e28f23 Mon Sep 17 00:00:00 2001 From: c4lcut3c <97532828+c4lcut3c@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:14:31 -0700 Subject: [PATCH] Use fewer instructions when unpacking uint6s. Differential Revision: D64548639 Pull Request resolved: https://github.com/pytorch/ao/pull/1109 --- .../kernels/cpu/aarch64/bitpacking/uint6.h | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h index 87712f7bcf..d15094ddfb 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h @@ -114,15 +114,20 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values( uint8x8_t packed1 = vld1_u8(packed + 8); uint8x8_t packed2 = vld1_u8(packed + 16); - // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) | - // ((packed[1] & 0b1100'0000u) >> 4) | - // ((packed[2] & 0b1100'0000u) >> 2); - const uint8x8_t high = vdup_n_u8(0b1100'0000u); uint8x8_t unpacked3; - unpacked3 = vorr_u8( - vshr_n_u8(vand_u8(packed0, high), 6), - vshr_n_u8(vand_u8(packed1, high), 4)); - unpacked3 = vorr_u8(unpacked3, vshr_n_u8(vand_u8(packed2, high), 2)); + // We want to extract bits 123456 and place them in unpacked3. + // Packed structure is: + // + // packed0: 56 | abcdef + // packed1: 34 | ghijkl + // packed2: 12 | mnopqr + // + // unpacked3 = 1234 ghij + unpacked3 = vsri_n_u8(packed2, packed1, 2); + // unpacked3 = 1234 56ab + unpacked3 = vsri_n_u8(unpacked3, packed0, 4); + // unpacked3 = 0012 3456 + unpacked3 = vshr_n_u8(unpacked3, 2); // unpacked[i] = packed[i] & 0b11'1111u; const uint8x8_t mask = vdup_n_u8(0b11'1111u); @@ -183,14 +188,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values( unpacked1 = vld1q_u8(packed + 16); unpacked2 = vld1q_u8(packed + 32); - // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) | - // ((packed[1] & 0b1100'0000u) >> 4) | - // ((packed[2] & 0b1100'0000u) >> 2); - const uint8x16_t high = vdupq_n_u8(0b1100'0000u); - unpacked3 = vorrq_u8( - vshrq_n_u8(vandq_u8(unpacked0, high), 6), - vshrq_n_u8(vandq_u8(unpacked1, high), 4)); - unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(unpacked2, high), 2)); + // We want to extract bits 123456 and place them in unpacked3. + // Packed structure is: + // + // packed0: 56 | abcdef + // packed1: 34 | ghijkl + // packed2: 12 | mnopqr + // + // unpacked3 = 1234 ghij + unpacked3 = vsriq_n_u8(unpacked2, unpacked1, 2); + // unpacked3 = 1234 56ab + unpacked3 = vsriq_n_u8(unpacked3, unpacked0, 4); + // unpacked3 = 0012 3456 + unpacked3 = vshrq_n_u8(unpacked3, 2); // unpacked[i] = packed[i] & 0b11'1111u; const uint8x16_t mask = vdupq_n_u8(0b11'1111u);