Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable float8 attention support (q/k/v) #1382

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ matplotlib
pandas
fire # QOL for commandline scripts
tabulate # QOL for printing tables to stdout
einops # for testing flash attention 3


# Custom CUDA Extensions
ninja
Expand Down
177 changes: 177 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import copy
import io
import math
import random
import unittest
from contextlib import nullcontext
Expand All @@ -17,6 +18,7 @@

import pytest
import torch
from einops import rearrange, repeat
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils

Expand Down Expand Up @@ -288,7 +290,182 @@ def test_fp8_weight_dimension_warning(self):
)


# copied from https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185


def construct_local_mask(
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
key_leftpad=None,
):
row_idx = rearrange(
torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1"
)
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)


def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
key_leftpad=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores.tanh()
scores *= softcap
if key_padding_mask is not None:
scores.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
key_leftpad=key_leftpad,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
if key_padding_mask is not None:
output.masked_fill_(
rearrange(
torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"
),
0.0,
)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


class TestAffineQuantizedFloat8Attention(common_utils.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_float8_attention(self):
import torch.nn.functional as F

from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant

class MyModel(torch.nn.Module):
def forward(self, q, k, v, float8_quantize=False):
if float8_quantize:
q = _float8_symmetric_per_tensor_quant(q)
k = _float8_symmetric_per_tensor_quant(k)
v = _float8_symmetric_per_tensor_quant(v)
return F.scaled_dot_product_attention(q, k, v)

# note: last headdim must be 64, 128, 256
q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda")
k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda")
v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda")

m = MyModel().eval()
# it differs a lot from the non-quantized implementation
# sqnr = -2.5
# ref = m(q, k, v)

# but matches the custom attention implementation in flash attention repo
ref = attention_ref(q, k, v)[0]
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
quantized = m(q, k, v, True)
assert compute_error(ref, quantized) > 25.0


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

if __name__ == "__main__":
pytest.main([__file__])
common_utils.run_tests()
40 changes: 28 additions & 12 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torch.nn import functional as F
from torchao.utils import find_multiple

_QUANTIZE_ATTN = True

# TODO remove suplerfluous arg
def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
Expand Down Expand Up @@ -85,7 +87,7 @@ def from_name(cls, name: str):
),
}

# this is a model specific variable that controls whether index_put is used for the kv_cache update,
# this is a model specific variable that controls whether index_put is used for the kv_cache update,
# it is needed for GPTQ but otherwise attenuates perf so the default is to not use it
use_index_put_for_kv_cache = False

Expand Down Expand Up @@ -124,7 +126,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtyp
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))

def update(self, input_pos, k_val, v_val):
# quantize current k_val and store it in the cache
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
Expand All @@ -138,7 +140,7 @@ def update(self, input_pos, k_val, v_val):
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
v_out = self.v_cache*self.v_cache_scale
v_out[:, :, input_pos] = v_val

return k_out, v_out

@classmethod
Expand Down Expand Up @@ -194,16 +196,16 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
else:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
self.freqs_cis = precompute_freqs_cis(
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
dtype,
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
dtype,
use_scaled=self.config.use_scaled_rope
)

def reset_caches(self):
"""Reset caches.

The caches used by training stage and inference stage may be different, reset them before switching.
"""
self.max_batch_size = -1
Expand All @@ -215,7 +217,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
"""Forward pass of the model.

Args:
idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
Indices of input sequence tokens in the vocabulary.
input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
Expand All @@ -227,7 +229,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
"""
assert self.freqs_cis is not None, "Caches must be initialized first"

if input_pos is None:
if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
else:
Expand Down Expand Up @@ -311,11 +313,25 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_po

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)

# quantize q/k/v with per tensor float8 quantization
padded = False
if _QUANTIZE_ATTN:
from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant
original_dtype = v.dtype
if q.shape[-1] in [64, 128, 256]:
q = _float8_symmetric_per_tensor_quant(q)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't see this, maybe we can add it after spinquant is integrated

Copy link

@bhack bhack Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api

https://pytorch.org/blog/hadacore/

k = _float8_symmetric_per_tensor_quant(k)
v = _float8_symmetric_per_tensor_quant(v)

if mask is not None:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
else:
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)

if _QUANTIZE_ATTN:
y = y.to(original_dtype)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
Expand Down Expand Up @@ -371,8 +387,8 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

def precompute_freqs_cis(
seq_len: int,
n_elem: int,
seq_len: int,
n_elem: int,
base: int = 10000,
dtype: torch.dtype = torch.bfloat16,
use_scaled: bool=False
Expand Down
25 changes: 25 additions & 0 deletions torchao/_models/sam2/modeling/sam/transformer.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the current SAM2 readme with all the ao optimizations we have introduced?

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False

# whether to turn on float8 quantization for sdpa or not
_QUANTIZE_ATTN = False


def sdp_kernel_context(dropout_p):
"""
Expand Down Expand Up @@ -263,6 +266,25 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)

# quantize q/k/v with per tensor float8 quantization
padded = False
if _QUANTIZE_ATTN:
from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant
original_head_dim = list(q.shape)[-1]
original_dtype = v.dtype
padded = False
# padding:
if q.shape[-1] == 32:
q = F.pad(q, (0, 32))
k = F.pad(k, (0, 32))
v = F.pad(v, (0, 32))
padded = True

if q.shape[-1] in [64, 128, 256]:
q = _float8_symmetric_per_tensor_quant(q)
k = _float8_symmetric_per_tensor_quant(k)
v = _float8_symmetric_per_tensor_quant(v)

dropout_p = self.dropout_p if self.training else 0.0
# # Attention
# try:
Expand All @@ -281,6 +303,9 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
# TODO: This scale should not be needed. But without it compile causes a NaN.
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, scale=(1.0 / math.sqrt(q.size(-1))))
if _QUANTIZE_ATTN and padded:
out = out[:, :, :, :original_head_dim]
out = out.to(v.dtype)

out = self._recombine_heads(out)
out = self.out_proj(out)
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,9 @@ def from_hp_to_intx(
)


######################################################
###############################################
# Layout and TensorImpl Subclass Registration #
######################################################
###############################################
register_layout = AffineQuantizedTensor.register_layout
get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor

Expand Down
Loading
Loading