Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] add INT8 SDPA path for CPU #1372

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ def get_extensions():
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
# ## AVX2
# "-DCPU_CAPABILITY=AVX2",
# "-DCPU_CAPABILITY_AVX2",
# "-mavx2",
# "-mfma",
# "-mf16c",
## AVX512
"-DCPU_CAPABILITY=AVX512",
"-DCPU_CAPABILITY_AVX512",
"-mavx512f",
"-mavx512bw",
"-mavx512vl",
"-mavx512dq",
"-mfma",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down
199 changes: 199 additions & 0 deletions test/quantization/test_sfdp_int8_fx_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import torchao

import contextlib
import functools
import itertools
import math

import torch
import torch.utils.checkpoint
from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA

import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.export import export_for_training
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
X86InductorQuantizer,
)
from torchao.quantization.sfdp_int8_fx_pass import _sfdp_init_int8

class SelfAttnLikeModule(torch.nn.Module):
def __init__(
self,
input_dim,
has_mask,
num_attention_heads=None,
attention_head_size=None,
) -> None:
super().__init__()
self.input_dim = input_dim
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
self.softmax = torch.nn.Softmax(dim=-1)
assert num_attention_heads is not None
assert attention_head_size is not None
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
self.dropout = torch.nn.Dropout(0)
self.has_mask = has_mask

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute([0, 2, 1, 3])

def forward(self, x, mask):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
if self.has_mask:
scores = scores + mask
attention = self.softmax(scores)
attention = self.dropout(attention)
context_layer = torch.matmul(attention, v)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(
context_layer.size()[:-2] + (self.all_head_size,)
)
return self.dense(context_layer)

def _generate_qdq_quantized_model(mod, inputs, quantizer):
with torch.no_grad():
export_model = export_for_training(mod, inputs).module()
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
return convert_model

class TestSDPAPatternRewriterTemplate(TestCase):
def _clone_inputs(self, inputs):
def clone(x):
if not isinstance(x, torch.Tensor):
return x
return x.clone()

return [clone(x) for x in inputs]

def _check_common(
self,
dot_prod_attention,
args1=None,
contains=True,
atol=1e-5,
has_fuse_pattern=True,
has_dropout=False,
check_train=True,
override_check_equal=False,
dtype=torch.float,
rtol=1.3e-6,
):
if args1 is None:
tensor_shape = (4, 2, 16, 32)
args1 = [
torch.randn(tensor_shape, device=self.device, dtype=dtype),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
]
else:
args1 = list(args1)
args2 = self._clone_inputs(args1)

for training in [False, True] if check_train else [False]:
for x in itertools.chain(args1[:], args2[:]):
if isinstance(x, torch.Tensor) and x.is_floating_point():
x.requires_grad = training

dropout_arg = [training] if has_dropout else []
torch.manual_seed(1234)
result1 = dot_prod_attention(*(args1 + dropout_arg))

counters.clear()
torch.manual_seed(1234)
result2, source_code = run_and_get_code(
torch.compile(dot_prod_attention, fullgraph=True),
*(args2 + dropout_arg),
)
source_code = "\n".join(source_code)
if has_fuse_pattern:
self.assertGreaterEqual(counters["inductor"]["fuse_attention_int8"], 1)
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
"torchao.scaled_dot_product_int8",
source_code,
)

# some tests configured with very low dropout where we still want to check equality
if not has_dropout or override_check_equal:
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)

if training:
result1.sum().backward()
result2.sum().backward()
for arg1, arg2 in zip(args1, args2):
if (
isinstance(arg1, torch.Tensor)
and arg1.is_floating_point()
and (not has_dropout or override_check_equal)
):
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)

@skipIfRocm
@config.patch({"freezing": True})
def _test_sdpa_rewriter_int8_1_to_4(self):
# pattern is different for bs=1
for dtype, has_mask, bs in itertools.product(
[torch.float32], [True, False], [56, 1]
):
mod = SelfAttnLikeModule(
input_dim=64 * 16,
has_mask=has_mask,
num_attention_heads=16,
attention_head_size=64,
).eval()
maybe_autocast = (
torch.cpu.amp.autocast()
if dtype == torch.bfloat16
else contextlib.nullcontext()
)
inputs = (
torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype),
torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None,
)
with torch.no_grad(), maybe_autocast:
_sfdp_init_int8()
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
quantizer.set_function_type_qconfig(
torch.matmul, quantizer.get_global_quantization_config()
)
convert_model = _generate_qdq_quantized_model(mod, inputs, quantizer)
self._check_common(
convert_model, args1=inputs, check_train=False, atol=1.0
)

