Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable float8 attention support (q/k/v) #1382

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,50 @@ def test_fp8_weight_dimension_warning(self):
)


@unittest.skip(
"Only running locally so we don't need to add installation of fa3 "
"hopper kernels to CI, we'll probably copy paste kernel in the future"
)
class TestAffineQuantizedFloat8Attention(common_utils.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_float8_attention(self):
import torch.nn.functional as F

from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant

class MyModel(torch.nn.Module):
def forward(self, q, k, v, float8_quantize=False):
if float8_quantize:
# F.scaled_dot_product_attention is using (batch_size, nheads, seqlen, headdim)
# while flash attention kernel has (batch_size, seqlen, nheads, headdim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = _float8_symmetric_per_tensor_quant(q)
k = _float8_symmetric_per_tensor_quant(k)
v = _float8_symmetric_per_tensor_quant(v)

return F.scaled_dot_product_attention(q, k, v)

# note: last dim headdim must be 64, 128 or 256
q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda")
k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda")
v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda")

m = MyModel().eval()

# bfloat16 ref result
ref = m(q, k, v)

# float8 quantized result
quantized = m(q, k, v, True)

sqnr = compute_error(ref, quantized)
assert sqnr > 25.0, f"Got sqnr: {sqnr}"


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

if __name__ == "__main__":
pytest.main([__file__])
common_utils.run_tests()
40 changes: 28 additions & 12 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torch.nn import functional as F
from torchao.utils import find_multiple

_QUANTIZE_ATTN = False

# TODO remove suplerfluous arg
def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
Expand Down Expand Up @@ -85,7 +87,7 @@ def from_name(cls, name: str):
),
}

# this is a model specific variable that controls whether index_put is used for the kv_cache update,
# this is a model specific variable that controls whether index_put is used for the kv_cache update,
# it is needed for GPTQ but otherwise attenuates perf so the default is to not use it
use_index_put_for_kv_cache = False

Expand Down Expand Up @@ -124,7 +126,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtyp
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))

def update(self, input_pos, k_val, v_val):
# quantize current k_val and store it in the cache
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
Expand All @@ -138,7 +140,7 @@ def update(self, input_pos, k_val, v_val):
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
v_out = self.v_cache*self.v_cache_scale
v_out[:, :, input_pos] = v_val

return k_out, v_out

@classmethod
Expand Down Expand Up @@ -194,16 +196,16 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
else:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
self.freqs_cis = precompute_freqs_cis(
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
dtype,
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
dtype,
use_scaled=self.config.use_scaled_rope
)

def reset_caches(self):
"""Reset caches.

The caches used by training stage and inference stage may be different, reset them before switching.
"""
self.max_batch_size = -1
Expand All @@ -215,7 +217,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
"""Forward pass of the model.

Args:
idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
Indices of input sequence tokens in the vocabulary.
input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
Expand All @@ -227,7 +229,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
"""
assert self.freqs_cis is not None, "Caches must be initialized first"

if input_pos is None:
if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
else:
Expand Down Expand Up @@ -311,11 +313,25 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_po

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)

# quantize q/k/v with per tensor float8 quantization
padded = False
if _QUANTIZE_ATTN:
from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant
original_dtype = v.dtype
if q.shape[-1] in [64, 128, 256]:
q = _float8_symmetric_per_tensor_quant(q)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't see this, maybe we can add it after spinquant is integrated

Copy link

@bhack bhack Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api

https://pytorch.org/blog/hadacore/

k = _float8_symmetric_per_tensor_quant(k)
v = _float8_symmetric_per_tensor_quant(v)

if mask is not None:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
else:
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)

if _QUANTIZE_ATTN:
y = y.to(original_dtype)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
Expand Down Expand Up @@ -371,8 +387,8 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

