From 2ac728d7ed19f6ea626ef9a8cf4e81fd06aab426 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 3 Sep 2024 23:41:10 -0700 Subject: [PATCH] int4 fixes and improvements Summary: 1) added int4 to autoquant using hqq by default 2) fixes to hqq in normal int4 class so it can actually be used with normal UX 3) adding hqq to eval/generate 3) eval hqq to make sure its a reasonable default for autoquant 4) running llama3 eval now that llama3 is working correctly (fixed in 3.1 PR) 5) testing hqq v GPTQ so we have a comparison in our benchmarks/eval 6) GPTQ was broken -> fixes to utils and GPTQ to fix broken code Test Plan: benchmarks.sh (new autoquant-int4 benchmarks) export CHECKPOINT_PATH=../../../checkpoints export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8dq --compile python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-hqq python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-gptq export MODEL_REPO=meta-llama/Meta-Llama-3-8B python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8dq --compile python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-hqq python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-gptq (see results in README.md) Reviewers: Subscribers: Tasks: Tags: --- README.md | 14 +++-- torchao/_models/llama/README.md | 2 +- torchao/_models/llama/benchmark_results.txt | 34 ++++++----- torchao/_models/llama/benchmarks.sh | 2 + torchao/_models/llama/eval.py | 9 ++- torchao/_models/llama/generate.py | 16 +++-- torchao/quantization/GPTQ.py | 8 +-- torchao/quantization/README.md | 36 +++++++---- torchao/quantization/__init__.py | 2 + torchao/quantization/autoquant.py | 67 +++++++++++++++++---- torchao/quantization/quant_api.py | 7 ++- torchao/quantization/utils.py | 5 +- 12 files changed, 140 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 69b7c2a0a7..ebce73b6d6 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,12 @@ from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only ) -quantize_(m, int4_weight_only()) +quantize_(m, int4_weight_only()) ``` -For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline +For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. +Note: For models that are less memory bound, the int4 weight only quantization kernel can be slower than other kernels, if you are seeing slowdowns, using [autoquant](./torchao/quantization/README.md#autoquantization) with int4 quantization +can solve the issue. See the [quantization readme](./torchao/quantization/README.md#autoquantization) for details. If you're unsure which option to use, you can also run [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers for you and skip quantizing layers where overhead is too large. @@ -102,7 +104,7 @@ from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8 optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions ``` -In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) +In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%** @@ -114,7 +116,7 @@ optim.load_state_dict(ckpt["optim"]) ## Composability 1. `torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels. -3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. +3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you. @@ -135,7 +137,7 @@ Things we're excited about but need more time to cook in the oven 1. [MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. 2. [Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet. -3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype +3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype 4. [Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support. ## Installation @@ -169,7 +171,7 @@ USE_CPP=0 pip install -e . We're also fortunate to be integrated into some of the leading open-source libraries including 1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865) 2. Hugging Face diffusers with a minimal example thanks to [Sayak Paul](https://www.linkedin.com/posts/sayak-paul_want-to-combine-quantization-and-benefit-activity-7231950868605022208-g52d?utm_source=share&utm_medium=member_desktop) -3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) +3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) ## Videos * [Slaying OOMs at the Mastering LLM's course](https://www.youtube.com/watch?v=UvRl4ansfCg) diff --git a/torchao/_models/llama/README.md b/torchao/_models/llama/README.md index 93acc3b22a..cfd9c353ed 100644 --- a/torchao/_models/llama/README.md +++ b/torchao/_models/llama/README.md @@ -2,7 +2,7 @@ The llama folder contains code/scripts for stable benchmarking llama models. -To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b and/or https://huggingface.co/meta-llama/Meta-Llama-3-8B +To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b, https://huggingface.co/meta-llama/Meta-Llama-3-8B, https://huggingface.co/meta-llama/Meta-Llama-3.1-8B and follow the steps to gain access. Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 9d8ec434d4..3bea35cc49 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,24 +1,26 @@ llama 2 -20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224311, tok/s= 26.75, mem/s= 707.01 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224512, tok/s= 22.97, mem/s= 303.53 GB/s, peak_mem=13.64 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224958, tok/s=108.48, mem/s=1433.57 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105522, tok/s=105.14, mem/s=1389.35 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105921, tok/s= 9.20, mem/s= 60.93 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619110107, tok/s=150.18, mem/s= 994.40 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831225155, tok/s=107.38, mem/s=1418.93 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831225810, tok/s= 9.61, mem/s= 63.67 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230013, tok/s=170.83, mem/s=1131.18 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230205, tok/s=201.14, mem/s= 751.42 GB/s, peak_mem= 4.87 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230736, tok/s=177.45, mem/s=1194.35 GB/s, peak_mem= 8.64 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240902100527, tok/s=209.19, mem/s= 804.32 GB/s, peak_mem= 4.89 GB, model_size= 3.84 GB quant: autoquant-int4, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 llama 3 -20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231514, tok/s= 26.54, mem/s= 796.59 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231725, tok/s= 23.67, mem/s= 355.33 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232327, tok/s= 96.59, mem/s=1449.85 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123018, tok/s= 94.97, mem/s=1425.55 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123441, tok/s= 8.44, mem/s= 63.45 GB/s, peak_mem= 8.98 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232535, tok/s= 95.64, mem/s=1435.54 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831233224, tok/s= 8.61, mem/s= 64.75 GB/s, peak_mem= 9.24 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831233853, tok/s=153.03, mem/s=1150.80 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831234218, tok/s=180.80, mem/s= 763.33 GB/s, peak_mem= 6.88 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831235355, tok/s=158.10, mem/s=1193.24 GB/s, peak_mem=10.04 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240902101015, tok/s=188.41, mem/s= 800.58 GB/s, peak_mem= 7.14 GB, model_size= 4.25 GB quant: autoquant-int4, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 kv cache quantization: 20240826161508, tok/s= 19.71, mem/s= 295.80 GB/s, peak_mem=17.86 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 8192 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index b4d3bf5fe9..3024ebb3bc 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -11,6 +11,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Meta-Llama-3-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt @@ -22,6 +23,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index fc8634dd06..b5df88aa05 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -65,9 +65,14 @@ def run_evaluation( if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization and not "gptq" in quantization: + if "hqq" in quantization: + quantization = quantization[:-4] + use_hqq = True + else: + use_hqq = False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model.to(device), int4_weight_only(group_size=groupsize)) + quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -114,7 +119,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 20a2f401f7..eca38ac5a2 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -216,11 +216,19 @@ def main( if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: + if "hqq" in quantization: + use_hqq=True + quantization = quantization[:-4] + else: + use_hqq=False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model, int4_weight_only(group_size=groupsize)) - if "autoquant" == quantization: - model = autoquant(model, manual=True) + quantize_(model, int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) + if "autoquant" in quantization: + if "autoquant-int4" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + else: + model = autoquant(model, manual=True) generate( model, @@ -385,7 +393,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index ac7e097fa1..dfe33204a5 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -98,7 +98,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = True + self.debug = False def configure_quantization_mode( self, @@ -790,14 +790,14 @@ def __init__( # TODO: this is the gpt-fast version, merge with the main version later def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] + k = q.shape[1]*2 if not _check_linear_int4_k(k, groupsize): new_k = find_multiple(k, 1024) else: new_k = k # how much we need to pad the weight - delta_k = new_k - q.shape[1] - q = q.to(torch.int32).to(self.device) + delta_k = int((new_k - k)/2) + q = q.to(self.device) final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 494e1b961a..c7a937329e 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -7,17 +7,21 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 | -| | int8dq | 12.262 | 9.20 | 60.93 | 8.33 | 6.62 | -| | int8wo | 12.204 | 150.18 | 994.40 | 8.95 | 6.62 | -| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 | -| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 | -| | autoquant | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 | -| Llama-3-8B | Base (bfloat16) | N/A | 94.97 | 1425.55 | 16.43 | 15.01 | -| | int8dq | N/A | 8.44 | 63.45 | 8.98 | 7.52 | -| | int8wo | N/A | 139.76 | 1051.02 | 10.42 | 7.52 | -| | int4wo-64 | N/A | 179.44 | 757.60 | 6.62 | 4.22 | -| | autoquant | N/A | 137.71 | 1037.74 | 11.08 | 7.54 | +| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | +| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | +| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | +| | int4wo-64-GPTQ | 12.489 | 201.14 | 751.42 | 4.87 | 3.74 | +| | autoquant | 12.204 | 177.45 | 1194.35 | 8.64 | 6.72 | +| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 6.72 | + +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | +| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | +| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-GPTQ | N/A | 180.80 | 763.33 | 6.88 | 4.22 | +| | autoquant | 7.447 | 158.10 | 1193.24 | 10.04 | 7.55 | +| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -28,7 +32,7 @@ And a quick crash course on inference quantization to help parse the above table ## Autoquantization The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes -of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer by default. ```python import torch @@ -46,6 +50,14 @@ model = torchao.autoquant(torch.compile(model, mode='max-autotune')) model(input) ``` +There is also an option to add int4 weight only quantization as an `autoquant` option for maximum performance or if applying int4 quantization without `autoquant` causes a perf regression. In such cases, `autoquant` will avoid quantizing the layers that are causing the perf regression. + +```python +from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST +model = torchao.autoquant(torch.compile(model, mode='max-autotune'), qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) +model(input) +``` + Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. ```python diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 2ac4a0c285..84975f745e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -21,6 +21,8 @@ "swap_conv2d_1x1_to_linear" "safe_int_mm", "autoquant", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index cc51dd5ced..1adf286df3 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( @@ -23,6 +23,8 @@ __all__ = [ "AutoQuantizableLinearWeight", "autoquant", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", ] @@ -360,7 +362,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") return res_f -class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight """ @@ -371,10 +373,10 @@ def from_float(cls, weight): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 block_size = (1, weight.shape[1]) - return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) -class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight2(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel @@ -408,7 +410,7 @@ def _autoquant_test(cls, act_mat, *args): return torch.inf return super()._autoquant_test(act_mat, *args) -class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel @@ -422,6 +424,40 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): y += bias return y + +class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight + """ + group_size: int = 32 + @classmethod + def from_float(cls, weight): + group_size = cls.group_size + layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + + if weight.shape[-1] % group_size != 0: + return weight + use_hqq = True + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) + +class AQInt4G64WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 64 + +class AQInt4G128WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 128 + +class AQInt4G256WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 256 + class AQFloatLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -441,15 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): def from_float(cls, weight): return weight -DEFAULT_CLASS_LIST = [ +# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison +DEFAULT_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, - AQWeightOnlyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight2, - # AQWeightOnlyQuantizedLinearWeight3, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, ] +DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQInt4G64WeightOnlyQuantizedLinearWeight +] + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -459,7 +502,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): from torchao.quantization.quant_api import _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized - kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST) kwargs["mode"] = kwargs.get("mode", ["relu", None]) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.quant_api import _get_subclass_inserter @@ -515,7 +558,7 @@ def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, * def autoquant( model, example_input=None, - qtensor_class_list=DEFAULT_CLASS_LIST, + qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], manual=False, @@ -547,7 +590,7 @@ def autoquant( model (torch.nn.Module): The model to be autoquantized. example_input (Any, optional): An example input for the model. If provided, the function performs a forward pass on this input (which fully autoquantizes the model unless manual=True). Defaults to None. - qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_CLASS_LIST. + qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_AUTOQUANT_CLASS_LIST. filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bb017fa35b..a1cecf1add 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -393,7 +393,7 @@ def insert_subclass(lin): return insert_subclass -def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)): +def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -410,8 +410,9 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` + `use_hqq`: whether to use hqq or default quantization mode, default is False """ - def apply_int4_weight_only_quant(weight, use_hqq=False): + def apply_int4_weight_only_quant(weight): if weight.shape[-1] % group_size != 0: return weight @@ -424,7 +425,7 @@ def apply_int4_weight_only_quant(weight, use_hqq=False): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 99ad0a4f6c..7002d1b55f 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -357,7 +357,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_max = 2 ** n_bit - 1 int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) - if TORCH_VERSION_AT_LEAST_2_5: + if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: int_data_device_type = int_data.device.type # Move to cpu, until issue with MPS memory management of temporary tensors is resolved if int_data_device_type == 'mps': @@ -376,7 +376,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( ): assert groupsize > 1 assert w_int4x8.dim() == 2 - if TORCH_VERSION_AT_LEAST_2_5: + # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path + if TORCH_VERSION_AT_LEAST_2_5 and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1]>1): data = w_int4x8.to(torch.int32) high_bits = data >> 4 low_bits = data & 0x0F