From 910edb39e0ab7b09e006fd65a91f8f194397dfcd Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 19 Dec 2024 15:25:40 -0800 Subject: [PATCH] for now, delete the float8-all-gather-only functionality from float8 training Summary: In https://github.com/pytorch/ao/pull/1093 we added a config option, off by default, to use only float8 all-gather for training and do the matrix multiply in high precision. This seems generally useful for communication bound workloads, but we can probably think of a cleaner way to add this functionality (such as a weight wrapper tensor subclass). The current implementation adds non-trivial complexity and doesn't jive well with where we want to take this codebase. Since no one is using this internally or externally yet and we haven't talked about it in the release notes, I think we should do a BC-breaking delete as a one-off. However, if people have concerns - let me know and we can talk about less aggressive options. Test Plan: ``` ./test/float8/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: --- .../test_fsdp2/test_fsdp2_fp8_comm_only.py | 180 ------------------ test/quantization/test_quant_api.py | 8 +- torchao/float8/config.py | 9 - torchao/float8/float8_linear.py | 27 +-- torchao/testing/float8/fsdp2_utils.py | 46 ----- 5 files changed, 5 insertions(+), 265 deletions(-) delete mode 100644 test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py deleted file mode 100644 index d2e9a51c7f..0000000000 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ /dev/null @@ -1,180 +0,0 @@ -import copy -from typing import Optional - -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - -import torch -import torch._dynamo.testing -import torch.distributed as dist -import torch.nn as nn -from torch.distributed._composable.fsdp import fully_shard -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - ModelArgs, - Transformer, -) - -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - swap_linear_layers, -) -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import GemmInputRole -from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only - -if not is_sm_at_least_89(): - pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) - - -class Float8CommTestLinear(torch.nn.Linear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - fp8_param = hp_tensor_to_float8_dynamic( - self.weight, - torch.float8_e4m3fn, - None, # mm_linear_config, - reduce_amax=False, - gemm_input_role=GemmInputRole.WEIGHT, - ) - weight_orig = fp8_param.to_original_precision() - output = torch.matmul(input, weight_orig.t()) - if self.bias is not None: - output = output + self.bias.to(output.dtype) - return output - - @classmethod - def from_float( - cls, - mod, - ): - with torch.device("meta"): - new_mod = cls( - mod.in_features, - mod.out_features, - bias=(mod.bias is not None), - ) - new_mod.weight = mod.weight - new_mod.bias = mod.bias - return new_mod - - -def convert_to_float8_comm_test_layers( - module: nn.Module, -) -> nn.Module: - from_float = lambda m: Float8CommTestLinear.from_float( - m, - ) - return swap_linear_layers( - module, - from_float, - ) - - -class TestFloat8Common: - def broadcast_module(self, module: nn.Module) -> None: - # Broadcast for multi-threaded process group tests since seed is per - # process, not per thread - for param in module.parameters(): - dist.broadcast(param, src=0) - - def init_transformer( - self, weight_tying: bool, dtype: Optional[torch.dtype] = None - ) -> nn.Module: - torch.manual_seed(42) - args = ModelArgs( - n_layers=3, - dim=768, - n_heads=12, - dropout_p=0.0, - weight_tying=weight_tying, - vocab_size=32, - ) - module = Transformer(args).cuda() - if dtype is not None: - module = module.to(dtype=dtype) - self.broadcast_module(module) - return module - - -class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 2) - - @skip_if_lt_x_gpu(2) - def test_transformer_parity(self): - self.run_subtests( - { - "compile_transformer_block": [False, True], - "precompute": [False, True], - "scaling_type_weight": [ScalingType.DYNAMIC], - "dtype": [torch.float32, torch.bfloat16], - }, - self._test_transformer_parity, - ) - - def _test_transformer_parity( - self, - precompute: bool, - scaling_type_weight: ScalingType, - compile_transformer_block: bool, - dtype: Optional[torch.dtype] = None, - ): - if scaling_type_weight is ScalingType.DELAYED and precompute: - return - - module = self.init_transformer(weight_tying=False, dtype=dtype) - - local_inp = torch.randint( - 0, module.tok_embeddings.weight.size(0), (16, 16), device="cuda" - ) - - # reference modules - ref_module = copy.deepcopy(module) - convert_to_float8_comm_test_layers( - ref_module, - ) - - # fp8 comm-only modules - float8_linear_config2 = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=scaling_type_weight), - enable_fsdp_float8_all_gather=True, - use_fp8_all_gather_only=True, - ) - convert_to_float8_training( - module, - config=float8_linear_config2, - ) - - for layer_id, transformer_block in module.layers.named_children(): - if compile_transformer_block: - transformer_block = torch.compile(transformer_block, dynamic=False) - fully_shard(transformer_block) - module.layers.register_module(layer_id, transformer_block) - fully_shard(module) - - ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) - optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) - - check_parity_fp8_comm_only( - self, - ref_module, - ref_optim, - module, - optim, - local_inp, - config=float8_linear_config2, - precompute=precompute, - compile=compile_transformer_block, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 205ba91290..177c357047 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -15,25 +15,25 @@ import torch from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, XNNPACKQuantizer, + get_symmetric_quantization_config, ) from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase from torchao import quantize_ -from torchao._models.llama.model import prepare_inputs_for_model, Transformer +from torchao._models.llama.model import Transformer, prepare_inputs_for_model from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import AffineQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.quant_api import ( + Quantizer, + TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, - Quantizer, - TwoStepQuantizer, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( diff --git a/torchao/float8/config.py b/torchao/float8/config.py index d4a5516154..c7f32cd3fa 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -234,12 +234,6 @@ class Float8LinearConfig: # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False - # If True, we only use fp8-all-gather to reduce the communication cost. - # The gemm computation is still done in the original precision. - # `cast_config_weight` is used to decide how to cast the weight to fp8, - # other casting configs will be ignored. - use_fp8_all_gather_only: bool = False - def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass @@ -301,9 +295,6 @@ def __post_init__(self): cc1.target_dtype == cc2.target_dtype ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" - if self.use_fp8_all_gather_only: - assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" - # See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning. if ( self.enable_fsdp_float8_all_gather diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6be4969e3e..b7a3449277 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -20,7 +20,6 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( - Float8Tensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, @@ -344,12 +343,6 @@ def cast_weight_to_float8_t( ) return weight_fp8.t() - def cast_weight_to_original_t(self, weight: torch.Tensor): - if isinstance(weight, Float8Tensor): - return weight.to_original_precision().t() - else: - return weight.t() - def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: assert self.scaling_type_grad_output is ScalingType.DYNAMIC output = NoopFwToFloat8BwDynamic.apply( @@ -359,7 +352,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: ) return output - def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: has_any_axiswise_scaling = any( cc.scaling_granularity is ScalingGranularity.AXISWISE for cc in [ @@ -403,24 +396,6 @@ def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: self.linear_mm_config, self.config, ) - return output - - def forward_original_precision_matmul(self, input: torch.Tensor) -> torch.Tensor: - if self.config.force_recompute_fp8_weight_in_bwd: - orig_weight_t = checkpoint.checkpoint( - self.cast_weight_to_original_t, self.weight - ) - else: - orig_weight_t = self.cast_weight_to_original_t(self.weight) - - output = torch.matmul(input, orig_weight_t) - return output - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.config.use_fp8_all_gather_only: - output = self.forward_original_precision_matmul(input) - else: - output = self.forward_fp8_matmul(input) if self.bias is not None: output = output + self.bias.to(output.dtype) diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index af46b7fa71..a059b4d2a9 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -93,49 +93,3 @@ def check_parity_bf16_mp( losses[1], msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", ) - - -def check_parity_fp8_comm_only( - test_cls, - ref_model: nn.Module, - ref_optim: torch.optim.Optimizer, - fsdp_model: nn.Module, - fsdp_optim: torch.optim.Optimizer, - local_inp: torch.Tensor, - config: Float8LinearConfig, - precompute: bool = False, - compile: bool = False, -): - for iter_idx in range(10): - losses: List[torch.Tensor] = [] - for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): - optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) - losses.append(model(local_inp).sum()) - losses[-1].backward() - if model is ref_model: - for name, param in model.named_parameters(): - dist.all_reduce(param.grad) - param.grad.div_(dist.get_world_size()) - - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model) - - optim.step() - if ( - model is fsdp_model - and precompute - and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC - ): - precompute_float8_dynamic_scale_for_fsdp(model) - - if compile: - # When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now. - assert ( - torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any() - ), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" - else: - test_cls.assertEqual( - losses[0], - losses[1], - f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", - )