Skip to content

v0.7.0

Latest
Compare
Choose a tag to compare
@vkuzo vkuzo released this 06 Dec 22:13
· 65 commits to main since this release
e39126a

Highlights

We are excited to announce the 0.7.0 release of torchao! This release moves QAT out of prototype with improved LoRA support and more flexible APIs, and adds support for new experimental kernels such as Marlin QQQ (for CUDA), int8_dynamic_activation_intx_weight (for ARM CPU), and more!

QAT moved out of prototype, LoRA integration, new flexible APIs (#1020, #1085, #1152, #1037, #1152)

QAT has been moved out of prototype to torchao/quantization/qat to provide better API stability guarantees moving forward. In addition to the existing *QATQuantizer classes, we now also support the more flexible FakeQuantizedLinear and FakeQuantizedEmbedding modules for users to configure the exact quantization settings they wish to use during QAT.

from torchao.quantization.qat.api import FakeQuantizeConfig
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear

# Specify quantization schemes to use during QAT
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=8)

# Replace nn.Linear and nn.Embedding with these in your model
fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config)
fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config)

We also leveraged the new flexible APIs to build a new QAT + LoRA fine-tuning flow in torchtune. Try it out today!

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora

Marlin QQQ for CUDA (#1113)

Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. For more details about Marlin QQQ, please refer to paper.

from torchao.dtypes import MarlinQQQLayout
quantize_(
    model,
    int8_dynamic_activation_int4_weight(
        group_size=128,
        mapping_type=MappingType.SYMMETRIC,
        act_mapping_type=MappingType.SYMMETRIC,
        layout=MarlinQQQLayout(),
    ),
)

Benchmarking results can be found in https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#marlin-qqq.

This is a prototype feature - feel free to try out!

int8_dynamic_activation_intx_weight Quantization for ARM CPU (#995, #1027, #1254, #1353)

We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon).

from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision"

# Build kernels in temp location, and load them in torch
# This requires an ARM CPU
from torchao.experimental.temp_build import temp_build_and_load_torchao_ops
temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental")
# Quantize model
nbit = 4
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
group_size = 128
has_weight_zeros = False
quantize_(
    model,
    int8_dynamic_activation_intx_weight(
        group_size=group_size,
        nbit=nbit,
        has_weight_zeros=has_weight_zeros,
    ),
)

Benchmarking results can be found in https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#int8_dynamic_activation_intx_weight-quantization

We are still trying to figure out how to ship the ARM CPU kernels, so the exact API is subject to change.

BC Breaking

Rename AQT#2 LayoutType -> Layout (#1049)

Before:

from torchao.dtypes import (
    BlockSparseLayoutType,
    Int4CPULayoutType,
    MarlinQQQLayoutType,
    MarlinSparseLayoutType,
    SemiSparseLayoutType,
    TensorCoreTiledLayoutType,
    UintxLayoutType,
    Float8LayoutType,
    LayoutType,
    PlainLayoutType,
)

After:

from torchao.dtypes import (
    BlockSparseLayout,
    Int4CPULayout,
    MarlinQQQLayout,
    MarlinSparseLayout,
    SemiSparseLayout,
    TensorCoreTiledLayout,
    UintxLayout,
    Float8Layout,
    Layout,
    PlainLayout,
)

QAT imports after move out of prototype (#1091)

Before:

from torchao.quantization.prototype.qat import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
    FakeQuantizer,
)

After:

from torchao.quantization.qat import (
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
    FakeQuantizer,
)

New Features

  • Add BF16 stochastic rounding option for optimizers (#1124)
  • Add quantize_() API support for NF4 (#1216)
  • Support W4A8 Marlin kernel (#1113)

Improvements

quantize_

  • Add default filtering to remove mis-alinged weights (#1194)
  • Add tensor parallelism support for int4_weight_only quantization (#1120)
  • Add support for asymmetric act quant for int8 dynamic quant (#1131)
  • Add support for groupwise quantization for int8 weight only quantization (#1121)
  • Add AQT tensor parallel for float8_dynamic_quant (#1078)
  • Int8wo Embedding Quant (#1167)
  • Making sure int4 weight only supports cpu as well (#1203)
  • BF16 support for Quant-LLM kernel (#1147)
  • Add hardware check to fp8 quant (#1314)
  • Add support for quantize_() with Float8Linear module (#1344)

autoquant

  • Added support for Per Tensor Scaling for Float8 Dynamic Autoquant (#1175)
  • Add floating point options for autoquant and add accuracy measurement (#1355)

benchmarks

  • Adding batchsize support for torchao llama benchmarks (#1182)
  • Add capability of benchmarking arbitrary binary (#1107)

experimental

  • Add embedding ops aten (#1129)
  • Add embedding ops executorch (#1137)
  • Add quantized embedding kernels to torchao (#1018)
  • Allow deprecated declarations what using Parallel ExecuTorch (#1031)
  • Introduce lowbit quantized linear MPS kernels (#954)
  • Enable 6-bit kernel (#1027)
  • Kleidi 4b blockwise gemv prototype (#997)
  • Experimental 6-bit quantization for Llama in torchchat (#1094)
  • Introduce 7-bit quantization for Llama in torchchat. (#1139)
  • Executorch Subclass API (#966) (#995)
  • 8-bit packing support (#1248)
  • Experimental Enable 8-bit (#1254)
  • Experimental Benchmarking (#1353)

optimizer

  • [low-bit optim] Upcast everything to FP32 for internal calculations (#1068)
  • [Low-bit optim] Support for dcp.save() and dcp.load() (#1217)
  • Enable CPU Offload for Intel GPU (#1324)

SAM2

  • SAM2.1 copy (#1172)
  • SAM2 AMG server side request batching (#1197)
  • More SAM2-fast server improvements (#1285)
  • SAM2 Fast AMG: memory profiling and more compile (#1296)
  • SAM2 AMG cli and other QoL improvements (#1336)
  • SAM2 AMG cli.py on modal (#1349)
  • Reduce SAM2 AMG cli startup by using deploy (#1350)
  • Reduce startup time for SAM2 AMG by using torch.export (#1358)
  • More batching and improved furious accuracy/performance (#1253)
  • SAM2.1 and example README (#1048)
  • SAM2 AMG example mIoU, perf numbers and more SAM2 model annotations (#1196)

other

  • Add SpinQuant to generate.py (#1069)
  • SpinQuant (#983)
  • SmoothQuant using tensor subclassing (#1030)
  • Expose FakeQuantizeConfigs in QAT quantizers (#1214)
  • Add module-swap UX for INT8 mixed-precision training (#1179)
  • Float8 training: move module attribute setting to sync function (#1341)

Bug Fixes

  • Header bug fix (#1079)
  • Temporary fix for QAT quantizer when linear layer bias is True (#1087)
  • Fix out-of-bounds memory access in Galore dequant kernel (#1125)
  • Fixed weights_only=True load for float8_dynamic_activation_float8_weight in quant_api (#1122)
  • Fix int8_weight_only group_size (#1165)
  • Is_linear fix for MHA (#1141)
  • Fixing eval.py to use GPTQ_MT for gptq (#1176)
  • [CPU offload optim] Fix when there are non-trainable params (#1210)
  • Fix for weights-only load (#1228)
  • Pin nightlies to deal with std::badalloc (#1256)
  • Fix 2.5.1 failing sparsity test (#1261)
  • Call narrow only for TensorCoreTiledLayout (#1207)
  • Fix an autoquant bug in flatten/unflatten (#1288)
  • Float8 with delayed scaling: fix autocast handling (#1306)
  • Fix bug with float8 training + FSDP2 + TP (#1327)
  • Float8 training: fix bug with AC + compile (#1329)
  • Fix torchtitan + float8 + delayed + compile (#1334)
  • [low-bit optim] Fix edge cases for FSDP2 integration (#1269)
  • [NF4] .to() fixes (#1312)
  • Check scale.ndim before applying t/transpose (#1339)

Performance

  • Swap in faster uint6 bitpacking function (#1098)
  • Implement more efficient pack and unpack uint5 (#1138)
  • Fix 20x slowdown of FP6 kernel due to device properties query (#1092)

Documentation

  • Add a developer guide for exporting to executorch (#1219)
  • Enable AWQ example on CPU (#1043)
  • Add readme doc for experiemental (#1130)
  • Move float8 out of prototype in quantization README (#1166)
  • Update torchao api reference and add contributor guide (#1255)
  • Fix pickle.dump missing file argument typo in README (#1316)
  • Update README.md (#1319)
  • Update README.md: Fix bibtex and sglang links (#1361)
  • Add bibtex (#1177)
  • Clarify torchao.float8 PyTorch version support (#1191)

Developers

  • [Tp Test] Fix the placement of the device tensor (#1054)
  • Skip test_fpx_weight_only in fbcode (#1056)
  • Pin pt nightly CPU version (#1061)
  • Unpin CUDA Nightly (#1064)
  • Update smoke test (#1111)
  • Update regression_test.yml (#1163)
  • Add PyTorch 2.5 to regression test (#1168)
  • Fix Bias APIs, re-enable kleidi tests for arm64 (#1162)
  • Create CITATION.cff (#1178)
  • Unpin nightlies (#1183)
  • [experimental] Kleidi - add operator level tests (#1173)
  • Ruff format and lint (#1226)
  • Update pre-commit to match CI/CD (#1227)
  • Fixing pytest skip for only test_floatx.py (#1251)
  • Fixed invalid url in citation section (#1348)
  • Add to safe globals (#1171)
  • Aqt rename#1 Layout -> TensorImpl (#1046)
  • Move and rename GranularityType -> Granularity (#1038)
  • Change torchao quantization types from int to size_t and preface vars with "preferred_" (#1041)
  • Shrink hadamard matrices (#1051)
  • Use ExecuTorch prebuilt library in pip package to build custom kernels (#1059)
  • Update base.h unit to unsigned int (#962)
  • Create header for packed weight ops (#1072)
  • Update cmake files (#1070)
  • Create build_wheels_aarch64_linux.yml (#1083)
  • ROCM binary upload (#1099)
  • Create build_wheels_windows.yml (#1101)
  • Use fewer instructions when unpacking uint6s. (#1109)
  • [CI] XPU binary build enable (#1105)
  • Move common ET/Aten op stuff to ops/library.h (#1116)
  • Move bias from kernel to packed_weights (#1119)
  • Update gpu_sparsity kernel benchmarking script (#1143)
  • [ROCm] use dataclass for fnuz type setting (#1142)
  • Move files to prototype/sparsity (#1145)
  • C10::nullopt -> std::nullopt (#1032) (#1151)
  • [reland][ROCm] use dataclass for fnuz type setting (#1150)
  • Move float8_aten_api to float8_ops (#1155)
  • Initialize model with meta device for generation benchmarking (#1144)
  • Replace torch.empty with torch.zeros (#1157)
  • Update utils.py (#1186)
  • Remove int_scaled_mm's dependency on triton for cpu (#128)
  • at::optional -> std::optional (#1170) (#1212)
  • fast_flush kwarg of do_bench is removed (#1222)
  • Remove calibration args from generate.py (#1258)
  • Skip marlin QQQ ops test in fbcode (#1289)
  • Fix Marlin QQQ ops test with unittest (#1294)
  • Fix Failing CI - Update bitsandbytes import (#1343)
  • Remove lm_eval warning (#1347)
  • Refactor Affine Quantized Tensor (#1234)
  • Move files from quantization/prototype -> prototype/quantization (#1187)
  • Add TTFT benchmarks + update sparsity benchmarks (#1140)
  • Add "_gemm_input_role" to dunder slots (#984)
  • Add an option to use fp8-all-gather only without fp8 computation. (#1093)
  • Bump version to 0.7 (#1045)

New Contributors

Full Changelog: v0.6.1...v0.7.0-rc1