Skip to content

v0.6.1

Compare
Choose a tag to compare
@drisspg drisspg released this 21 Oct 21:45
· 245 commits to main since this release
99c8d52

Highlights

We are excited to announce the 0.6.1 release of torchao! This release adds support for Auto-Round support, Float8 Axiswise scaled training, a BitNet training recipe, an implementation of AWQ and much more!

Auto-Round Support (#581)

Auto-Round is a new weight-only quantization algorithm, it has as achieved superior accuracy compared to GPTQ, AWQ, and OmniQuant across 11 tasks, particularly excelling in low-bit quantization (e.g., 2-bits and 3-bits). Auto-Round supports quantization from 2 to 8 bits, involves low tuning costs, and imposes no additional overhead during inference. Key results are summarized below, with detailed information available in our paper, GitHub repository, and Hugging Face low-bit quantization leaderboard.

from torchao.prototype.autoround.core import prepare_model_for_applying_auto_round_
from torchao.prototype.autoround.core import apply_auto_round

prepare_model_for_applying_auto_round_(
    model,
    is_target_module=is_target_module,
    bits=4,
    group_size=128,
    iters=200,
    device=device,
)

input_ids_lst = []
for data in dataloader:
    input_ids_lst.append(data["input_ids"].to(model_device))

multi_t_input_ids = MultiTensor(input_ids_lst)
out = model(multi_t_input_ids)

quantize_(model, apply_auto_round(), is_target_module)

Added float8 training axiswise scaling support with per-gemm-argument configuration (#940)

We added experimental support for rowwise scaled float8 gemm to torchao.float8, with per-gemm-input configurability to enable exploration of various recipes. Here is how a user can configure all-axiswise scaling

# all-axiswise scaling
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m = torchao.float8.convert_to_float8_training(config)

# or, a custom recipe by @lw where grad_weight is left in bfloat16
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
m = torchao.float8.convert_to_float8_training(config)

Early performance benchmarks show all-axiswise scaling achieve a 1.13x speedup vs bf16 on torchtitan / LLaMa 3 8B / 8 H100 GPUs (compared to 1.17x from all-tensorwise scaling in the same setup), and loss curves which match to bf16 and all-tensorwise scaling. Further performance and accuracy benchmarks will follow in future releases.

Introduced BitNet b1.58 training recipe (#930)

Adds recipe for doing BitNet b1.58](https://arxiv.org/abs/2402.17764) ternary weights clamping.

from torchao.prototype.quantized_training import bitnet_training
from torchao import quantize_

model = ...
quantize_(model, bitnet_training())

Notably: Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases.

[Prototype] Implemented Activation Aware Weight Quantization AWQ (#743)

Perplexity and performance measured on A100 GPU:

Model Quantization Tokens/sec Throughput (GB/sec) Peak Mem (GB) Model Size (GB)
Llama-2-7b-chat-hf bfloat16 107.38 1418.93 13.88 13.21
awq-hqq-int4 196.6 761.2 5.05 3.87
awq-uint4 43.59 194.93 7.31 4.47
int4wo-hqq 209.19 804.32 4.89 3.84
int4wo-64 201.14 751.42 4.87 3.74

Usage:

from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model=model.to(device)
insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
with torch.no_grad():
    for batch in calibration_data:
        model(batch.to(device))
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)

New Features

  • [Prototype] Added Float8 support for AQT tensor parallel (#1003)
  • Added composable QAT quantizer (#938)
  • Introduced torchchat quantizer (#897)
  • Added INT8 mixed-precision training (#748)
  • Implemented sparse marlin AQT layout (#621)
  • Added a PerTensor static quant api (#787)
  • Introduced uintx quant to generate and eval (#811)
  • Added Float8 Weight Only and FP8 weight + dynamic activation (#740)
  • Implemented Auto-Round support (#581)
  • Added 2, 3, 4, 5 bit custom ops (#828)
  • Introduced symmetric quantization with no clipping error in the tensor subclass based API (#845)
  • Added int4 weight-only embedding QAT (#947)
  • Added support for 1-bit and 6-bit quantization for Llama in torchchat (#910, #1007)
  • Added a linear_observer class for doing static activation calibration (#807)
  • Exposed hqq through uintx_weight_only API (#786)
  • Added RowWise scaling option for Float8 dynamic activation quantization (#819)
  • Added Float8 weight only to autoquant api (#866)

Improvements

  • Enhanced Auto-Round functionality (#870)
  • Improved FSDP support for low-bit optimizers (#538)
  • Added support for using AffineQuantizedTensor with weights_only=True for torch.load (#630)
  • Optimized 3-bit packing (#1029)
  • Added more evaluation metrics to llama/eval.sh (#934)
  • Improved eager numerics for dynamic scales in float8 (#904)

Bug fixes

  • Fixed inference_mode issues (#885)
  • Fixed failing FP6 benchmark (#931)
  • Resolved various issues with float8 support (#918, #923)
  • Fixed load state dict when device is different for low-bit optim (#1021)

Performance

  • Added SM75 (Turing) support for FP6 kernel (#942)
  • Implemented int8 dynamic quant + bsr support (#821)
  • Added workaround to recover the perf for quantized vit in torch.compile (#926)

INT8 Mixed-Precision Training

On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision.

from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig
from torchao.quantization import quantize_

model = ...

# apply INT8 matmul to all 3 matmuls
quantize_(model, int8_mixed_precision_training())

# customize which matmul is left in original precision.
config = Int8MixedPrecisionTrainingConfig(
    output=True,
    grad_input=True,
    grad_weight=False,
)
quantize_(model, int8_mixed_precision_training(config))

End2end speed benchmark using benchmarks/quantized_training/pretrain_llama2.py

Model & GPU bs x seq_len Config Tok/s Peak mem (GB)
Llama2-7B, A100 8 x 2048 BF16 (baseline) ~4400 59.69
Llama2-7B, A100 8 x 2048 INT8 mixed-precision ~6100 (+39%) 58.28
Llama2-1B, 4090 16 x 2048 BF16 (baseline) ~17,900 18.23
Llama2-1B, 4090 16 x 2048 INT8 mixed-precision ~30,700 (+72%) 18.34

Docs

  • Updated README with more current float8 speedup information (#816)
  • Added tutorial for trainable tensor subclass (#908)
  • Improved documentation for float8 unification and inference (#895, #896)

Devs

  • Added compile tests to test suite (#906)
  • Improved CI setup and build processes (#887)
  • Added M1 wheel support (#822)
  • Added more benchmarking and profiling tools (#1017)
  • Renamed fpx to floatx (#877)
  • Removed torchao_nightly package (#661)
  • Added more lint fixes (#827)
  • Added better subclass testing support (#839)
  • Added CI to catch syntax errors (#861)
  • Added tutorial on composing quantized subclass w/ Dtensor based TP (#785)

Security

No significant security updates in this release.

Untopiced

  • Added basic SAM2 AutomaticMaskGeneration example server (#1039)

New Contributors

New Contributors

Full Changelog: v0.5.0...v0.6.1