Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int4 fixes and improvements #804

Merged
merged 8 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only
)
quantize_(m, int4_weight_only())
quantize_(m, int4_weight_only())
```

For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline
For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.
HDCharles marked this conversation as resolved.
Show resolved Hide resolved

If you're unsure which option to use, you can also run [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers for you and skip quantizing layers where overhead is too large.
If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer.

```python
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
Expand Down Expand Up @@ -102,7 +102,7 @@ from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
```

In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)

We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**

Expand All @@ -114,7 +114,7 @@ optim.load_state_dict(ckpt["optim"])
## Composability

1. `torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels.
3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.
3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.

The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.

Expand All @@ -135,7 +135,7 @@ Things we're excited about but need more time to cook in the oven

1. [MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
2. [Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet.
3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype
3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype
4. [Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support.

## Installation
Expand Down Expand Up @@ -169,7 +169,7 @@ USE_CPP=0 pip install -e .
We're also fortunate to be integrated into some of the leading open-source libraries including
1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
2. Hugging Face diffusers with a minimal example thanks to [Sayak Paul](https://www.linkedin.com/posts/sayak-paul_want-to-combine-quantization-and-benefit-activity-7231950868605022208-g52d?utm_source=share&utm_medium=member_desktop)
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)

## Videos
* [Slaying OOMs at the Mastering LLM's course](https://www.youtube.com/watch?v=UvRl4ansfCg)
Expand Down
16 changes: 8 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@
)
from torchao.quantization.autoquant import (
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3,
AQInt8WeightOnlyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

)
Expand Down Expand Up @@ -727,21 +727,21 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down Expand Up @@ -1498,10 +1498,10 @@ def test_get_model_size_autoquant(self, device, dtype):
size = torchao.utils.get_model_size_in_bytes(model)

from torchao.quantization.autoquant import (
AQWeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight2,
)
qtensor_class_list = (
AQWeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight2,
)
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False)
mod(example_input)
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

The llama folder contains code/scripts for stable benchmarking llama models.

To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b and/or https://huggingface.co/meta-llama/Meta-Llama-3-8B
To get model weights, go to https://huggingface.co/meta-llama/Llama-2-7b, https://huggingface.co/meta-llama/Meta-Llama-3-8B, https://huggingface.co/meta-llama/Meta-Llama-3.1-8B
and follow the steps to gain access.

Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to
Expand Down
Loading
Loading