def precompute_freqs_cis(
seq_len: int,
n_elem: int,
seq_len: int,
n_elem: int,
base: int = 10000,
dtype: torch.dtype = torch.bfloat16,
use_scaled: bool=False
Expand Down
28 changes: 28 additions & 0 deletions torchao/_models/sam2/modeling/sam/transformer.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the current SAM2 readme with all the ao optimizations we have introduced?

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False

# whether to turn on float8 quantization for sdpa or not
_QUANTIZE_ATTN = False


def sdp_kernel_context(dropout_p):
"""
Expand Down Expand Up @@ -263,6 +266,28 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)

# quantize q/k/v with per tensor float8 quantization
padded = False
if _QUANTIZE_ATTN:
from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant
original_head_dim = list(q.shape)[-1]
original_dtype = v.dtype
padded = False
# padding:
if q.shape[-1] == 32:
q = F.pad(q, (0, 32))
k = F.pad(k, (0, 32))
v = F.pad(v, (0, 32))
padded = True

if q.shape[-1] in [64, 128, 256]:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = _float8_symmetric_per_tensor_quant(q)
k = _float8_symmetric_per_tensor_quant(k)
v = _float8_symmetric_per_tensor_quant(v)

dropout_p = self.dropout_p if self.training else 0.0
# # Attention
# try:
Expand All @@ -281,6 +306,9 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
# TODO: This scale should not be needed. But without it compile causes a NaN.
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, scale=(1.0 / math.sqrt(q.size(-1))))
if _QUANTIZE_ATTN and padded:
out = out[:, :, :, :original_head_dim]
out = out.to(v.dtype)

out = self._recombine_heads(out)
out = self.out_proj(out)
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,9 @@ def from_hp_to_intx(
)


######################################################
###############################################
# Layout and TensorImpl Subclass Registration #
######################################################
###############################################
register_layout = AffineQuantizedTensor.register_layout
get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor

Expand Down
97 changes: 58 additions & 39 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
_linear_fp8_act_fp8_weight_impl,
_linear_fp_act_fp8_weight_check,
_linear_fp_act_fp8_weight_impl,
_sdpa_float8_check,
_sdpa_float8_impl,
)
from torchao.dtypes.floatx.floatx_tensor_core_layout import (
_linear_f16_bf16_act_floatx_weight_check,
Expand Down Expand Up @@ -89,6 +91,12 @@ class QuantizedLinearNotImplementedError(NotImplementedError):
pass


class QuantizedSDPANotImplementedError(NotImplementedError):
"""Thin wrapper around NotImplementedError to make it easier to catch this error during dispatch"""

pass


@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items():
Expand Down Expand Up @@ -177,45 +185,6 @@ def _(func, types, args, kwargs):
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(torch.nn.functional.embedding)
def _(func, types, args, kwargs):
# new_arg1 = args[1].dequantize()
# return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs)
assert isinstance(
args[1].tensor_impl, PlainAQTTensorImpl
), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}"
assert (
kwargs["padding_idx"] is None
and kwargs["max_norm"] is None
and not kwargs["scale_grad_by_freq"]
and not kwargs["sparse"]
and kwargs["norm_type"] == 2.0
)
idx = args[0]
int_data, scale, zero_point = args[1].tensor_impl.get_plain()

sliced_data, sliced_scale, sliced_zero_point = (
int_data[idx],
scale[idx],
zero_point[idx],
)
# Block size is expecting 2 dimensions [1, group size] but
# batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so
# we need to increase block size to correct dim
new_blocks = idx.dim() - 1
return dequantize_affine(
sliced_data,
new_blocks * [1] + list(args[1].block_size),
sliced_scale,
sliced_zero_point,
sliced_data.dtype,
args[1].quant_min,
args[1].quant_max,
args[1].zero_point_domain,
output_dtype=sliced_scale.dtype,
)


@implements(aten.addmm.default)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
Expand Down Expand Up @@ -277,6 +246,56 @@ def _(func, types, args, kwargs):
return func(input_tensor, weight_tensor)


@implements(torch.nn.functional.embedding)
def _(func, types, args, kwargs):
# new_arg1 = args[1].dequantize()
# return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs)
assert isinstance(
args[1].tensor_impl, PlainAQTTensorImpl
), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}"
assert (
kwargs["padding_idx"] is None
and kwargs["max_norm"] is None
and not kwargs["scale_grad_by_freq"]
and not kwargs["sparse"]
and kwargs["norm_type"] == 2.0
)
idx = args[0]
int_data, scale, zero_point = args[1].tensor_impl.get_plain()

sliced_data, sliced_scale, sliced_zero_point = (
int_data[idx],
scale[idx],
zero_point[idx],
)
# Block size is expecting 2 dimensions [1, group size] but
# batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so
# we need to increase block size to correct dim
new_blocks = idx.dim() - 1
return dequantize_affine(
sliced_data,
new_blocks * [1] + list(args[1].block_size),
sliced_scale,
sliced_zero_point,
sliced_data.dtype,
args[1].quant_min,
args[1].quant_max,
args[1].zero_point_domain,
output_dtype=sliced_scale.dtype,
)


@implements(torch.nn.functional.scaled_dot_product_attention)
def _(func, types, args, kwargs):
q, k, v = args[:3]
if _sdpa_float8_check(q, k, v, args, kwargs):
return _sdpa_float8_impl(q, k, v, args, kwargs)
else:
raise QuantizedSDPANotImplementedError(
"No specialized dispatch found for quantized sdpa"
)


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
Loading
Loading