Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
2 parents 778901d + a356ac5 commit 54a3213
Show file tree
Hide file tree
Showing 31 changed files with 1,260 additions and 745 deletions.
2 changes: 0 additions & 2 deletions benchmarks/bench_galore_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
def run(args):
dtype = getattr(torch, args.dtype)
allow_tf32 = args.allow_tf32
fp8_fast_accum = False
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
kernel = args.kernel
M, N = args.M, args.N
rank = args.rank

Expand Down
118 changes: 83 additions & 35 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
"""
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs"""

import copy

import torch

from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.subclass import (
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
import copy
from torchao.utils import unwrap_tensor_subclass


def _int8wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
Expand All @@ -27,14 +30,20 @@ def _int8wo_api(mod, **kwargs):
else:
change_linear_weights_to_int8_woqtensors(mod, **kwargs)


def _int8da_int8w_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(**kwargs),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod, **kwargs)


def _int4wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
kwargs_copy = kwargs.copy()
Expand All @@ -47,31 +56,43 @@ def _int4wo_api(mod, **kwargs):
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)


class ToyLinearModel(torch.nn.Module):
"""Single linear for m * k * n problem size
"""
def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"):
"""Single linear for m * k * n problem size"""

def __init__(
self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"
):
super().__init__()
self.m = m
self.dtype = dtype
self.device = device
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device)
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(
dtype=self.dtype, device=self.device
)

def example_inputs(self):
return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),)
return (
torch.randn(
self.m, self.linear.in_features, dtype=self.dtype, device=self.device
),
)

def forward(self, x):
x = self.linear(x)
return x


def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _in_features_greater_than_16
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.quant_api import (
_get_subclass_inserter,
_in_features_greater_than_16,
_is_linear,
)
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

if filter_fn is None:
Expand All @@ -80,40 +101,54 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
model,
_get_subclass_inserter(
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
),
filter_fn,
)


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear

filter_fn = kwargs.pop("filter_fn", _is_linear)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
_get_subclass_inserter(
deprecated_tenosr_subclass, enable_parametrization=True, **kwargs
),
filter_fn,
)

return _ref_change_linear_weights_to_woqtensors

_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)

_ref_change_linear_weights_to_int8_woqtensors = (
_get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
)
_ref_change_linear_weights_to_int4_woqtensors = (
_get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
)


torch._dynamo.config.cache_size_limit = 50000


@torch.no_grad
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
if kwargs is None:
kwargs = {}

m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval()
m = ToyLinearModel(
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
).eval()
m_bf16 = copy.deepcopy(m)
m_ref = copy.deepcopy(m)
example_inputs = m.example_inputs()
Expand All @@ -130,26 +165,30 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):

# perf comparison
from torchao.utils import benchmark_model

# warmup
WARMUP = 20
RUNS = 100

torch._dynamo.reset()
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True)
benchmark_model(m_ref, WARMUP, example_inputs)
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)

torch._dynamo.reset()
m = torch.compile(m, mode='max-autotune', fullgraph=True)
m = torch.compile(m, mode="max-autotune", fullgraph=True)
benchmark_model(m, WARMUP, example_inputs)
elapsed_time = benchmark_model(m, RUNS, example_inputs)

torch._dynamo.reset()
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")
print(
f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
)


if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
all_shapes = [
Expand All @@ -158,16 +197,25 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):

print("_int8da_int8w_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors

for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K)
_bench_quantized_tensor_subclass_perf(
_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K
)

print("_int8wo_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors

for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K)
_bench_quantized_tensor_subclass_perf(
_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K
)

print("_int4wo_api")
kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors

for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs)
_bench_quantized_tensor_subclass_perf(
_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs
)
39 changes: 28 additions & 11 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import torch
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm

from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout
from torchao.dtypes.floatx import FloatxTensorCoreLayout
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda")
float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2))
fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2))
fp6_weight_fp16 = to_affine_quantized_fpx(
float_data_fp16, FloatxTensorCoreLayout(3, 2)
)
fp6_weight_bf16 = to_affine_quantized_fpx(
float_data_bf16, FloatxTensorCoreLayout(3, 2)
)
fp16_weight = fp6_weight_fp16.dequantize(torch.float16)
bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16)

Expand All @@ -22,15 +27,27 @@ def benchmark(m: int, k: int, n: int):
fp16_output = F.linear(fp16_act, fp16_weight)
bf16_output = F.linear(bf16_act, bf16_weight)

fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight)
fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16)
fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16)
fp16_time = benchmark_torch_function_in_microseconds(
F.linear, fp16_act, fp16_weight
)
bf16_time = benchmark_torch_function_in_microseconds(
F.linear, bf16_act, bf16_weight
)
fp6_time_fp16 = benchmark_torch_function_in_microseconds(
F.linear, fp16_act, fp6_weight_fp16
)
fp6_time_bf16 = benchmark_torch_function_in_microseconds(
F.linear, bf16_act, fp6_weight_bf16
)

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2
correct_fp16 = (
fp6_output_fp16 - fp16_output
).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_bf16 = (
fp6_output_bf16 - bf16_output
).abs().mean() / bf16_output.abs().mean() < 1e-2

return {
"m": m,
Expand Down
Loading

0 comments on commit 54a3213

Please sign in to comment.