diff --git a/benchmarks/quantized_training/benchmark_int8mm.py b/benchmarks/quantized_training/benchmark_int8mm.py new file mode 100644 index 0000000000..85892afa85 --- /dev/null +++ b/benchmarks/quantized_training/benchmark_int8mm.py @@ -0,0 +1,45 @@ +import pandas as pd +import torch +from triton.testing import do_bench + +from torchao.prototype.quantized_training.int8_mm import int8_mm_dequant + + +def bench_f(f, *args): + return do_bench(lambda: f(*args), fast_flush=False, return_mode="median") + + +shapes = [(sz, sz, sz) for sz in [1024, 2048, 4096]] + +# Llama-8B shapes +shapes += [ + # linear in attention + (32_768, 4096, 4096), + (4096, 4096, 32_768), + # linear in feed-forward + (32_768, 14_336, 4096), + (32_768, 4096, 14_336), + (14_336, 4096, 32_768), +] + +data = [] +for M, N, K in shapes: + print(f"{M=}, {N=}, {K=}") + + A_bf16 = torch.randn(M, K).bfloat16().cuda() + B_bf16 = torch.randn(N, K).bfloat16().cuda() + A_i8 = torch.randint(-128, 127, size=(M, K), dtype=torch.int8).cuda() + B_i8 = torch.randint(-128, 127, size=(N, K), dtype=torch.int8).cuda() + A_scale = torch.randn(M).bfloat16().cuda() + B_scale = torch.randn(N).bfloat16().cuda() + + # benchmark F.linear() i.e. A @ B.T + bf16_time = bench_f(torch.mm, A_bf16, B_bf16.T) + i8_time = bench_f(torch._int_mm, A_i8, B_i8.T) + i8_dequant_time = bench_f(int8_mm_dequant, A_i8, B_i8.T, A_scale, B_scale) + + sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time] + data.append(sample) + +df = pd.DataFrame(data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"]) +print(df.to_markdown()) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index de3ed04e8f..fc87c2cd6e 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -3,12 +3,14 @@ # # BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile # INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only +# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import argparse +import time from functools import partial from pathlib import Path @@ -18,22 +20,34 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from torchao._models.llama.model import ModelArgs, Transformer +from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs from torchao.prototype import low_bit_optim -from torchao.prototype.quantized_training import int8_weight_only_quantized_training +from torchao.prototype.quantized_training import ( + int8_mixed_precision_training, + int8_weight_only_quantized_training, +) from torchao.quantization.quant_api import quantize_ +# not official models +transformer_configs.update( + ( + ("470M", dict(n_layer=24, n_head=16, dim=1024, intermediate_size=4096)), + ("1B", dict(n_layer=24, n_head=24, dim=1536, intermediate_size=6144)), + ) +) + + # hack from fairseq # https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py def enable_activation_checkpointing(m: torch.nn.Module): assert not hasattr(m, "_forward") m._forward = m.forward - m.forward = partial(checkpoint, m.forward) + m.forward = partial(checkpoint, m.forward, use_reentrant=False) def get_loss(model: Transformer, batch: torch.Tensor): - logits = model(batch)[:, :-1].flatten(0, 1) + logits = model(batch)[:, :-1].float().flatten(0, 1) labels = batch[:, 1:].flatten() return torch.nn.functional.cross_entropy(logits, labels) @@ -77,12 +91,7 @@ def get_tinystories(): if __name__ == "__main__": parser = argparse.ArgumentParser() - # default config is 470M - parser.add_argument("--d_model", type=int, default=1024) - parser.add_argument("--depth", type=int, default=24) - parser.add_argument("--ffn_size", type=int, default=4096) - parser.add_argument("--head_dim", type=int, default=64) - + parser.add_argument("--model", default="470M", choices=transformer_configs.keys()) parser.add_argument("--quantize") parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") @@ -98,30 +107,33 @@ def get_tinystories(): parser.add_argument("--project", default="int8_quantized_training") parser.add_argument("--run_name") parser.add_argument("--seed", type=int) + parser.add_argument("--log_interval", type=int, default=10) args = parser.parse_args() if args.seed is not None: torch.manual_seed(args.seed) - config = ModelArgs( - block_size=args.seq_len, - n_layer=args.depth, - n_head=args.d_model // args.head_dim, - dim=args.d_model, - intermediate_size=args.ffn_size, - ) + config = ModelArgs.from_name(args.model) + config.block_size = args.seq_len model = Transformer(config).bfloat16().cuda() with torch.device("cuda"): model.setup_caches(args.batch_size, args.seq_len, training=True) if args.activation_checkpointing: for layer in model.layers: enable_activation_checkpointing(layer) + + # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. + # TODO: might want to do the same for int8_weight_only to standardize. if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + elif args.quantize == "int8_mixed_precision": + quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") + print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") + torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights # only use optimizers from torchao.prototype.low_bit_optim to support quantized training if args.optim == "AdamW": @@ -129,13 +141,14 @@ def get_tinystories(): optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) data = get_tinystories().cuda() + args.torch_version = torch.__version__ run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) step = 0 - log_interval = 50 pbar = tqdm(total=args.n_steps, dynamic_ncols=True) model.train() _get_loss = torch.compile(get_loss) if args.compile else get_loss + time0 = time.time() while step < args.n_steps: # randomly select a continuous chunk, then reshape it @@ -145,13 +158,17 @@ def get_tinystories(): loss = _get_loss(model, batch) loss.backward() - if step % log_interval == 0: + if step % args.log_interval == 0: log_dict = dict( loss=loss.item(), lr=optim.param_groups[0]["lr"], max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, - max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9, + max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9, ) + if step > 0: + time1 = time.time() + log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0) + time0 = time1 run.log(log_dict, step=step) pbar.set_postfix(loss=log_dict["loss"]) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 6b4b6a6be9..bffff16fc1 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -1,21 +1,30 @@ +import pytest + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 + +if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("Requires torch>=2.4", allow_module_level=True) + import copy -import pytest import torch +import torch.distributed as dist import torch.nn.functional as F from torch import nn +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer from torchao.prototype.low_bit_optim import _AdamW -from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training +from torchao.prototype.quantized_training import ( + Int8MixedPrecisionTrainingConfig, + int8_mixed_precision_training, + int8_weight_only_quantized_training, + quantize_int8_rowwise, +) from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 - -if not TORCH_VERSION_AFTER_2_3: - pytest.skip("Requires torch>=2.4", allow_module_level=True) - _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -35,7 +44,7 @@ def test_int8_stochastic_rounding(self, device): x = torch.randn(32, device=device) x_samples = x.view(1, -1).repeat(100_000, 1) - x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True) + x_int8, x_scale = quantize_int8_rowwise(x_samples, stochastic_rounding=True) x_dequant_samples = x_int8 * x_scale.view(-1, 1) x_dequant_mean = x_dequant_samples.mean(0) @@ -43,10 +52,18 @@ def test_int8_stochastic_rounding(self, device): # due to the statistical nature, this assertion may still fail, though very rarely. torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) + @staticmethod + def _forward_and_backward(module, input, grad): + # clone input, since we want to inspect its gradient later + input = input.detach().clone().requires_grad_(True) + output = module(input) + output.backward(grad) + return input, output + @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear(self, leading_dims, bias, device): + def test_int8_weight_only_correctness(self, leading_dims, bias, device): _reset() embed_dim = 32 @@ -55,20 +72,13 @@ def test_int8_linear(self, leading_dims, bias, device): quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False) linear_fp32.weight.data = linear_int8.weight.data.dequantize() - input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) - input_int8 = input_fp32.clone() - input_fp32.requires_grad_(True) - input_int8.requires_grad_(True) + input = torch.randn(leading_dims + (embed_dim,), device=device) + grad = torch.randn(leading_dims + (embed_dim,), device=device) - # test forward - out_fp32 = linear_fp32(input_fp32) - out_int8 = linear_int8(input_int8) - torch.testing.assert_close(out_fp32, out_int8) + input_fp32, out_fp32 = self._forward_and_backward(linear_fp32, input, grad) + input_int8, out_int8 = self._forward_and_backward(linear_int8, input, grad) - # test backward - grad = torch.randn(leading_dims + (embed_dim,), device=device) - out_fp32.backward(grad) - out_int8.backward(grad) + torch.testing.assert_close(out_fp32, out_int8) torch.testing.assert_close(input_fp32.grad, input_int8.grad) torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad) if bias: @@ -77,7 +87,7 @@ def test_int8_linear(self, leading_dims, bias, device): @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_compile(self, leading_dims, bias, device): + def test_int8_weight_only_compile(self, leading_dims, bias, device): _reset() embed_dim = 128 @@ -86,18 +96,13 @@ def test_int8_linear_compile(self, leading_dims, bias, device): linear_compiled = copy.deepcopy(linear_eager) linear_compiled.compile() - input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10 - input_compiled = input_eager.clone() - input_eager.requires_grad_(True) - input_compiled.requires_grad_(True) + input = torch.randn(leading_dims + (embed_dim,), device=device) * 10 + grad = torch.randn(leading_dims + (embed_dim,), device=device) + + input_eager, out_eager = self._forward_and_backward(linear_eager, input, grad) + input_compiled, out_compiled = self._forward_and_backward(linear_compiled, input, grad) - out_eager = linear_eager(input_eager) - out_compiled = linear_compiled(input_compiled) torch.testing.assert_close(out_eager, out_compiled) - - grad = torch.randn(leading_dims + (embed_dim,), device=device) - out_eager.backward(grad) - out_compiled.backward(grad) torch.testing.assert_close(input_eager.grad, input_compiled.grad) torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad) if bias: @@ -105,7 +110,7 @@ def test_int8_linear_compile(self, leading_dims, bias, device): @parametrize("compile", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_training(self, compile, device): + def test_int8_weight_only_training(self, compile, device): _reset() bsize = 4 embed_dim = 32 @@ -117,7 +122,6 @@ def test_int8_linear_training(self, compile, device): nn.Linear(embed_dim * 2, n_classes), ).to(device) model_int8 = copy.deepcopy(model_fp32) - # don't set inductor flags to speed up CI time quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False) if compile: @@ -144,33 +148,104 @@ def test_int8_linear_training(self, compile, device): optim_int8.step() optim_int8.zero_grad() + @parametrize("compile", [False, True]) + @parametrize( + "config", + [ + Int8MixedPrecisionTrainingConfig(), + Int8MixedPrecisionTrainingConfig(output=False), + Int8MixedPrecisionTrainingConfig(grad_input=False), + Int8MixedPrecisionTrainingConfig(grad_weight=False), + ], + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_int8_mixed_precision_training(self, compile, config): + _reset() + bsize = 4 + embed_dim = 32 + device = "cuda" + + # only use 1 matmul shape to reduce triton autotune time + model_ref = nn.Sequential( + nn.Linear(embed_dim, embed_dim, bias=False), + nn.GELU(), + nn.Linear(embed_dim, embed_dim), + ).to(device) + model_int8mp = copy.deepcopy(model_ref) + quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) + + if compile: + model_ref.compile() + model_int8mp.compile() + + optim_ref = torch.optim.AdamW(model_ref.parameters()) + optim_int8mp = torch.optim.AdamW(model_int8mp.parameters()) + + for i in range(5): + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(embed_dim, size=(bsize,), device=device) + loss_ref = F.cross_entropy(model_ref(inputs), labels) + loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels) + + rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item()) + assert rel_error < 3e-3, (i, rel_error) + + loss_ref.backward() + optim_ref.step() + optim_ref.zero_grad() + + loss_int8mp.backward() + for p in model_int8mp.parameters(): + assert p.grad is not None + optim_int8mp.step() + optim_int8mp.zero_grad() + + +_FSDP_WORLD_SIZE = 2 + class TestFSDP2(FSDPTest): @property def world_size(self) -> int: - return 2 - - @skip_if_lt_x_gpu(2) - def test_fsdp2(self): - # FSDP2 + compiled quantized training fails with PyTorch 2.4 - compile_layer_choices = [False] - if TORCH_VERSION_AFTER_2_4: - compile_layer_choices.append(True) - - self.run_subtests( - {"compile_layer": compile_layer_choices}, - self._test_fsdp2, - ) - - def _test_fsdp2(self, compile_layer): - import torch.distributed as dist - from torch.distributed._composable.fsdp import fully_shard - from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer + return _FSDP_WORLD_SIZE + + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + def test_fsdp2_correctness(self): + test_args = [ + ( + int8_weight_only_quantized_training(), # quantize_fn for base model + int8_weight_only_quantized_training(), # quantize_fn for FSDP model + MixedPrecisionPolicy(), + 0.05, # tolerance. due to stochastic rounding, use a pretty large tolerance here + ), + ( + int8_mixed_precision_training(), + int8_mixed_precision_training(), + MixedPrecisionPolicy(), + 1e-6, + ), + ( + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + int8_mixed_precision_training(), + int8_mixed_precision_training(Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=torch.bfloat16)), + MixedPrecisionPolicy(param_dtype=torch.bfloat16), + 1e-2, + ), + ] + self.run_subtests({"args": test_args}, self._run_subtest) + + def _run_subtest(self, args): + base_quantize_fn, fsdp_quantize_fn, mp_policy, tolerance = args - _reset() batch_size = 3 vocab_size = 32 seq_len = 64 + + # NOTE: if weight_tying=True and we also quantize LM head, INT8 mixed-precision will fail. model_args = ModelArgs( n_layers=2, n_heads=2, @@ -181,19 +256,16 @@ def _test_fsdp2(self, compile_layer): ) torch.manual_seed(42) base_model = Transformer(model_args).cuda() - quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False) fsdp_model = copy.deepcopy(base_model) - if compile_layer: - for layer in base_model.layers: - layer.compile() + quantize_(base_model.layers, base_quantize_fn, set_inductor_config=False) + quantize_(fsdp_model.layers, fsdp_quantize_fn, set_inductor_config=False) for layer in fsdp_model.layers: - if compile_layer: - layer.compile() - fully_shard(layer) - fully_shard(fsdp_model) + fully_shard(layer, mp_policy=mp_policy) + fully_shard(fsdp_model, mp_policy=mp_policy) + # start testing base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) @@ -203,19 +275,20 @@ def _test_fsdp2(self, compile_layer): fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() fsdp_loss.backward() + for param in fsdp_model.parameters(): + assert param.grad is not None fsdp_optim.step() base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) base_loss = base_model(inp).sum() base_loss.backward() for param in base_model.parameters(): - if param.grad is not None: - dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) + assert param.grad is not None + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) base_optim.step() - # due to stochastic rounding, use a pretty large tolerance here rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < 0.05, rel_error + assert rel_error < tolerance, (iter_idx, rel_error) instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 9b2980aa2b..1dde72598c 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -1,20 +1,32 @@ # Quantized training -This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from: -- Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] +This folder contains experimental work on quantized training (QT), with a focus on INT8. We take inspirations from: + - AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] +- SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)] +- Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] +- JetFire: [[paper](https://arxiv.org/abs/2403.12422)] [[code](https://github.com/thu-ml/Jetfire-INT8Training)] -Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. +The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. However, terminologies for INT8 training are generally not standardized yet. To be precise, we use these terms with the following meaning: -In precise terms, the probability of rounding up is `x - ⌊x⌋`. Note that when the value is exactly an integer value, the probability of rounding up is zero. +- **Quantized training**: model weights are quantized. This is a strict requirement. Does not matter what is the compute precision. Examples of this: Q-GaLore, JetFire. +- **INT8 mixed-precision training**: model weights are in original precision, while compute dtype for some or all ops is in INT8. We call it like this because it is similar to FP16/BF16 mixed-precision training. One difference is that in FP16/BF16 mixed-precision training, matmul will return FP16/BF16 outputs, while for INT8 mixed-precision training, the returned dtype is usually not INT8. Examples include Google AQT and SwitchBack. -There are 2 main benefits for training in this way: -1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting. -2. What you train is what you serve ([WYTIWYS](https://github.com/google/aqt?tab=readme-ov-file#features)). +There are 3 main benefits of using low-precision dtype for training (the extent depends on the actual strategies): -Currently we only support weight-only channel-wise INT8 symmetric quantization. +- **Memory**: reduce memory footprint by model weights, activations, gradients, and distributed communication bandwidth. +- **Speed**: speedup compute-bound ops with low-precision hardware instructions (e.g. INT8 Tensor Cores) and speedup memory-bound ops with quantized inputs/outputs. +- [What you train is what you serve](https://github.com/google/aqt?tab=readme-ov-file#features). + +[`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training on single GPU for strategies implemented in this folder. + +## INT8 quantized training + +Typically, quantized weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. + +In precise terms, the probability of rounding up is `x - ⌊x⌋`. Note that when the value is exactly an integer value, the probability of rounding up is zero. -## INT8 weight only +Currently we only support weight-only channel-wise INT8 symmetric quantization. In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization `[-127, 127]`. In the forward and backward pass, the weights are upcast to activations' dtype (e.g. BF16). Therefore, their gradients are also in activations' dtype. @@ -23,7 +35,7 @@ Usage ```python from torchao.prototype.quantized_training import int8_weight_only_quantized_training from torchao.prototype.low_bit_optim import _AdamW -from torchao.quantization.quant_api import quantize_ +from torchao.quantization import quantize_ model = ... quantize_(model, int8_weight_only_quantized_training()) @@ -33,21 +45,93 @@ optim = _AdamW(model.parameters(), lr=3e-4) Only `torch.optim.Adam` and optimizers from `torchao.prototype.low_bit_optim` are known to work with quantized training in this folder. This is because we implement stochastic rounding logic within tensor subclass instead of the optimizer. We provide `torchao.prototype.low_bit_optim._AdamW` as an alternative to `torch.optim.AdamW` specifically for this purpose. -[`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. - See [#644](https://github.com/pytorch/ao/pull/644) for some early results. -TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Memory benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing. +TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing, 4070Ti SUPER. + +Model | Peak memory (GB) | toks/s +----------------|------------------|------- +BF16 eager | 11.07 | 6200 +BF16 compile | 10.25 | 9000 +INT8 QT eager | 10.12 | 5600 +INT8 QT compile | 9.84 | 8700 + +## INT8 mixed-precision + +On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. + +### Basic usage + +```python +from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig +from torchao.quantization import quantize_ + +model = ... -Model | Peak memory (GB) -----------------|----------------- -BF16 eager | 11.06847 -BF16 compile | 10.16915 -INT8 QT eager | 10.11437 -INT8 QT compile | 10.03365 +# apply INT8 matmul to all 3 matmuls +quantize_(model, int8_mixed_precision_training()) + +# customize which matmul is left in original precision. +config = Int8MixedPrecisionTrainingConfig( + output=True, + grad_input=True, + grad_weight=False, +) +quantize_(model, int8_mixed_precision_training(config)) + +# train model as usual +``` + +During training, there are 3 matmuls involved in each `nn.Linear` layer: +- 1 in forward: `output = input @ weight.T` +- 2 in backward: + - `grad_input = grad_output @ weight` + - `grad_weight = grad_output.T @ input` + +You can configure which matmul to be applied with INT8 mixed-precision (shown above). If convergence is an issue, we recommend leaving `grad_weight` in original matmul precision, and also `grad_input` if the issue still persists. + +Note: +- When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT for INT8 dynamic activations + INT8 weight quantization (A8W8). +- When we only apply INT8 mixed-precision to `output` and `grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling. +- Apply stochastic rounding to INT8 quantization may improve matmul accuracy. However, from our testing, this seems to be unnecessary, thus we don't implement it at the moment. + +Pre-train Llama2-1B on C4 realnewslike subset. bs=32, seq_len=2048 -> 65k tok/batch. Train for 20k steps (1.3B tokens). Using 4090. INT8 mixed precision is not applied to LM head. + +Config | Tok/s | Peak mem (GB) | Val loss +---------------------|-------|---------------|--------- +BF16 (baseline) | ~17k | 19.47 | 2.97 +INT8 mixed-precision | ~29k | 19.47 | 2.90 + +See [#748](https://github.com/pytorch/ao/pull/748) for more results. + +### FSDP support + +Out of the box, this INT8 mixed-precision training is not compatible with FSDP2 `MixedPrecisionPolicy(param_dtype=param_dtype)`, where `param_dtype` != model dtype. As a workaround, you will need to manually specify the FSDP2's `param_dtype` in `Int8MixedPrecisionTrainingConfig` + +```python +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig +from torchao.quantization import quantize_ + +model = ... # FP32 model + +# setup configs +mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) +int8mp_config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype) + +# exclude LM head +quantize_(model.layers, int8_mixed_precision_training(int8mp_config)) + +# shard the model w/ FSDP2 +for layer in model.layers: + fully_shard(layer, mp_policy=mp_policy) +fully_shard(model, mp_policy=mp_policy) + +# train model as usual +``` ## Future ideas -- INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores. +- Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire). - INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). - FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy. diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index 6c7f8eb9b1..ccf2f5375d 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1 +1,10 @@ -from .int8 import Int8QTLinearWeight, int8_weight_only_quantized_training +from .int8 import ( + Int8QuantizedTrainingLinearWeight, + int8_weight_only_quantized_training, + quantize_int8_rowwise, +) +from .int8_mixed_precision import ( + Int8MixedPrecisionTrainingConfig, + Int8MixedPrecisionTrainingLinearWeight, + int8_mixed_precision_training, +) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index fa8805fce7..828655f04c 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -1,10 +1,11 @@ from typing import Any, Optional, Tuple import torch -from torch import Tensor, nn +from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import TorchAOBaseTensor +from torchao.quantization.quant_api import _get_linear_subclass_inserter aten = torch.ops.aten @@ -12,13 +13,41 @@ _c10d_functional = torch.ops._c10d_functional -class Int8QTLinearWeight(TorchAOBaseTensor): +@torch.no_grad() +def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): + """Normal rounding will always round down small changes in weight update. To tackle this problem, + stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The + probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next + integer value. Thus, stochastic rounding also approximates the floating point value exactly. + + Currently this function differs from AQT's `int8_weight_only()` in the following way: + 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input + to FP32 before quantization. Output scale maintains the original input dtype. + 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is + done here. + 3. Apply scale: AQT uses `input * (1 / scale)`, while this function performs `input / scale`. + """ + # absmax symmetric quantization + scale = tensor.abs().amax(1) / 127 # same dtype as tensor + inv_scale = 1.0 / scale.float().clip(1e-12) + tensor = tensor.float() * inv_scale.view(-1, 1) # slightly faster than divide directly + + if stochastic_rounding: + tensor = (tensor + torch.rand_like(tensor)).floor() + else: + tensor = tensor.round() + + tensor = tensor.clip(-128, 127).to(torch.int8) + return tensor, scale + + +class Int8QuantizedTrainingLinearWeight(TorchAOBaseTensor): """INT8 symmetric quantization weight, with absmax scaling [-127, 127]. The main difference of this tensor subclass from AffineQuantizedTensor: 1. `F.linear` is differentiable i.e. backward is defined. 2. All in-place ops, such as `aten.copy_`, will perform stochastic rounding. `Int8QTLinearWeight.from_float()` does not perform stochastic rounding. - 3. The numerics for quantization is slightly different. See `Int8QTLinearWeight.quantize()` + 3. The numerics for quantization is slightly different. See `quantize_int8_rowwise()` for more details. """ @@ -51,42 +80,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) - @staticmethod - @torch.no_grad() - def quantize(tensor: Tensor, stochastic_rounding: bool = False): - """Normal rounding will always round down small changes in weight update. To tackle this problem, - stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The - probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next - integer value. Thus, stochastic rounding also approximates the floating point value exactly. - - Currently this function differs from AQT's `int8_weight_only()` in the following way: - 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input - to FP32 before quantization, and downcast scale to original dtype. - 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is - done here. - 3. Apply scale: AQT uses `input * (1 / scale)`, while this function performs `input / scale`. - """ - original_dtype = tensor.dtype - tensor = tensor.float() - - # absmax symmetric quantization - scale = tensor.abs().amax(-1) / 127 - tensor = tensor / scale.clip(1e-12).view(-1, 1) - - if stochastic_rounding: - tensor = (tensor + torch.rand_like(tensor)).floor() - else: - tensor = tensor.round() - - tensor = tensor.clip(-128, 127).to(torch.int8) - return tensor, scale.to(original_dtype) - @classmethod def from_float(cls, tensor: Tensor): """Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed. This function is not differentiable. """ - int_data, scale = cls.quantize(tensor.detach()) + int_data, scale = quantize_int8_rowwise(tensor.detach()) out = cls(int_data, scale) out.requires_grad_(tensor.requires_grad) return out @@ -112,12 +111,12 @@ def fsdp_post_all_gather( out: Optional[Tensor] = None, ): int_data, scale = all_gather_outputs - return Int8QTLinearWeight(int_data, scale), all_gather_outputs + return Int8QuantizedTrainingLinearWeight(int_data, scale), all_gather_outputs class _Int8WeightOnlyLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): + def forward(ctx, input: Tensor, weight: Int8QuantizedTrainingLinearWeight, bias: Optional[Tensor] = None): ctx.save_for_backward(input, weight) ctx.bias = bias is not None @@ -136,12 +135,15 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias -@Int8QTLinearWeight.implements(torch.nn.functional.linear) +implements = Int8QuantizedTrainingLinearWeight.implements + + +@implements(torch.nn.functional.linear) def _(func, types, args, kwargs): return _Int8WeightOnlyLinear.apply(*args, **kwargs) -@Int8QTLinearWeight.implements( +@implements( [ aten.detach.default, aten.clone.default, @@ -155,20 +157,20 @@ def _(func, types, args, kwargs): ) def _(func, types, args, kwargs): # will error out if try to slice 2nd dim - out = Int8QTLinearWeight( + out = Int8QuantizedTrainingLinearWeight( func(args[0].int_data, *args[1:], **kwargs), func(args[0].scale, *args[1:], **kwargs), ) return return_and_correct_aliasing(func, args, kwargs, out) -@Int8QTLinearWeight.implements(aten._to_copy.default) +@implements(aten._to_copy.default) def _(func, types, args, kwargs): # only perform dtype casting on scale, which determines the appearance dtype # TODO: handle non_blocking kwarg? device = kwargs.get("device", None) dtype = kwargs.get("dtype", None) - out = Int8QTLinearWeight( + out = Int8QuantizedTrainingLinearWeight( args[0].int_data.to(device=device), args[0].scale.to(device=device, dtype=dtype), ) @@ -176,7 +178,7 @@ def _(func, types, args, kwargs): # to make training work with existing PyTorch optimizers, we return a normal tensor -@Int8QTLinearWeight.implements(aten.zeros_like.default) +@implements(aten.zeros_like.default) def _(func, types, args, kwargs): dtype = kwargs.get("dtype", args[0].dtype) device = kwargs.get("device", args[0].device) @@ -184,20 +186,20 @@ def _(func, types, args, kwargs): # out-of-place math ops always return plain tensor -@Int8QTLinearWeight.implements([aten.sub.Tensor, aten.mul.Tensor]) +@implements([aten.sub.Tensor, aten.mul.Tensor]) def _(func, types, args, kwargs): - args = [x.dequantize() if isinstance(x, Int8QTLinearWeight) else x for x in args] + args = [x.dequantize() if isinstance(x, Int8QuantizedTrainingLinearWeight) else x for x in args] return func(*args, **kwargs) -@Int8QTLinearWeight.implements(aten.copy_.default) +@implements(aten.copy_.default) def _(func, types, args, kwargs): - if isinstance(args[0], Int8QTLinearWeight) and isinstance(args[1], Int8QTLinearWeight): + if isinstance(args[0], Int8QuantizedTrainingLinearWeight) and isinstance(args[1], Int8QuantizedTrainingLinearWeight): args[0].int_data.copy_(args[1].int_data, **kwargs) args[0].scale.copy_(args[1].scale, **kwargs) - elif isinstance(args[0], Int8QTLinearWeight): - int_data, scale = Int8QTLinearWeight.quantize(args[1], stochastic_rounding=True) + elif isinstance(args[0], Int8QuantizedTrainingLinearWeight): + int_data, scale = quantize_int8_rowwise(args[1], stochastic_rounding=True) args[0].int_data.copy_(int_data, **kwargs) args[0].scale.copy_(scale, **kwargs) @@ -207,7 +209,7 @@ def _(func, types, args, kwargs): return args[0] -@Int8QTLinearWeight.implements([aten.addcdiv_.default, aten.add_.Tensor]) +@implements([aten.addcdiv_.default, aten.add_.Tensor]) def _(func, types, args, kwargs): original = args[0] out = func(args[0].dequantize(), *args[1:], **kwargs) @@ -215,20 +217,20 @@ def _(func, types, args, kwargs): # FSDP ops -@Int8QTLinearWeight.implements(aten.split.Tensor) +@implements(aten.split.Tensor) def _(func, types, args, kwargs): if len(args) == 3 and args[2] != 0: raise NotImplementedError("Int8QTLinearWeight only supports split at dim=0") - int8_weight: Int8QTLinearWeight = args[0] + int8_weight: Int8QuantizedTrainingLinearWeight = args[0] int_data_list = func(int8_weight.int_data, *args[1:], **kwargs) scale_list = func(int8_weight.scale, *args[1:], **kwargs) - out = [Int8QTLinearWeight(int_data, scale) for int_data, scale in zip(int_data_list, scale_list)] + out = [Int8QuantizedTrainingLinearWeight(int_data, scale) for int_data, scale in zip(int_data_list, scale_list)] return out -@Int8QTLinearWeight.implements(aten.new_zeros.default) +@implements(aten.new_zeros.default) def _(func, types, args, kwargs): size = args[1] if len(size) != 2: @@ -239,7 +241,7 @@ def _(func, types, args, kwargs): dtype = kwargs.get("dtype", args[0].dtype) int_data = torch.zeros(size, device=device, dtype=torch.int8) scale = torch.zeros(size[0], device=device, dtype=dtype) - return Int8QTLinearWeight(int_data, scale) + return Int8QuantizedTrainingLinearWeight(int_data, scale) # FSDP2 will call these two ops, expecting a view, not a copy. It doesn't make sense to @@ -247,21 +249,11 @@ def _(func, types, args, kwargs): # since this is channel-wise quantization. # Thus, this is a workaround for FSDP2. Users SHOULD NOT call these ops directly, since # they will produce unexpected or wrong results. -@Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) +@implements([aten.view.default, aten.as_strided.default]) def _(func, types, args, kwargs): - out = Int8QTLinearWeight(args[0].int_data, args[0].scale) + out = Int8QuantizedTrainingLinearWeight(args[0].int_data, args[0].scale) return return_and_correct_aliasing(func, args, kwargs, out) def int8_weight_only_quantized_training(): - # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` - # when we have this out of prototype (or there are stable trainable tensor subclasses), - # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. - def apply_int8_linear_weight(linear: nn.Linear): - linear.weight = nn.Parameter( - Int8QTLinearWeight.from_float(linear.weight), - requires_grad=linear.weight.requires_grad, - ) - return linear - - return apply_int8_linear_weight + return _get_linear_subclass_inserter(Int8QuantizedTrainingLinearWeight.from_float, allow_requires_grad=True) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py new file mode 100644 index 0000000000..0f96e348ba --- /dev/null +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -0,0 +1,231 @@ +from typing import Any, NamedTuple, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch.utils._triton import has_triton + +from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor + +from .int8 import quantize_int8_rowwise + +if has_triton(): + from .int8_mm import int8_mm_dequant + +else: + + # This is less performant than the explicit hand-written Triton kernel, though things might + # change in the future. + # Multiplying B_scale first is faster than the other way round. + def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: + return torch._int_mm(A, B) * B_scale_colwise * A_scale_rowwise.view(-1, 1) + + +class Int8MixedPrecisionTrainingConfig(NamedTuple): + output: bool = True + grad_input: bool = True + grad_weight: bool = True + + # workaround for FSDP2 with `MixedPrecisionPolicy(param_dtype)` + # see `Int8MixedPrecisionTrainingLinearWeight.fsdp_pre_all_gather()` for more details. + fsdp_param_dtype: Optional[torch.dtype] = None + + +_DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() + + +aten = torch.ops.aten + + +class Int8MixedPrecisionTrainingLinearWeight(TorchAOBaseTensor): + """Linear weight for INT8 mixed-precision training. The weight is in original precision (e.g. FP32 or BF16). + During training, weight and activation are dynamically quantized and cast to INT8 to utilize INT8 Tensor Cores, + and then scaled back to original precision. This is also applied to backward pass. + """ + + @staticmethod + @torch._dynamo.disable + def __new__(cls, data: Tensor, config: Int8MixedPrecisionTrainingConfig): + return Tensor._make_wrapper_subclass( + cls, + data.shape, + data.stride(), + data.storage_offset(), + dtype=data.dtype, + device=data.device, + ) + + @torch._dynamo.disable + def __init__(self, data: Tensor, config: Int8MixedPrecisionTrainingConfig): + self._data = data + self.config = config + + def __tensor_flatten__(self): + return ["_data"], [self.config] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["_data"], *tensor_attributes) + + def __repr__(self): + return f"{self.__class__.__name__}(data={self._data}, config={self.config})" + + def to_original(self): + return self._data.clone() + + # adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + config = None + + def unwrap(x: cls): + nonlocal config + if config is None: + config = x.config + else: + assert x.config == config + return x._data + + out = func( + *pytree.tree_map_only(cls, unwrap, args), + **pytree.tree_map_only(cls, unwrap, kwargs), + ) + + if func is aten.copy_.default: + # return original object + return args[0] + elif func in { + aten.t.default, + aten.detach.default, + aten.empty_like.default, + aten.new_zeros.default, + aten.slice.Tensor, + aten.view.default, + aten.as_strided.default, + aten._to_copy.default, + aten._pin_memory.default, + aten.split.Tensor, + aten.clone.default, + }: + # return new wrapped object + return pytree.tree_map_only(Tensor, lambda x: cls(x, config), out) + else: + # return new unwrapped object + return out + + def fsdp_pre_all_gather(self, mesh): + # TODO: pre-quantize weight here -> reduce comm bandwidth. + # we will need another tensor subclass to hold the quantized weight. + + # doing dtype casting to `param_dtype` in `fsdp_post_all_gather()` will give wrong results. + # as a workaround, we do it in `fsdp_pre_all_gather()` instead. since `param_dtype` is not + # exposed to `fsdp_pre_all_gather()`, we need to specify it in the config. + # this workaround can be removed once we implement INT8 communication. + data = self._data.to(dtype=self.config.fsdp_param_dtype) + return (data,), (self.config,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Tensor] = None, + ): + (data,) = all_gather_outputs + (config,) = metadata + if out is not None: + assert isinstance(out, Int8MixedPrecisionTrainingLinearWeight) + assert out.config == config + return + return Int8MixedPrecisionTrainingLinearWeight(data, config), all_gather_outputs + + +@Int8MixedPrecisionTrainingLinearWeight.implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + if torch.is_autocast_enabled("cuda"): + dtype = torch.get_autocast_gpu_dtype() + args = tuple(x.to(dtype) if x is not None else x for x in args) + return _Int8MixedPrecisionTrainingLinear.apply(*args, **kwargs) + + +def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: + """Dynamically quantize A and B to perform INT8 matmul, then scale the results back to original precision. + To fuse scaling to matmul output, we use row-wise scaling for A and column-wise scaling for B. + + We transpose B before quantization for 2 reasons: + - INT8 matmul is the most performant when A is row-major and B is column-major. + - Row-wise scaling for B.T is column-wise scaling for B -> we only need to implement row-wise scaling. + + Note that inputs and outputs of `quantize_int8_rowwise()` are not guaranteed to be contiguous. We call + `.contiguous()` to outputs of the quantize op to make sure: + - Performant layout for INT8 matmul inputs (see above). + - Scales are contiguous (this is a limitation of our triton kernel). + + We hope that the `.contiguous()` calls, as well as possible layout transpose before quantization, are + fused into quantize op by torch compiler. + + TODO: check if transpose+quantize are actually fused. + """ + # A may have more than 2 dims, while B must be exactly 2-dim + A_i8, A_scale_rowwise = quantize_int8_rowwise(A.view(-1, A.shape[-1])) + B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) + out = int8_mm_dequant( + A_i8.contiguous(), + B_t_i8.contiguous().T, + A_scale_rowwise.contiguous(), + B_scale_colwise.contiguous(), + ) + return out.view(*A.shape[:-1], out.shape[-1]) + + +class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function): + @staticmethod + def forward(input: Tensor, weight: Int8MixedPrecisionTrainingLinearWeight, bias: Optional[Tensor]): + if weight.config.output: + out = _dynamic_int8_mm(input, weight._data.T) + else: + out = input @ weight._data.T + out = out + bias if bias is not None else out + return out + + @staticmethod + def setup_context(ctx, inputs, output): + input, weight, bias = inputs + ctx.config = weight.config + ctx.save_for_backward(input, weight._data) + ctx.bias = bias is not None + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + if ctx.config.grad_input: + grad_input = _dynamic_int8_mm(grad_output, weight) + else: + grad_input = grad_output @ weight + + if ctx.needs_input_grad[1]: + grad_output = grad_output.view(-1, weight.shape[0]) + input = input.view(-1, weight.shape[1]) + if ctx.config.grad_weight: + # grad_weight = _dynamic_int8_mm(grad_output.T, input) + grad_weight = _dynamic_int8_mm(input.T, grad_output).T # this is slightly faster + else: + grad_weight = grad_output.T @ input + + if ctx.needs_input_grad[2] and ctx.bias: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + + +def int8_mixed_precision_training(config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): + return _get_linear_subclass_inserter( + Int8MixedPrecisionTrainingLinearWeight, + config=config, + allow_requires_grad=True, + ) diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py new file mode 100644 index 0000000000..b316e82208 --- /dev/null +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -0,0 +1,144 @@ +# TODO: might merge this with torchao/kernel/intmm_triton.py + +import torch +import triton +import triton.language as tl +from torch import Tensor + +lib = torch.library.Library("torchao", "FRAGMENT") + + +# TODO: prune configs to speedup triton autotune +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) +configs = [ + (128, 256, 64, 3, 8), + (64, 256, 32, 4, 4), + (128, 128, 32, 4, 4), + (128, 64, 32, 4, 4), + (64, 128, 32, 4, 4), + (128, 32, 32, 4, 4), + (64, 32, 32, 5, 2), + (32, 64, 32, 5, 2), + # Good config for fp8 inputs + (128, 256, 128, 3, 8), + (256, 128, 128, 3, 8), + (256, 64, 128, 4, 4), + (64, 256, 128, 4, 4), + (128, 128, 128, 4, 4), + (128, 64, 64, 4, 4), + (64, 128, 64, 4, 4), + (128, 32, 64, 4, 4), + # https://github.com/pytorch/pytorch/blob/7868b65c4d4f34133607b0166f08e9fbf3b257c4/torch/_inductor/kernel/mm_common.py#L172 + (64, 64, 32, 2, 4), + (64, 128, 32, 3, 4), + (128, 64, 32, 3, 4), + (64, 128, 32, 4, 8), + (128, 64, 32, 4, 8), + (64, 32, 32, 5, 8), + (32, 64, 32, 5, 8), + (128, 128, 32, 2, 8), + (64, 64, 64, 3, 8), + (128, 256, 128, 3, 8), + (256, 128, 128, 3, 8), +] + +configs = [ + triton.Config(dict(BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K), num_stages=num_stages, num_warps=num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps in configs +] + + +@triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) +@triton.jit +def _int8_mm_dequant_kernel( + A_ptr, B_ptr, C_ptr, + A_scale_rowwise_ptr, + B_scale_colwise_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr = 8, + EVEN_K: tl.constexpr = True, +): + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A_ptr + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B_ptr + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + a_scale = tl.load(A_scale_rowwise_ptr + idx_m, mask=idx_m < M).to(tl.float32) + b_scale = tl.load(B_scale_colwise_ptr + idx_n, mask=idx_n < N).to(tl.float32) + acc = acc.to(tl.float32) * a_scale * b_scale + + # inductor generates a suffix + xindex = idx_m * stride_cm + idx_n * stride_cn + tl.store(C_ptr + tl.broadcast_to(xindex, mask.shape), acc, mask) + + +lib.define("int8_mm_dequant(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale) -> Tensor") + + +def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: + assert A.dtype is torch.int8 and B.dtype is torch.int8 + assert A_scale_rowwise.dtype is B_scale_colwise.dtype + assert A.shape[1] == B.shape[0] + assert A_scale_rowwise.squeeze().shape == (A.shape[0],) + assert B_scale_colwise.squeeze().shape == (B.shape[1],) + assert A_scale_rowwise.is_contiguous() + assert B_scale_colwise.is_contiguous() + return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) + + +@torch.library.impl(lib, "int8_mm_dequant", "Meta") +def _(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): + return torch.empty((A.shape[0], B.shape[1]), device=A.device, dtype=A_scale_rowwise.dtype) + + +@torch.library.impl(lib, "int8_mm_dequant", "CUDA") +def int8_mm_dequant_cuda(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): + M, K = A.shape + _, N = B.shape + C = torch.empty(M, N, device=A.device, dtype=A_scale_rowwise.dtype) + grid = lambda meta: (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + _int8_mm_dequant_kernel[grid]( + A, B, C, A_scale_rowwise, B_scale_colwise, M, N, K, *A.stride(), *B.stride(), *C.stride(), EVEN_K=K % 2 == 0 + ) + return C diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9516da9763..5ca6e7cd33 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -383,12 +383,13 @@ def _quantization_type(weight: torch.Tensor): def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, **kwargs): +def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ def insert_subclass(lin): - lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=False) + requires_grad = allow_requires_grad and lin.weight.requires_grad + lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad) lin.extra_repr = types.MethodType(_linear_extra_repr, lin) return lin