diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index 178215595..450c61287 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -19,14 +19,21 @@ namespace { -// Benchmark utility to compare variants of uint1 packing -void pack_uint1_values( +// Benchmark utility to compare variants of odd bit packing +template < + typename pack_8_values_fn_type, + typename vec_pack_64_values_fn_type, + typename vec_pack_128_values_fn_type> +TORCHAO_ALWAYS_INLINE void pack_uint_odd_bit_values( + pack_8_values_fn_type pack_8_values_func, + vec_pack_64_values_fn_type vec_pack_64_values_func, + vec_pack_128_values_fn_type vec_pack_128_values_func, + const int nbit, uint8_t* packed, uint8_t* unpacked, int packed_size, int unpacked_size, int variant) { - constexpr int nbit = 1; constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); assert(packed_size % variant == 0); @@ -43,15 +50,14 @@ void pack_uint1_values( switch (variant) { case 8: for (int i = 0; i < unpacked_size; i += 8) { - torchao::bitpacking::internal::pack_8_uint1_values( - packed + ((i * nbit) / bitsPerByte), unpacked + i); + pack_8_values_func(packed + ((i * nbit) / bitsPerByte), unpacked + i); } break; case 64: for (int i = 0; i < unpacked_size; i += 64) { torchao::bitpacking::internal::vec_load_64_uint8_values( unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); - torchao::bitpacking::internal::vec_pack_64_uint1_values( + vec_pack_64_values_func( packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1, @@ -65,7 +71,7 @@ void pack_uint1_values( unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); torchao::bitpacking::internal::vec_load_64_uint8_values( unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64); - torchao::bitpacking::internal::vec_pack_128_uint1_values( + vec_pack_128_values_func( packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1, @@ -80,14 +86,21 @@ void pack_uint1_values( } } -// Benchmark utility to compare variants of uint1 packing -void unpack_uint1_values( +// Benchmark utility to compare variants of odd bit unpacking +template < + typename unpack_8_values_fn_type, + typename vec_unpack_64_values_fn_type, + typename vec_unpack_128_values_fn_type> +TORCHAO_ALWAYS_INLINE void unpack_uint_odd_bit_values( + unpack_8_values_fn_type unpack_8_values_func, + vec_unpack_64_values_fn_type vec_unpack_64_values_func, + vec_unpack_128_values_fn_type vec_unpack_128_values_func, + const int nbit, uint8_t* unpacked, uint8_t* packed, int unpacked_size, int packed_size, int variant) { - constexpr int nbit = 1; constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); assert(packed_size % variant == 0); @@ -104,13 +117,12 @@ void unpack_uint1_values( switch (variant) { case 8: for (int i = 0; i < unpacked_size; i += 8) { - torchao::bitpacking::internal::unpack_8_uint1_values( - unpacked + i, packed + ((i * nbit) / bitsPerByte)); + unpack_8_values_func(unpacked + i, packed + ((i * nbit) / bitsPerByte)); } break; case 64: for (int i = 0; i < unpacked_size; i += 64) { - torchao::bitpacking::internal::vec_unpack_64_uint1_values( + vec_unpack_64_values_func( unpacked0, unpacked1, unpacked2, @@ -122,7 +134,7 @@ void unpack_uint1_values( break; case 128: for (int i = 0; i < unpacked_size; i += 128) { - torchao::bitpacking::internal::vec_unpack_128_uint1_values( + vec_unpack_128_values_func( unpacked0, unpacked1, unpacked2, @@ -141,8 +153,67 @@ void unpack_uint1_values( } } +template +void pack_uint_values( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant); + +template +void unpack_uint_values( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant); + +// Benchmark utility to compare variants of uint1 packing +template <> +void pack_uint_values<1>( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 1; + pack_uint_odd_bit_values( + torchao::bitpacking::internal::pack_8_uint1_values, + torchao::bitpacking::internal::vec_pack_64_uint1_values, + torchao::bitpacking::internal::vec_pack_128_uint1_values, + nbit, + packed, + unpacked, + packed_size, + unpacked_size, + variant); +} + +// Benchmark utility to compare variants of uint1 unpacking +template <> +void unpack_uint_values<1>( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 1; + unpack_uint_odd_bit_values( + torchao::bitpacking::internal::unpack_8_uint1_values, + torchao::bitpacking::internal::vec_unpack_64_uint1_values, + torchao::bitpacking::internal::vec_unpack_128_uint1_values, + nbit, + unpacked, + packed, + unpacked_size, + packed_size, + variant); +} + // Benchmark utility to compare variants of uint2 packing -void pack_uint2_values( +template <> +void pack_uint_values<2>( uint8_t* packed, uint8_t* unpacked, int packed_size, @@ -206,7 +277,8 @@ void pack_uint2_values( } // Benchmark utility to compare variants of uint2 packing -void unpack_uint2_values( +template <> +void unpack_uint_values<2>( uint8_t* unpacked, uint8_t* packed, int unpacked_size, @@ -270,129 +342,50 @@ void unpack_uint2_values( } // Benchmark utility to compare variants of uint3 packing -void pack_uint3_values( +template <> +void pack_uint_values<3>( uint8_t* packed, uint8_t* unpacked, int packed_size, int unpacked_size, int variant) { constexpr int nbit = 3; - constexpr int bitsPerByte = 8; - assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); - - uint8x16_t unpacked0; - uint8x16_t unpacked1; - uint8x16_t unpacked2; - uint8x16_t unpacked3; - uint8x16_t unpacked4; - uint8x16_t unpacked5; - uint8x16_t unpacked6; - uint8x16_t unpacked7; - - switch (variant) { - case 8: - for (int i = 0; i < unpacked_size; i += 8) { - torchao::bitpacking::internal::pack_8_uint3_values( - packed + ((i * nbit) / bitsPerByte), unpacked + i); - } - break; - case 64: - for (int i = 0; i < unpacked_size; i += 64) { - torchao::bitpacking::internal::vec_load_64_uint8_values( - unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); - torchao::bitpacking::internal::vec_pack_64_uint3_values( - packed + ((i * nbit) / bitsPerByte), - unpacked0, - unpacked1, - unpacked2, - unpacked3); - } - break; - case 128: - for (int i = 0; i < unpacked_size; i += 128) { - torchao::bitpacking::internal::vec_load_64_uint8_values( - unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); - torchao::bitpacking::internal::vec_load_64_uint8_values( - unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64); - torchao::bitpacking::internal::vec_pack_128_uint3_values( - packed + ((i * nbit) / bitsPerByte), - unpacked0, - unpacked1, - unpacked2, - unpacked3, - unpacked4, - unpacked5, - unpacked6, - unpacked7); - } - break; - } + pack_uint_odd_bit_values( + torchao::bitpacking::internal::pack_8_uint3_values, + torchao::bitpacking::internal::vec_pack_64_uint3_values, + torchao::bitpacking::internal::vec_pack_128_uint3_values, + nbit, + packed, + unpacked, + packed_size, + unpacked_size, + variant); } // Benchmark utility to compare variants of uint3 unpacking -void unpack_uint3_values( +template <> +void unpack_uint_values<3>( uint8_t* unpacked, uint8_t* packed, int unpacked_size, int packed_size, int variant) { constexpr int nbit = 3; - constexpr int bitsPerByte = 8; - assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); - - uint8x16_t unpacked0; - uint8x16_t unpacked1; - uint8x16_t unpacked2; - uint8x16_t unpacked3; - uint8x16_t unpacked4; - uint8x16_t unpacked5; - uint8x16_t unpacked6; - uint8x16_t unpacked7; - - switch (variant) { - case 8: - for (int i = 0; i < unpacked_size; i += 8) { - torchao::bitpacking::internal::unpack_8_uint3_values( - unpacked + i, packed + ((i * nbit) / bitsPerByte)); - } - break; - case 64: - for (int i = 0; i < unpacked_size; i += 64) { - torchao::bitpacking::internal::vec_unpack_64_uint3_values( - unpacked0, - unpacked1, - unpacked2, - unpacked3, - packed + ((i * nbit) / bitsPerByte)); - torchao::bitpacking::internal::vec_store_64_uint8_values( - unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); - } - break; - case 128: - for (int i = 0; i < unpacked_size; i += 128) { - torchao::bitpacking::internal::vec_unpack_128_uint3_values( - unpacked0, - unpacked1, - unpacked2, - unpacked3, - unpacked4, - unpacked5, - unpacked6, - unpacked7, - packed + ((i * nbit) / bitsPerByte)); - torchao::bitpacking::internal::vec_store_64_uint8_values( - unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); - torchao::bitpacking::internal::vec_store_64_uint8_values( - unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7); - } - break; - } + unpack_uint_odd_bit_values( + torchao::bitpacking::internal::unpack_8_uint3_values, + torchao::bitpacking::internal::vec_unpack_64_uint3_values, + torchao::bitpacking::internal::vec_unpack_128_uint3_values, + nbit, + unpacked, + packed, + unpacked_size, + packed_size, + variant); } // Benchmark utility to compare variants of uint4 packing -void pack_uint4_values( +template <> +void pack_uint_values<4>( uint8_t* packed, uint8_t* unpacked, int packed_size, @@ -432,7 +425,8 @@ void pack_uint4_values( } // Benchmark utility to compare variants of uint4 unpacking -void unpack_uint4_values( +template <> +void unpack_uint_values<4>( uint8_t* unpacked, uint8_t* packed, int unpacked_size, @@ -472,285 +466,53 @@ void unpack_uint4_values( } // Benchmark utility to compare variants of uint5 packing -void pack_uint5_values( +template <> +void pack_uint_values<5>( uint8_t* packed, uint8_t* unpacked, int packed_size, int unpacked_size, int variant) { constexpr int nbit = 5; - constexpr int bitsPerByte = 8; - assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); - - uint8x16_t unpacked0; - uint8x16_t unpacked1; - uint8x16_t unpacked2; - uint8x16_t unpacked3; - uint8x16_t unpacked4; - uint8x16_t unpacked5; - uint8x16_t unpacked6; - uint8x16_t unpacked7; - - switch (variant) { - case 8: - for (int i = 0; i < unpacked_size; i += 8) { - torchao::bitpacking::internal::pack_8_uint5_values( - packed + ((i * nbit) / bitsPerByte), unpacked + i); - } - break; - case 64: - for (int i = 0; i < unpacked_size; i += 64) { - torchao::bitpacking::internal::vec_load_64_uint8_values( - unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); - torchao::bitpacking::internal::vec_pack_64_uint5_values( - packed + ((i * nbit) / bitsPerByte), - unpacked0, - unpacked1, - unpacked2, - unpacked3); - } - break; - case 128: - for (int i = 0; i < unpacked_size; i += 128) { - torchao::bitpacking::internal::vec_load_64_uint8_values( - unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); - torchao::bitpacking::internal::vec_load_64_uint8_values( - unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64); - torchao::bitpacking::internal::vec_pack_128_uint5_values( - packed + ((i * nbit) / bitsPerByte), - unpacked0, - unpacked1, - unpacked2, - unpacked3, - unpacked4, - unpacked5, - unpacked6, - unpacked7); - } - break; - } + pack_uint_odd_bit_values( + torchao::bitpacking::internal::pack_8_uint5_values, + torchao::bitpacking::internal::vec_pack_64_uint5_values, + torchao::bitpacking::internal::vec_pack_128_uint5_values, + nbit, + packed, + unpacked, + packed_size, + unpacked_size, + variant); } // Benchmark utility to compare variants of uint5 unpacking -void unpack_uint5_values( +template <> +void unpack_uint_values<5>( uint8_t* unpacked, uint8_t* packed, int unpacked_size, int packed_size, int variant) { constexpr int nbit = 5; - constexpr int bitsPerByte = 8; - assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); - - uint8x16_t unpacked0; - uint8x16_t unpacked1; - uint8x16_t unpacked2; - uint8x16_t unpacked3; - uint8x16_t unpacked4; - uint8x16_t unpacked5; - uint8x16_t unpacked6; - uint8x16_t unpacked7; - - switch (variant) { - case 8: - for (int i = 0; i < unpacked_size; i += 8) { - torchao::bitpacking::internal::unpack_8_uint5_values( - unpacked + i, packed + ((i * nbit) / bitsPerByte)); - } - break; - case 64: - for (int i = 0; i < unpacked_size; i += 64) { - torchao::bitpacking::internal::vec_unpack_64_uint5_values( - unpacked0, - unpacked1, - unpacked2, - unpacked3, - packed + ((i * nbit) / bitsPerByte)); - torchao::bitpacking::internal::vec_store_64_uint8_values( - unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); - } - break; - case 128: - for (int i = 0; i < unpacked_size; i += 128) { - torchao::bitpacking::internal::vec_unpack_128_uint5_values( - unpacked0, - unpacked1, - unpacked2, - unpacked3, - unpacked4, - unpacked5, - unpacked6, - unpacked7, - packed + ((i * nbit) / bitsPerByte)); - torchao::bitpacking::internal::vec_store_64_uint8_values( - unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); - torchao::bitpacking::internal::vec_store_64_uint8_values( - unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7); - } - break; - } + unpack_uint_odd_bit_values( + torchao::bitpacking::internal::unpack_8_uint5_values, + torchao::bitpacking::internal::vec_unpack_64_uint5_values, + torchao::bitpacking::internal::vec_unpack_128_uint5_values, + nbit, + unpacked, + packed, + unpacked_size, + packed_size, + variant); } } // namespace -static void benchmark_pack_uint1_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 1; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = std::vector(packed_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); - - for (auto _ : state) { - pack_uint1_values( - packed.data(), unpacked.data(), packed_size, unpacked_size, variant); - } -} - -static void benchmark_unpack_uint1_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 1; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = torchao::get_random_lowbit_vector(packed_size, 8); - auto unpacked = std::vector(unpacked_size, 0); - - for (auto _ : state) { - unpack_uint1_values( - unpacked.data(), - packed.data(), - unpacked.size(), - packed.size(), - variant); - } -} - -static void benchmark_pack_uint2_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 2; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = std::vector(packed_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); - - for (auto _ : state) { - pack_uint2_values( - packed.data(), unpacked.data(), packed_size, unpacked_size, variant); - } -} - -static void benchmark_unpack_uint2_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 2; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = torchao::get_random_lowbit_vector(packed_size, 8); - auto unpacked = std::vector(unpacked_size, 0); - - for (auto _ : state) { - unpack_uint2_values( - unpacked.data(), - packed.data(), - unpacked.size(), - packed.size(), - variant); - } -} - -static void benchmark_pack_uint3_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 3; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = std::vector(packed_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); - - for (auto _ : state) { - pack_uint3_values( - packed.data(), unpacked.data(), packed_size, unpacked_size, variant); - } -} - -static void benchmark_unpack_uint3_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 3; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = torchao::get_random_lowbit_vector(packed_size, 8); - auto unpacked = std::vector(unpacked_size, 0); - - for (auto _ : state) { - unpack_uint3_values( - unpacked.data(), - packed.data(), - unpacked.size(), - packed.size(), - variant); - } -} - -static void benchmark_pack_uint4_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 4; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = std::vector(packed_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); - - for (auto _ : state) { - pack_uint4_values( - packed.data(), unpacked.data(), packed_size, unpacked_size, variant); - } -} - -static void benchmark_unpack_uint4_values(benchmark::State& state) { - int unpacked_size = state.range(0); - int variant = state.range(1); - int nbit = 4; - - assert(unpacked_size % 8 == 0); - int packed_size = (unpacked_size / 8) * nbit; - - auto packed = torchao::get_random_lowbit_vector(packed_size, 8); - auto unpacked = std::vector(unpacked_size, 0); - - for (auto _ : state) { - unpack_uint4_values( - unpacked.data(), - packed.data(), - unpacked.size(), - packed.size(), - variant); - } -} - -static void benchmark_pack_uint5_values(benchmark::State& state) { +template +static void benchmark_pack_uint_values(benchmark::State& state) { int unpacked_size = state.range(0); int variant = state.range(1); - int nbit = 5; assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; @@ -759,15 +521,15 @@ static void benchmark_pack_uint5_values(benchmark::State& state) { auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { - pack_uint5_values( + pack_uint_values( packed.data(), unpacked.data(), packed_size, unpacked_size, variant); } } -static void benchmark_unpack_uint5_values(benchmark::State& state) { +template +static void benchmark_unpack_uint_values(benchmark::State& state) { int unpacked_size = state.range(0); int variant = state.range(1); - int nbit = 5; assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; @@ -776,7 +538,7 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) { auto unpacked = std::vector(unpacked_size, 0); for (auto _ : state) { - unpack_uint5_values( + unpack_uint_values( unpacked.data(), packed.data(), unpacked.size(), @@ -785,16 +547,16 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) { } } -BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}}); -BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}}); -BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); -BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); -BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); -BENCHMARK(benchmark_unpack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); -BENCHMARK(benchmark_pack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}}); -BENCHMARK(benchmark_unpack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}}); -BENCHMARK(benchmark_pack_uint5_values)->ArgsProduct({{128}, {8, 64, 128}}); -BENCHMARK(benchmark_unpack_uint5_values)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_pack_uint_values<1>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint_values<1>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_pack_uint_values<2>)->ArgsProduct({{128}, {4, 32, 64}}); +BENCHMARK(benchmark_unpack_uint_values<2>)->ArgsProduct({{128}, {4, 32, 64}}); +BENCHMARK(benchmark_pack_uint_values<3>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint_values<3>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_pack_uint_values<4>)->ArgsProduct({{128}, {2, 16, 32}}); +BENCHMARK(benchmark_unpack_uint_values<4>)->ArgsProduct({{128}, {2, 16, 32}}); +BENCHMARK(benchmark_pack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}}); // Run the benchmark BENCHMARK_MAIN();