diff --git a/README.md b/README.md index 69b7c2a0a7..7257ff27d9 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,12 @@ from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only ) -quantize_(m, int4_weight_only()) +quantize_(m, int4_weight_only()) ``` -For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline +For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. -If you're unsure which option to use, you can also run [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers for you and skip quantizing layers where overhead is too large. +If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer. ```python model = torchao.autoquant(torch.compile(model, mode='max-autotune')) @@ -102,7 +102,7 @@ from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8 optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions ``` -In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) +In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%** @@ -114,7 +114,7 @@ optim.load_state_dict(ckpt["optim"]) ## Composability 1. `torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels. -3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. +3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you. @@ -135,7 +135,7 @@ Things we're excited about but need more time to cook in the oven 1. [MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. 2. [Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet. -3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype +3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype 4. [Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support. ## Installation @@ -169,7 +169,7 @@ USE_CPP=0 pip install -e . We're also fortunate to be integrated into some of the leading open-source libraries including 1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865) 2. Hugging Face diffusers with a minimal example thanks to [Sayak Paul](https://www.linkedin.com/posts/sayak-paul_want-to-combine-quantization-and-benefit-activity-7231950868605022208-g52d?utm_source=share&utm_medium=member_desktop) -3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) +3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) ## Videos * [Slaying OOMs at the Mastering LLM's course](https://www.youtube.com/watch?v=UvRl4ansfCg) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c031d6e6d1..29f8ec604c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -68,9 +68,9 @@ ) from torchao.quantization.autoquant import ( AQInt8DynamicallyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight2, - AQWeightOnlyQuantizedLinearWeight3, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, ) @@ -727,21 +727,21 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) def test_aq_int8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -1498,10 +1498,10 @@ def test_get_model_size_autoquant(self, device, dtype): size = torchao.utils.get_model_size_in_bytes(model) from torchao.quantization.autoquant import ( - AQWeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight2, ) qtensor_class_list = ( - AQWeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight2, ) mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False) mod(example_input) diff --git a/torchao/_models/llama/README.md b/torchao/_models/llama/README.md index 93acc3b22a..cfd9c353ed 100644 --- a/torchao/_models/llama/README.md +++ b/torchao/_models/llama/README.md @@ -2,7 +2,7 @@ The llama folder contains code/scripts for stable benchmarking llama models. -To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b and/or https://huggingface.co/meta-llama/Meta-Llama-3-8B +To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b, https://huggingface.co/meta-llama/Meta-Llama-3-8B, https://huggingface.co/meta-llama/Meta-Llama-3.1-8B and follow the steps to gain access. Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 9d8ec434d4..3bea35cc49 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,24 +1,26 @@ llama 2 -20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224311, tok/s= 26.75, mem/s= 707.01 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224512, tok/s= 22.97, mem/s= 303.53 GB/s, peak_mem=13.64 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224958, tok/s=108.48, mem/s=1433.57 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105522, tok/s=105.14, mem/s=1389.35 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105921, tok/s= 9.20, mem/s= 60.93 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619110107, tok/s=150.18, mem/s= 994.40 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831225155, tok/s=107.38, mem/s=1418.93 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831225810, tok/s= 9.61, mem/s= 63.67 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230013, tok/s=170.83, mem/s=1131.18 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230205, tok/s=201.14, mem/s= 751.42 GB/s, peak_mem= 4.87 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230736, tok/s=177.45, mem/s=1194.35 GB/s, peak_mem= 8.64 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240902100527, tok/s=209.19, mem/s= 804.32 GB/s, peak_mem= 4.89 GB, model_size= 3.84 GB quant: autoquant-int4, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 llama 3 -20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231514, tok/s= 26.54, mem/s= 796.59 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231725, tok/s= 23.67, mem/s= 355.33 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232327, tok/s= 96.59, mem/s=1449.85 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123018, tok/s= 94.97, mem/s=1425.55 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123441, tok/s= 8.44, mem/s= 63.45 GB/s, peak_mem= 8.98 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232535, tok/s= 95.64, mem/s=1435.54 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831233224, tok/s= 8.61, mem/s= 64.75 GB/s, peak_mem= 9.24 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831233853, tok/s=153.03, mem/s=1150.80 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831234218, tok/s=180.80, mem/s= 763.33 GB/s, peak_mem= 6.88 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831235355, tok/s=158.10, mem/s=1193.24 GB/s, peak_mem=10.04 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240902101015, tok/s=188.41, mem/s= 800.58 GB/s, peak_mem= 7.14 GB, model_size= 4.25 GB quant: autoquant-int4, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 kv cache quantization: 20240826161508, tok/s= 19.71, mem/s= 295.80 GB/s, peak_mem=17.86 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 8192 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 9020715e70..c86406735a 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -11,6 +11,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt # auto-round w/ quant_lm_head python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround @@ -28,6 +29,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt # auto-round w/ quant_lm_head python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index f673a966de..6522fd9757 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -68,12 +68,17 @@ def run_evaluation( quantize_(model, int8_weight_only()) if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) + if "fp6" in quantization: + quantize_(model, fpx_weight_only(3, 2)) if "int4wo" in quantization and not "gptq" in quantization: + if "hqq" in quantization: + quantization = quantization[:-4] + use_hqq = True + else: + use_hqq = False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model.to(device), int4_weight_only(group_size=groupsize)) - if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -120,7 +125,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index a559fc241c..089f247656 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -216,10 +216,14 @@ def main( if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: + if "hqq" in quantization: + use_hqq=True + quantization = quantization[:-4] + else: + use_hqq=False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) - if "autoround" in quantization: from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ from transformers import AutoTokenizer @@ -265,11 +269,13 @@ def main( ) model.to(device) model.reset_caches() - if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) - if "autoquant" == quantization: - model = autoquant(model, manual=True) + if "autoquant" in quantization: + if "autoquant-int4" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + else: + model = autoquant(model, manual=True) generate( model, @@ -434,7 +440,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoround-------') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/dtypes/fpx/fpx.py index c4f818ea12..6afa22f560 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/dtypes/fpx/fpx.py @@ -5,7 +5,6 @@ from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones -from torchao.ops import quant_llm_linear from torchao.dtypes.utils import ( LayoutType, ) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index ac7e097fa1..dfe33204a5 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -98,7 +98,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = True + self.debug = False def configure_quantization_mode( self, @@ -790,14 +790,14 @@ def __init__( # TODO: this is the gpt-fast version, merge with the main version later def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] + k = q.shape[1]*2 if not _check_linear_int4_k(k, groupsize): new_k = find_multiple(k, 1024) else: new_k = k # how much we need to pad the weight - delta_k = new_k - q.shape[1] - q = q.to(torch.int32).to(self.device) + delta_k = int((new_k - k)/2) + q = q.to(self.device) final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 9f6f032133..6112abd027 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -7,17 +7,19 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 | -| | int8dq | 12.262 | 9.20 | 60.93 | 8.33 | 6.62 | -| | int8wo | 12.204 | 150.18 | 994.40 | 8.95 | 6.62 | -| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 | -| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 | -| | autoquant | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 | -| Llama-3-8B | Base (bfloat16) | N/A | 94.97 | 1425.55 | 16.43 | 15.01 | -| | int8dq | N/A | 8.44 | 63.45 | 8.98 | 7.52 | -| | int8wo | N/A | 139.76 | 1051.02 | 10.42 | 7.52 | -| | int4wo-64 | N/A | 179.44 | 757.60 | 6.62 | 4.22 | -| | autoquant | N/A | 137.71 | 1037.74 | 11.08 | 7.54 | +| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | +| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | +| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | +| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | +| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 | + +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | +| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | +| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | +| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -28,18 +30,24 @@ And a quick crash course on inference quantization to help parse the above table ## Autoquantization The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes -of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()`. ```python import torch import torchao +from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') - -# perform autoquantization and torch.compile -model = torchao.autoquant(torch.compile(model, mode='max-autotune')) +use_autoquant_default = True + +if use_autoquant_default: + # perform autoquantization and torch.compile with default settings + model = torchao.autoquant(torch.compile(model, mode='max-autotune')) +elif not use_autoquant_default: + # perform autoquantization and torch.compile with int4 support + model = torchao.autoquant(torch.compile(model, mode='max-autotune'), qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST) # pass in an input which is used in order to pick fastest quantization operations # and apply torch compilation. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 82b76849b7..05c55b255d 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -21,6 +21,8 @@ "swap_conv2d_1x1_to_linear" "safe_int_mm", "autoquant", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 1183c153ef..39482caf84 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( @@ -23,6 +23,8 @@ __all__ = [ "AutoQuantizableLinearWeight", "autoquant", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", ] @@ -360,7 +362,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") return res_f -class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight """ @@ -371,10 +373,10 @@ def from_float(cls, weight): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 block_size = (1, weight.shape[1]) - return super(AQWeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) -class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight2(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel @@ -408,7 +410,7 @@ def _autoquant_test(cls, act_mat, *args): return torch.inf return super()._autoquant_test(act_mat, *args) -class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel @@ -422,6 +424,40 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): y += bias return y + +class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight + """ + group_size: int = 32 + @classmethod + def from_float(cls, weight): + group_size = cls.group_size + layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + + if weight.shape[-1] % group_size != 0: + return weight + use_hqq = True + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) + +class AQInt4G64WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 64 + +class AQInt4G128WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 128 + +class AQInt4G256WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 256 + class AQFloatLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -441,15 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): def from_float(cls, weight): return weight -DEFAULT_CLASS_LIST = [ +# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison +DEFAULT_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, - AQWeightOnlyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight2, - # AQWeightOnlyQuantizedLinearWeight3, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, ] +DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQInt4G64WeightOnlyQuantizedLinearWeight +] + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -459,7 +502,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): from torchao.quantization.quant_api import _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized - kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST) kwargs["mode"] = kwargs.get("mode", ["relu", None]) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.quant_api import _get_subclass_inserter @@ -515,7 +558,7 @@ def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, * def autoquant( model, example_input=None, - qtensor_class_list=DEFAULT_CLASS_LIST, + qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], manual=False, @@ -547,7 +590,7 @@ def autoquant( model (torch.nn.Module): The model to be autoquantized. example_input (Any, optional): An example input for the model. If provided, the function performs a forward pass on this input (which fully autoquantizes the model unless manual=True). Defaults to None. - qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_CLASS_LIST. + qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_AUTOQUANT_CLASS_LIST. filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 813f97484b..b8c8a26fc0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -419,7 +419,7 @@ def int8_dynamic_activation_int4_weight(group_size=32): return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size) -def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)): +def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -436,8 +436,9 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` + `use_hqq`: whether to use hqq or default quantization mode, default is False """ - def apply_int4_weight_only_quant(weight, use_hqq=False): + def apply_int4_weight_only_quant(weight): if weight.shape[-1] % group_size != 0: logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" @@ -453,7 +454,7 @@ def apply_int4_weight_only_quant(weight, use_hqq=False): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index ae4f48d9db..df01c579c9 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -357,7 +357,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_max = 2 ** n_bit - 1 int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) - if TORCH_VERSION_AT_LEAST_2_5: + if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: int_data_device_type = int_data.device.type # Move to cpu, until issue with MPS memory management of temporary tensors is resolved if int_data_device_type == 'mps': @@ -376,7 +376,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( ): assert groupsize > 1 assert w_int4x8.dim() == 2 - if TORCH_VERSION_AT_LEAST_2_5: + # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path + if TORCH_VERSION_AT_LEAST_2_5 and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1]>1): data = w_int4x8.to(torch.int32) high_bits = data >> 4 low_bits = data & 0x0F