From daf06d97c94334b0d97d8b47d48bce1ba2412d26 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 02:23:27 -0500 Subject: [PATCH 1/5] [CPU] add int8 sdpa path for cpu --- test/quantization/test_sfdp_int8_fx_pass.py | 199 ++ torchao/csrc/cpu/sdpa.cpp | 2195 +++++++++++++++++++ torchao/csrc/cpu/toy.cpp | 20 + torchao/ops.py | 103 + torchao/quantization/__init__.py | 4 + torchao/quantization/sfdp_int8_fx_pass.py | 733 +++++++ 6 files changed, 3254 insertions(+) create mode 100644 test/quantization/test_sfdp_int8_fx_pass.py create mode 100644 torchao/csrc/cpu/sdpa.cpp create mode 100644 torchao/csrc/cpu/toy.cpp create mode 100644 torchao/quantization/sfdp_int8_fx_pass.py diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py new file mode 100644 index 0000000000..a39a98c364 --- /dev/null +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -0,0 +1,199 @@ +import torchao + +import contextlib +import functools +import itertools +import math + +import torch +import torch.utils.checkpoint +from torch._dynamo.debug_utils import aot_graph_input_parser +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, +) +from torchao.quantization.sfdp_int8_fx_pass import _sfdp_init_int8 + +class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + has_mask, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.softmax = torch.nn.Softmax(dim=-1) + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) + self.dropout = torch.nn.Dropout(0) + self.has_mask = has_mask + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute([0, 2, 1, 3]) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + if self.has_mask: + scores = scores + mask + attention = self.softmax(scores) + # attention = self.dropout(attention) + context_layer = torch.matmul(attention, v) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view( + context_layer.size()[:-2] + (self.all_head_size,) + ) + return self.dense(context_layer) + +def _generate_qdq_quantized_model(mod, inputs, quantizer): + with torch.no_grad(): + export_model = capture_pre_autograd_graph(mod, inputs) + prepare_model = prepare_pt2e(export_model, quantizer) + prepare_model(*inputs) + convert_model = convert_pt2e(prepare_model) + torch.ao.quantization.move_exported_model_to_eval(convert_model) + return convert_model + +class TestSDPAPatternRewriterTemplate(TestCase): + def _clone_inputs(self, inputs): + def clone(x): + if not isinstance(x, torch.Tensor): + return x + return x.clone() + + return [clone(x) for x in inputs] + + def _check_common( + self, + dot_prod_attention, + args1=None, + contains=True, + atol=1e-5, + has_fuse_pattern=True, + has_dropout=False, + check_train=True, + override_check_equal=False, + dtype=torch.float, + rtol=1.3e-6, + ): + if args1 is None: + tensor_shape = (4, 2, 16, 32) + args1 = [ + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + ] + else: + args1 = list(args1) + args2 = self._clone_inputs(args1) + + for training in [False, True] if check_train else [False]: + for x in itertools.chain(args1[:], args2[:]): + if isinstance(x, torch.Tensor) and x.is_floating_point(): + x.requires_grad = training + + dropout_arg = [training] if has_dropout else [] + torch.manual_seed(1234) + result1 = dot_prod_attention(*(args1 + dropout_arg)) + + counters.clear() + torch.manual_seed(1234) + result2, source_code = run_and_get_code( + torch.compile(dot_prod_attention, fullgraph=True), + *(args2 + dropout_arg), + ) + source_code = "\n".join(source_code) + if has_fuse_pattern: + self.assertGreaterEqual(counters["inductor"]["fuse_attention_int8"], 1) + if contains: + # many of the patterns get re-expanded in dispatcher + self.assertIn( + "torchao.scaled_dot_product_int8", + source_code, + ) + + # some tests configured with very low dropout where we still want to check equality + if not has_dropout or override_check_equal: + self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) + + if training: + result1.sum().backward() + result2.sum().backward() + for arg1, arg2 in zip(args1, args2): + if ( + isinstance(arg1, torch.Tensor) + and arg1.is_floating_point() + and (not has_dropout or override_check_equal) + ): + self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) + + @skipIfRocm + @config.patch({"freezing": True}) + def _test_sdpa_rewriter_int8_1_to_4(self): + # pattern is different for bs=1 + for dtype, has_mask, bs in itertools.product( + [torch.float32], [True, False], [56, 1] + ): + mod = SelfAttnLikeModule( + input_dim=64 * 16, + has_mask=has_mask, + num_attention_heads=16, + attention_head_size=64, + ).eval() + maybe_autocast = ( + torch.cpu.amp.autocast() + if dtype == torch.bfloat16 + else contextlib.nullcontext() + ) + inputs = [ + torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype), + torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None, + ] + with torch.no_grad(), maybe_autocast: + _sfdp_init_int8() + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + convert_model = _generate_qdq_quantized_model(mod, inputs, quantizer) + self._check_common( + convert_model, args1=inputs, check_train=False, atol=1.0 + ) + +if HAS_CPU: + class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): + device = "cpu" + test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4 + +if __name__ == "__main__": + if IS_LINUX: + run_tests() diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp new file mode 100644 index 0000000000..44aaef1bcc --- /dev/null +++ b/torchao/csrc/cpu/sdpa.cpp @@ -0,0 +1,2195 @@ +// // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// #include +// #include + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// // #include +// // #include +// #include +// #include + +// #ifndef AT_PER_OPERATOR_HEADERS +// #include +// #else +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +template +struct is_reduced_floating_point: + std::integral_constant || + std::is_same_v> { +}; + +template +constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; + +// out = val * a + b +// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), +// take b as a scalar pointer. +template +void _scale_attn_mask_fusion_kernel( + T1* a, + T2* b, + const int& size, + T1* out, + T1& val) { + const auto vec_size1 = at::vec::Vectorized::size(); + const auto vec_size2 = at::vec::Vectorized::size(); + constexpr int64_t T1_n = + (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v) ? 2 : 1; + constexpr int64_t T2_n = 1; + auto vec_scale = at::vec::VectorizedN(val); + int64_t i = 0; + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + at::vec::VectorizedN b_n; + if constexpr(is_b_stride_zero) { + b_n = at::vec::VectorizedN((T1)b[0]); + } else { + b_n = at::vec::VectorizedN::loadu(b + i); + } + auto b_n_convert = at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + T1 tmp1; + if constexpr(is_b_stride_zero) { + tmp1 = (T1)b[0]; + } else { + tmp1 = (T1)b[i]; + } + out[i] = tmp0 * val + tmp1; + } +} + +// 1) out = exp(a - val) +// 2) val = sum(out) +template +void _exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(out + i, tmp2); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +void _mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + _store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + // max = std::max( + // tmp_max, + // at::vec::vec_reduce_all( + // [](vec::Vectorized& x, at::vec::Vectorized& y) { + // return at::vec::maximum(x, y); + // }, + // vec_tmp_max)); + max = std::max(tmp_max, vec_tmp_max.reduce_max()); +} + +template +static scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { + TORCH_CHECK(ptr2 == nullptr); + return ptr; +} + +template , int> = 0> +static scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { + return ptr2; +} + +template +void fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (; d < size; d++) { + data[d] = val; + } +} + +void reshape_attn_mask_to_4d( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + // Support mask shapes: + // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) + // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) + // Guaranteed in check_attn_mask_shape + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + +// TODO: Use at::native::_store instead when it supports Half. +template +void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { + src.store(dst, size); +} + +template +typename std::enable_if_t, void> +_store(scalar_t* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_from_float(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +template +typename std::enable_if_t || std::is_same_v, void> +_store(scalar_t* dst, at::vec::Vectorized src) { + auto res = at::vec::convert(src); + res.store(dst, at::vec::Vectorized::size()); +} + +template +void pad_row_zero( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi) { + auto vec_size = at::vec::Vectorized::size(); + int i = 0; + for (; i < rows - 1; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } + + // zero padding + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = at::vec::Vectorized(0); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized(0); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } +} + +template +void pad_row_128_padding( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi, + int padding) { + auto vec_size = at::vec::Vectorized::size(); + int i = 0; + for (; i < rows - padding; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } + + // 128 padding + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = at::vec::Vectorized(128); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized(128); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } +} + +template +void pad_col_zero( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi) { + auto vec_size = at::vec::Vectorized::size(); + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < cols - 1 - ((cols - 1) % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + if (j < cols - 1) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - 1 - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - 1 - j); + *(padding_value_ptr + i * cols + cols - 1) = scalar_t(0); + } + } +} + +template +void pad_col_zero_padding( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi, + int padding) { + auto vec_size = at::vec::Vectorized::size(); + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < cols - padding - ((cols - padding) % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + if (j < cols - padding) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - padding - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - padding - j); + *(padding_value_ptr + i * cols + cols - padding) = scalar_t(0); + } + } +} + +/* +1. dequant +2. add mask +3. max reduce for softmax +*/ +template +void _dequant_mask_max_fusion_kernel( + const int32_t* in, + const mask_t* mask_ptr, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldm, // leading dimension mask + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + const mask_t* mask_data_ptr = mask_ptr + row * ldm; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); + auto tmp7 = at::vec::convert(tmp6); + auto tmp8 = tmp5 + tmp7; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp8); + _store(tmp_out + col, tmp8); + } + tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + auto tmp6 = mask_data_ptr[col]; + auto tmp7 = (float) tmp6; + auto tmp8 = tmp5 + tmp7; + tmp_max = std::max(tmp_max, tmp8); + tmp_out[col] = tmp8; + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + } +} + +/* +1. dequant +2. max reduce for softmax +*/ +void _dequant_max_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp5); + _store(tmp_out + col, tmp5); + } + tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + tmp_max = std::max(tmp_max, tmp5); + tmp_out[col] = tmp5; + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + } +} + +/* +1. Softmax: sub max, exp, sum reduce, div sum +2. quant +3. sum for attention +*/ +template +void _sub_exp_sum_div_quant_sum_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const int32_t& beta2, // zp_b + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr, + int32_t* sum_a_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sfm_max; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + tmp_out[col] = tmp2; + } + sfm_sum_ptr[row] += tmp_sum; + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::maximum(tmp3, vec_min_val); + auto tmp5 = at::vec::minimum(tmp4, vec_max_val); + _store(tmp_out + col, tmp5); + auto tmp6 = at::vec::convert(tmp5); + vec_tmp_sum += tmp6; + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 * sum_scale; + auto tmp2 = std::nearbyint(tmp1); + auto tmp3 = tmp2 + beta1_float; + auto tmp4 = std::max(tmp3, min_val); + auto tmp5 = std::min(tmp4, max_val); + tmp_out[col] = tmp5; + auto tmp6 = (int32_t) tmp5; + tmp_sum += tmp6; + } + sum_a_ptr[row] += tmp_sum * beta2; + // set zero + for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { + tmp_out[col] = zero; + } + } + } +} + +template +void _sub_exp_sum_div_quant_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sfm_max; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + tmp_out[col] = tmp2; + } + sfm_sum_ptr[row] += tmp_sum; + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::maximum(tmp3, vec_min_val); + auto tmp5 = at::vec::minimum(tmp4, vec_max_val); + _store(tmp_out + col, tmp5); + } + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 * sum_scale; + auto tmp2 = std::nearbyint(tmp1); + auto tmp3 = tmp2 + beta1_float; + auto tmp4 = std::max(tmp3, min_val); + auto tmp5 = std::min(tmp4, max_val); + tmp_out[col] = tmp5; + } + // set zero + for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { + tmp_out[col] = zero; + } + } + } +} + +/* +1. dequant +2. quant +*/ +template +void _dequant_quant_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta1, // zp_a*zp_b*k + const int32_t& beta2, // zp_c + const float& alpha, // scale_a*scale_b/scale_c + scalar_t* out) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_beta1 = at::vec::Vectorized(beta1); + auto vec_alpha = at::vec::Vectorized(alpha); + float beta2_float = (float) beta2; + auto vec_beta2 = at::vec::Vectorized(beta2_float); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + scalar_t* tmp_out = out + row * ldo; + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::maximum(tmp7, vec_min_val); + auto tmp9 = at::vec::minimum(tmp8, vec_max_val); + _store(tmp_out + col, tmp9); + } + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta1; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + auto tmp6 = std::nearbyint(tmp5); + auto tmp7 = tmp6 + beta2_float; + auto tmp8 = std::max(tmp7, min_val); + auto tmp9 = std::min(tmp8, max_val); + tmp_out[col] = tmp9; + } + } +} + +template +void _int_sum_b_contiguous_kernel_helper( + const scalar_t* in, + int32_t* out, + const int& N, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (N / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(in + i); + auto tmp1 = at::vec::convert(tmp0); + vec_tmp_sum = vec_tmp_sum + tmp1; + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long i = vec_size * (N / vec_size); i < N; i++) { + // for (long i = 0; i < N; i++) { + tmp_sum += static_cast(in[i]); + } + out[0] = tmp_sum * scale; +} + +template +void _int_sum_b_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + for (long r = 0; r < M; r += 1) { + _int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); + } +} + +template +void _int_sum_a_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + // initialization with 0 + int32_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + _store(out + i, vec_zero); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + out[i] = zero; + } + // sum + for (long j = 0; j < N; j++) { + const scalar_t* tmp_in = in + j * ld; + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + i); + auto tmp1 = at::vec::Vectorized::loadu(out + i); + auto tmp2 = at::vec::convert(tmp0); + auto tmp3 = tmp1 + tmp2; + _store(out + i, tmp3); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + // for (long i = 0; i < M; i++) { + auto tmp0 = tmp_in[i]; + auto tmp1 = out[i]; + auto tmp2 = static_cast(tmp0); + auto tmp3 = tmp1 + tmp2; + out[i] = tmp3; + } + } + // scale + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(out + i); + auto tmp1 = tmp0 * vec_scale; + _store(out + i, tmp1); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + auto tmp0 = out[i]; + auto tmp1 = tmp0 * scale; + out[i] = tmp1; + } +} + +void do_convert_u8_s8( + unsigned char* src, + signed char* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_128 = at::vec::Vectorized(128); + for (int64_t r = 0; r < in_rows; r++) { + const unsigned char* tmp_src = src + r * ldi; + signed char* tmp_dst = dst + r * ldo; + for (int64_t c = 0; c < vec_size * (in_cols / vec_size); c += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_src + c, vec_size); + auto tmp1 = at::vec::convert(tmp0); + auto tmp2 = tmp1 - vec_128; + auto tmp3 = at::vec::convert(tmp2); + _store(tmp_dst + c, tmp3, vec_size); + } + for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { + // for (int64_t c = 0; c < in_cols; c++) { + auto tmp0 = tmp_src[c]; + auto tmp1 = (int16_t) tmp0; + auto tmp2 = tmp1 - 128; + auto tmp3 = (signed char) tmp2; + tmp_dst[c] = tmp3; + } + } +} + +template +void do_transpose( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r +void do_copy( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r +void pad_remain_row_col( + scalar_t* value_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + auto psize = pcols - cols; + if (psize == 0 && prows == rows) { + return; + } + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + if (psize > 0) { + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < psize - (psize % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + cols + j); + } + if (j < psize) { + pad.store(value_ptr + i * ldi + cols + j, psize - j); + } + } + } + + for (int i = rows; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + j); + } + if (j < pcols) { + pad.store(value_ptr + i * ldi + j, pcols - j); + } + } +} + +template +void copy_value_with_pad( + scalar_t* value_ptr, + scalar_t* dst_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + int i = 0; + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + int pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + pad.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + pad.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + + // row padding + for (; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + pad.store(dst_ptr + i * pcols + j, pcols - j); + } + + } + +} + +// thread_local std::unordered_map< +// BrgemmKey, +// std::shared_ptr> cache_brgemm_kernels; + +// thread_local std::unordered_map< +// PackBKey, +// std::shared_ptr> cache_packb_kernels; + +// std::shared_ptr create_or_get_microkernel( +// int64_t M, +// int64_t N, +// int64_t K, +// int64_t batch_size, +// int lda, +// int ldb, +// int ldc, +// dt dt_a, +// dt dt_b, +// dt dt_c) { +// BrgemmKey key_brgemm(M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c); +// auto search = cache_brgemm_kernels.find(key_brgemm); +// if (search != cache_brgemm_kernels.end()) { +// return search->second; +// } else { +// cache_brgemm_kernels.insert( +// {key_brgemm, +// std::make_shared( +// M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c)}); +// return cache_brgemm_kernels[key_brgemm]; +// } +// } + +// std::shared_ptr create_or_get_packb_microkernel( +// int64_t K, +// int64_t N, +// int ld_in, +// int ld_out, +// dt dt_in, +// dt dt_out, +// bool do_trans) { +// PackBKey key_packb(K, N, ld_in, ld_out, dt_in, dt_out); +// auto search = cache_packb_kernels.find(key_packb); +// if (search != cache_packb_kernels.end()) { +// return search->second; +// } else { +// cache_packb_kernels.insert( +// {key_packb, +// std::make_shared( +// K, N, +// do_trans ? dnnl::ukernel::pack_type::trans : dnnl::ukernel::pack_type::no_trans, +// ld_in, ld_out, dt_in, dt_out)}); +// return cache_packb_kernels[key_packb]; +// } +// } + +// UINT8 - u8u8s32 +template +typename std::enable_if_t, void> +sdpa_int8_kernel_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + at::Tensor& attention_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // using dt = dnnl::memory::data_type; + // using namespace dnnl; + // using namespace dnnl::ukernel; + // auto starts = duration_cast(system_clock::now().time_since_epoch()).count(); + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + const auto accumulate_dtype = at::kFloat; // at::toOpMathType(dtype); + + using accum_t = float; // at::opmath_type; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = + sdp::calculate_scale(query, scale).as_float_unchecked(); + // if (attention_mask.defined() && attention_mask.scalar_type() != ScalarType::Float) { + // attention_mask = attention_mask.to(at::kFloat); + // } + int block_64 = 64; + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + + bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (attention_mask.defined() && attention_mask.size(0) > 1) + ? attention_mask.stride(0) + : 0; + int64_t mStrideH = + (attention_mask.defined() && attention_mask.size(1) > 1) + ? attention_mask.stride(1) + : 0; + int64_t mStrideM = + (attention_mask.defined() && attention_mask.size(2) > 1) + ? attention_mask.stride(2) + : 0; + int64_t mStrideN = + (attention_mask.defined() && attention_mask.size(3) > 1) + ? attention_mask.stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qTail = (qSize - 1) % qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + // one of 16, 32, 48, 64 + auto select_tail_tail_block_size = [](int64_t size) -> int64_t { + if (size == 0) { + return 0; + } else if (size <= 16) { + return 16; + } else if (size <= 32) { + return 32; + } else if (size <= 48) { + return 48; + } else { + return 64; + } + }; + int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; + int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; + int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; + + + // dt u8_dt = dt::u8; + // dt s8_dt = dt::s8; + // // dt f32_dt = dt::f32; + // dt s32_dt = dt::s32; + auto u8_dt = at::ScalarType::Byte; + auto s8_dt = at::ScalarType::Int; + auto f32_dt = at::ScalarType::Float; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.defined() + ? attention_mask.data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + // Create tpp kernels for Query @ Key + bool headSize_mul4 = headSize % 4 == 0; + // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + // auto && qk_gemm = create_or_get_microkernel( + // qSplitSize, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // block_64, //ldb + // rndkvSplitSize, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // (*qk_gemm).finalize(); + // (*qk_gemm).generate(); + // size_t qk_scratchpad_size = (*qk_gemm).get_scratchpad_size(); + + // auto && qk_gemm_ktail = create_or_get_microkernel( + // qSplitSize, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // block_64, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // // size_t qk_ktail_scratchpad_size = (*qk_gemm_ktail).get_scratchpad_size(); + // (*qk_gemm_ktail).finalize(); + // (*qk_gemm_ktail).generate(); + + // std::shared_ptr qk_gemm_ktail_tail; + // if (kvTail % block_64 != 0) { + // qk_gemm_ktail_tail = create_or_get_microkernel( + // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // kv_tail_tail_block_size, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // (*qk_gemm_ktail_tail).finalize(); + // (*qk_gemm_ktail_tail).generate(); + // } + + // auto && qk_gemm_qtail = create_or_get_microkernel( + // qTail, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda + // block_64, //ldb + // rndkvSplitSize, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // // size_t qk_qtail_scratchpad_size = (*qk_gemm_qtail).get_scratchpad_size(); + // (*qk_gemm_qtail).finalize(); + // (*qk_gemm_qtail).generate(); + // auto && qk_gemm_qktail = create_or_get_microkernel( + // qTail, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // block_64, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // // size_t qk_qktail_scratchpad_size = (*qk_gemm_qktail).get_scratchpad_size(); + // (*qk_gemm_qktail).finalize(); + // (*qk_gemm_qktail).generate(); + + // std::shared_ptr qk_gemm_qktail_tail; + // if (kvTail % block_64 != 0) { + // qk_gemm_qktail_tail = create_or_get_microkernel( + // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // kv_tail_tail_block_size, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // (*qk_gemm_qktail_tail).finalize(); + // (*qk_gemm_qktail_tail).generate(); + // } + + // std::vector> A_B_offsets(1); + std::vector> A_B_offsets(1); + A_B_offsets[0] = std::make_pair(0, 0); + + // std::vector> A_B_offsets_batch(kvSlice); + std::vector> A_B_offsets_batch(kvSlice); + for (auto s=0; s(); + + int64_t kv_sum_size_per_BH = + /* key_sum */ kvSize + + /* value_sum */ headSize; + + at::Tensor kv_sum_buf = at::empty( + {batchSize, num_head, kv_sum_size_per_BH}, + query.options().dtype(at::kInt)); + int32_t* k_sum_buf_data = kv_sum_buf.data_ptr(); + int32_t* v_sum_buf_data = k_sum_buf_data + batchSize * num_head * kvSize; + + int64_t kv_reorder_size_per_BH = + /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* value_t_reorder */ kvSlice * av_gemm_K * rndHeadSize; + + at::Tensor kv_reorder_buf = at::empty( + {batchSize, num_head, kv_reorder_size_per_BH}, + query.options()); + scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); + scalar_t* key_reorder_ptr = kv_reorder_buf_data; + scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; + +// // Create transforms for Key +// auto && brgemm_k_xform = create_or_get_packb_microkernel( +// qk_gemm_K, // K +// block_64, // N +// block_64, // kStrideN, // block_64, // ld_in +// block_64, // ld_out +// u8_dt, // dt_in +// u8_dt, // dt_out +// false // true +// ); +// (*brgemm_k_xform).generate(); +// auto && brgemm_k_xform_tail = create_or_get_packb_microkernel( +// qk_gemm_K, +// block_64, +// block_64, // kStrideN, // block_64, +// block_64, +// u8_dt, +// u8_dt, +// false // true +// ); +// (*brgemm_k_xform_tail).generate(); +// std::shared_ptr brgemm_k_xform_tail_tail; +// if (kvTail % block_64 != 0) { +// brgemm_k_xform_tail_tail = create_or_get_packb_microkernel( +// qk_gemm_K, +// kv_tail_tail_block_size, +// kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, +// kv_tail_tail_block_size, +// u8_dt, +// u8_dt, +// false // true +// ); +// (*brgemm_k_xform_tail_tail).generate(); +// } + +// // Create transforms for Value +// auto && brgemm_v_xform = create_or_get_packb_microkernel( +// av_gemm_K, +// block_64, +// vStrideN, // block_64, +// block_64, +// u8_dt, +// u8_dt, +// false +// ); +// (*brgemm_v_xform).generate(); +// auto && brgemm_v_xform_tail = create_or_get_packb_microkernel( +// av_gemm_K_tail, +// block_64, +// vStrideN, // block_64, +// block_64, +// u8_dt, +// u8_dt, +// false +// ); +// (*brgemm_v_xform_tail).generate(); + + // sum k + if (q_zp != 0) { + at::parallel_for( + 0, batchSize * num_head * kvSize, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, k, kvSize); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + int32_t* k_sum_ptr = k_sum_buf_data + + i * num_head * kvSize + + j * kvSize + k; + _int_sum_b_contiguous_kernel_helper( + k_data + i * kStrideB + j * kStrideH + k * kStrideN, + k_sum_ptr, + headSize, q_zp); + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, kvSize); + } + }); + } + + // sum v + if (a_zp != 0) { + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + int32_t* v_sum_ptr = v_sum_buf_data + + i * num_head * headSize + + j * headSize; + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + } + + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qSplitSize * av_gemm_K; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * qk_gemm_K; + // scalar_t* scratchpad_gemm = reinterpret_cast(total_buf_ptr + offset); + // offset += scratchpad_size; + + scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qk_gemm_K * rndkvSize; + scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + // pack + for (int64_t n = 0; n < kvSize; n += kvSplitSize) { + // long ss, ee; + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + bool tail = kvSplitSize - b < block_64; + do_transpose( + // do_copy( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvSplitSize - b : block_64, + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + // Pack + // (*brgemm_k_xform).execute( + // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + // B_blocked_xform_u8, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K + // ); + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + // (*brgemm_v_xform).execute( + // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + // // B_blocked_xform_u8, + // value_reorder_ptr + n * rndHeadSize + + // av_gemm_K * b); + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + bool tail = kvTail - b < block_size; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvTail - b : block_size, + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + if (block_size == block_64) { + // (*brgemm_k_xform_tail).execute( + // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + // B_blocked_xform_u8, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K + // ); + at::native::cpublas::pack( + qk_gemm_K, + block_64, + block_64, // kStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } else { + // (*brgemm_k_xform_tail_tail).execute( + // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + // B_blocked_xform_u8, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K + // ); + at::native::cpublas::pack( + qk_gemm_K, + kv_tail_tail_block_size, + kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, + kv_tail_tail_block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + // (*brgemm_v_xform).execute( + // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + // // B_blocked_xform_u8, + // value_reorder_ptr + n * rndHeadSize + + // av_gemm_K * b); + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } + } + + // sdpa core + int32_t* k_sum_ptr = k_sum_buf_data + i * num_head * kvSize + j * kvSize; + int32_t* v_sum_ptr = v_sum_buf_data + i * num_head * headSize + j * headSize; + for (int64_t k = 0; k < qSlice; k++) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp == 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm).set_hw_context(); + // (*qk_gemm).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_ktail).set_hw_context(); + // (*qk_gemm_ktail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + block_64, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } else { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_ktail_tail).set_hw_context(); + // (*qk_gemm_ktail_tail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + kv_tail_tail_block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_qtail).set_hw_context(); + // (*qk_gemm_qtail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_qktail).set_hw_context(); + // (*qk_gemm_qktail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + block_64, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } else { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_qktail_tail).set_hw_context(); + // (*qk_gemm_qktail_tail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + kv_tail_tail_block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*av_gemm_batch).set_hw_context(); + // (*av_gemm_batch).execute( + // qk_reduced_data, + // value_reorder_ptr + b * av_gemm_K, + // A_B_offsets_batch, + // dst_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + kvSlice, //batch_size + av_gemm_K, // lda + rndHeadSize, //block_64, //ldb + rndHeadSize, //ldc + false, + qk_reduced_data, + value_reorder_ptr + b * av_gemm_K, + dst_s32_data + b, + A_B_offsets_batch); + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + // Once all computations are done, need to release HW context. + // brgemm::release_hw_context(); + at::native::cpublas::brgemm_release(); +} + +#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Bool, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Double, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, mask_t, __VA_ARGS__)) + +void sdpa_int8_kernel( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attn_mask, + double scale, + long q_zp, + double q_scale, + long k_zp, + double k_scale, + long v_zp, + double v_scale, + long a_zp, + double a_scale, + long o_zp, + double o_scale) { + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t q_seq_len = query.size(2); + + TORCH_CHECK(query.scalar_type() == c10::kByte); + if (!attn_mask.defined()) { + if (q_seq_len >= 768) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (q_seq_len >= 768) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + }); + } +} + +// at::Tensor sdpa_int8_math_impl( +// const at::Tensor& query_, +// const at::Tensor& key, +// const at::Tensor& value, +// double dropout_p, +// bool is_causal, +// at::Tensor& attn_mask_, +// double scale, +// int32_t q_zp, +// float q_scale, +// int32_t k_zp, +// float k_scale, +// int32_t v_zp, +// float v_scale, +// int32_t a_zp, +// float a_scale, +// int32_t o_zp, +// float o_scale) { +// // dequant q/k/v +// auto q = (query_.to(at::kFloat) - q_zp) * q_scale; +// auto k = (key.to(at::kFloat) - k_zp) * k_scale; +// auto v = (value.to(at::kFloat) - v_zp) * v_scale; +// auto attn_mask = attn_mask_; +// if (attn_mask.defined()) { +// *attn_mask = (*attn_mask).to(at::kFloat); +// } +// // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math +// bool is_negative_scaling = scale.defined() && scale < 0.0; +// const auto scaling_factor = sdp::calculate_scale(q, is_negative_scaling ? std::abs(scale) : scale).sqrt(); +// q = q * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor); +// auto attn = at::matmul(q, k.transpose(-2, -1) * scaling_factor); +// if (attn_mask.defined()) { +// if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { +// attn = attn.add(*attn_mask); +// } else { +// attn.add_(*attn_mask); +// } +// } +// attn = at::softmax(attn, -1); +// // quant attn +// attn = at::clamp_max( +// at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 +// ); +// // dequant attn +// attn = (attn - a_zp) * a_scale; +// auto output = at::matmul(attn, v); +// // quant output +// output = at::clamp_max( +// at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 +// ).to(at::kByte); +// return output; +// } + + +at::Tensor _scaled_dot_product_int8_cpu( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + at::Tensor& attn_mask, + // const std::optional& attn_mask, + double dropout_p, + bool is_causal, + double scale, + // std::optional scale, + int64_t q_zp, + double q_scale, + int64_t k_zp, + double k_scale, + int64_t v_zp, + double v_scale, + int64_t a_zp, + double a_scale, + int64_t o_zp, + double o_scale) { + const auto dtype = query.scalar_type(); + TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), + "_scaled_dot_product_int8_cpu: Only accept plain inputs"); + TORCH_CHECK(!is_causal, + "_scaled_dot_product_int8_cpu: is_causal not supported."); + TORCH_CHECK(dtype == at::ScalarType::Byte, + "_scaled_dot_product_int8_cpu: Expected data type be U8, but got ", dtype, " instead."); + TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "_scaled_dot_product_int8_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); + TORCH_CHECK(dropout_p == 0.0, + "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); + TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); + TORCH_CHECK(!attn_mask.defined() || + attn_mask.scalar_type() == at::kFloat || + attn_mask.scalar_type() == at::kBFloat16, + "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); + TORCH_CHECK(!attn_mask.defined() || + (attn_mask.dim() == 2 || attn_mask.dim() == 4), + "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); + + // fallback math path + // at::Tensor output = sdpa_int8_math_impl(query, key, value, + // dropout_p, is_causal, attn_mask, scale, + // q_zp, q_scale, + // k_zp, k_scale, + // v_zp, v_scale, + // a_zp, a_scale, + // o_zp, o_scale); + + // fused sdpa int8 impl + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + sdpa_int8_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + + return output.transpose(1, 2); +} + + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::scaled_dot_product_int8", &_scaled_dot_product_int8_cpu); +} + +// } // at::native +} // namespace torchao diff --git a/torchao/csrc/cpu/toy.cpp b/torchao/csrc/cpu/toy.cpp new file mode 100644 index 0000000000..a835aae661 --- /dev/null +++ b/torchao/csrc/cpu/toy.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +namespace torchao { + +torch::Tensor toy_op2_cpu( + torch::Tensor _in_feats) +{ + std::cout<<"---- run into cpu 2 ----"< Tensor") lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") lib.define("marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor") +lib.define("scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=1.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor") def register_custom_op(name): @@ -71,6 +72,56 @@ def _( return _in_feats.new_empty((BS, OC)) +def scaled_dot_product_int8( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 1.0, + q_zp: int = 0, + q_scale: float = 1.0, + k_zp: int = 0, + k_scale: float = 1.0, + v_zp: int = 0, + v_scale: float = 1.0, + a_zp: int = 0, + a_scale: float = 1.0, + o_zp: int = 0, + o_scale: float = 1.0, +) -> Tensor: + return torch.ops.torchao.scaled_dot_product_int8.default(query, key, value, + attn_mask, dropout_p, is_causal, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale) + + +@register_custom_op("torchao::scaled_dot_product_int8") +def _( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 1.0, + q_zp: int = 0, + q_scale: float = 1.0, + k_zp: int = 0, + k_scale: float = 1.0, + v_zp: int = 0, + v_scale: float = 1.0, + a_zp: int = 0, + a_scale: float = 1.0, + o_zp: int = 0, + o_scale: float = 1.0, +) -> Tensor: + return query + def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ @@ -418,3 +469,55 @@ def _( ) return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) + + + +# def scaled_dot_product_int8( +# query: Tensor, +# key: Tensor, +# value: Tensor, +# attn_mask: Optional[Tensor] = None, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# scale: Optional[float] = None, +# q_zp: int = 0, +# q_scale: float = 1.0, +# k_zp: int = 0, +# k_scale: float = 1.0, +# v_zp: int = 0, +# v_scale: float = 1.0, +# a_zp: int = 0, +# a_scale: float = 1.0, +# o_zp: int = 0, +# o_scale: float = 1.0, +# ) -> Tensor: +# return torch.ops.torchao.scaled_dot_product_int8.default(query, key, value, +# attn_mask, dropout_p, is_causal, scale, +# q_zp, q_scale, +# k_zp, k_scale, +# v_zp, v_scale, +# a_zp, a_scale, +# o_zp, o_scale) + + +# @register_custom_op("torchao::scaled_dot_product_int8") +# def _( +# query: Tensor, +# key: Tensor, +# value: Tensor, +# attn_mask: Optional[Tensor] = None, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# scale: Optional[float] = None, +# q_zp: int = 0, +# q_scale: float = 1.0, +# k_zp: int = 0, +# k_scale: float = 1.0, +# v_zp: int = 0, +# v_scale: float = 1.0, +# a_zp: int = 0, +# a_scale: float = 1.0, +# o_zp: int = 0, +# o_scale: float = 1.0, +# ) -> Tensor: +# return query.new_empty(query.shape) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 344bdeea41..c4dff31e64 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -78,6 +78,9 @@ smooth_fq_linear_to_inference, swap_linear_with_smooth_fq_linear, ) +from .sfdp_int8_fx_pass import ( + _sfdp_init_int8, +) from .subclass import * # noqa: F403 from .unified import Quantizer, TwoStepQuantizer from .utils import ( @@ -150,4 +153,5 @@ "WeightOnlyInt8QuantLinear", "TwoStepQuantizer", "Quantizer", + "_sfdp_init_int8", ] diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py new file mode 100644 index 0000000000..672db14f6b --- /dev/null +++ b/torchao/quantization/sfdp_int8_fx_pass.py @@ -0,0 +1,733 @@ +import functools +from typing import Callable + +import torch +from torch._inductor import config +from torch._inductor.pattern_matcher import ( + filter_nodes, + fwd_only, + register_replacement, + gen_register_replacement, + PatternMatcherPass, +) +from torch._dynamo.utils import counters +from torch._inductor.fx_passes.fuse_attention import ( + partialize_and_update_signature +) +from torchao.ops import scaled_dot_product_int8 + +__all__ = [ + # "_sfdp_pattern_int8", + # "_sfdp_replacement_int8", + # "_gen_sfdp_patterns_int8", + "_sfdp_init_int8", +] + +aten = torch.ops.aten +# scaled_dot_product_int8 = torch.ops.torchao.scaled_dot_product_int8 +patterns = PatternMatcherPass() + +# def _sfdp_pattern_int8(query, key, value, inv_scale): +# return ( +# torch.matmul(query, key.transpose(-2, -1)) +# .div(inv_scale) +# .softmax(dim=-1) +# .matmul(value) +# ) + +# def _sfdp_replacement_int8(query, key, value, inv_scale): +# print("*** enter _sfdp_replacement in torchao ***") +# counters["inductor"]["fuse_attention_int8"] += 1 +# return torch.nn.functional.scaled_dot_product_attention( +# query, +# key, +# value, +# attn_mask=None, +# dropout_p=0.0, +# is_causal=False, +# scale=1.0 / inv_scale, +# ) + +def _sfdp_pattern_int8_1( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-fp32 QUANTIZED SDPA with mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ) + a = torch.nn.functional.dropout( + (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_1( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_1") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask, + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_pattern_int8_2( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-reduce QUANTIZED SDPA with mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ).to(torch.float16) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ).to(torch.float16) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ).to(torch.float16) + a = torch.nn.functional.dropout( + (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ).to(torch.float16) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_2( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_2") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask, + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_pattern_int8_3( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-fp32 QUANTIZED SDPA without mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ) + a = torch.nn.functional.dropout( + torch.matmul(q, k).div(inv_scale).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_3( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_3") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_pattern_int8_4( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-reduce QUANTIZED SDPA without mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ).to(torch.float16) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ).to(torch.float16) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ).to(torch.float16) + a = torch.nn.functional.dropout( + torch.matmul(q, k).div(inv_scale).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ).to(torch.float16) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_4( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_4") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_params_check_int8(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_nodes = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + add_mask_node = [n for n in add_nodes if n.prev.target == torch.ops.aten.div.Tensor] + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + ): + return False + return True + + +def _sfdp_extra_check_int8(scale_factor_op=None, disable_cuda=False): + def fn(match): + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + if scale_factor_op is not None: + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + return _sfdp_params_check_int8(match) + + return fn + + +def _gen_sfdp_patterns_int8(): + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) # attn_mask + c_inp = functools.partial(torch.tensor, 2.0, device=device) # inv_scale + zp_inp = functools.partial(torch.tensor, 127, device=device) # quant_zero_point + scale_inp = functools.partial(torch.tensor, 0.018, device=device) # quant_scale + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + for dtype in [torch.float, torch.half]: + # g = functools.partial(g_inp, dtype=dtype) + # c = functools.partial(c_inp, dtype=dtype) + # candidates = [ + # ( + # _sfdp_pattern_int8, + # _sfdp_replacement_int8, + # [g(), g(), g(), c()], + # {}, + # _sfdp_extra_check_int8(aten.div.Tensor), + # ), + # ] + g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False) + g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False) + m = functools.partial(m_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + zp = functools.partial(zp_inp, dtype=torch.int) + scale = functools.partial(scale_inp, dtype=torch.float) + d_u8 = { + "dropout": 0.113377, + "q_zp": 23, + "q_scale": 0.0111541, + "k_zp": 14, + "k_scale": 0.0256212, + "v_zp": 28, + "v_scale": 0.0164518, + "a_zp": 12, + "a_scale": 0.0572114, + "o_zp": 36, + "o_scale": 0.0235489, + } + int8_candidates = [ + ( + _sfdp_pattern_int8_1, + _sfdp_replacement_int8_1, + [ + g_u8(), + g_u8(), + g_u8(), + m(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_1, + _sfdp_replacement_int8_1, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + m_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_2, + _sfdp_replacement_int8_2, + [ + g_u8(), + g_u8(), + g_u8(), + m(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_2, + _sfdp_replacement_int8_2, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + m_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_3, + _sfdp_replacement_int8_3, + [ + g_u8(), + g_u8(), + g_u8(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_3, + _sfdp_replacement_int8_3, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_4, + _sfdp_replacement_int8_4, + [ + g_u8(), + g_u8(), + g_u8(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_4, + _sfdp_replacement_int8_4, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ] + for pattern, replacement, args, workaround, extra_check in int8_candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if len(workaround) >= 1: + # if "dropout_p" in workaround: + # # functools.partial insufficient because we look at signature downstream + # pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + # replacement = partialize_and_update_signature( + # replacement, dropout_p=0.0 + # ) + # workaround = {} + # else: + # for uint8 pattern with more workarounds other than dropout, + # we need to rename it to avoid influcing other patterns + pattern = partialize_and_update_signature(pattern, dropout=0.0) + replacement = partialize_and_update_signature( + replacement, dropout=0.0 + ) + if "dropout" in workaround: + del workaround["dropout"] + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + } + + +@functools.lru_cache(None) +def _sfdp_init_int8(): + for key, register_replacement_kwargs in _gen_sfdp_patterns_int8(): + register_replacement(**register_replacement_kwargs) + config.joint_custom_pre_pass = patterns.apply + # print("\n\njoint_custom_pre_pass:", config.joint_custom_pre_pass) From c6289e66cddabeb1606a9f0d64aa439cd28ad7d1 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 02:45:51 -0500 Subject: [PATCH 2/5] update int8 sdpa --- test/test_ops.py | 1104 ++++++++++++++++++++----------------- torchao/csrc/cpu/sdpa.cpp | 27 - torchao/ops.py | 52 -- 3 files changed, 613 insertions(+), 570 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c5821eed44..aaa8c8946d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,6 +19,7 @@ ) from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq import pytest +import math if is_fbcode(): pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels") @@ -38,502 +39,623 @@ class TestOps(TestCase): - def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype): - # Randomly initialize each byte - nbits = 1 + ebits + mbits - floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) - scale = torch.rand(OC).to(dtype) + 0.5 - fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 - return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - @parametrize("dtype", [torch.half, torch.bfloat16]) - def test_quant_llm_linear(self, ebits, mbits, dtype): - BS = 2 - OC = 256 - IC = 256 - splitK = 1 - floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) - - # smoke test - torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) - @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - @parametrize("dtype", [torch.half, torch.bfloat16]) - def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype): - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py - floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) - - results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - - fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype) - results_fp16 = fp16_act @ fp16_weight.T - - error = (results_floatx - results_fp16).abs().mean() - gt = results_fp16.abs().mean() - relative_error = error / gt - rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 - assert relative_error < rtol - -instantiate_parametrized_tests(TestOps) - - -## Tests for `tensor_core_layout` -kTileSizeN = 8 -kTileSizeK = 16 - -SHAPES = [ - (4096, 4096), - # Llama 2 GEMM shapes - (4096, 11008), - (11008, 4096), - # Llama 3 GEMM shapes - (4096, 14336), - (14336, 4096), -] -INNERKTILES = [2, 4, 8] -QGROUP_SIZES = [32, 64, 128, 256] -TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) -TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) -def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): - N, K = shape - assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 - - t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) - packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) - unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) - assert torch.equal(t, unpacked) - -# TODO: Fix "test_aot_dispatch_dynamic" test failure -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) -def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): - test_utils = [ - "test_schema", - "test_autograd_registration", - "test_faketensor", - ] - - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") - - t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) - packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) - - opcheck( - torch.ops.torchao.unpack_tensor_core_tiled_layout, - (packed_w, inner_k_tiles), - test_utils=test_utils, - ) - -def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): - n, k = q.shape - assert q.dtype == torch.int - - n_groups = k // group_size - assert scales.shape[0] == n and scales.shape[1] == n_groups - assert scales.shape == zeros.shape - - midpoint = 2 ** (nbits - 1) - - #Convert fron u4 -> s4 and upcast to bfloat16 - q = q.sub(midpoint).to(dtype) - - # Dequantize - q = q.reshape(-1, group_size) - dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) - - return dq.reshape(n, k) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): - n, k = shape - dtype = torch.bfloat16 - - device = "cuda" - - t = torch.randn(n, k, dtype=dtype, device=device) - scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) - - # Quantize - q = groupwise_affine_quantize_tensor_from_qparams( - t, scales, zeros, n_bit=4, groupsize=group_size - ) - - # Pack to tensor core layout - packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) - q_groups = k // group_size - assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) - - # Dequantize 'ao' ref - dq_ao = groupwise_affine_dequantize_tensor_from_qparams( - q, scales, zeros, n_bit=4, groupsize=group_size - ) - - # Dequantize by passing in an identity matrix as the activation - a_eye = torch.eye(k, device=device, dtype=dtype) - dq_id = torch.ops.aten._weight_int4pack_mm( - a_eye, - packed, - group_size, - scales_and_zeros, - ).t() - - # Actual operation to test - dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) - - # Compare results - diff_ao_id = (dq_id - dq_ao).abs().max() - diff_op_id = (dq_op - dq_id).abs().max() - diff_op_ao = (dq_op - dq_ao).abs().max() - - # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` - # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast - # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are - # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. - - # Test that the `dequant` kernel gives same results as identity matrix-based dequant - assert diff_op_id == 0 - - # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix - assert diff_op_ao == diff_ao_id - - assert diff_op_ao < 1e-1 - -# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): - n, k = shape - dtype = torch.bfloat16 - device = "cuda" - - # Quantize and pack - t = torch.randn(n, k, dtype=dtype, device=device) - scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) - q = groupwise_affine_quantize_tensor_from_qparams( - t, scales, zeros, n_bit=4, groupsize=group_size - ) - - packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) - - # Unpack and dequantize - unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) - - dq_ao = groupwise_affine_dequantize_tensor_from_qparams( - unpacked, scales, zeros, n_bit=4, groupsize=group_size - ) - - # Dequantize by passing in an identity matrix as the activation - a_eye = torch.eye(k, device=device, dtype=dtype) - dq_id = torch.ops.aten._weight_int4pack_mm( - a_eye, - packed, - group_size, - scales_and_zeros, - ).t() - - # Actual operation to test - dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) - - # Compare results - diff_ao_id = (dq_id - dq_ao).abs().max() - diff_op_id = (dq_op - dq_id).abs().max() - diff_op_ao = (dq_op - dq_ao).abs().max() - - # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` - # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast - # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are - # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. - - # Test that the `dequant` kernel gives same results as identity matrix-based dequant - assert diff_op_id == 0 - - # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix - assert diff_op_ao == diff_ao_id - - assert diff_op_ao < 1e-1 - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): - n, k = shape - device = "cuda" - - q = torch.randint(0, 16, shape, dtype=torch.int, device=device) - if TORCH_VERSION_AT_LEAST_2_5: - q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) - packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) - q_groups = k // group_size - scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) - zeros = torch.randn_like(scales) - scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) - - test_utils = [ - "test_schema", - "test_autograd_registration", - "test_faketensor", - ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") - opcheck( - torch.ops.torchao.dequantize_tensor_core_tiled_layout, - (packed_w, scales_and_zeros, group_size, inner_k_tiles), - test_utils=test_utils, - ) - - -MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -MARLIN_24_K_CHUNKS = [128] -MARLIN_24_N_CHUNKS = [512] -MNK_FACTORS = [ - (1, 1, 1), - (1, 4, 8), - (1, 7, 5), - (13, 17, 67), - (26, 37, 13), - (67, 13, 11), -] -MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_TEST_PARAMS = list(itertools.product( - MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, - MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS -)) - -def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Reshape to [groupsize, -1] - if group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - s = torch.max(torch.abs(w), 0, keepdim=True)[0] - s *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s - - # Restore original shapes - if group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - s = s.reshape((-1, size_n)).contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - ) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) -def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda") - b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") - - # Inject 2:4 sparsity - w_24, _ = inject_24(b_weight, size_k, size_n) - - # Symmetric quantize - w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) - - # Reshape input into 2D tensor - input_2d = a_input.view(-1, a_input.shape[-1]) - a_input_in, a_input_out = input_2d.shape - - # Obtains reference output - output_ref = torch.matmul(input_2d, w_24_ref) - output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],)) - - # Packs to marlin 2:4 - marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) - workspace_24 = marlin_24_workspace(size_n) - - fn_inputs = ( - input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, - num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out, - ) - output = torchao.ops.marlin_24_gemm(*fn_inputs) - output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],)) - - max_diff = compute_max_diff(output, output_ref) - assert max_diff < 0.04 - - # Performs opcheck - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] - opcheck( - torch.ops.torchao.marlin_24_gemm, - fn_inputs, - test_utils=test_utils, - ) - - -MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -MARLIN_QQQ_K_CHUNKS = [128] -MARLIN_QQQ_N_CHUNKS = [64, 128, 256] -MNK_FACTORS = [ - (1, 1, 1), - (1, 4, 8), - (1, 7, 5), - (13, 17, 67), - (26, 37, 13), - (67, 13, 11), -] -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_TEST_PARAMS = list( - itertools.product( - MARLIN_QQQ_BATCH_SIZE, - MARLIN_QQQ_K_CHUNKS, - MARLIN_QQQ_N_CHUNKS, - MARLIN_QQQ_SUPPORTED_NUM_BITS, - MARLIN_QQQ_SUPPORTED_GROUP_SIZES, - MNK_FACTORS, + def _scaled_dot_product_int8_op_ref( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0, + is_causal=False, + q_zp=0, + q_scale=1.0, + k_zp=0, + k_scale=1.0, + v_zp=0, + v_scale=1.0, + a_zp=0, + a_scale=1.0, + o_zp=0, + o_scale=1.0): + q = q.to(torch.float) + k = k.to(torch.float) + v = v.to(torch.float) + scale_factor = 1 / math.sqrt(q.size(-1)) + attn = q @ k.transpose(-2, -1) + attn = attn * scale_factor + if attn_mask is not None: + attn = attn + attn_mask + attn_max = attn.max(dim=-1, keepdim=True).values + attn = attn - attn_max + attn = torch.exp(attn) + attn_sum = torch.sum(attn, dim=-1, keepdim=True) + attn = attn / attn_sum + math_ref = attn @ v + return math_ref.to(torch.uint8) + + SDPA_INT8_BATCH_SIZE = [56] + SDPA_INT8_NUM_HEADS = [16] + SDPA_INT8_Q_SEQ_LEN = [188] + SDPA_INT8_KV_SEQ_LEN = [253] + SDPA_INT8_HEAD_DIM = [64] + SDPA_INT8_MASK_DTYPE = [torch.bfloat16] + + SDPA_INT8_TEST_PARAMS = list( + itertools.product( + SDPA_INT8_BATCH_SIZE, + SDPA_INT8_NUM_HEADS, + SDPA_INT8_Q_SEQ_LEN, + SDPA_INT8_KV_SEQ_LEN, + SDPA_INT8_HEAD_DIM, + SDPA_INT8_MASK_DTYPE, + ) ) -) + @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) + @parametrize("n_head", SDPA_INT8_NUM_HEADS) + @parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN) + @parametrize("kv_seq_len", SDPA_INT8_KV_SEQ_LEN) + @parametrize("head_dim", SDPA_INT8_HEAD_DIM) + @parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE) + def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype): + device = "cpu" + q_zp = int(127) + q_scale = float(1.7907238006591797) + k_zp = int(125) + k_scale = float(1.8039721250534058) + v_zp = int(127) + v_scale = float(1.839004635810852) + a_zp = int(120) + a_scale = float(0.003919653594493866) + o_zp = int(128) + o_scale = float(1.8191684484481812) + q_shape = [batch_size, n_head, q_seq_len, head_dim] + kv_shape = [batch_size, n_head, kv_seq_len, head_dim] + mask_shape = [batch_size, 1, q_seq_len, kv_seq_len] + q = torch.randn(q_shape, dtype=torch.float, device=device) + k = torch.randn(kv_shape, dtype=torch.float, device=device) + v = torch.randn(kv_shape, dtype=torch.float, device=device) + q = q.to(torch.uint8) + k = k.to(torch.uint8) + v = v.to(torch.uint8) + attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) + q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() + + math_ref = self._scaled_dot_product_int8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale + ) + actual = torch.ops.torchao.scaled_dot_product_int8( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale + ) + + self.assertEqual(actual, math_ref, atol=3.0, rtol=5e-6) + + # def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype): + # # Randomly initialize each byte + # nbits = 1 + ebits + mbits + # floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) + # scale = torch.rand(OC).to(dtype) + 0.5 + # fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 + # return floatx_weight.to(device), scale.to(device), fp16_act.to(device) + + # @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + # @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + # @parametrize("dtype", [torch.half, torch.bfloat16]) + # def test_quant_llm_linear(self, ebits, mbits, dtype): + # BS = 2 + # OC = 256 + # IC = 256 + # splitK = 1 + # floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) + + # # smoke test + # torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) + + # # comprehensive testing + # test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + # opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils) + + # @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + # @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + # @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + # @parametrize("dtype", [torch.half, torch.bfloat16]) + # def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype): + # # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py + # floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) + + # results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) + + # fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype) + # results_fp16 = fp16_act @ fp16_weight.T + + # error = (results_floatx - results_fp16).abs().mean() + # gt = results_fp16.abs().mean() + # relative_error = error / gt + # rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 + # assert relative_error < rtol -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", - MARLIN_TEST_PARAMS, - ids=str, -) -@pytest.mark.skip(reason="test outputs nan after cuda is upgraded to 12.4") -def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): - int8_traits = torch.iinfo(torch.int8) - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = torch.randn( - (batch_size, size_m, size_k), dtype=torch.float16, device="cuda" - ) - b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda") - - # Reshape input into 2D tensor - input_2d = a_input.view(-1, a_input.shape[-1]) - a_input_in, a_input_out = input_2d.shape - - # Quantize activations - s_a = ( - input_2d.abs() - .max(dim=-1, keepdim=True)[0] - .div(int8_traits.max) - .to(torch.float32) - ) - q_a = ( - (input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - ) +instantiate_parametrized_tests(TestOps) - # Quantize weights - q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( - b_weight, num_bits, group_size - ) - q_w = q_w.t() - s_group = s_group.t() - s_channel = s_channel.t() - w_ref = w_ref.t() - marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( - q_w, s_group, s_channel, num_bits, group_size - ) - workspace = marlin_qqq_workspace(size_n) - - # Obtains reference output - output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) - output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,)) - - fn_inputs = ( - q_a, - marlin_qqq_q_w, - s_a, - marlin_qqq_s_channel, - marlin_qqq_s_group, - workspace, - a_input_in, - size_n, - a_input_out, - ) - output = torchao.ops.marlin_qqq_gemm(*fn_inputs) - output = output.reshape(a_input.shape[:-1] + (size_n,)) - - max_diff = compute_max_diff(output, output_ref) - assert max_diff < 0.04 - - # Performs opcheck - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] - opcheck( - torch.ops.torchao.marlin_qqq_gemm, - fn_inputs, - test_utils=test_utils, - ) +# ## Tests for `tensor_core_layout` +# kTileSizeN = 8 +# kTileSizeK = 16 + +# SHAPES = [ +# (4096, 4096), +# # Llama 2 GEMM shapes +# (4096, 11008), +# (11008, 4096), +# # Llama 3 GEMM shapes +# (4096, 14336), +# (14336, 4096), +# ] +# INNERKTILES = [2, 4, 8] +# QGROUP_SIZES = [32, 64, 128, 256] +# TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) +# TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) +# def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): +# N, K = shape +# assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 + +# t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") +# if TORCH_VERSION_AT_LEAST_2_5: +# t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) +# packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) +# unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) +# if TORCH_VERSION_AT_LEAST_2_5: +# unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) +# assert torch.equal(t, unpacked) + +# # TODO: Fix "test_aot_dispatch_dynamic" test failure +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) +# def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): +# test_utils = [ +# "test_schema", +# "test_autograd_registration", +# "test_faketensor", +# ] + +# # TODO: Figure out why test fails unless torch >= 2.5 +# if TORCH_VERSION_AT_LEAST_2_5: +# test_utils.append("test_aot_dispatch_dynamic") + +# t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") +# if TORCH_VERSION_AT_LEAST_2_5: +# t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) +# packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) + +# opcheck( +# torch.ops.torchao.unpack_tensor_core_tiled_layout, +# (packed_w, inner_k_tiles), +# test_utils=test_utils, +# ) + +# def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): +# n, k = q.shape +# assert q.dtype == torch.int + +# n_groups = k // group_size +# assert scales.shape[0] == n and scales.shape[1] == n_groups +# assert scales.shape == zeros.shape + +# midpoint = 2 ** (nbits - 1) + +# #Convert fron u4 -> s4 and upcast to bfloat16 +# q = q.sub(midpoint).to(dtype) + +# # Dequantize +# q = q.reshape(-1, group_size) +# dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) + +# return dq.reshape(n, k) + + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +# def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): +# n, k = shape +# dtype = torch.bfloat16 + +# device = "cuda" + +# t = torch.randn(n, k, dtype=dtype, device=device) +# scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) + +# # Quantize +# q = groupwise_affine_quantize_tensor_from_qparams( +# t, scales, zeros, n_bit=4, groupsize=group_size +# ) + +# # Pack to tensor core layout +# packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) +# scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) +# q_groups = k // group_size +# assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) + +# # Dequantize 'ao' ref +# dq_ao = groupwise_affine_dequantize_tensor_from_qparams( +# q, scales, zeros, n_bit=4, groupsize=group_size +# ) + +# # Dequantize by passing in an identity matrix as the activation +# a_eye = torch.eye(k, device=device, dtype=dtype) +# dq_id = torch.ops.aten._weight_int4pack_mm( +# a_eye, +# packed, +# group_size, +# scales_and_zeros, +# ).t() + +# # Actual operation to test +# dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + +# # Compare results +# diff_ao_id = (dq_id - dq_ao).abs().max() +# diff_op_id = (dq_op - dq_id).abs().max() +# diff_op_ao = (dq_op - dq_ao).abs().max() + +# # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` +# # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast +# # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are +# # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + +# # Test that the `dequant` kernel gives same results as identity matrix-based dequant +# assert diff_op_id == 0 + +# # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix +# assert diff_op_ao == diff_ao_id + +# assert diff_op_ao < 1e-1 + +# # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +# def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): +# n, k = shape +# dtype = torch.bfloat16 +# device = "cuda" + +# # Quantize and pack +# t = torch.randn(n, k, dtype=dtype, device=device) +# scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) +# q = groupwise_affine_quantize_tensor_from_qparams( +# t, scales, zeros, n_bit=4, groupsize=group_size +# ) + +# packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) +# scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + +# # Unpack and dequantize +# unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) +# if TORCH_VERSION_AT_LEAST_2_5: +# unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + +# dq_ao = groupwise_affine_dequantize_tensor_from_qparams( +# unpacked, scales, zeros, n_bit=4, groupsize=group_size +# ) + +# # Dequantize by passing in an identity matrix as the activation +# a_eye = torch.eye(k, device=device, dtype=dtype) +# dq_id = torch.ops.aten._weight_int4pack_mm( +# a_eye, +# packed, +# group_size, +# scales_and_zeros, +# ).t() + +# # Actual operation to test +# dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + +# # Compare results +# diff_ao_id = (dq_id - dq_ao).abs().max() +# diff_op_id = (dq_op - dq_id).abs().max() +# diff_op_ao = (dq_op - dq_ao).abs().max() + +# # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` +# # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast +# # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are +# # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + +# # Test that the `dequant` kernel gives same results as identity matrix-based dequant +# assert diff_op_id == 0 + +# # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix +# assert diff_op_ao == diff_ao_id + +# assert diff_op_ao < 1e-1 + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +# @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +# def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): +# n, k = shape +# device = "cuda" + +# q = torch.randint(0, 16, shape, dtype=torch.int, device=device) +# if TORCH_VERSION_AT_LEAST_2_5: +# q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) +# packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) +# q_groups = k // group_size +# scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) +# zeros = torch.randn_like(scales) +# scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + +# test_utils = [ +# "test_schema", +# "test_autograd_registration", +# "test_faketensor", +# ] +# # TODO: Figure out why test fails unless torch >= 2.5 +# if TORCH_VERSION_AT_LEAST_2_5: +# test_utils.append("test_aot_dispatch_dynamic") +# opcheck( +# torch.ops.torchao.dequantize_tensor_core_tiled_layout, +# (packed_w, scales_and_zeros, group_size, inner_k_tiles), +# test_utils=test_utils, +# ) + + +# MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +# MARLIN_24_K_CHUNKS = [128] +# MARLIN_24_N_CHUNKS = [512] +# MNK_FACTORS = [ +# (1, 1, 1), +# (1, 4, 8), +# (1, 7, 5), +# (13, 17, 67), +# (26, 37, 13), +# (67, 13, 11), +# ] +# MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +# MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +# MARLIN_TEST_PARAMS = list(itertools.product( +# MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, +# MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS +# )) + +# def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): +# orig_device = w.device +# size_k, size_n = w.shape + +# assert w.is_floating_point(), "w must be float" + +# if group_size == -1: +# group_size = size_k +# assert group_size <= size_k + +# max_q_val = 2**num_bits - 1 +# half_q_val = (max_q_val + 1) // 2 + +# # Reshape to [groupsize, -1] +# if group_size < size_k: +# w = w.reshape((-1, group_size, size_n)) +# w = w.permute(1, 0, 2) +# w = w.reshape((group_size, -1)) + +# # Compute scale for each group +# s = torch.max(torch.abs(w), 0, keepdim=True)[0] +# s *= 2 / max_q_val # 2 => symmetric + +# # Quantize +# q_w = torch.round(w / s).int() +# q_w += half_q_val +# q_w = torch.clamp(q_w, 0, max_q_val) + +# # Compute ref (dequantized) +# w_ref = (q_w - half_q_val).half() * s + +# # Restore original shapes +# if group_size < size_k: + +# def reshape_w(w): +# w = w.reshape((group_size, -1, size_n)) +# w = w.permute(1, 0, 2) +# w = w.reshape((size_k, size_n)).contiguous() +# return w + +# q_w = reshape_w(q_w) +# w_ref = reshape_w(w_ref) + +# s = s.reshape((-1, size_n)).contiguous() + +# return ( +# w_ref.to(device=orig_device), +# q_w.to(device=orig_device), +# s.to(device=orig_device), +# ) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) +# def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): +# m_factor, n_factor, k_factor = mnk_factors + +# size_m = m_factor +# size_k = k_chunk * k_factor +# size_n = n_chunk * n_factor + +# a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda") +# b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") + +# # Inject 2:4 sparsity +# w_24, _ = inject_24(b_weight, size_k, size_n) + +# # Symmetric quantize +# w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) + +# # Reshape input into 2D tensor +# input_2d = a_input.view(-1, a_input.shape[-1]) +# a_input_in, a_input_out = input_2d.shape + +# # Obtains reference output +# output_ref = torch.matmul(input_2d, w_24_ref) +# output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],)) + +# # Packs to marlin 2:4 +# marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) +# workspace_24 = marlin_24_workspace(size_n) + +# fn_inputs = ( +# input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, +# num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out, +# ) +# output = torchao.ops.marlin_24_gemm(*fn_inputs) +# output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],)) + +# max_diff = compute_max_diff(output, output_ref) +# assert max_diff < 0.04 + +# # Performs opcheck +# test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] +# opcheck( +# torch.ops.torchao.marlin_24_gemm, +# fn_inputs, +# test_utils=test_utils, +# ) + + +# MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +# MARLIN_QQQ_K_CHUNKS = [128] +# MARLIN_QQQ_N_CHUNKS = [64, 128, 256] +# MNK_FACTORS = [ +# (1, 1, 1), +# (1, 4, 8), +# (1, 7, 5), +# (13, 17, 67), +# (26, 37, 13), +# (67, 13, 11), +# ] +# MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +# MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] + +# MARLIN_TEST_PARAMS = list( +# itertools.product( +# MARLIN_QQQ_BATCH_SIZE, +# MARLIN_QQQ_K_CHUNKS, +# MARLIN_QQQ_N_CHUNKS, +# MARLIN_QQQ_SUPPORTED_NUM_BITS, +# MARLIN_QQQ_SUPPORTED_GROUP_SIZES, +# MNK_FACTORS, +# ) +# ) + + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.parametrize( +# "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", +# MARLIN_TEST_PARAMS, +# ids=str, +# ) +# @pytest.mark.skip(reason="test outputs nan after cuda is upgraded to 12.4") +# def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): +# int8_traits = torch.iinfo(torch.int8) +# m_factor, n_factor, k_factor = mnk_factors + +# size_m = m_factor +# size_k = k_chunk * k_factor +# size_n = n_chunk * n_factor + +# a_input = torch.randn( +# (batch_size, size_m, size_k), dtype=torch.float16, device="cuda" +# ) +# b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda") + +# # Reshape input into 2D tensor +# input_2d = a_input.view(-1, a_input.shape[-1]) +# a_input_in, a_input_out = input_2d.shape + +# # Quantize activations +# s_a = ( +# input_2d.abs() +# .max(dim=-1, keepdim=True)[0] +# .div(int8_traits.max) +# .to(torch.float32) +# ) +# q_a = ( +# (input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) +# ) + +# # Quantize weights +# q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( +# b_weight, num_bits, group_size +# ) +# q_w = q_w.t() +# s_group = s_group.t() +# s_channel = s_channel.t() +# w_ref = w_ref.t() +# marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( +# q_w, s_group, s_channel, num_bits, group_size +# ) + +# workspace = marlin_qqq_workspace(size_n) + +# # Obtains reference output +# output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) +# output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,)) + +# fn_inputs = ( +# q_a, +# marlin_qqq_q_w, +# s_a, +# marlin_qqq_s_channel, +# marlin_qqq_s_group, +# workspace, +# a_input_in, +# size_n, +# a_input_out, +# ) +# output = torchao.ops.marlin_qqq_gemm(*fn_inputs) +# output = output.reshape(a_input.shape[:-1] + (size_n,)) + +# max_diff = compute_max_diff(output, output_ref) +# assert max_diff < 0.04 + +# # Performs opcheck +# test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] +# opcheck( +# torch.ops.torchao.marlin_qqq_gemm, +# fn_inputs, +# test_utils=test_utils, +# ) if __name__ == "__main__": diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 44aaef1bcc..229ef3433e 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -1,30 +1,3 @@ -// // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -// #include -// #include - -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// // #include -// // #include -// #include -// #include - -// #ifndef AT_PER_OPERATOR_HEADERS -// #include -// #else -// #include -// #endif - #include #include #include diff --git a/torchao/ops.py b/torchao/ops.py index d5e9d0e75b..4be4064aa2 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -469,55 +469,3 @@ def _( ) return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) - - - -# def scaled_dot_product_int8( -# query: Tensor, -# key: Tensor, -# value: Tensor, -# attn_mask: Optional[Tensor] = None, -# dropout_p: float = 0.0, -# is_causal: bool = False, -# scale: Optional[float] = None, -# q_zp: int = 0, -# q_scale: float = 1.0, -# k_zp: int = 0, -# k_scale: float = 1.0, -# v_zp: int = 0, -# v_scale: float = 1.0, -# a_zp: int = 0, -# a_scale: float = 1.0, -# o_zp: int = 0, -# o_scale: float = 1.0, -# ) -> Tensor: -# return torch.ops.torchao.scaled_dot_product_int8.default(query, key, value, -# attn_mask, dropout_p, is_causal, scale, -# q_zp, q_scale, -# k_zp, k_scale, -# v_zp, v_scale, -# a_zp, a_scale, -# o_zp, o_scale) - - -# @register_custom_op("torchao::scaled_dot_product_int8") -# def _( -# query: Tensor, -# key: Tensor, -# value: Tensor, -# attn_mask: Optional[Tensor] = None, -# dropout_p: float = 0.0, -# is_causal: bool = False, -# scale: Optional[float] = None, -# q_zp: int = 0, -# q_scale: float = 1.0, -# k_zp: int = 0, -# k_scale: float = 1.0, -# v_zp: int = 0, -# v_scale: float = 1.0, -# a_zp: int = 0, -# a_scale: float = 1.0, -# o_zp: int = 0, -# o_scale: float = 1.0, -# ) -> Tensor: -# return query.new_empty(query.shape) From fb7d62f963b2912dc905c9b8a19147787be0d9bc Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 02:47:47 -0500 Subject: [PATCH 3/5] update int8 sdpa --- torchao/quantization/sfdp_int8_fx_pass.py | 47 ----------------------- 1 file changed, 47 deletions(-) diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py index 672db14f6b..1dedd60ee5 100644 --- a/torchao/quantization/sfdp_int8_fx_pass.py +++ b/torchao/quantization/sfdp_int8_fx_pass.py @@ -17,37 +17,12 @@ from torchao.ops import scaled_dot_product_int8 __all__ = [ - # "_sfdp_pattern_int8", - # "_sfdp_replacement_int8", - # "_gen_sfdp_patterns_int8", "_sfdp_init_int8", ] aten = torch.ops.aten -# scaled_dot_product_int8 = torch.ops.torchao.scaled_dot_product_int8 patterns = PatternMatcherPass() -# def _sfdp_pattern_int8(query, key, value, inv_scale): -# return ( -# torch.matmul(query, key.transpose(-2, -1)) -# .div(inv_scale) -# .softmax(dim=-1) -# .matmul(value) -# ) - -# def _sfdp_replacement_int8(query, key, value, inv_scale): -# print("*** enter _sfdp_replacement in torchao ***") -# counters["inductor"]["fuse_attention_int8"] += 1 -# return torch.nn.functional.scaled_dot_product_attention( -# query, -# key, -# value, -# attn_mask=None, -# dropout_p=0.0, -# is_causal=False, -# scale=1.0 / inv_scale, -# ) - def _sfdp_pattern_int8_1( query, key, @@ -476,17 +451,6 @@ def _gen_sfdp_patterns_int8(): ) m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) for dtype in [torch.float, torch.half]: - # g = functools.partial(g_inp, dtype=dtype) - # c = functools.partial(c_inp, dtype=dtype) - # candidates = [ - # ( - # _sfdp_pattern_int8, - # _sfdp_replacement_int8, - # [g(), g(), g(), c()], - # {}, - # _sfdp_extra_check_int8(aten.div.Tensor), - # ), - # ] g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False) g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False) m = functools.partial(m_inp, dtype=dtype) @@ -696,16 +660,6 @@ def _gen_sfdp_patterns_int8(): name = pattern.__name__ if len(workaround) >= 1: - # if "dropout_p" in workaround: - # # functools.partial insufficient because we look at signature downstream - # pattern = partialize_and_update_signature(pattern, dropout_p=0.0) - # replacement = partialize_and_update_signature( - # replacement, dropout_p=0.0 - # ) - # workaround = {} - # else: - # for uint8 pattern with more workarounds other than dropout, - # we need to rename it to avoid influcing other patterns pattern = partialize_and_update_signature(pattern, dropout=0.0) replacement = partialize_and_update_signature( replacement, dropout=0.0 @@ -730,4 +684,3 @@ def _sfdp_init_int8(): for key, register_replacement_kwargs in _gen_sfdp_patterns_int8(): register_replacement(**register_replacement_kwargs) config.joint_custom_pre_pass = patterns.apply - # print("\n\njoint_custom_pre_pass:", config.joint_custom_pre_pass) From 545b2b8242d7fde63f86f7ba8c401b5d049a0693 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 03:03:42 -0500 Subject: [PATCH 4/5] update int8 sdpa --- torchao/csrc/cpu/toy.cpp | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 torchao/csrc/cpu/toy.cpp diff --git a/torchao/csrc/cpu/toy.cpp b/torchao/csrc/cpu/toy.cpp deleted file mode 100644 index a835aae661..0000000000 --- a/torchao/csrc/cpu/toy.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include -#include -#include - -namespace torchao { - -torch::Tensor toy_op2_cpu( - torch::Tensor _in_feats) -{ - std::cout<<"---- run into cpu 2 ----"< Date: Tue, 17 Dec 2024 00:11:14 -0500 Subject: [PATCH 5/5] update int8 sdpa cpu --- setup.py | 14 + test/quantization/test_sfdp_int8_fx_pass.py | 10 +- test/test_ops.py | 1049 +++++++++-------- torchao/csrc/cpu/sdpa.cpp | 1148 ++++++------------- torchao/ops.py | 6 +- 5 files changed, 922 insertions(+), 1305 deletions(-) diff --git a/setup.py b/setup.py index b7334631a8..080bdea503 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,20 @@ def get_extensions(): "cxx": [ "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", + # ## AVX2 + # "-DCPU_CAPABILITY=AVX2", + # "-DCPU_CAPABILITY_AVX2", + # "-mavx2", + # "-mfma", + # "-mf16c", + ## AVX512 + "-DCPU_CAPABILITY=AVX512", + "-DCPU_CAPABILITY_AVX512", + "-mavx512f", + "-mavx512bw", + "-mavx512vl", + "-mavx512dq", + "-mfma", ], "nvcc": [ "-O3" if not debug_mode else "-O0", diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py index a39a98c364..3e17d9ce81 100644 --- a/test/quantization/test_sfdp_int8_fx_pass.py +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -16,7 +16,7 @@ from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch._export import capture_pre_autograd_graph +from torch.export import export_for_training from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, @@ -65,7 +65,7 @@ def forward(self, x, mask): if self.has_mask: scores = scores + mask attention = self.softmax(scores) - # attention = self.dropout(attention) + attention = self.dropout(attention) context_layer = torch.matmul(attention, v) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.view( @@ -75,7 +75,7 @@ def forward(self, x, mask): def _generate_qdq_quantized_model(mod, inputs, quantizer): with torch.no_grad(): - export_model = capture_pre_autograd_graph(mod, inputs) + export_model = export_for_training(mod, inputs).module() prepare_model = prepare_pt2e(export_model, quantizer) prepare_model(*inputs) convert_model = convert_pt2e(prepare_model) @@ -173,10 +173,10 @@ def _test_sdpa_rewriter_int8_1_to_4(self): if dtype == torch.bfloat16 else contextlib.nullcontext() ) - inputs = [ + inputs = ( torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype), torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None, - ] + ) with torch.no_grad(), maybe_autocast: _sfdp_init_int8() quantizer = X86InductorQuantizer() diff --git a/test/test_ops.py b/test/test_ops.py index aaa8c8946d..0b0540e47d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -57,39 +57,31 @@ def _scaled_dot_product_int8_op_ref( a_scale=1.0, o_zp=0, o_scale=1.0): - q = q.to(torch.float) - k = k.to(torch.float) - v = v.to(torch.float) + q = (q.to(torch.float) - q_zp) * q_scale + k = (k.to(torch.float) - k_zp) * k_scale + v = (v.to(torch.float) - v_zp) * v_scale scale_factor = 1 / math.sqrt(q.size(-1)) attn = q @ k.transpose(-2, -1) attn = attn * scale_factor if attn_mask is not None: - attn = attn + attn_mask + attn = attn + attn_mask.to(torch.float) attn_max = attn.max(dim=-1, keepdim=True).values attn = attn - attn_max attn = torch.exp(attn) attn_sum = torch.sum(attn, dim=-1, keepdim=True) attn = attn / attn_sum - math_ref = attn @ v - return math_ref.to(torch.uint8) - - SDPA_INT8_BATCH_SIZE = [56] - SDPA_INT8_NUM_HEADS = [16] - SDPA_INT8_Q_SEQ_LEN = [188] - SDPA_INT8_KV_SEQ_LEN = [253] - SDPA_INT8_HEAD_DIM = [64] - SDPA_INT8_MASK_DTYPE = [torch.bfloat16] - - SDPA_INT8_TEST_PARAMS = list( - itertools.product( - SDPA_INT8_BATCH_SIZE, - SDPA_INT8_NUM_HEADS, - SDPA_INT8_Q_SEQ_LEN, - SDPA_INT8_KV_SEQ_LEN, - SDPA_INT8_HEAD_DIM, - SDPA_INT8_MASK_DTYPE, - ) - ) + attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255) + attn = (attn - a_zp) * a_scale + out = attn @ v + out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) + return out.to(torch.uint8) + + SDPA_INT8_BATCH_SIZE = [56, 120] + SDPA_INT8_NUM_HEADS = [2, 16] + SDPA_INT8_Q_SEQ_LEN = [18, 89] + SDPA_INT8_KV_SEQ_LEN = [100, 253] + SDPA_INT8_HEAD_DIM = [32, 64] + SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16] @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) @parametrize("n_head", SDPA_INT8_NUM_HEADS) @@ -98,6 +90,7 @@ def _scaled_dot_product_int8_op_ref( @parametrize("head_dim", SDPA_INT8_HEAD_DIM) @parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE) def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype): + torch.manual_seed(1234) device = "cpu" q_zp = int(127) q_scale = float(1.7907238006591797) @@ -109,23 +102,23 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ a_scale = float(0.003919653594493866) o_zp = int(128) o_scale = float(1.8191684484481812) - q_shape = [batch_size, n_head, q_seq_len, head_dim] - kv_shape = [batch_size, n_head, kv_seq_len, head_dim] - mask_shape = [batch_size, 1, q_seq_len, kv_seq_len] - q = torch.randn(q_shape, dtype=torch.float, device=device) - k = torch.randn(kv_shape, dtype=torch.float, device=device) - v = torch.randn(kv_shape, dtype=torch.float, device=device) + q_shape = [batch_size, q_seq_len, n_head, head_dim] + kv_shape = [batch_size, kv_seq_len, n_head, head_dim] + mask_shape = [batch_size, 1, 1, kv_seq_len] + q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 + k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 + v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 q = q.to(torch.uint8) k = k.to(torch.uint8) v = v.to(torch.uint8) - attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) - q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() + attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype is not None else None + q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() if mask_dtype is not None else None math_ref = self._scaled_dot_product_int8_op_ref( q2, k2, v2, - attn_mask=attn_mask_2, + attn_mask=attn_mask, dropout_p=0.0, is_causal=False, q_zp=q_zp, @@ -143,7 +136,7 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ q, k, v, - attn_mask=attn_mask, + attn_mask=attn_mask_2, dropout_p=0.0, is_causal=False, q_zp=q_zp, @@ -158,504 +151,504 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ o_scale=o_scale ) - self.assertEqual(actual, math_ref, atol=3.0, rtol=5e-6) - - # def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype): - # # Randomly initialize each byte - # nbits = 1 + ebits + mbits - # floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) - # scale = torch.rand(OC).to(dtype) + 0.5 - # fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 - # return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - - # @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - # @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - # @parametrize("dtype", [torch.half, torch.bfloat16]) - # def test_quant_llm_linear(self, ebits, mbits, dtype): - # BS = 2 - # OC = 256 - # IC = 256 - # splitK = 1 - # floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) - - # # smoke test - # torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - - # # comprehensive testing - # test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - # opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils) - - # @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - # @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) - # @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - # @parametrize("dtype", [torch.half, torch.bfloat16]) - # def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype): - # # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py - # floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) - - # results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - - # fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype) - # results_fp16 = fp16_act @ fp16_weight.T - - # error = (results_floatx - results_fp16).abs().mean() - # gt = results_fp16.abs().mean() - # relative_error = error / gt - # rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 - # assert relative_error < rtol + self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) + + def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype): + # Randomly initialize each byte + nbits = 1 + ebits + mbits + floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) + scale = torch.rand(OC).to(dtype) + 0.5 + fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 + return floatx_weight.to(device), scale.to(device), fp16_act.to(device) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + @parametrize("dtype", [torch.half, torch.bfloat16]) + def test_quant_llm_linear(self, ebits, mbits, dtype): + BS = 2 + OC = 256 + IC = 256 + splitK = 1 + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) + + # smoke test + torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + @parametrize("dtype", [torch.half, torch.bfloat16]) + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype): + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) + + results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) + + fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype) + results_fp16 = fp16_act @ fp16_weight.T + + error = (results_floatx - results_fp16).abs().mean() + gt = results_fp16.abs().mean() + relative_error = error / gt + rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 + assert relative_error < rtol instantiate_parametrized_tests(TestOps) -# ## Tests for `tensor_core_layout` -# kTileSizeN = 8 -# kTileSizeK = 16 - -# SHAPES = [ -# (4096, 4096), -# # Llama 2 GEMM shapes -# (4096, 11008), -# (11008, 4096), -# # Llama 3 GEMM shapes -# (4096, 14336), -# (14336, 4096), -# ] -# INNERKTILES = [2, 4, 8] -# QGROUP_SIZES = [32, 64, 128, 256] -# TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) -# TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -# @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) -# def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): -# N, K = shape -# assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 - -# t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") -# if TORCH_VERSION_AT_LEAST_2_5: -# t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) -# packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) -# unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) -# if TORCH_VERSION_AT_LEAST_2_5: -# unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) -# assert torch.equal(t, unpacked) - -# # TODO: Fix "test_aot_dispatch_dynamic" test failure -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -# @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) -# def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): -# test_utils = [ -# "test_schema", -# "test_autograd_registration", -# "test_faketensor", -# ] - -# # TODO: Figure out why test fails unless torch >= 2.5 -# if TORCH_VERSION_AT_LEAST_2_5: -# test_utils.append("test_aot_dispatch_dynamic") - -# t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") -# if TORCH_VERSION_AT_LEAST_2_5: -# t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) -# packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) - -# opcheck( -# torch.ops.torchao.unpack_tensor_core_tiled_layout, -# (packed_w, inner_k_tiles), -# test_utils=test_utils, -# ) - -# def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): -# n, k = q.shape -# assert q.dtype == torch.int - -# n_groups = k // group_size -# assert scales.shape[0] == n and scales.shape[1] == n_groups -# assert scales.shape == zeros.shape - -# midpoint = 2 ** (nbits - 1) - -# #Convert fron u4 -> s4 and upcast to bfloat16 -# q = q.sub(midpoint).to(dtype) - -# # Dequantize -# q = q.reshape(-1, group_size) -# dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) - -# return dq.reshape(n, k) - - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -# @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -# def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): -# n, k = shape -# dtype = torch.bfloat16 - -# device = "cuda" - -# t = torch.randn(n, k, dtype=dtype, device=device) -# scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) - -# # Quantize -# q = groupwise_affine_quantize_tensor_from_qparams( -# t, scales, zeros, n_bit=4, groupsize=group_size -# ) - -# # Pack to tensor core layout -# packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) -# scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) -# q_groups = k // group_size -# assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) - -# # Dequantize 'ao' ref -# dq_ao = groupwise_affine_dequantize_tensor_from_qparams( -# q, scales, zeros, n_bit=4, groupsize=group_size -# ) - -# # Dequantize by passing in an identity matrix as the activation -# a_eye = torch.eye(k, device=device, dtype=dtype) -# dq_id = torch.ops.aten._weight_int4pack_mm( -# a_eye, -# packed, -# group_size, -# scales_and_zeros, -# ).t() - -# # Actual operation to test -# dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) - -# # Compare results -# diff_ao_id = (dq_id - dq_ao).abs().max() -# diff_op_id = (dq_op - dq_id).abs().max() -# diff_op_ao = (dq_op - dq_ao).abs().max() - -# # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` -# # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast -# # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are -# # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. - -# # Test that the `dequant` kernel gives same results as identity matrix-based dequant -# assert diff_op_id == 0 - -# # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix -# assert diff_op_ao == diff_ao_id - -# assert diff_op_ao < 1e-1 - -# # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -# @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -# def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): -# n, k = shape -# dtype = torch.bfloat16 -# device = "cuda" - -# # Quantize and pack -# t = torch.randn(n, k, dtype=dtype, device=device) -# scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) -# q = groupwise_affine_quantize_tensor_from_qparams( -# t, scales, zeros, n_bit=4, groupsize=group_size -# ) - -# packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) -# scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) - -# # Unpack and dequantize -# unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) -# if TORCH_VERSION_AT_LEAST_2_5: -# unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) - -# dq_ao = groupwise_affine_dequantize_tensor_from_qparams( -# unpacked, scales, zeros, n_bit=4, groupsize=group_size -# ) - -# # Dequantize by passing in an identity matrix as the activation -# a_eye = torch.eye(k, device=device, dtype=dtype) -# dq_id = torch.ops.aten._weight_int4pack_mm( -# a_eye, -# packed, -# group_size, -# scales_and_zeros, -# ).t() - -# # Actual operation to test -# dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) - -# # Compare results -# diff_ao_id = (dq_id - dq_ao).abs().max() -# diff_op_id = (dq_op - dq_id).abs().max() -# diff_op_ao = (dq_op - dq_ao).abs().max() - -# # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` -# # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast -# # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are -# # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. - -# # Test that the `dequant` kernel gives same results as identity matrix-based dequant -# assert diff_op_id == 0 - -# # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix -# assert diff_op_ao == diff_ao_id - -# assert diff_op_ao < 1e-1 - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") -# @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -# def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): -# n, k = shape -# device = "cuda" - -# q = torch.randint(0, 16, shape, dtype=torch.int, device=device) -# if TORCH_VERSION_AT_LEAST_2_5: -# q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) -# packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) -# q_groups = k // group_size -# scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) -# zeros = torch.randn_like(scales) -# scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) - -# test_utils = [ -# "test_schema", -# "test_autograd_registration", -# "test_faketensor", -# ] -# # TODO: Figure out why test fails unless torch >= 2.5 -# if TORCH_VERSION_AT_LEAST_2_5: -# test_utils.append("test_aot_dispatch_dynamic") -# opcheck( -# torch.ops.torchao.dequantize_tensor_core_tiled_layout, -# (packed_w, scales_and_zeros, group_size, inner_k_tiles), -# test_utils=test_utils, -# ) - - -# MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -# MARLIN_24_K_CHUNKS = [128] -# MARLIN_24_N_CHUNKS = [512] -# MNK_FACTORS = [ -# (1, 1, 1), -# (1, 4, 8), -# (1, 7, 5), -# (13, 17, 67), -# (26, 37, 13), -# (67, 13, 11), -# ] -# MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -# MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -# MARLIN_TEST_PARAMS = list(itertools.product( -# MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, -# MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS -# )) - -# def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): -# orig_device = w.device -# size_k, size_n = w.shape - -# assert w.is_floating_point(), "w must be float" - -# if group_size == -1: -# group_size = size_k -# assert group_size <= size_k - -# max_q_val = 2**num_bits - 1 -# half_q_val = (max_q_val + 1) // 2 - -# # Reshape to [groupsize, -1] -# if group_size < size_k: -# w = w.reshape((-1, group_size, size_n)) -# w = w.permute(1, 0, 2) -# w = w.reshape((group_size, -1)) - -# # Compute scale for each group -# s = torch.max(torch.abs(w), 0, keepdim=True)[0] -# s *= 2 / max_q_val # 2 => symmetric - -# # Quantize -# q_w = torch.round(w / s).int() -# q_w += half_q_val -# q_w = torch.clamp(q_w, 0, max_q_val) - -# # Compute ref (dequantized) -# w_ref = (q_w - half_q_val).half() * s - -# # Restore original shapes -# if group_size < size_k: - -# def reshape_w(w): -# w = w.reshape((group_size, -1, size_n)) -# w = w.permute(1, 0, 2) -# w = w.reshape((size_k, size_n)).contiguous() -# return w - -# q_w = reshape_w(q_w) -# w_ref = reshape_w(w_ref) - -# s = s.reshape((-1, size_n)).contiguous() - -# return ( -# w_ref.to(device=orig_device), -# q_w.to(device=orig_device), -# s.to(device=orig_device), -# ) - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) -# def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): -# m_factor, n_factor, k_factor = mnk_factors - -# size_m = m_factor -# size_k = k_chunk * k_factor -# size_n = n_chunk * n_factor - -# a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda") -# b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") - -# # Inject 2:4 sparsity -# w_24, _ = inject_24(b_weight, size_k, size_n) - -# # Symmetric quantize -# w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) - -# # Reshape input into 2D tensor -# input_2d = a_input.view(-1, a_input.shape[-1]) -# a_input_in, a_input_out = input_2d.shape - -# # Obtains reference output -# output_ref = torch.matmul(input_2d, w_24_ref) -# output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],)) - -# # Packs to marlin 2:4 -# marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) -# workspace_24 = marlin_24_workspace(size_n) - -# fn_inputs = ( -# input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, -# num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out, -# ) -# output = torchao.ops.marlin_24_gemm(*fn_inputs) -# output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],)) - -# max_diff = compute_max_diff(output, output_ref) -# assert max_diff < 0.04 - -# # Performs opcheck -# test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] -# opcheck( -# torch.ops.torchao.marlin_24_gemm, -# fn_inputs, -# test_utils=test_utils, -# ) - - -# MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -# MARLIN_QQQ_K_CHUNKS = [128] -# MARLIN_QQQ_N_CHUNKS = [64, 128, 256] -# MNK_FACTORS = [ -# (1, 1, 1), -# (1, 4, 8), -# (1, 7, 5), -# (13, 17, 67), -# (26, 37, 13), -# (67, 13, 11), -# ] -# MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -# MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] - -# MARLIN_TEST_PARAMS = list( -# itertools.product( -# MARLIN_QQQ_BATCH_SIZE, -# MARLIN_QQQ_K_CHUNKS, -# MARLIN_QQQ_N_CHUNKS, -# MARLIN_QQQ_SUPPORTED_NUM_BITS, -# MARLIN_QQQ_SUPPORTED_GROUP_SIZES, -# MNK_FACTORS, -# ) -# ) - - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -# @pytest.mark.parametrize( -# "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", -# MARLIN_TEST_PARAMS, -# ids=str, -# ) -# @pytest.mark.skip(reason="test outputs nan after cuda is upgraded to 12.4") -# def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): -# int8_traits = torch.iinfo(torch.int8) -# m_factor, n_factor, k_factor = mnk_factors - -# size_m = m_factor -# size_k = k_chunk * k_factor -# size_n = n_chunk * n_factor - -# a_input = torch.randn( -# (batch_size, size_m, size_k), dtype=torch.float16, device="cuda" -# ) -# b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda") - -# # Reshape input into 2D tensor -# input_2d = a_input.view(-1, a_input.shape[-1]) -# a_input_in, a_input_out = input_2d.shape - -# # Quantize activations -# s_a = ( -# input_2d.abs() -# .max(dim=-1, keepdim=True)[0] -# .div(int8_traits.max) -# .to(torch.float32) -# ) -# q_a = ( -# (input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) -# ) - -# # Quantize weights -# q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( -# b_weight, num_bits, group_size -# ) -# q_w = q_w.t() -# s_group = s_group.t() -# s_channel = s_channel.t() -# w_ref = w_ref.t() -# marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( -# q_w, s_group, s_channel, num_bits, group_size -# ) - -# workspace = marlin_qqq_workspace(size_n) - -# # Obtains reference output -# output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) -# output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,)) - -# fn_inputs = ( -# q_a, -# marlin_qqq_q_w, -# s_a, -# marlin_qqq_s_channel, -# marlin_qqq_s_group, -# workspace, -# a_input_in, -# size_n, -# a_input_out, -# ) -# output = torchao.ops.marlin_qqq_gemm(*fn_inputs) -# output = output.reshape(a_input.shape[:-1] + (size_n,)) - -# max_diff = compute_max_diff(output, output_ref) -# assert max_diff < 0.04 - -# # Performs opcheck -# test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] -# opcheck( -# torch.ops.torchao.marlin_qqq_gemm, -# fn_inputs, -# test_utils=test_utils, -# ) +## Tests for `tensor_core_layout` +kTileSizeN = 8 +kTileSizeK = 16 + +SHAPES = [ + (4096, 4096), + # Llama 2 GEMM shapes + (4096, 11008), + (11008, 4096), + # Llama 3 GEMM shapes + (4096, 14336), + (14336, 4096), +] +INNERKTILES = [2, 4, 8] +QGROUP_SIZES = [32, 64, 128, 256] +TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) +TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) +def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): + N, K = shape + assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 + + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + if TORCH_VERSION_AT_LEAST_2_5: + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) + if TORCH_VERSION_AT_LEAST_2_5: + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + assert torch.equal(t, unpacked) + +# TODO: Fix "test_aot_dispatch_dynamic" test failure +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) +def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AT_LEAST_2_5: + test_utils.append("test_aot_dispatch_dynamic") + + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + if TORCH_VERSION_AT_LEAST_2_5: + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) + + opcheck( + torch.ops.torchao.unpack_tensor_core_tiled_layout, + (packed_w, inner_k_tiles), + test_utils=test_utils, + ) + +def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): + n, k = q.shape + assert q.dtype == torch.int + + n_groups = k // group_size + assert scales.shape[0] == n and scales.shape[1] == n_groups + assert scales.shape == zeros.shape + + midpoint = 2 ** (nbits - 1) + + #Convert fron u4 -> s4 and upcast to bfloat16 + q = q.sub(midpoint).to(dtype) + + # Dequantize + q = q.reshape(-1, group_size) + dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) + + return dq.reshape(n, k) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): + n, k = shape + dtype = torch.bfloat16 + + device = "cuda" + + t = torch.randn(n, k, dtype=dtype, device=device) + scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) + + # Quantize + q = groupwise_affine_quantize_tensor_from_qparams( + t, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Pack to tensor core layout + packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + q_groups = k // group_size + assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) + + # Dequantize 'ao' ref + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( + q, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Dequantize by passing in an identity matrix as the activation + a_eye = torch.eye(k, device=device, dtype=dtype) + dq_id = torch.ops.aten._weight_int4pack_mm( + a_eye, + packed, + group_size, + scales_and_zeros, + ).t() + + # Actual operation to test + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + + # Compare results + diff_ao_id = (dq_id - dq_ao).abs().max() + diff_op_id = (dq_op - dq_id).abs().max() + diff_op_ao = (dq_op - dq_ao).abs().max() + + # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` + # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast + # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are + # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + + # Test that the `dequant` kernel gives same results as identity matrix-based dequant + assert diff_op_id == 0 + + # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix + assert diff_op_ao == diff_ao_id + + assert diff_op_ao < 1e-1 + +# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): + n, k = shape + dtype = torch.bfloat16 + device = "cuda" + + # Quantize and pack + t = torch.randn(n, k, dtype=dtype, device=device) + scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) + q = groupwise_affine_quantize_tensor_from_qparams( + t, scales, zeros, n_bit=4, groupsize=group_size + ) + + packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + + # Unpack and dequantize + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) + if TORCH_VERSION_AT_LEAST_2_5: + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( + unpacked, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Dequantize by passing in an identity matrix as the activation + a_eye = torch.eye(k, device=device, dtype=dtype) + dq_id = torch.ops.aten._weight_int4pack_mm( + a_eye, + packed, + group_size, + scales_and_zeros, + ).t() + + # Actual operation to test + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + + # Compare results + diff_ao_id = (dq_id - dq_ao).abs().max() + diff_op_id = (dq_op - dq_id).abs().max() + diff_op_ao = (dq_op - dq_ao).abs().max() + + # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` + # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast + # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are + # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + + # Test that the `dequant` kernel gives same results as identity matrix-based dequant + assert diff_op_id == 0 + + # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix + assert diff_op_ao == diff_ao_id + + assert diff_op_ao < 1e-1 + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): + n, k = shape + device = "cuda" + + q = torch.randint(0, 16, shape, dtype=torch.int, device=device) + if TORCH_VERSION_AT_LEAST_2_5: + q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) + packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) + q_groups = k // group_size + scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) + zeros = torch.randn_like(scales) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AT_LEAST_2_5: + test_utils.append("test_aot_dispatch_dynamic") + opcheck( + torch.ops.torchao.dequantize_tensor_core_tiled_layout, + (packed_w, scales_and_zeros, group_size, inner_k_tiles), + test_utils=test_utils, + ) + + +MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +MARLIN_24_K_CHUNKS = [128] +MARLIN_24_N_CHUNKS = [512] +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] +MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_TEST_PARAMS = list(itertools.product( + MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, + MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS +)) + +def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + ) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) +def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda") + b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") + + # Inject 2:4 sparsity + w_24, _ = inject_24(b_weight, size_k, size_n) + + # Symmetric quantize + w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) + + # Reshape input into 2D tensor + input_2d = a_input.view(-1, a_input.shape[-1]) + a_input_in, a_input_out = input_2d.shape + + # Obtains reference output + output_ref = torch.matmul(input_2d, w_24_ref) + output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],)) + + # Packs to marlin 2:4 + marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) + workspace_24 = marlin_24_workspace(size_n) + + fn_inputs = ( + input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, + num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out, + ) + output = torchao.ops.marlin_24_gemm(*fn_inputs) + output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],)) + + max_diff = compute_max_diff(output, output_ref) + assert max_diff < 0.04 + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.marlin_24_gemm, + fn_inputs, + test_utils=test_utils, + ) + + +MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +MARLIN_QQQ_K_CHUNKS = [128] +MARLIN_QQQ_N_CHUNKS = [64, 128, 256] +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_TEST_PARAMS = list( + itertools.product( + MARLIN_QQQ_BATCH_SIZE, + MARLIN_QQQ_K_CHUNKS, + MARLIN_QQQ_N_CHUNKS, + MARLIN_QQQ_SUPPORTED_NUM_BITS, + MARLIN_QQQ_SUPPORTED_GROUP_SIZES, + MNK_FACTORS, + ) +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", + MARLIN_TEST_PARAMS, + ids=str, +) +@pytest.mark.skip(reason="test outputs nan after cuda is upgraded to 12.4") +def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): + int8_traits = torch.iinfo(torch.int8) + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + a_input = torch.randn( + (batch_size, size_m, size_k), dtype=torch.float16, device="cuda" + ) + b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda") + + # Reshape input into 2D tensor + input_2d = a_input.view(-1, a_input.shape[-1]) + a_input_in, a_input_out = input_2d.shape + + # Quantize activations + s_a = ( + input_2d.abs() + .max(dim=-1, keepdim=True)[0] + .div(int8_traits.max) + .to(torch.float32) + ) + q_a = ( + (input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) + ) + + # Quantize weights + q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( + b_weight, num_bits, group_size + ) + q_w = q_w.t() + s_group = s_group.t() + s_channel = s_channel.t() + w_ref = w_ref.t() + marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( + q_w, s_group, s_channel, num_bits, group_size + ) + + workspace = marlin_qqq_workspace(size_n) + + # Obtains reference output + output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) + output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,)) + + fn_inputs = ( + q_a, + marlin_qqq_q_w, + s_a, + marlin_qqq_s_channel, + marlin_qqq_s_group, + workspace, + a_input_in, + size_n, + a_input_out, + ) + output = torchao.ops.marlin_qqq_gemm(*fn_inputs) + output = output.reshape(a_input.shape[:-1] + (size_n,)) + + max_diff = compute_max_diff(output, output_ref) + assert max_diff < 0.04 + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.marlin_qqq_gemm, + fn_inputs, + test_utils=test_utils, + ) if __name__ == "__main__": diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 229ef3433e..3357608db5 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -37,11 +37,17 @@ struct is_reduced_floating_point: template constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; +inline double calculate_scale( + const at::Tensor& query, + double scale) { + return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; +} + // out = val * a + b // is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), // take b as a scalar pointer. template -void _scale_attn_mask_fusion_kernel( +inline void _scale_attn_mask_fusion_kernel( T1* a, T2* b, const int& size, @@ -81,7 +87,7 @@ void _scale_attn_mask_fusion_kernel( // 1) out = exp(a - val) // 2) val = sum(out) template -void _exp_reduce_sum_fusion_kernel( +inline void _exp_reduce_sum_fusion_kernel( T1* a, const int& size, T2* out, @@ -115,7 +121,7 @@ void _exp_reduce_sum_fusion_kernel( // 1) out = a * scale // 2) max = max(out) template -void _mul_reduce_max_fusion_kernel( +inline void _mul_reduce_max_fusion_kernel( const scalar_t* a, const scalar_t& scale, const int& size, @@ -137,30 +143,23 @@ void _mul_reduce_max_fusion_kernel( tmp_max = std::max(tmp_max, tmp1); out[i] = tmp1; } - // max = std::max( - // tmp_max, - // at::vec::vec_reduce_all( - // [](vec::Vectorized& x, at::vec::Vectorized& y) { - // return at::vec::maximum(x, y); - // }, - // vec_tmp_max)); max = std::max(tmp_max, vec_tmp_max.reduce_max()); } template -static scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { +static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { TORCH_CHECK(ptr2 == nullptr); return ptr; } template , int> = 0> -static scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { +static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { return ptr2; } template -void fill_stub(scalar_t* data, scalar_t val, int64_t size) { +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { using Vec = at::vec::Vectorized; Vec data_vec = Vec(val); int64_t d = 0; @@ -202,26 +201,26 @@ void reshape_attn_mask_to_4d( // TODO: Use at::native::_store instead when it supports Half. template -void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { +inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { src.store(dst, size); } template -typename std::enable_if_t, void> +inline typename std::enable_if_t, void> _store(scalar_t* dst, at::vec::Vectorized src) { auto res = at::vec::convert_from_float(src, src); res.store(dst, at::vec::Vectorized::size()); } template -typename std::enable_if_t || std::is_same_v, void> +inline typename std::enable_if_t || std::is_same_v, void> _store(scalar_t* dst, at::vec::Vectorized src) { auto res = at::vec::convert(src); res.store(dst, at::vec::Vectorized::size()); } template -void pad_row_zero( +inline void pad_row_zero( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -258,7 +257,7 @@ void pad_row_zero( } template -void pad_row_128_padding( +inline void pad_row_128_padding( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -298,7 +297,7 @@ void pad_row_128_padding( } template -void pad_col_zero( +inline void pad_col_zero( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -322,7 +321,7 @@ void pad_col_zero( } template -void pad_col_zero_padding( +inline void pad_col_zero_padding( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -352,7 +351,7 @@ void pad_col_zero_padding( 3. max reduce for softmax */ template -void _dequant_mask_max_fusion_kernel( +inline void _dequant_mask_max_fusion_kernel( const int32_t* in, const mask_t* mask_ptr, const int32_t* sum_a_ptr, @@ -414,7 +413,7 @@ void _dequant_mask_max_fusion_kernel( 1. dequant 2. max reduce for softmax */ -void _dequant_max_fusion_kernel( +inline void _dequant_max_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -469,7 +468,7 @@ void _dequant_max_fusion_kernel( 3. sum for attention */ template -void _sub_exp_sum_div_quant_sum_fusion_kernel( +inline void _sub_exp_sum_div_quant_sum_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -572,7 +571,7 @@ void _sub_exp_sum_div_quant_sum_fusion_kernel( } template -void _sub_exp_sum_div_quant_fusion_kernel( +inline void _sub_exp_sum_div_quant_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -669,7 +668,7 @@ void _sub_exp_sum_div_quant_fusion_kernel( 2. quant */ template -void _dequant_quant_fusion_kernel( +inline void _dequant_quant_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -727,7 +726,7 @@ void _dequant_quant_fusion_kernel( } template -void _int_sum_b_contiguous_kernel_helper( +inline void _int_sum_b_contiguous_kernel_helper( const scalar_t* in, int32_t* out, const int& N, @@ -742,14 +741,13 @@ void _int_sum_b_contiguous_kernel_helper( } tmp_sum += vec_tmp_sum.reduce_add(); for (long i = vec_size * (N / vec_size); i < N; i++) { - // for (long i = 0; i < N; i++) { tmp_sum += static_cast(in[i]); } out[0] = tmp_sum * scale; } template -void _int_sum_b_contiguous_kernel( +inline void _int_sum_b_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -762,7 +760,7 @@ void _int_sum_b_contiguous_kernel( } template -void _int_sum_a_contiguous_kernel( +inline void _int_sum_a_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -791,7 +789,6 @@ void _int_sum_a_contiguous_kernel( _store(out + i, tmp3); } for (long i = vec_size * (M / vec_size); i < M; i++) { - // for (long i = 0; i < M; i++) { auto tmp0 = tmp_in[i]; auto tmp1 = out[i]; auto tmp2 = static_cast(tmp0); @@ -812,7 +809,7 @@ void _int_sum_a_contiguous_kernel( } } -void do_convert_u8_s8( +inline void do_convert_u8_s8( unsigned char* src, signed char* dst, int64_t in_rows, @@ -832,7 +829,6 @@ void do_convert_u8_s8( _store(tmp_dst + c, tmp3, vec_size); } for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { - // for (int64_t c = 0; c < in_cols; c++) { auto tmp0 = tmp_src[c]; auto tmp1 = (int16_t) tmp0; auto tmp2 = tmp1 - 128; @@ -843,7 +839,7 @@ void do_convert_u8_s8( } template -void do_transpose( +inline void do_transpose( scalar_t* src, scalar_t* dst, int64_t in_rows, @@ -858,7 +854,7 @@ void do_transpose( } template -void do_copy( +inline void do_copy( scalar_t* src, scalar_t* dst, int64_t in_rows, @@ -873,7 +869,7 @@ void do_copy( } template -void pad_remain_row_col( +inline void pad_remain_row_col( scalar_t* value_ptr, int rows, int cols, @@ -911,7 +907,7 @@ void pad_remain_row_col( } template -void copy_value_with_pad( +inline void copy_value_with_pad( scalar_t* value_ptr, scalar_t* dst_ptr, int rows, @@ -964,64 +960,9 @@ void copy_value_with_pad( } -// thread_local std::unordered_map< -// BrgemmKey, -// std::shared_ptr> cache_brgemm_kernels; - -// thread_local std::unordered_map< -// PackBKey, -// std::shared_ptr> cache_packb_kernels; - -// std::shared_ptr create_or_get_microkernel( -// int64_t M, -// int64_t N, -// int64_t K, -// int64_t batch_size, -// int lda, -// int ldb, -// int ldc, -// dt dt_a, -// dt dt_b, -// dt dt_c) { -// BrgemmKey key_brgemm(M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c); -// auto search = cache_brgemm_kernels.find(key_brgemm); -// if (search != cache_brgemm_kernels.end()) { -// return search->second; -// } else { -// cache_brgemm_kernels.insert( -// {key_brgemm, -// std::make_shared( -// M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c)}); -// return cache_brgemm_kernels[key_brgemm]; -// } -// } - -// std::shared_ptr create_or_get_packb_microkernel( -// int64_t K, -// int64_t N, -// int ld_in, -// int ld_out, -// dt dt_in, -// dt dt_out, -// bool do_trans) { -// PackBKey key_packb(K, N, ld_in, ld_out, dt_in, dt_out); -// auto search = cache_packb_kernels.find(key_packb); -// if (search != cache_packb_kernels.end()) { -// return search->second; -// } else { -// cache_packb_kernels.insert( -// {key_packb, -// std::make_shared( -// K, N, -// do_trans ? dnnl::ukernel::pack_type::trans : dnnl::ukernel::pack_type::no_trans, -// ld_in, ld_out, dt_in, dt_out)}); -// return cache_packb_kernels[key_packb]; -// } -// } - // UINT8 - u8u8s32 template -typename std::enable_if_t, void> +inline typename std::enable_if_t, void> sdpa_int8_kernel_impl( const at::Tensor& output, const at::Tensor& q, @@ -1041,10 +982,6 @@ sdpa_int8_kernel_impl( float a_scale, int32_t o_zp, float o_scale) { - // using dt = dnnl::memory::data_type; - // using namespace dnnl; - // using namespace dnnl::ukernel; - // auto starts = duration_cast(system_clock::now().time_since_epoch()).count(); // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -1055,15 +992,11 @@ sdpa_int8_kernel_impl( at::Tensor key = k.transpose(1, 2); at::Tensor value = v.transpose(1, 2); - const auto accumulate_dtype = at::kFloat; // at::toOpMathType(dtype); + const auto accumulate_dtype = at::kFloat; - using accum_t = float; // at::opmath_type; + using accum_t = float; using Vec = at::vec::Vectorized; - accum_t scaling_factor = - sdp::calculate_scale(query, scale).as_float_unchecked(); - // if (attention_mask.defined() && attention_mask.scalar_type() != ScalarType::Float) { - // attention_mask = attention_mask.to(at::kFloat); - // } + accum_t scaling_factor = calculate_scale(query, scale); int block_64 = 64; // Sizes TORCH_CHECK( @@ -1150,11 +1083,6 @@ sdpa_int8_kernel_impl( int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; - - // dt u8_dt = dt::u8; - // dt s8_dt = dt::s8; - // // dt f32_dt = dt::f32; - // dt s32_dt = dt::s32; auto u8_dt = at::ScalarType::Byte; auto s8_dt = at::ScalarType::Int; auto f32_dt = at::ScalarType::Float; @@ -1174,119 +1102,14 @@ sdpa_int8_kernel_impl( int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; int qk_gemm_K = headSize + qk_gemm_K_padding; - // auto && qk_gemm = create_or_get_microkernel( - // qSplitSize, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // block_64, //ldb - // rndkvSplitSize, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // (*qk_gemm).finalize(); - // (*qk_gemm).generate(); - // size_t qk_scratchpad_size = (*qk_gemm).get_scratchpad_size(); - - // auto && qk_gemm_ktail = create_or_get_microkernel( - // qSplitSize, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // block_64, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // // size_t qk_ktail_scratchpad_size = (*qk_gemm_ktail).get_scratchpad_size(); - // (*qk_gemm_ktail).finalize(); - // (*qk_gemm_ktail).generate(); - - // std::shared_ptr qk_gemm_ktail_tail; - // if (kvTail % block_64 != 0) { - // qk_gemm_ktail_tail = create_or_get_microkernel( - // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // kv_tail_tail_block_size, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // (*qk_gemm_ktail_tail).finalize(); - // (*qk_gemm_ktail_tail).generate(); - // } - - // auto && qk_gemm_qtail = create_or_get_microkernel( - // qTail, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda - // block_64, //ldb - // rndkvSplitSize, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // // size_t qk_qtail_scratchpad_size = (*qk_gemm_qtail).get_scratchpad_size(); - // (*qk_gemm_qtail).finalize(); - // (*qk_gemm_qtail).generate(); - // auto && qk_gemm_qktail = create_or_get_microkernel( - // qTail, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // block_64, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // // size_t qk_qktail_scratchpad_size = (*qk_gemm_qktail).get_scratchpad_size(); - // (*qk_gemm_qktail).finalize(); - // (*qk_gemm_qktail).generate(); - - // std::shared_ptr qk_gemm_qktail_tail; - // if (kvTail % block_64 != 0) { - // qk_gemm_qktail_tail = create_or_get_microkernel( - // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // kv_tail_tail_block_size, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // (*qk_gemm_qktail_tail).finalize(); - // (*qk_gemm_qktail_tail).generate(); - // } - - // std::vector> A_B_offsets(1); std::vector> A_B_offsets(1); A_B_offsets[0] = std::make_pair(0, 0); - // std::vector> A_B_offsets_batch(kvSlice); std::vector> A_B_offsets_batch(kvSlice); for (auto s=0; s(); int64_t kv_sum_size_per_BH = @@ -1313,9 +1133,8 @@ sdpa_int8_kernel_impl( at::Tensor kv_sum_buf = at::empty( {batchSize, num_head, kv_sum_size_per_BH}, - query.options().dtype(at::kInt)); - int32_t* k_sum_buf_data = kv_sum_buf.data_ptr(); - int32_t* v_sum_buf_data = k_sum_buf_data + batchSize * num_head * kvSize; + query.options().dtype(at::kInt)).zero_(); + int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); int64_t kv_reorder_size_per_BH = /* key_t_reorder */ qk_gemm_K * rndkvSize + @@ -1323,183 +1142,74 @@ sdpa_int8_kernel_impl( at::Tensor kv_reorder_buf = at::empty( {batchSize, num_head, kv_reorder_size_per_BH}, - query.options()); + query.options()).zero_(); scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); scalar_t* key_reorder_ptr = kv_reorder_buf_data; scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; -// // Create transforms for Key -// auto && brgemm_k_xform = create_or_get_packb_microkernel( -// qk_gemm_K, // K -// block_64, // N -// block_64, // kStrideN, // block_64, // ld_in -// block_64, // ld_out -// u8_dt, // dt_in -// u8_dt, // dt_out -// false // true -// ); -// (*brgemm_k_xform).generate(); -// auto && brgemm_k_xform_tail = create_or_get_packb_microkernel( -// qk_gemm_K, -// block_64, -// block_64, // kStrideN, // block_64, -// block_64, -// u8_dt, -// u8_dt, -// false // true -// ); -// (*brgemm_k_xform_tail).generate(); -// std::shared_ptr brgemm_k_xform_tail_tail; -// if (kvTail % block_64 != 0) { -// brgemm_k_xform_tail_tail = create_or_get_packb_microkernel( -// qk_gemm_K, -// kv_tail_tail_block_size, -// kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, -// kv_tail_tail_block_size, -// u8_dt, -// u8_dt, -// false // true -// ); -// (*brgemm_k_xform_tail_tail).generate(); -// } - -// // Create transforms for Value -// auto && brgemm_v_xform = create_or_get_packb_microkernel( -// av_gemm_K, -// block_64, -// vStrideN, // block_64, -// block_64, -// u8_dt, -// u8_dt, -// false -// ); -// (*brgemm_v_xform).generate(); -// auto && brgemm_v_xform_tail = create_or_get_packb_microkernel( -// av_gemm_K_tail, -// block_64, -// vStrideN, // block_64, -// block_64, -// u8_dt, -// u8_dt, -// false -// ); -// (*brgemm_v_xform_tail).generate(); - - // sum k - if (q_zp != 0) { - at::parallel_for( - 0, batchSize * num_head * kvSize, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, k = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head, k, kvSize); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - int32_t* k_sum_ptr = k_sum_buf_data - + i * num_head * kvSize - + j * kvSize + k; - _int_sum_b_contiguous_kernel_helper( - k_data + i * kStrideB + j * kStrideH + k * kStrideN, - k_sum_ptr, - headSize, q_zp); - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, k, kvSize); - } - }); - } - - // sum v - if (a_zp != 0) { - at::parallel_for( - 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - int32_t* v_sum_ptr = v_sum_buf_data - + i * num_head * headSize - + j * headSize; - _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, - v_sum_ptr, - headSize, kvSize, vStrideN, a_zp); - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); - } - }); - } - + // sum k and v at::parallel_for( 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, j = 0; at::native::data_index_init( begin, i, batchSize, j, num_head); - int ompIdx = at::get_thread_num(); - scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; - int32_t offset = 0; - accum_t* qk_data = reinterpret_cast(total_buf_ptr); - offset += kvSlice * qSplitSize * rndkvSplitSize * 4; - accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; - scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qSplitSize * av_gemm_K; - int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndkvSplitSize * 4; - int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndHeadSize * 4; - accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * qk_gemm_K; - // scalar_t* scratchpad_gemm = reinterpret_cast(total_buf_ptr + offset); - // offset += scratchpad_size; - - scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qk_gemm_K * rndkvSize; - scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - - uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; - for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); - // pack - for (int64_t n = 0; n < kvSize; n += kvSplitSize) { - // long ss, ee; - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - bool tail = kvSplitSize - b < block_64; - do_transpose( - // do_copy( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - tail ? kvSplitSize - b : block_64, - headSize, - kStrideN, - block_64); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_64, - qk_gemm_K, - block_64, - block_64 - ); - } - // Pack - // (*brgemm_k_xform).execute( - // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - // B_blocked_xform_u8, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K - // ); - at::native::cpublas::pack( + // packing + at::parallel_for( + 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, l = 0, n = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, l, kvSlice); + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + n = l * kvSplitSize; + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + bool tail = kvSplitSize - b < block_64; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvSplitSize - b : block_64, + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + at::native::cpublas::pack( qk_gemm_K, // K block_64, // N block_64, // ld_in @@ -1507,18 +1217,14 @@ sdpa_int8_kernel_impl( u8_dt, // dt_in u8_dt, // dt_out B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - // (*brgemm_v_xform).execute( - // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - // // B_blocked_xform_u8, - // value_reorder_ptr + n * rndHeadSize + - // av_gemm_K * b); - at::native::cpublas::pack( + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( av_gemm_K, block_64, vStrideN, // block_64, @@ -1526,80 +1232,67 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } else { - // tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < rndkvTail) { - bool tail = kvTail - b < block_size; - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - tail ? kvTail - b : block_size, - headSize, - kStrideN, - block_size); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_size, - qk_gemm_K, - block_size, - block_size - ); - } - // Pack - if (block_size == block_64) { - // (*brgemm_k_xform_tail).execute( - // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - // B_blocked_xform_u8, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K - // ); - at::native::cpublas::pack( + value_reorder_ptr + + i * num_head * kvSlice * av_gemm_K * rndHeadSize + + j * kvSlice * av_gemm_K * rndHeadSize + n * rndHeadSize + + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + bool tail = kvTail - b < block_size; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvTail - b : block_size, + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + if (block_size == block_64) { + at::native::cpublas::pack( qk_gemm_K, block_64, - block_64, // kStrideN, // block_64, + block_64, block_64, u8_dt, u8_dt, B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } else { - // (*brgemm_k_xform_tail_tail).execute( - // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - // B_blocked_xform_u8, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K - // ); - at::native::cpublas::pack( + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K); + } else { + at::native::cpublas::pack( qk_gemm_K, kv_tail_tail_block_size, - kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, + kv_tail_tail_block_size, kv_tail_tail_block_size, u8_dt, u8_dt, B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - // split headSize to block_64, block_64, block_64 ... - // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] - for (int64_t b = 0; b < headSize; b += block_64) { - // (*brgemm_v_xform).execute( - // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - // // B_blocked_xform_u8, - // value_reorder_ptr + n * rndHeadSize + - // av_gemm_K * b); - at::native::cpublas::pack( + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::pack( av_gemm_K, block_64, vStrideN, // block_64, @@ -1607,63 +1300,93 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } + value_reorder_ptr + + i * num_head * kvSlice * av_gemm_K * rndHeadSize + + j * kvSlice * av_gemm_K * rndHeadSize + n * rndHeadSize + + av_gemm_K * b); } + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + + at::parallel_for( + 0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qSplitSize * av_gemm_K; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; // sdpa core - int32_t* k_sum_ptr = k_sum_buf_data + i * num_head * kvSize + j * kvSize; - int32_t* v_sum_ptr = v_sum_buf_data + i * num_head * headSize + j * headSize; - for (int64_t k = 0; k < qSlice; k++) { - int64_t m = k * qSplitSize; - int64_t qBlockSize = std::min(qSplitSize, qSize - m); - // Initialize sum and max - fill_stub( - sfm_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - a_sum_ptr, static_cast(0), qSplitSize); + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { fill_stub( - sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - int64_t num_keys = - is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; - copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); - - if (k_zp == 0) { - _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, - q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); - } else { - fill_stub( - q_sum_ptr, static_cast(0), qSplitSize); - } - const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; - for (int64_t l = 0; l < rkvSlice; l++) { - int64_t n = l * kvSplitSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - // Calculate sums for dequant compensation item - if (qBlockSize == qSplitSize) { - // q main - if (n + kvSplitSize < kvSize) { - // k main - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm).set_hw_context(); - // (*qk_gemm).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1671,26 +1394,18 @@ sdpa_int8_kernel_impl( rndkvSplitSize, //ldc, false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - } else { - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - if (block_size == block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_ktail).set_hw_context(); - // (*qk_gemm_ktail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1698,21 +1413,13 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } else { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_ktail_tail).set_hw_context(); - // (*qk_gemm_ktail_tail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } else { + at::native::cpublas::brgemm( qSplitSize, kv_tail_tail_block_size, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1720,56 +1427,40 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } - } else { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_qtail).set_hw_context(); - // (*qk_gemm_qtail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( qTail, block_64, qk_gemm_K, 1, //batch_size - qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda + qk_gemm_K,// lda block_64, //ldb rndkvSplitSize, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - } else { - // k tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - if (block_size == block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_qktail).set_hw_context(); - // (*qk_gemm_qktail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + at::native::cpublas::brgemm( qTail, block_64, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1777,21 +1468,13 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } else { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_qktail_tail).set_hw_context(); - // (*qk_gemm_qktail_tail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } else { + at::native::cpublas::brgemm( qSplitSize, kv_tail_tail_block_size, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1799,108 +1482,99 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } } - - // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; - accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; - if (has_attn_mask) { - mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - _dequant_mask_max_fusion_kernel( - qk_s32_data, //in - mask_data_offset, //mask_ptr - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - mStrideM, //ldm - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } else { - _dequant_max_fusion_kernel( - qk_s32_data, //in - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } } - // sub max, exp, sum reduce, div sum for softmax - // and quant - // and sum for attention - if (v_zp == 0) { - _sub_exp_sum_div_quant_fusion_kernel( - qk_data, //in + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlices - qSplitSize * rndkvSplitSize, //ldi - qSplitSize * av_gemm_K, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr //sfm_sum_ptr + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr ); } else { - _sub_exp_sum_div_quant_sum_fusion_kernel( - qk_data, //in + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlice - qSplitSize * rndkvSplitSize, //ldi - qSplitSize * av_gemm_K, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - v_zp, // zp_b=beta2 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr, //sfm_sum_ptr - a_sum_ptr //a_sum_ptr - ); + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); } - - // Calculate Softmax(q @ k.T) @ v - for (int64_t b = 0; b < headSize; b += block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*av_gemm_batch).set_hw_context(); - // (*av_gemm_batch).execute( - // qk_reduced_data, - // value_reorder_ptr + b * av_gemm_K, - // A_B_offsets_batch, - // dst_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::brgemm( qSplitSize, block_64, av_gemm_K, kvSlice, //batch_size av_gemm_K, // lda @@ -1908,33 +1582,33 @@ sdpa_int8_kernel_impl( rndHeadSize, //ldc false, qk_reduced_data, - value_reorder_ptr + b * av_gemm_K, + value_reorder_ptr + + i * num_head * kvSlice * av_gemm_K * rndHeadSize + + j * kvSlice * av_gemm_K * rndHeadSize + b * av_gemm_K, dst_s32_data + b, A_B_offsets_batch); - } - - // After the last gemm, - // do dequant compensation, quant and convert from s32 to int8 - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); } }); // Once all computations are done, need to release HW context. - // brgemm::release_hw_context(); at::native::cpublas::brgemm_release(); } @@ -2040,60 +1714,6 @@ void sdpa_int8_kernel( } } -// at::Tensor sdpa_int8_math_impl( -// const at::Tensor& query_, -// const at::Tensor& key, -// const at::Tensor& value, -// double dropout_p, -// bool is_causal, -// at::Tensor& attn_mask_, -// double scale, -// int32_t q_zp, -// float q_scale, -// int32_t k_zp, -// float k_scale, -// int32_t v_zp, -// float v_scale, -// int32_t a_zp, -// float a_scale, -// int32_t o_zp, -// float o_scale) { -// // dequant q/k/v -// auto q = (query_.to(at::kFloat) - q_zp) * q_scale; -// auto k = (key.to(at::kFloat) - k_zp) * k_scale; -// auto v = (value.to(at::kFloat) - v_zp) * v_scale; -// auto attn_mask = attn_mask_; -// if (attn_mask.defined()) { -// *attn_mask = (*attn_mask).to(at::kFloat); -// } -// // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math -// bool is_negative_scaling = scale.defined() && scale < 0.0; -// const auto scaling_factor = sdp::calculate_scale(q, is_negative_scaling ? std::abs(scale) : scale).sqrt(); -// q = q * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor); -// auto attn = at::matmul(q, k.transpose(-2, -1) * scaling_factor); -// if (attn_mask.defined()) { -// if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { -// attn = attn.add(*attn_mask); -// } else { -// attn.add_(*attn_mask); -// } -// } -// attn = at::softmax(attn, -1); -// // quant attn -// attn = at::clamp_max( -// at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 -// ); -// // dequant attn -// attn = (attn - a_zp) * a_scale; -// auto output = at::matmul(attn, v); -// // quant output -// output = at::clamp_max( -// at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 -// ).to(at::kByte); -// return output; -// } - - at::Tensor _scaled_dot_product_int8_cpu( const at::Tensor& query, const at::Tensor& key, @@ -2135,16 +1755,6 @@ at::Tensor _scaled_dot_product_int8_cpu( (attn_mask.dim() == 2 || attn_mask.dim() == 4), "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); - // fallback math path - // at::Tensor output = sdpa_int8_math_impl(query, key, value, - // dropout_p, is_causal, attn_mask, scale, - // q_zp, q_scale, - // k_zp, k_scale, - // v_zp, v_scale, - // a_zp, a_scale, - // o_zp, o_scale); - - // fused sdpa int8 impl at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); sdpa_int8_kernel(output, query, key, value, dropout_p, is_causal, attn_mask, scale, diff --git a/torchao/ops.py b/torchao/ops.py index 4be4064aa2..2166240222 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -10,7 +10,7 @@ lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") lib.define("marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor") -lib.define("scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=1.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor") +lib.define("scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor") def register_custom_op(name): @@ -79,7 +79,7 @@ def scaled_dot_product_int8( attn_mask: Tensor = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: float = 1.0, + scale: float = 0.0, q_zp: int = 0, q_scale: float = 1.0, k_zp: int = 0, @@ -108,7 +108,7 @@ def _( attn_mask: Tensor = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: float = 1.0, + scale: float = 0.0, q_zp: int = 0, q_scale: float = 1.0, k_zp: int = 0,