if HAS_CPU:
class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
device = "cpu"
test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4

if __name__ == "__main__":
if IS_LINUX:
run_tests()
115 changes: 115 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
import pytest
import math

if is_fbcode():
pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels")
Expand All @@ -38,6 +39,120 @@


class TestOps(TestCase):
def _scaled_dot_product_int8_op_ref(
self,
q,
k,
v,
attn_mask=None,
dropout_p=0,
is_causal=False,
q_zp=0,
q_scale=1.0,
k_zp=0,
k_scale=1.0,
v_zp=0,
v_scale=1.0,
a_zp=0,
a_scale=1.0,
o_zp=0,
o_scale=1.0):
q = (q.to(torch.float) - q_zp) * q_scale
k = (k.to(torch.float) - k_zp) * k_scale
v = (v.to(torch.float) - v_zp) * v_scale
scale_factor = 1 / math.sqrt(q.size(-1))
attn = q @ k.transpose(-2, -1)
attn = attn * scale_factor
if attn_mask is not None:
attn = attn + attn_mask.to(torch.float)
attn_max = attn.max(dim=-1, keepdim=True).values
attn = attn - attn_max
attn = torch.exp(attn)
attn_sum = torch.sum(attn, dim=-1, keepdim=True)
attn = attn / attn_sum
attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255)
attn = (attn - a_zp) * a_scale
out = attn @ v
out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255)
return out.to(torch.uint8)

SDPA_INT8_BATCH_SIZE = [56, 120]
SDPA_INT8_NUM_HEADS = [2, 16]
SDPA_INT8_Q_SEQ_LEN = [18, 89]
SDPA_INT8_KV_SEQ_LEN = [100, 253]
SDPA_INT8_HEAD_DIM = [32, 64]
SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16]

@parametrize("batch_size", SDPA_INT8_BATCH_SIZE)
@parametrize("n_head", SDPA_INT8_NUM_HEADS)
@parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN)
@parametrize("kv_seq_len", SDPA_INT8_KV_SEQ_LEN)
@parametrize("head_dim", SDPA_INT8_HEAD_DIM)
@parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE)
def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype):
torch.manual_seed(1234)
device = "cpu"
q_zp = int(127)
q_scale = float(1.7907238006591797)
k_zp = int(125)
k_scale = float(1.8039721250534058)
v_zp = int(127)
v_scale = float(1.839004635810852)
a_zp = int(120)
a_scale = float(0.003919653594493866)
o_zp = int(128)
o_scale = float(1.8191684484481812)
q_shape = [batch_size, q_seq_len, n_head, head_dim]
kv_shape = [batch_size, kv_seq_len, n_head, head_dim]
mask_shape = [batch_size, 1, 1, kv_seq_len]
q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100
k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100
v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100
q = q.to(torch.uint8)
k = k.to(torch.uint8)
v = v.to(torch.uint8)
attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype is not None else None
q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() if mask_dtype is not None else None

math_ref = self._scaled_dot_product_int8_op_ref(
q2,
k2,
v2,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
q_zp=q_zp,
q_scale=q_scale,
k_zp=k_zp,
k_scale=k_scale,
v_zp=v_zp,
v_scale=v_scale,
a_zp=a_zp,
a_scale=a_scale,
o_zp=o_zp,
o_scale=o_scale
)
actual = torch.ops.torchao.scaled_dot_product_int8(
q,
k,
v,
attn_mask=attn_mask_2,
dropout_p=0.0,
is_causal=False,
q_zp=q_zp,
q_scale=q_scale,
k_zp=k_zp,
k_scale=k_scale,
v_zp=v_zp,
v_scale=v_scale,
a_zp=a_zp,
a_scale=a_scale,
o_zp=o_zp,
o_scale=o_scale
)

self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6)

def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
Expand Down
Loading