Skip to content

Commit

Permalink
Add INT8 mixed-precision training (#748)
Browse files Browse the repository at this point in the history
* initial commit

* expose some UX. update test

* add test. update bench

* update test. add doc

* fix ngpu

* fix FSDP

* fix

* fix fsdp test

* fix

* grammar

* simplify fsdp test

* update benchmark script

* update

* make claim more conservative

* register fused adam

* update benchmark script

* add more ops

* update default

* use TorchAOBaseTensor

* fix fsdp param_dtype

* fix param_dtype

* dtype check to prevent unnecessary errors

* move checks

* add note

* fix

* simplify script

* add module-based UX

* fix

* use FP8 impl of __torch_dispatch__

* rename _dynamice interface

* update test

* fix compile on 2.4

* log torch version

* make log interval customizable

* make naming for explicit

* update readme

* some change

* fix big bug

* add docstring. update _get_linear_inserter

* add TorchAOBaseTensor back

* fix FSDP

* update FSDP test. add autocast support

* reduce iter

* update int8_mm fallback

* put leading dims logic to _dynamic_int8_mm
  • Loading branch information
gau-nernst authored and andrewor14 committed Sep 9, 2024
1 parent 10d038f commit 3f7fc14
Show file tree
Hide file tree
Showing 9 changed files with 771 additions and 175 deletions.
45 changes: 45 additions & 0 deletions benchmarks/quantized_training/benchmark_int8mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pandas as pd
import torch
from triton.testing import do_bench

from torchao.prototype.quantized_training.int8_mm import int8_mm_dequant


def bench_f(f, *args):
return do_bench(lambda: f(*args), fast_flush=False, return_mode="median")


shapes = [(sz, sz, sz) for sz in [1024, 2048, 4096]]

# Llama-8B shapes
shapes += [
# linear in attention
(32_768, 4096, 4096),
(4096, 4096, 32_768),
# linear in feed-forward
(32_768, 14_336, 4096),
(32_768, 4096, 14_336),
(14_336, 4096, 32_768),
]

data = []
for M, N, K in shapes:
print(f"{M=}, {N=}, {K=}")

A_bf16 = torch.randn(M, K).bfloat16().cuda()
B_bf16 = torch.randn(N, K).bfloat16().cuda()
A_i8 = torch.randint(-128, 127, size=(M, K), dtype=torch.int8).cuda()
B_i8 = torch.randint(-128, 127, size=(N, K), dtype=torch.int8).cuda()
A_scale = torch.randn(M).bfloat16().cuda()
B_scale = torch.randn(N).bfloat16().cuda()

# benchmark F.linear() i.e. A @ B.T
bf16_time = bench_f(torch.mm, A_bf16, B_bf16.T)
i8_time = bench_f(torch._int_mm, A_i8, B_i8.T)
i8_dequant_time = bench_f(int8_mm_dequant, A_i8, B_i8.T, A_scale, B_scale)

sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time]
data.append(sample)

df = pd.DataFrame(data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"])
print(df.to_markdown())
57 changes: 37 additions & 20 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
import time
from functools import partial
from pathlib import Path

Expand All @@ -18,22 +20,34 @@
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm

from torchao._models.llama.model import ModelArgs, Transformer
from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.prototype.quantized_training import (
int8_mixed_precision_training,
int8_weight_only_quantized_training,
)
from torchao.quantization.quant_api import quantize_


# not official models
transformer_configs.update(
(
("470M", dict(n_layer=24, n_head=16, dim=1024, intermediate_size=4096)),
("1B", dict(n_layer=24, n_head=24, dim=1536, intermediate_size=6144)),
)
)


# hack from fairseq
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
def enable_activation_checkpointing(m: torch.nn.Module):
assert not hasattr(m, "_forward")
m._forward = m.forward
m.forward = partial(checkpoint, m.forward)
m.forward = partial(checkpoint, m.forward, use_reentrant=False)


def get_loss(model: Transformer, batch: torch.Tensor):
logits = model(batch)[:, :-1].flatten(0, 1)
logits = model(batch)[:, :-1].float().flatten(0, 1)
labels = batch[:, 1:].flatten()
return torch.nn.functional.cross_entropy(logits, labels)

Expand Down Expand Up @@ -77,12 +91,7 @@ def get_tinystories():

if __name__ == "__main__":
parser = argparse.ArgumentParser()
# default config is 470M
parser.add_argument("--d_model", type=int, default=1024)
parser.add_argument("--depth", type=int, default=24)
parser.add_argument("--ffn_size", type=int, default=4096)
parser.add_argument("--head_dim", type=int, default=64)

parser.add_argument("--model", default="470M", choices=transformer_configs.keys())
parser.add_argument("--quantize")
parser.add_argument("--activation_checkpointing", action="store_true")
parser.add_argument("--compile", action="store_true")
Expand All @@ -98,44 +107,48 @@ def get_tinystories():
parser.add_argument("--project", default="int8_quantized_training")
parser.add_argument("--run_name")
parser.add_argument("--seed", type=int)
parser.add_argument("--log_interval", type=int, default=10)
args = parser.parse_args()

if args.seed is not None:
torch.manual_seed(args.seed)

config = ModelArgs(
block_size=args.seq_len,
n_layer=args.depth,
n_head=args.d_model // args.head_dim,
dim=args.d_model,
intermediate_size=args.ffn_size,
)
config = ModelArgs.from_name(args.model)
config.block_size = args.seq_len
model = Transformer(config).bfloat16().cuda()
with torch.device("cuda"):
model.setup_caches(args.batch_size, args.seq_len, training=True)
if args.activation_checkpointing:
for layer in model.layers:
enable_activation_checkpointing(layer)

# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
# TODO: might want to do the same for int8_weight_only to standardize.
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize == "int8_mixed_precision":
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)
elif args.quantize is not None:
raise ValueError(f"Unsupported quantize={args.quantize}")

print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights

# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
if args.optim == "AdamW":
args.optim = "_AdamW"
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

data = get_tinystories().cuda()
args.torch_version = torch.__version__
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name)

step = 0
log_interval = 50
pbar = tqdm(total=args.n_steps, dynamic_ncols=True)
model.train()
_get_loss = torch.compile(get_loss) if args.compile else get_loss
time0 = time.time()

while step < args.n_steps:
# randomly select a continuous chunk, then reshape it
Expand All @@ -145,13 +158,17 @@ def get_tinystories():
loss = _get_loss(model, batch)
loss.backward()

if step % log_interval == 0:
if step % args.log_interval == 0:
log_dict = dict(
loss=loss.item(),
lr=optim.param_groups[0]["lr"],
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9,
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
)
if step > 0:
time1 = time.time()
log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)
time0 = time1
run.log(log_dict, step=step)
pbar.set_postfix(loss=log_dict["loss"])

Expand Down
Loading

0 comments on commit 3f7fc14

Please sign in to comment.