Skip to content

Commit

Permalink
for now, delete the float8-all-gather-only functionality from float8 …
Browse files Browse the repository at this point in the history
…training

Summary:

In #1093 we added a config option, off
by default, to use only float8 all-gather for training and do the matrix
multiply in high precision.

This seems generally useful for communication bound workloads, but we
can probably think of a cleaner way to add this functionality (such as a
weight wrapper tensor subclass). The current implementation adds
non-trivial complexity and doesn't jive well with where we want to take this
codebase.

Since no one is using this internally or externally yet and we haven't talked
about it in the release notes, I think we should do a BC-breaking delete
as a one-off.  However, if people have concerns - let me know and we can
talk about less aggressive options.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Dec 19, 2024
1 parent 7618c26 commit 910edb3
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 265 deletions.
180 changes: 0 additions & 180 deletions test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py

This file was deleted.

8 changes: 4 additions & 4 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@
import torch
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase

from torchao import quantize_
from torchao._models.llama.model import prepare_inputs_for_model, Transformer
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.quant_api import (
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
Quantizer,
TwoStepQuantizer,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.subclass import (
Expand Down
9 changes: 0 additions & 9 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,6 @@ class Float8LinearConfig:
# tests so that the warning does not spam the CI stdout.
force_recompute_fp8_weight_in_bwd: bool = False

# If True, we only use fp8-all-gather to reduce the communication cost.
# The gemm computation is still done in the original precision.
# `cast_config_weight` is used to decide how to cast the weight to fp8,
# other casting configs will be ignored.
use_fp8_all_gather_only: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down Expand Up @@ -301,9 +295,6 @@ def __post_init__(self):
cc1.target_dtype == cc2.target_dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
if (
self.enable_fsdp_float8_all_gather
Expand Down
27 changes: 1 addition & 26 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
Expand Down Expand Up @@ -344,12 +343,6 @@ def cast_weight_to_float8_t(
)
return weight_fp8.t()

def cast_weight_to_original_t(self, weight: torch.Tensor):
if isinstance(weight, Float8Tensor):
return weight.to_original_precision().t()
else:
return weight.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
Expand All @@ -359,7 +352,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
)
return output

def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
Expand Down Expand Up @@ -403,24 +396,6 @@ def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
self.linear_mm_config,
self.config,
)
return output

def forward_original_precision_matmul(self, input: torch.Tensor) -> torch.Tensor:
if self.config.force_recompute_fp8_weight_in_bwd:
orig_weight_t = checkpoint.checkpoint(
self.cast_weight_to_original_t, self.weight
)
else:
orig_weight_t = self.cast_weight_to_original_t(self.weight)

output = torch.matmul(input, orig_weight_t)
return output

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.config.use_fp8_all_gather_only:
output = self.forward_original_precision_matmul(input)
else:
output = self.forward_fp8_matmul(input)

if self.bias is not None:
output = output + self.bias.to(output.dtype)
Expand Down
46 changes: 0 additions & 46 deletions torchao/testing/float8/fsdp2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,49 +93,3 @@ def check_parity_bf16_mp(
losses[1],
msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
)


def check_parity_fp8_comm_only(
test_cls,
ref_model: nn.Module,
ref_optim: torch.optim.Optimizer,
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
config: Float8LinearConfig,
precompute: bool = False,
compile: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(model(local_inp).sum())
losses[-1].backward()
if model is ref_model:
for name, param in model.named_parameters():
dist.all_reduce(param.grad)
param.grad.div_(dist.get_world_size())

if linear_requires_sync(config):
sync_float8_amax_and_scale_history(model)

optim.step()
if (
model is fsdp_model
and precompute
and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
):
precompute_float8_dynamic_scale_for_fsdp(model)

if compile:
# When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now.
assert (
torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()
), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}"
else:
test_cls.assertEqual(
losses[0],
losses[1],
f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
)

0 comments on commit 910edb3

Please sign in to comment.