Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W8A16 #10

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions benchmark/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, x):


def run_benchmark(
use_q, d_model, dim_feedforward, batch_size, seq_len, minimize_error=True
use_q, d_model, dim_feedforward, batch_size, seq_len, minimize_error=True, use_w8a16=True,
):
inp = torch.randn(batch_size, seq_len, d_model)
inp = inp.half().cuda()
Expand All @@ -58,9 +58,13 @@ def run_benchmark(
ffn = ffn.half().cuda().eval()
fp16_ref = ffn(inp).detach().clone().float()
if use_q:
ffn.linear1 = protoquant.qlinear_from_linear(ffn.linear1, minimize_error)
ffn.linear2 = protoquant.qlinear_from_linear(ffn.linear2, minimize_error)
ffn = torch.compile(ffn, options={"max-autotune": True})
if use_w8a16:
ffn.linear1 = protoquant.w8a16_qlinear_from_linear(ffn.linear1, minimize_error)
ffn.linear2 = protoquant.w8a16_qlinear_from_linear(ffn.linear2, minimize_error)
else:
ffn.linear1 = protoquant.qlinear_from_linear(ffn.linear1, minimize_error)
ffn.linear2 = protoquant.qlinear_from_linear(ffn.linear2, minimize_error)
ffn = torch.compile(ffn, options={"max-autotune": True})
fp8_ref = ffn(inp).detach().clone().float()
torch.testing.assert_close(fp16_ref, fp8_ref, atol=3e-2, rtol=3e-2)
return benchmark_torch_function_in_microseconds(ffn, inp)
Expand Down
3 changes: 2 additions & 1 deletion protoquant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from .gemm import gemm, pad
from .qlinear import qlinear_from_linear
from .w8a16linear import w8a16_qlinear_from_linear
from .qt import QTensor
from .quantization import dqntz, qntz

_load_library()

__all__ = ["QTensor", "gemm", "pad", "qntz", "dqntz", "qlinear_from_linear"]
__all__ = ["QTensor", "gemm", "pad", "qntz", "dqntz", "qlinear_from_linear", "w8a16_qlinear_from_linear"]
74 changes: 74 additions & 0 deletions protoquant/w8a16linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from protoquant.quantization import dqntz, qntz
from protoquant.src.triton.matmul import matmul as matmul_int8
from torch.nn.parameter import Parameter
from typing import Callable, Optional, Tuple



class W8A16QLinear(torch.nn.Module):
def __init__(self, qweight, qscales, bias):
super(W8A16QLinear, self).__init__()
assert isinstance(bias, Parameter)
self.qweight = qweight
self.qscales = qscales
self.bias = bias
self.in_features = qweight.size(1)
self.out_features = qweight.size(1)

def forward(self, inp: torch.Tensor) -> torch.Tensor:
assert inp.dim() == 3
assert inp.dtype == torch.float16
return torch.nn.functional.linear(inp,
self.qweight.mul(self.qscales).to(torch.float16))

def extra_repr(self) -> str:
return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.bias is not None
)

half_range_lookup = {
8: torch.full((1,), (1 << (8 - 1)) - 1, dtype=torch.float16, device="cuda"),
}
full_range_lookup = {
8: torch.full((1,), 1 << 8, dtype=torch.float16, device="cuda"),
}

inv_half_range_lookup = {
8: torch.full((1,), 1 / ((1 << (8 - 1)) - 1), dtype=torch.float16, device="cuda"),
}
inv_full_range_lookup = {
8: torch.full((1,), 1 / (1 << 8), dtype=torch.float16, device="cuda"),
}

def scales_from_point(input: torch.Tensor, dim: Optional[int], qdtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
input_abs = torch.abs(input)
if dim is None:
input_abs_maxs = torch.max(input_abs)
else:
input_abs_maxs = torch.max(input_abs, dim, keepdim=True).values
scales = torch.mul(input_abs_maxs, inv_half_range_lookup[torch.iinfo(qdtype).bits]).to(torch.float32)
inv_scales = torch.div(half_range_lookup[torch.iinfo(qdtype).bits], input_abs_maxs).to(torch.float32)
return scales, inv_scales

def per_channel_scaled(input: torch.Tensor, qdtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
scales, inv_scales = scales_from_point(input, 0, qdtype)
qinput = torch.mul(input, inv_scales).to(qdtype)
return qinput, scales

def per_token_scaled(input: torch.Tensor, qdtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
scales, inv_scales = scales_from_point(input, 1, qdtype)
qinput = torch.mul(input, inv_scales).to(qdtype)
return qinput, scales


def w8a16_qlinear_from_linear(
linear: torch.nn.Module, minimize_error=True
) -> torch.nn.Module:
import protoquant

assert isinstance(linear, torch.nn.Linear)
assert linear.weight.dtype == torch.float16
assert linear.bias.dtype == torch.float16
qweight, qscales = per_token_scaled(linear.weight, torch.int8)
return W8A16QLinear(qweight, qscales, linear.bias)