From b7f2ef5d57a460c80bcbc0f3d63d47fb5599fc7e Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 3 Sep 2024 23:41:10 -0700 Subject: [PATCH 1/8] int4 fixes and improvements Summary: 1) added int4 to autoquant using hqq by default 2) fixes to hqq in normal int4 class so it can actually be used with normal UX 3) adding hqq to eval/generate 3) eval hqq to make sure its a reasonable default for autoquant 4) running llama3 eval now that llama3 is working correctly (fixed in 3.1 PR) 5) testing hqq v GPTQ so we have a comparison in our benchmarks/eval 6) GPTQ was broken -> fixes to utils and GPTQ to fix broken code Test Plan: benchmarks.sh (new autoquant-int4 benchmarks) export CHECKPOINT_PATH=../../../checkpoints export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8dq --compile python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-hqq python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-gptq export MODEL_REPO=meta-llama/Meta-Llama-3-8B python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8dq --compile python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-hqq python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64-gptq (see results in README.md) Reviewers: Subscribers: Tasks: Tags: --- README.md | 14 +++-- torchao/_models/llama/README.md | 2 +- torchao/_models/llama/benchmark_results.txt | 34 ++++++----- torchao/_models/llama/benchmarks.sh | 2 + torchao/_models/llama/eval.py | 13 ++-- torchao/_models/llama/generate.py | 18 ++++-- torchao/quantization/GPTQ.py | 8 +-- torchao/quantization/README.md | 36 +++++++---- torchao/quantization/__init__.py | 2 + torchao/quantization/autoquant.py | 67 +++++++++++++++++---- torchao/quantization/quant_api.py | 7 ++- torchao/quantization/utils.py | 5 +- 12 files changed, 144 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index 69b7c2a0a7..ebce73b6d6 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,12 @@ from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only ) -quantize_(m, int4_weight_only()) +quantize_(m, int4_weight_only()) ``` -For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline +For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. +Note: For models that are less memory bound, the int4 weight only quantization kernel can be slower than other kernels, if you are seeing slowdowns, using [autoquant](./torchao/quantization/README.md#autoquantization) with int4 quantization +can solve the issue. See the [quantization readme](./torchao/quantization/README.md#autoquantization) for details. If you're unsure which option to use, you can also run [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers for you and skip quantizing layers where overhead is too large. @@ -102,7 +104,7 @@ from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8 optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions ``` -In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) +In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%** @@ -114,7 +116,7 @@ optim.load_state_dict(ckpt["optim"]) ## Composability 1. `torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels. -3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. +3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you. @@ -135,7 +137,7 @@ Things we're excited about but need more time to cook in the oven 1. [MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. 2. [Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet. -3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype +3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype 4. [Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support. ## Installation @@ -169,7 +171,7 @@ USE_CPP=0 pip install -e . We're also fortunate to be integrated into some of the leading open-source libraries including 1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865) 2. Hugging Face diffusers with a minimal example thanks to [Sayak Paul](https://www.linkedin.com/posts/sayak-paul_want-to-combine-quantization-and-benefit-activity-7231950868605022208-g52d?utm_source=share&utm_medium=member_desktop) -3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) +3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) ## Videos * [Slaying OOMs at the Mastering LLM's course](https://www.youtube.com/watch?v=UvRl4ansfCg) diff --git a/torchao/_models/llama/README.md b/torchao/_models/llama/README.md index 93acc3b22a..cfd9c353ed 100644 --- a/torchao/_models/llama/README.md +++ b/torchao/_models/llama/README.md @@ -2,7 +2,7 @@ The llama folder contains code/scripts for stable benchmarking llama models. -To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b and/or https://huggingface.co/meta-llama/Meta-Llama-3-8B +To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b, https://huggingface.co/meta-llama/Meta-Llama-3-8B, https://huggingface.co/meta-llama/Meta-Llama-3.1-8B and follow the steps to gain access. Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 9d8ec434d4..3bea35cc49 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,24 +1,26 @@ llama 2 -20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224311, tok/s= 26.75, mem/s= 707.01 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224512, tok/s= 22.97, mem/s= 303.53 GB/s, peak_mem=13.64 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224958, tok/s=108.48, mem/s=1433.57 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105522, tok/s=105.14, mem/s=1389.35 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619105921, tok/s= 9.20, mem/s= 60.93 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619110107, tok/s=150.18, mem/s= 994.40 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831225155, tok/s=107.38, mem/s=1418.93 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831225810, tok/s= 9.61, mem/s= 63.67 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230013, tok/s=170.83, mem/s=1131.18 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230205, tok/s=201.14, mem/s= 751.42 GB/s, peak_mem= 4.87 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831230736, tok/s=177.45, mem/s=1194.35 GB/s, peak_mem= 8.64 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240902100527, tok/s=209.19, mem/s= 804.32 GB/s, peak_mem= 4.89 GB, model_size= 3.84 GB quant: autoquant-int4, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 llama 3 -20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231514, tok/s= 26.54, mem/s= 796.59 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231725, tok/s= 23.67, mem/s= 355.33 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232327, tok/s= 96.59, mem/s=1449.85 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123018, tok/s= 94.97, mem/s=1425.55 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123441, tok/s= 8.44, mem/s= 63.45 GB/s, peak_mem= 8.98 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232535, tok/s= 95.64, mem/s=1435.54 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831233224, tok/s= 8.61, mem/s= 64.75 GB/s, peak_mem= 9.24 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831233853, tok/s=153.03, mem/s=1150.80 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831234218, tok/s=180.80, mem/s= 763.33 GB/s, peak_mem= 6.88 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831235355, tok/s=158.10, mem/s=1193.24 GB/s, peak_mem=10.04 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240902101015, tok/s=188.41, mem/s= 800.58 GB/s, peak_mem= 7.14 GB, model_size= 4.25 GB quant: autoquant-int4, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 kv cache quantization: 20240826161508, tok/s= 19.71, mem/s= 295.80 GB/s, peak_mem=17.86 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 8192 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 9020715e70..c86406735a 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -11,6 +11,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt # auto-round w/ quant_lm_head python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround @@ -28,6 +29,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt # auto-round w/ quant_lm_head python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index f673a966de..6522fd9757 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -68,12 +68,17 @@ def run_evaluation( quantize_(model, int8_weight_only()) if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) + if "fp6" in quantization: + quantize_(model, fpx_weight_only(3, 2)) if "int4wo" in quantization and not "gptq" in quantization: + if "hqq" in quantization: + quantization = quantization[:-4] + use_hqq = True + else: + use_hqq = False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model.to(device), int4_weight_only(group_size=groupsize)) - if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -120,7 +125,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index a559fc241c..df2b0653a3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -216,10 +216,14 @@ def main( if "int8dq" in quantization: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: + if "hqq" in quantization: + use_hqq=True + quantization = quantization[:-4] + else: + use_hqq=False groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) - if "autoround" in quantization: from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ from transformers import AutoTokenizer @@ -265,11 +269,13 @@ def main( ) model.to(device) model.reset_caches() - if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) - if "autoquant" == quantization: - model = autoquant(model, manual=True) + if "autoquant" in quantization: + if "autoquant-int4" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + else: + model = autoquant(model, manual=True) generate( model, @@ -434,7 +440,11 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') +<<<<<<< HEAD parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoround-------') +======= + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq') +>>>>>>> 2ac728d7 (int4 fixes and improvements) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index ac7e097fa1..dfe33204a5 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -98,7 +98,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = True + self.debug = False def configure_quantization_mode( self, @@ -790,14 +790,14 @@ def __init__( # TODO: this is the gpt-fast version, merge with the main version later def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] + k = q.shape[1]*2 if not _check_linear_int4_k(k, groupsize): new_k = find_multiple(k, 1024) else: new_k = k # how much we need to pad the weight - delta_k = new_k - q.shape[1] - q = q.to(torch.int32).to(self.device) + delta_k = int((new_k - k)/2) + q = q.to(self.device) final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 9f6f032133..cbe5252a75 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -7,17 +7,21 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 | -| | int8dq | 12.262 | 9.20 | 60.93 | 8.33 | 6.62 | -| | int8wo | 12.204 | 150.18 | 994.40 | 8.95 | 6.62 | -| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 | -| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 | -| | autoquant | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 | -| Llama-3-8B | Base (bfloat16) | N/A | 94.97 | 1425.55 | 16.43 | 15.01 | -| | int8dq | N/A | 8.44 | 63.45 | 8.98 | 7.52 | -| | int8wo | N/A | 139.76 | 1051.02 | 10.42 | 7.52 | -| | int4wo-64 | N/A | 179.44 | 757.60 | 6.62 | 4.22 | -| | autoquant | N/A | 137.71 | 1037.74 | 11.08 | 7.54 | +| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | +| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | +| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | +| | int4wo-64-GPTQ | 12.489 | 201.14 | 751.42 | 4.87 | 3.74 | +| | autoquant | 12.204 | 177.45 | 1194.35 | 8.64 | 6.72 | +| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 6.72 | + +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | +| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | +| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-GPTQ | N/A | 180.80 | 763.33 | 6.88 | 4.22 | +| | autoquant | 7.447 | 158.10 | 1193.24 | 10.04 | 7.55 | +| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -28,7 +32,7 @@ And a quick crash course on inference quantization to help parse the above table ## Autoquantization The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes -of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer by default. ```python import torch @@ -46,6 +50,14 @@ model = torchao.autoquant(torch.compile(model, mode='max-autotune')) model(input) ``` +There is also an option to add int4 weight only quantization as an `autoquant` option for maximum performance or if applying int4 quantization without `autoquant` causes a perf regression. In such cases, `autoquant` will avoid quantizing the layers that are causing the perf regression. + +```python +from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST +model = torchao.autoquant(torch.compile(model, mode='max-autotune'), qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) +model(input) +``` + Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. ```python diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 82b76849b7..05c55b255d 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -21,6 +21,8 @@ "swap_conv2d_1x1_to_linear" "safe_int_mm", "autoquant", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 1183c153ef..fb4f055bd1 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( @@ -23,6 +23,8 @@ __all__ = [ "AutoQuantizableLinearWeight", "autoquant", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", ] @@ -360,7 +362,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") return res_f -class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight """ @@ -371,10 +373,10 @@ def from_float(cls, weight): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 block_size = (1, weight.shape[1]) - return super(AQWeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) -class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight2(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel @@ -408,7 +410,7 @@ def _autoquant_test(cls, act_mat, *args): return torch.inf return super()._autoquant_test(act_mat, *args) -class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin): +class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel @@ -422,6 +424,40 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): y += bias return y + +class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight + """ + group_size: int = 32 + @classmethod + def from_float(cls, weight): + group_size = cls.group_size + layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + + if weight.shape[-1] % group_size != 0: + return weight + use_hqq = True + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) + +class AQInt4G64WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 64 + +class AQInt4G128WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 128 + +class AQInt4G256WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): + group_size: int = 256 + class AQFloatLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -441,15 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): def from_float(cls, weight): return weight -DEFAULT_CLASS_LIST = [ +# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison +DEFAULT_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, - AQWeightOnlyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight2, - # AQWeightOnlyQuantizedLinearWeight3, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, ] +DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQInt4G64WeightOnlyQuantizedLinearWeight +] + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -459,7 +502,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): from torchao.quantization.quant_api import _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized - kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST) kwargs["mode"] = kwargs.get("mode", ["relu", None]) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.quant_api import _get_subclass_inserter @@ -515,7 +558,7 @@ def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, * def autoquant( model, example_input=None, - qtensor_class_list=DEFAULT_CLASS_LIST, + qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], manual=False, @@ -547,7 +590,7 @@ def autoquant( model (torch.nn.Module): The model to be autoquantized. example_input (Any, optional): An example input for the model. If provided, the function performs a forward pass on this input (which fully autoquantizes the model unless manual=True). Defaults to None. - qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_CLASS_LIST. + qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_AUTOQUANT_CLASS_LIST. filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 813f97484b..c298dd4ed4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -419,7 +419,7 @@ def int8_dynamic_activation_int4_weight(group_size=32): return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size) -def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)): +def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -436,8 +436,9 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` + `use_hqq`: whether to use hqq or default quantization mode, default is False """ - def apply_int4_weight_only_quant(weight, use_hqq=False): + def apply_int4_weight_only_quant(weight): if weight.shape[-1] % group_size != 0: logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" @@ -453,7 +454,7 @@ def apply_int4_weight_only_quant(weight, use_hqq=False): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index ae4f48d9db..df01c579c9 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -357,7 +357,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_max = 2 ** n_bit - 1 int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) - if TORCH_VERSION_AT_LEAST_2_5: + if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: int_data_device_type = int_data.device.type # Move to cpu, until issue with MPS memory management of temporary tensors is resolved if int_data_device_type == 'mps': @@ -376,7 +376,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( ): assert groupsize > 1 assert w_int4x8.dim() == 2 - if TORCH_VERSION_AT_LEAST_2_5: + # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path + if TORCH_VERSION_AT_LEAST_2_5 and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1]>1): data = w_int4x8.to(torch.int32) high_bits = data >> 4 low_bits = data & 0x0F From e368e31d6ffeb617d9f8d76b2222e1b4d2246071 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 06:22:56 -0700 Subject: [PATCH 2/8] final testing done Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index cbe5252a75..21890e6fe4 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -11,7 +11,7 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | | | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | | | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | -| | int4wo-64-GPTQ | 12.489 | 201.14 | 751.42 | 4.87 | 3.74 | +| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | | | autoquant | 12.204 | 177.45 | 1194.35 | 8.64 | 6.72 | | | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 6.72 | @@ -19,7 +19,7 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | | | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | | | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | -| | int4wo-64-GPTQ | N/A | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | | | autoquant | 7.447 | 158.10 | 1193.24 | 10.04 | 7.55 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | From 1d352405c1e3994030ed12ad97cb27cb7ad63f55 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 06:35:43 -0700 Subject: [PATCH 3/8] fixing rebase issue Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/generate.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index df2b0653a3..089f247656 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -440,11 +440,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') -<<<<<<< HEAD - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoround-------') -======= - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq') ->>>>>>> 2ac728d7 (int4 fixes and improvements) + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') From f9f191b5e78e37426ceabee84c2fb6e765205b6f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 14:26:35 -0700 Subject: [PATCH 4/8] fixing fpx bug from rebase Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/fpx/fpx.py | 2 +- torchao/quantization/README.md | 18 ++++++++---------- torchao/quantization/autoquant.py | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/dtypes/fpx/fpx.py index c4f818ea12..9f4c7b065c 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/dtypes/fpx/fpx.py @@ -5,7 +5,7 @@ from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones -from torchao.ops import quant_llm_linear +# from torchao.ops import quant_llm_linear from torchao.dtypes.utils import ( LayoutType, ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 21890e6fe4..42aeee7300 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -12,16 +12,14 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | | | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | | | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | -| | autoquant | 12.204 | 177.45 | 1194.35 | 8.64 | 6.72 | -| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 6.72 | - -| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | -| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | -| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | -| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | -| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | -| | autoquant | 7.447 | 158.10 | 1193.24 | 10.04 | 7.55 | -| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 | + +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | +| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | +| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | +| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index fb4f055bd1..39482caf84 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -447,7 +447,7 @@ def from_float(cls, weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) class AQInt4G64WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): group_size: int = 64 From 03d01adcbcf8f85b0e61bef84957013b89e6902f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 14:43:25 -0700 Subject: [PATCH 5/8] readme fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 4 +--- torchao/quantization/README.md | 20 +++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index ebce73b6d6..7257ff27d9 100644 --- a/README.md +++ b/README.md @@ -39,10 +39,8 @@ quantize_(m, int4_weight_only()) ``` For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. -Note: For models that are less memory bound, the int4 weight only quantization kernel can be slower than other kernels, if you are seeing slowdowns, using [autoquant](./torchao/quantization/README.md#autoquantization) with int4 quantization -can solve the issue. See the [quantization readme](./torchao/quantization/README.md#autoquantization) for details. -If you're unsure which option to use, you can also run [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers for you and skip quantizing layers where overhead is too large. +If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer. ```python model = torchao.autoquant(torch.compile(model, mode='max-autotune')) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 42aeee7300..6112abd027 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -30,32 +30,30 @@ And a quick crash course on inference quantization to help parse the above table ## Autoquantization The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes -of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer by default. +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()`. ```python import torch import torchao +from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') +use_autoquant_default = True -# perform autoquantization and torch.compile -model = torchao.autoquant(torch.compile(model, mode='max-autotune')) +if use_autoquant_default: + # perform autoquantization and torch.compile with default settings + model = torchao.autoquant(torch.compile(model, mode='max-autotune')) +elif not use_autoquant_default: + # perform autoquantization and torch.compile with int4 support + model = torchao.autoquant(torch.compile(model, mode='max-autotune'), qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST) # pass in an input which is used in order to pick fastest quantization operations # and apply torch compilation. model(input) ``` -There is also an option to add int4 weight only quantization as an `autoquant` option for maximum performance or if applying int4 quantization without `autoquant` causes a perf regression. In such cases, `autoquant` will avoid quantizing the layers that are causing the perf regression. - -```python -from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST -model = torchao.autoquant(torch.compile(model, mode='max-autotune'), qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) -model(input) -``` - Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. ```python From 8376847771f48aab2f4472fdfc2d561cd2b56f71 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 14:50:07 -0700 Subject: [PATCH 6/8] test fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c031d6e6d1..29f8ec604c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -68,9 +68,9 @@ ) from torchao.quantization.autoquant import ( AQInt8DynamicallyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight, - AQWeightOnlyQuantizedLinearWeight2, - AQWeightOnlyQuantizedLinearWeight3, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, ) @@ -727,21 +727,21 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) def test_aq_int8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -1498,10 +1498,10 @@ def test_get_model_size_autoquant(self, device, dtype): size = torchao.utils.get_model_size_in_bytes(model) from torchao.quantization.autoquant import ( - AQWeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight2, ) qtensor_class_list = ( - AQWeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight2, ) mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False) mod(example_input) From 07643fe26e7cc0b66a27b0c8cd636f4213ec6ed5 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 17:29:37 -0700 Subject: [PATCH 7/8] final fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/quant_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c298dd4ed4..b8c8a26fc0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -454,7 +454,7 @@ def apply_int4_weight_only_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hq) + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) From dc1a07d30a110156f46fb43b7facb6a70e6d015e Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 19:01:39 -0700 Subject: [PATCH 8/8] delete instead of comment out Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/fpx/fpx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/dtypes/fpx/fpx.py index 9f4c7b065c..6afa22f560 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/dtypes/fpx/fpx.py @@ -5,7 +5,6 @@ from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones -# from torchao.ops import quant_llm_linear from torchao.dtypes.utils import ( LayoutType, )