Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into awq
Browse files Browse the repository at this point in the history
  • Loading branch information
vayuda authored Oct 3, 2024
2 parents 7314d99 + 09b8b3c commit dc0c507
Show file tree
Hide file tree
Showing 11 changed files with 511 additions and 81 deletions.
54 changes: 33 additions & 21 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -677,27 +678,28 @@ def _test_lin_weight_subclass_impl(
):
if not "cuda" in test_device:
self.skipTest("test requires cuda")
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
ref_f = lin(x)

lin.weight = torch.nn.Parameter(
test_subclass_from_float(lin.weight), requires_grad=False
)
test = lin(x)
self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}"
)
lin_comp = torch.compile(lin, mode='max-autotune')
test_comp = lin_comp(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}"
)
with torch.no_grad():
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
ref_f = lin(x)

lin.weight = torch.nn.Parameter(
test_subclass_from_float(lin.weight), requires_grad=False
)
test = lin(x)
self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}"
)
lin_comp = torch.compile(lin, mode='max-autotune')
test_comp = lin_comp(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}"
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen")
Expand Down Expand Up @@ -753,6 +755,16 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Fails for {dtype}")
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def hqqint4(weight):
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant-float8" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
else:
model = autoquant(model, manual=True)

Expand Down Expand Up @@ -464,7 +466,7 @@ def callback(x):
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand Down
2 changes: 2 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
on using `torchao.float8` in a distributed setting.

:warning: <em>When using FSDP, it's recommended to enable `config.force_recompute_fp8_weight_in_bwd`, which prevents the un-sharded fp8 weights to be saved for backward. If you are using customized activation checkpoiting, you may ignore this config and handle the recomputation of fp8 weights in the customized AC code. </em>

# Performance

A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:
Expand Down
21 changes: 19 additions & 2 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class CastConfig:

def __post_init__(self):
if self.scaling_type is ScalingType.STATIC:
assert self.static_scale is not None, \
"static_scale must be specified for static scaling"
assert (
self.static_scale is not None
), "static_scale must be specified for static scaling"


@dataclass(frozen=True)
class DelayedScalingConfig:
Expand Down Expand Up @@ -132,6 +134,21 @@ class Float8LinearConfig:
# configuration, this field may move to per-tensor configs.
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()

# If the option is enabled, fp8_weight will always be re-computed in backward.
# It's recommended to enable this flag when using FSDP.
# Otherwise, the entire fp8_weight, instead of the sharded weight may be saved.
# If using outer activation checkpointing context or SAC, you may disable this option
# and handle the recomputation of fp8 weight in your customized AC context.
#
# Details:
# When using float8 training with FSDP, the original weight is sharded; fp8_weight (in forward) and fp8_weight_transpose (in backward) are used by the model.
# However, when partitioning the forward_backward graph, torch.compile may decide to
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.

force_recompute_fp8_weight_in_bwd: bool = False


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
Expand Down
131 changes: 79 additions & 52 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import dataclasses
import enum
import logging

from typing import Optional

import torch

import torch.utils.checkpoint as checkpoint

from torchao.float8.config import Float8LinearConfig, ScalingType

from torchao.float8.float8_scaling_utils import (
Expand All @@ -29,18 +32,26 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
)

from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax
from torchao.float8.float8_utils import (
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
)

from torchao.float8.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
WeightWithStaticFloat8CastTensor,
)

logger = logging.getLogger(__name__)


# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files
@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -180,6 +191,15 @@ def __init__(self, *args, **kwargs):
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward

# See the comments in config.py for more details of this option.
if (
self.config.enable_pre_and_post_forward
and not self.config.force_recompute_fp8_weight_in_bwd
):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
Expand Down Expand Up @@ -226,17 +246,17 @@ def create_buffers(self):

if self.config.cast_config_input.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_input",
"fp8_static_scale_input",
self.config.cast_config_input.static_scale.to(device),
)
if self.config.cast_config_weight.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_weight",
"fp8_static_scale_weight",
self.config.cast_config_weight.static_scale.to(device),
)
if self.config.cast_config_grad_output.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_grad_output",
"fp8_static_scale_grad_output",
self.config.cast_config_grad_output.static_scale.to(device),
)

Expand Down Expand Up @@ -296,56 +316,48 @@ def cast_input_to_float8(
input_fp8 = hp_tensor_to_float8_static(
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
)

return input_fp8

def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if isinstance(weight, Float8Tensor):
return None
if self.scaling_type_weight is ScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

weight_fp8 = hp_tensor_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
self.fp8_amax_weight,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = hp_tensor_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return tensor_to_scale(weight, e4m3_dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
weight_fp8 = hp_tensor_to_float8_static(
self.weight,
self.fp8_static_scale_weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8
return self.fp8_static_scale_weight

def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
is_amax_initialized: bool,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(weight, Float8Tensor):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
Expand All @@ -364,8 +376,8 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8E5M2BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
)
return output
Expand Down Expand Up @@ -396,9 +408,24 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_pre_forward(input)

input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = self.get_weight_scale(self.weight)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
self.cast_weight_to_float8_t,
self.weight,
self.is_amax_initialized,
weight_scale,
)
else:
weight_fp8_t = self.cast_weight_to_float8_t(
self.weight, self.is_amax_initialized, weight_scale
)

output = manual_float8_matmul.apply(input_fp8, weight_fp8_t)

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
Expand Down
2 changes: 1 addition & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
and j_is_nonzero_multiple_of_8
and k_is_nonzero_multiple_of_8
)

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
Expand Down
9 changes: 7 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .linear_activation_scale import ( # noqat: F403
to_weight_tensor_with_linear_activation_scale_metadata,
)

__all__ = [
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
Expand All @@ -42,10 +46,11 @@
"int4_weight_only",
"int8_weight_only",
"uintx_weight_only",
"float8_weight_only",
"fpx_weight_only",
"LinearActivationQuantizedTensor",
"to_linear_activation_quantized",
"to_weight_tensor_with_linear_activation_scale_metadata",
"float8_weight_only",
"float8_dynamic_activation_float8_weight"
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight"
]
Loading

0 comments on commit dc0c507

Please sign in to comment.