diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 4d8312b427..9460672add 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -288,7 +288,50 @@ def test_fp8_weight_dimension_warning(self): ) +@unittest.skip( + "Only running locally so we don't need to add installation of fa3 " + "hopper kernels to CI, we'll probably copy paste kernel in the future" +) +class TestAffineQuantizedFloat8Attention(common_utils.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_float8_attention(self): + import torch.nn.functional as F + + from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant + + class MyModel(torch.nn.Module): + def forward(self, q, k, v, float8_quantize=False): + if float8_quantize: + # F.scaled_dot_product_attention is using (batch_size, nheads, seqlen, headdim) + # while flash attention kernel has (batch_size, seqlen, nheads, headdim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + + return F.scaled_dot_product_attention(q, k, v) + + # note: last dim headdim must be 64, 128 or 256 + q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") + k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") + v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") + + m = MyModel().eval() + + # bfloat16 ref result + ref = m(q, k, v) + + # float8 quantized result + quantized = m(q, k, v, True) + + sqnr = compute_error(ref, quantized) + assert sqnr > 25.0, f"Got sqnr: {sqnr}" + + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) if __name__ == "__main__": pytest.main([__file__]) + common_utils.run_tests() diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 74cad30cbd..af834479f1 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -13,6 +13,8 @@ from torch.nn import functional as F from torchao.utils import find_multiple +_QUANTIZE_ATTN = False + # TODO remove suplerfluous arg def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d @@ -85,7 +87,7 @@ def from_name(cls, name: str): ), } -# this is a model specific variable that controls whether index_put is used for the kv_cache update, +# this is a model specific variable that controls whether index_put is used for the kv_cache update, # it is needed for GPTQ but otherwise attenuates perf so the default is to not use it use_index_put_for_kv_cache = False @@ -124,7 +126,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtyp self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8)) self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) - + def update(self, input_pos, k_val, v_val): # quantize current k_val and store it in the cache q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) @@ -138,7 +140,7 @@ def update(self, input_pos, k_val, v_val): self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) v_out = self.v_cache*self.v_cache_scale v_out[:, :, input_pos] = v_val - + return k_out, v_out @classmethod @@ -194,16 +196,16 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_ else: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) self.freqs_cis = precompute_freqs_cis( - self.config.block_size, - self.config.dim // self.config.n_head, - self.config.rope_base, - dtype, + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + dtype, use_scaled=self.config.use_scaled_rope ) def reset_caches(self): """Reset caches. - + The caches used by training stage and inference stage may be different, reset them before switching. """ self.max_batch_size = -1 @@ -215,7 +217,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: """Forward pass of the model. Args: - idx (`torch.LongTensor` of shape `(batch_size, seq_length)`): + idx (`torch.LongTensor` of shape `(batch_size, seq_length)`): Indices of input sequence tokens in the vocabulary. input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. @@ -227,7 +229,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: """ assert self.freqs_cis is not None, "Caches must be initialized first" - if input_pos is None: + if input_pos is None: mask = None freqs_cis = self.freqs_cis[:idx.shape[1]] else: @@ -311,11 +313,25 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_po k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + # quantize q/k/v with per tensor float8 quantization + padded = False + if _QUANTIZE_ATTN: + from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant + original_dtype = v.dtype + if q.shape[-1] in [64, 128, 256]: + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + if mask is not None: y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) else: y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) + if _QUANTIZE_ATTN: + y = y.to(original_dtype) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) y = self.wo(y) @@ -371,8 +387,8 @@ def apply_scaling(freqs: torch.Tensor): return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) def precompute_freqs_cis( - seq_len: int, - n_elem: int, + seq_len: int, + n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16, use_scaled: bool=False diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index 2e3d85ccd4..3d7bc4e248 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -24,6 +24,9 @@ # A fallback setting to allow all available kernels if Flash Attention fails ALLOW_ALL_KERNELS = False +# whether to turn on float8 quantization for sdpa or not +_QUANTIZE_ATTN = False + def sdp_kernel_context(dropout_p): """ @@ -263,6 +266,28 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) + # quantize q/k/v with per tensor float8 quantization + padded = False + if _QUANTIZE_ATTN: + from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant + original_head_dim = list(q.shape)[-1] + original_dtype = v.dtype + padded = False + # padding: + if q.shape[-1] == 32: + q = F.pad(q, (0, 32)) + k = F.pad(k, (0, 32)) + v = F.pad(v, (0, 32)) + padded = True + + if q.shape[-1] in [64, 128, 256]: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + dropout_p = self.dropout_p if self.training else 0.0 # # Attention # try: @@ -281,6 +306,9 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) # TODO: This scale should not be needed. But without it compile causes a NaN. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, scale=(1.0 / math.sqrt(q.size(-1)))) + if _QUANTIZE_ATTN and padded: + out = out[:, :, :, :original_head_dim] + out = out.to(v.dtype) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 93d2766d1e..af4ff01911 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -510,9 +510,9 @@ def from_hp_to_intx( ) -###################################################### +############################################### # Layout and TensorImpl Subclass Registration # -###################################################### +############################################### register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bd7ff7d333..a2319a9647 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -11,6 +11,8 @@ _linear_fp8_act_fp8_weight_impl, _linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl, + _sdpa_float8_check, + _sdpa_float8_impl, ) from torchao.dtypes.floatx.floatx_tensor_core_layout import ( _linear_f16_bf16_act_floatx_weight_check, @@ -89,6 +91,12 @@ class QuantizedLinearNotImplementedError(NotImplementedError): pass +class QuantizedSDPANotImplementedError(NotImplementedError): + """Thin wrapper around NotImplementedError to make it easier to catch this error during dispatch""" + + pass + + @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): @@ -177,45 +185,6 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) -@implements(torch.nn.functional.embedding) -def _(func, types, args, kwargs): - # new_arg1 = args[1].dequantize() - # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance( - args[1].tensor_impl, PlainAQTTensorImpl - ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert ( - kwargs["padding_idx"] is None - and kwargs["max_norm"] is None - and not kwargs["scale_grad_by_freq"] - and not kwargs["sparse"] - and kwargs["norm_type"] == 2.0 - ) - idx = args[0] - int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = ( - int_data[idx], - scale[idx], - zero_point[idx], - ) - # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so - # we need to increase block size to correct dim - new_blocks = idx.dim() - 1 - return dequantize_affine( - sliced_data, - new_blocks * [1] + list(args[1].block_size), - sliced_scale, - sliced_zero_point, - sliced_data.dtype, - args[1].quant_min, - args[1].quant_max, - args[1].zero_point_domain, - output_dtype=sliced_scale.dtype, - ) - - @implements(aten.addmm.default) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -277,6 +246,56 @@ def _(func, types, args, kwargs): return func(input_tensor, weight_tensor) +@implements(torch.nn.functional.embedding) +def _(func, types, args, kwargs): + # new_arg1 = args[1].dequantize() + # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) + assert isinstance( + args[1].tensor_impl, PlainAQTTensorImpl + ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert ( + kwargs["padding_idx"] is None + and kwargs["max_norm"] is None + and not kwargs["scale_grad_by_freq"] + and not kwargs["sparse"] + and kwargs["norm_type"] == 2.0 + ) + idx = args[0] + int_data, scale, zero_point = args[1].tensor_impl.get_plain() + + sliced_data, sliced_scale, sliced_zero_point = ( + int_data[idx], + scale[idx], + zero_point[idx], + ) + # Block size is expecting 2 dimensions [1, group size] but + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # we need to increase block size to correct dim + new_blocks = idx.dim() - 1 + return dequantize_affine( + sliced_data, + new_blocks * [1] + list(args[1].block_size), + sliced_scale, + sliced_zero_point, + sliced_data.dtype, + args[1].quant_min, + args[1].quant_max, + args[1].zero_point_domain, + output_dtype=sliced_scale.dtype, + ) + + +@implements(torch.nn.functional.scaled_dot_product_attention) +def _(func, types, args, kwargs): + q, k, v = args[:3] + if _sdpa_float8_check(q, k, v, args, kwargs): + return _sdpa_float8_impl(q, k, v, args, kwargs) + else: + raise QuantizedSDPANotImplementedError( + "No specialized dispatch found for quantized sdpa" + ) + + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index dd995fb157..b3d120880d 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -311,3 +311,91 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) + + +def _sdpa_float8_check( + q: Union[torch.Tensor, "AffineQuantizedTensor"], + k: Union[torch.Tensor, "AffineQuantizedTensor"], + v: Union[torch.Tensor, "AffineQuantizedTensor"], + args, + kwargs, +) -> bool: + def is_compatible_per_tensor_float8_aqt(t): + # tensor is float8 quantized affine quantized tensor + return ( + isinstance(t, AffineQuantizedTensor) + and isinstance(t._layout, Float8Layout) + and t.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and (t.shape == t.block_size) + and t.shape[-1] in [64, 128, 256] + ) + + dropout_p = kwargs.get("dropout_p", 0.0) + + return ( + is_compatible_per_tensor_float8_aqt(q) + and is_compatible_per_tensor_float8_aqt(k) + and is_compatible_per_tensor_float8_aqt(v) + and "attn_mask" not in kwargs + and dropout_p == 0.0 + ) + + +def _sdpa_float8_impl( + q: Union[torch.Tensor, "AffineQuantizedTensor"], + k: Union[torch.Tensor, "AffineQuantizedTensor"], + v: Union[torch.Tensor, "AffineQuantizedTensor"], + args, + kwargs, +) -> torch.Tensor: + try: + # for libc10.so + import torch + from hopper.flash_attn_interface import flash_attn_func + except ImportError as e: + raise ImportError( + f"please install FlashAttention 3 before using float8 sdpa: https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release, original import error {e}" + ) + + q_tensor_impl = q.tensor_impl + assert not q_tensor_impl.transposed + q_float8_data = q_tensor_impl.float8_data + # change from scalar to tensor of size [1] + q_scale = q_tensor_impl.scale + q_scale = torch.tensor([q_scale], device=q_scale.device) + + k_tensor_impl = k.tensor_impl + assert not k_tensor_impl.transposed + k_float8_data = k_tensor_impl.float8_data + k_scale = k_tensor_impl.scale + k_scale = torch.tensor([k_scale], device=k_scale.device) + + v_tensor_impl = v.tensor_impl + assert not v_tensor_impl.transposed + v_float8_data = v_tensor_impl.float8_data + v_scale = v_tensor_impl.scale + v_scale = torch.tensor([v_scale], device=v_scale.device) + + dropout_p = kwargs.get("dropout_p", None) + assert ( + dropout_p is None or dropout_p == 0.0 + ), "dropout_p should be set to 0.0 during inference" + causal = kwargs.get("causal", False) + + out, _ = flash_attn_func( + q_float8_data, + k_float8_data, + v_float8_data, + causal=causal, + window_size=(-1, -1), + descale_q=q_scale, + descale_k=k_scale, + descale_v=v_scale, + ) + + # F.scaled_dot_product_attention is using (batch_size, nheads, seqlen, headdim) + # while flash attention kernel has (batch_size, seqlen, nheads, headdim) + # so we need to transpose output to match the expected dimension + out = out.transpose(1, 2) + + return out diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index d832731657..a554fd9bc6 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -44,6 +44,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( # must pad row, col = tmp.shape from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( @@ -51,7 +52,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t()[:row, :] + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 3fc2cb5ef0..7255340161 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -305,7 +305,9 @@ Note that the workaround will not be needed after https://github.com/pytorch/pyt Note that the workaround is also required for `torch.compile` with `freezing` (`torch._inductor.config.freezing=True`) until https://github.com/pytorch/pytorch/pull/136265 is fixed. -## Other Available Quantization Techniques +## [Prototype Features] Other Available Quantization Techniques + +Note: APIs in this section are prototype and subject to change. ### KV Cache Quantization We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. @@ -360,6 +362,37 @@ We have kernels that do 8-bit dynamic quantization of activations and uintx grou You try can out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. +### Float8 `scaled_dot_product_attention` Support +We also have initial support for per tensor float8 `scaled_dot_product_attention`, using flash attention 3 (optimized for H100 GPUs). To use the feature: + +1. Install from source: +https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release + +2. Modify the model to quantize q/k/v to per tensor float8 tensors +``` +from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant +import torch.nn.functional as F + +class MyModel(torch.nn.Module): + def forward(self, q, k, v, float8_quantize=False): + if float8_quantize: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + return F.scaled_dot_product_attention(q, k, v) + +# note: (last dimension) headdim must be 64, 128, 256 +q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") +k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") +v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") +``` +See `test_float8_attention` in `test/dtypes/test_affine_quantized_float.py` on the full test. + +We might be adding new variations of attention implementation in the future (per row, per column, per block scaling etc.), and supporting arguments like `attn_mask`. + ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 96ccb1889c..4fc5513462 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -803,6 +803,36 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) +def _float8_symmetric_per_token_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +): + from torchao.dtypes import to_affine_quantized_floatx + + return to_affine_quantized_floatx( + input_float=x, + block_size=_get_per_token_block_size(x), + target_dtype=dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + + +def _float8_symmetric_per_tensor_quant( + x: torch.Tensor, + dtype: torch.dtype = torch.float8_e4m3fn, + mm_config: Optional[Float8MMConfig] = None, +): + from torchao.dtypes import to_affine_quantized_floatx + + return to_affine_quantized_floatx( + input_float=x, + block_size=tuple(x.shape), + target_dtype=dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. @@ -814,17 +844,9 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): The actual matmul will be computed in original precision of the weight tensor. """ - from torchao.dtypes import to_affine_quantized_floatx def apply_float8wo_quant(weight): - block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + return _float8_symmetric_per_token_quant(weight, weight_dtype) return _get_linear_subclass_inserter(apply_float8wo_quant) @@ -1172,5 +1194,6 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, _input_activation_quant_func_fp8, + _float8_symmetric_per_token_quant, ] )