diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 2780600..0bbf116 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -14,11 +14,7 @@ import torch import torch.utils.benchmark as benchmark -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( linear_requires_sync, @@ -107,15 +103,13 @@ def main( device = "cuda" print(f"Compile is set to | {compile}") - scaling_type_input = TensorScalingType(scaling_type_input) - scaling_type_weight = TensorScalingType(scaling_type_weight) - scaling_type_grad_output = TensorScalingType(scaling_type_grad_output) + scaling_type_input = ScalingType(scaling_type_input) + scaling_type_weight = ScalingType(scaling_type_weight) + scaling_type_grad_output = ScalingType(scaling_type_grad_output) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), ) # LLaMa 2 70B single-node weight shapes diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 5cb5223..a741dec 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -14,11 +14,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, sync_float8_amax_and_scale_history, @@ -33,11 +29,9 @@ lr = 0.01 config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index e90c5b3..716ceed 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -18,11 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -217,15 +213,13 @@ def main( assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") - scaling_type_input = TensorScalingType(scaling_type_input) - scaling_type_weight = TensorScalingType(scaling_type_weight) - scaling_type_grad_output = TensorScalingType(scaling_type_grad_output) + scaling_type_input = ScalingType(scaling_type_input) + scaling_type_weight = ScalingType(scaling_type_weight) + scaling_type_grad_output = ScalingType(scaling_type_grad_output) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), ) scaling_repr = "_".join( [ diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 08c0ac4..8fd8476 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,11 +5,11 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.config import ( + CastConfig, DelayedScalingConfig, Float8GemmConfig, Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, + ScalingType, ) from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( @@ -33,10 +33,10 @@ __all__ = [ # configuration "DelayedScalingConfig", - "TensorScalingType", + "ScalingType", "Float8GemmConfig", "Float8LinearConfig", - "Float8TensorCastConfig", + "CastConfig", # top level UX "convert_to_float8_training", "linear_requires_sync", diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 6408ac7..5d1bf9f 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -8,25 +8,26 @@ from dataclasses import dataclass -class TensorScalingType(enum.Enum): +# TODO(future): consider renaming to ScalingType +class ScalingType(enum.Enum): DELAYED = "delayed" DYNAMIC = "dynamic" def short_str(self): - if self is TensorScalingType.DELAYED: + if self is ScalingType.DELAYED: return "del" else: - assert self is TensorScalingType.DYNAMIC + assert self is ScalingType.DYNAMIC return "dyn" @dataclass(frozen=True) -class Float8TensorCastConfig: +class CastConfig: """ Configuration for casting a single tensor to float8 """ - scaling_type: TensorScalingType = TensorScalingType.DYNAMIC + scaling_type: ScalingType = ScalingType.DYNAMIC @dataclass(frozen=True) @@ -74,9 +75,9 @@ class Float8LinearConfig: # # Per-tensor configuration for `input`, `weight`, `grad_output` # - cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig() - cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig() - cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig() + cast_config_input: CastConfig = CastConfig() + cast_config_weight: CastConfig = CastConfig() + cast_config_grad_output: CastConfig = CastConfig() # # Per-gemm configuration for gemms calculating `output`, `grad_input` and diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 42eeb86..fd76a8e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -14,7 +14,7 @@ import torch -from float8_experimental.config import Float8LinearConfig, TensorScalingType +from float8_experimental.config import Float8LinearConfig, ScalingType from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, @@ -159,9 +159,9 @@ def __init__(self, *args, **kwargs): self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type # Convenience flag to skip code related to delayed scaling self.has_any_delayed_scaling = ( - self.scaling_type_input is TensorScalingType.DELAYED - or self.scaling_type_weight is TensorScalingType.DELAYED - or self.scaling_type_grad_output is TensorScalingType.DELAYED + self.scaling_type_input is ScalingType.DELAYED + or self.scaling_type_weight is ScalingType.DELAYED + or self.scaling_type_grad_output is ScalingType.DELAYED ) self.config = config @@ -284,7 +284,7 @@ def cast_input_to_float8( autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - if self.scaling_type_input is TensorScalingType.DELAYED: + if self.scaling_type_input is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( input, @@ -305,14 +305,14 @@ def cast_input_to_float8( gemm_input_role=GemmInputRole.INPUT, ) else: - assert self.scaling_type_input is TensorScalingType.DYNAMIC + assert self.scaling_type_input is ScalingType.DYNAMIC input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config) return input_fp8 def cast_weight_to_float8( self, weight: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: - if self.scaling_type_weight is TensorScalingType.DELAYED: + if self.scaling_type_weight is ScalingType.DELAYED: if isinstance(self.weight, Float8Tensor): # cast by FSDP weight_fp8 = self.weight else: @@ -337,7 +337,7 @@ def cast_weight_to_float8( gemm_input_role=GemmInputRole.WEIGHT, ) else: - assert self.scaling_type_weight is TensorScalingType.DYNAMIC + assert self.scaling_type_weight is ScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP weight_fp8 = self.weight else: @@ -349,7 +349,7 @@ def cast_weight_to_float8( return weight_fp8 def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: - if self.scaling_type_grad_output is TensorScalingType.DELAYED: + if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name output = NoopFwToFloat8E5M2Bw.apply( output, @@ -361,7 +361,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: self.linear_mm_config, ) else: - assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC + assert self.scaling_type_grad_output is ScalingType.DYNAMIC output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config) return output @@ -448,7 +448,7 @@ def from_float( # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized if config.enable_fsdp_float8_all_gather: - if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC: + if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC: new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( new_mod.weight, @@ -456,9 +456,7 @@ def from_float( ) ) else: - assert ( - config.cast_config_weight.scaling_type is TensorScalingType.DELAYED - ) + assert config.cast_config_weight.scaling_type is ScalingType.DELAYED new_mod.weight = torch.nn.Parameter( WeightWithDelayedFloat8CastTensor( new_mod.weight, diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index c72b620..7fffcde 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.config import Float8LinearConfig, TensorScalingType +from float8_experimental.config import Float8LinearConfig, ScalingType from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_utils import ( @@ -27,9 +27,9 @@ def linear_requires_sync(config: Float8LinearConfig): """Returns whether the given linear_type requires sync before forward.""" return any( [ - config.cast_config_input.scaling_type is TensorScalingType.DELAYED, - config.cast_config_weight.scaling_type is TensorScalingType.DELAYED, - config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED, + config.cast_config_input.scaling_type is ScalingType.DELAYED, + config.cast_config_weight.scaling_type is ScalingType.DELAYED, + config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, ] ) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 99850ad..eea7376 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from float8_experimental.config import TensorScalingType +from float8_experimental.config import ScalingType from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, @@ -28,8 +28,8 @@ def _float8_linear_supports_float8_allgather(m): # TODO(future): add support for delayed scaling for activations # and gradients return ( - m.scaling_type_input == TensorScalingType.DYNAMIC - and m.scaling_type_grad_output == TensorScalingType.DYNAMIC + m.scaling_type_input == ScalingType.DYNAMIC + and m.scaling_type_grad_output == ScalingType.DYNAMIC ) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 5de51e3..607de34 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -33,13 +33,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: optim.step() precompute_float8_dynamic_scale_for_fsdp(model) """ - from float8_experimental.config import TensorScalingType + from float8_experimental.config import ScalingType from float8_experimental.float8_linear import Float8Linear from torch.distributed._tensor import DTensor if any( - isinstance(m, Float8Linear) - and m.scaling_type_weight is TensorScalingType.DELAYED + isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED for m in module.modules() ): raise NotImplementedError("Only supports delayed scaling") diff --git a/test/test_base.py b/test/test_base.py index ffc8d0c..2f7c717 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -16,11 +16,8 @@ import torch import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType +from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -183,15 +180,15 @@ def _test_linear_impl( amax_buffer_names = [] amax_history_buffer_names = [] scale_buffer_names = [] - if config.cast_config_input.scaling_type is TensorScalingType.DELAYED: + if config.cast_config_input.scaling_type is ScalingType.DELAYED: amax_buffer_names.append("fp8_amax_input") amax_history_buffer_names.append("fp8_amax_history_input") scale_buffer_names.append("fp8_scale_input") - if config.cast_config_weight.scaling_type is TensorScalingType.DELAYED: + if config.cast_config_weight.scaling_type is ScalingType.DELAYED: amax_buffer_names.append("fp8_amax_weight") amax_history_buffer_names.append("fp8_amax_history_weight") scale_buffer_names.append("fp8_scale_weight") - if config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED: + if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED: amax_buffer_names.append("fp8_amax_grad_output") amax_history_buffer_names.append("fp8_amax_history_grad_output") scale_buffer_names.append("fp8_scale_grad_output") @@ -223,14 +220,14 @@ def _test_linear_impl( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + [ScalingType.DELAYED, ScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @@ -239,9 +236,9 @@ def test_linear( self, x_shape, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, linear_dtype: torch.dtype, linear_bias: bool, ): @@ -257,11 +254,9 @@ def test_linear( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) self._test_linear_impl( @@ -292,15 +287,9 @@ def test_autocast_outputs( m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) m = Float8Linear.from_float(copy.deepcopy(m_ref), config) @@ -385,9 +374,7 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): def test_repr(self): m = nn.Linear(32, 16) config = Float8LinearConfig( - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), emulate=True, ) m = Float8Linear.from_float( diff --git a/test/test_compile.py b/test/test_compile.py index e7b5285..a71b879 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -13,11 +13,7 @@ import torch import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -67,13 +63,13 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -81,18 +77,16 @@ def _test_compile_base( def test_eager_only( fullgraph, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, dtype: torch.dtype, ): torch._dynamo.reset() config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) _test_compile_base( @@ -106,31 +100,29 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, dtype: torch.dtype, ): torch._dynamo.reset() config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) _test_compile_base( @@ -144,31 +136,29 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( fullgraph, emulate: bool, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, dtype: torch.dtype, ): torch._dynamo.reset() config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), emulate=emulate, ) _test_compile_base( @@ -270,15 +260,9 @@ def test_sync_amax_func(): nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) float8_mod = convert_to_float8_training( module, @@ -314,15 +298,9 @@ def test_sync_amax_func_cuda_graph_success(): nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ).to("cuda") config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) convert_to_float8_training( my_module, diff --git a/test/test_fsdp.py b/test/test_fsdp.py index c7f86cc..f5be23b 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -21,11 +21,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -78,12 +74,10 @@ def fsdp_main(rank, world_size, args): model_fp8 = copy.deepcopy(model) scaling_type_weight = ( - TensorScalingType.DYNAMIC - if use_weight_dynamic_scaling - else TensorScalingType.DELAYED + ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED ) config = Float8LinearConfig( - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), # TODO(future): delete this arg as it's always False emulate=False, ) diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 6d5719a..266bd6d 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -8,11 +8,7 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import convert_to_float8_training from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor from test_fsdp2_common import check_parity_bf16_mp, check_parity_no_mp @@ -86,8 +82,8 @@ def test_transformer_parity(self): "enable_fsdp_float8_all_gather": [False, True], "precompute": [False, True], "scaling_type_weight": [ - TensorScalingType.DYNAMIC, - TensorScalingType.DELAYED, + ScalingType.DYNAMIC, + ScalingType.DELAYED, ], "compile_transformer_block": [False, True], }, @@ -98,12 +94,12 @@ def _test_transformer_parity( self, enable_fsdp_float8_all_gather: bool, precompute: bool, - scaling_type_weight: TensorScalingType, + scaling_type_weight: ScalingType, compile_transformer_block: bool, ): if not enable_fsdp_float8_all_gather and precompute: return - elif scaling_type_weight is TensorScalingType.DELAYED and precompute: + elif scaling_type_weight is ScalingType.DELAYED and precompute: return # NOTE: Weight-tying does not compose with fp8 all-gather because the @@ -114,7 +110,7 @@ def _test_transformer_parity( module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) float8_linear_config1 = Float8LinearConfig( - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) convert_to_float8_training( ref_module, @@ -126,7 +122,7 @@ def _test_transformer_parity( ref_module.layers.register_module(layer_id, transformer_block) float8_linear_config2 = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) convert_to_float8_training( module, @@ -416,20 +412,16 @@ def test_fp32_fp8_single_module_parity(self): """ choices = itertools.product( [False, True], - [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + [ScalingType.DYNAMIC, ScalingType.DELAYED], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) float8_linear_config2 = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) @@ -464,20 +456,16 @@ def test_fp32_fp8_multi_module_parity(self): """ choices = itertools.product( [False, True], - [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + [ScalingType.DYNAMIC, ScalingType.DELAYED], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) float8_linear_config2 = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=Float8TensorCastConfig( - scaling_type=scaling_type_weight - ), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) module = self.init_multi_module().cuda() ref_module = copy.deepcopy(module) @@ -546,9 +534,7 @@ def test_delayed_scaling_inplace_update(self): module = self.init_single_module() float8_linear_config = Float8LinearConfig( enable_fsdp_float8_all_gather=True, - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), ) m_fp8 = convert_to_float8_training( module, diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 8f1fa80..f26278c 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.config import Float8LinearConfig, TensorScalingType +from float8_experimental.config import Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -44,7 +44,7 @@ def check_parity_no_mp( if ( model is fsdp_model and precompute - and config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC + and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC ): precompute_float8_dynamic_scale_for_fsdp(model) diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index bc80023..e20ab15 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import Float8LinearConfig -from float8_experimental.config import Float8TensorCastConfig, TensorScalingType +from float8_experimental.config import CastConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, sync_float8_amax_and_scale_history, @@ -57,15 +57,9 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): # to get around this, we can disable amax init config = Float8LinearConfig( enable_amax_init=False, - cast_config_input=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_weight=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=TensorScalingType.DELAYED - ), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 35b640a..421b7a9 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import TensorScalingType +from float8_experimental.config import ScalingType from float8_experimental.float8_linear_utils import convert_to_float8_training from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import compute_error diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 4d4446c..73e3211 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -14,11 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import ( - Float8LinearConfig, - Float8TensorCastConfig, - TensorScalingType, -) +from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType from float8_experimental.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -79,22 +75,22 @@ def init_weights(self, init_std: float): class TestFloat8NumericsIntegrationTest: @pytest.mark.parametrize( - "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + [ScalingType.DELAYED, ScalingType.DYNAMIC], ) @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( self, - scaling_type_input: TensorScalingType, - scaling_type_weight: TensorScalingType, - scaling_type_grad_output: TensorScalingType, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 @@ -114,11 +110,9 @@ def test_encoder_fw_bw( # for now just test the encoder to simplify things model_fp8 = copy.deepcopy(model_ref) config = Float8LinearConfig( - cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input), - cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), - cast_config_grad_output=Float8TensorCastConfig( - scaling_type=scaling_type_grad_output - ), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), ) convert_to_float8_training( model_fp8,