Skip to content

Commit

Permalink
Update QAT READMEs using new APIs
Browse files Browse the repository at this point in the history
Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

ghstack-source-id: 0755ab8dda73a26df42307e298a4bde2aefacbbb
Pull Request resolved: #1541
  • Loading branch information
andrewor14 committed Jan 10, 2025
1 parent 1a11857 commit c210c7a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 44 deletions.
33 changes: 22 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con

In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)

## Training

### Quantization Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)

```python
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer()
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics
model = qat_quantizer.prepare(model)
# Insert fake quantization
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
my_model,
intx_quantization_aware_training(activation_config, weight_config),
)

# Run Training...
# Run training... (not shown)

# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)
# Convert fake quantization to actual quantized operations
quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

## Training

### Float8

[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
Expand Down
146 changes: 113 additions & 33 deletions torchao/quantization/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
```

## API

torchao currently supports two QAT schemes for linear layers:
- int8 per token dynamic activations + int4 per group weights
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)

QAT typically involves applying a transformation to your model before and after training.
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
Expand All @@ -34,16 +28,24 @@ Between these two steps, training can proceed exactly as before.

![qat](images/qat_diagram.png)

To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
training, then apply the convert step after training for inference or generation.
For example, on a single GPU:

## API

torchao currently supports two QAT APIs, one through the [`quantize_`](https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_)
API (recommended) and one through the Quantizer classes (legacy). The `quantize_` API
allows flexible configuration of quantization settings for both activations and weights,
while the Quantizer classes each hardcode a specific quantization setting.

Here's an example of running QAT using the following quantization setting on a single GPU:
- int8 per-token dynamic asymmetric activation (for linears)
- int4 per-group symmetric weight (for linears)

```python
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

# Smaller version of llama3 to fit in a single GPU
# Set up smaller version of llama3 to fit in a single GPU
model = llama3(
vocab_size=4096,
num_layers=16,
Expand All @@ -53,45 +55,123 @@ model = llama3(
max_seq_len=2048,
).cuda()

# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Example training loop
def train(m: torch.nn.Module):
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = m(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
```

### quantize_

```python
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)

# prepare: insert fake quantization ops
# Model consists of `FakeQuantizedLinear` afterwards
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
my_model,
intx_quantization_aware_training(activation_config, weight_config),
)

# train (not shown)

# convert: transform fake quantization ops into actual quantized ops
# Model consists of `torch.nn.Linear` with quantized activation and weight tensors afterwards
quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))

# inference or generate
```

To fake quantize embedding in addition to linear, you can additionally call
the following with a filter function during the prepare step.

```
quantize_(
m,
intx_quantization_aware_training(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
```

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting

### Quantizer

```python
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)

# prepare: insert fake quantization ops
# Model consists of `Int8DynActInt4WeightQATLinear` afterwards
model = qat_quantizer.prepare(model)

# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
# train (not shown)

# convert: transform fake quantization ops into actual quantized ops
# Model consists of `Int8DynActInt4WeightLinear` afterwards
model = qat_quantizer.convert(model)

# inference or generate
```

torchao currently supports the following Quantizers:
- Linear: [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight
- Linear: [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
- Embedding: [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94), targeting int4 per-group symmetric weight
- [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242), allow users to compose multiple Quantizers (one for each layer), for example:

```
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
quantizer = ComposableQATQuantizer([
Int8DynActInt4WeightQATQuantizer(groupsize=group_size),
Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size),
])
# prepare + train + convert as before
```

## torchtune integration

Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
and apply quantized-aware fine-tuning as follows:

```
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
```

For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
torchtune also supports a QAT + LoRA distributed training recipe that is 1.89x faster
and uses 36.1% memory compared to vanilla QAT in our early experiments. You can read
more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700).

```
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora
```

For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).

## Evaluation Results

Expand Down

0 comments on commit c210c7a

Please sign in to comment.