diff --git a/test/float8/test_base.py b/test/float8/test_base.py index ba6281deaf..e5f64abf57 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -26,6 +26,8 @@ from torchao.float8.config import ( CastConfig, + e4m3_dtype, + e5m2_dtype, Float8LinearConfig, Float8LinearRecipeName, ScalingGranularity, @@ -53,8 +55,6 @@ from torchao.float8.float8_utils import ( FP8_TYPES, compute_error, - e4m3_dtype, - e5m2_dtype, fp8_tensor_statistics, tensor_to_scale, ) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 6d21686e32..57362d6990 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -27,6 +27,7 @@ from torchao.float8.config import ( CastConfig, + e4m3_dtype, Float8LinearConfig, Float8LinearRecipeName, ScalingType, @@ -47,7 +48,6 @@ LinearMMConfig, ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5985a3f5b5..e0de749d0b 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -31,9 +31,9 @@ from tqdm import tqdm from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType +from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic +from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -45,7 +45,7 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) -from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale +from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.dtensor_utils import ToyModel @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) + out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward() diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 6a092d5f38..de57655e88 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -62,6 +62,7 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + dtype: Optional[torch.dtype] = None def short_str(self): return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" @@ -75,6 +76,9 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" + assert self.dtype is None or ( + self.dtype.is_floating_point and self.dtype.itemsize == 1 + ), "must specify a 8-bit floating-point dtype" @dataclass(frozen=True) @@ -124,6 +128,12 @@ def __post_init__(self): self.e5m2_dtype = torch.float8_e5m2fnuz +# User defined type for using the individual F8 type based on config +type_config = Float8TypeConfig() +e4m3_dtype = type_config.e4m3_dtype +e5m2_dtype = type_config.e5m2_dtype + + @dataclass(frozen=True) class Float8GemmConfig: """ @@ -276,6 +286,20 @@ def __post_init__(self): is_disabled_1 == is_disabled_2 ), f"incompatible operand precision for {gemm_name}" + for cc1, cc2, operand_name, default_dtype in [ + (cc_i, cc_i_gw, "input", e4m3_dtype), + (cc_w, cc_w_gi, "weight", e4m3_dtype), + (cc_go, cc_go_gw, "grad_output", e5m2_dtype), + ]: + # Override the dataclass being frozen + if cc1.dtype is None: + object.__setattr__(cc1, "dtype", default_dtype) + if cc2.dtype is None: + object.__setattr__(cc2, "dtype", default_dtype) + assert ( + cc1.dtype == cc2.dtype + ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" + if self.use_fp8_all_gather_only: assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" @@ -340,12 +364,14 @@ def recipe_name_to_linear_config( cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype + ) cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) - cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype) return Float8LinearConfig( cast_config_input=cc_i, diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 776de917f1..c34c5be670 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -15,9 +15,9 @@ from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDelayed, - NoopFwToFloat8E5M2BwDynamic, - NoopFwToFloat8E5M2BwStatic, + NoopFwToFloat8BwDelayed, + NoopFwToFloat8BwDynamic, + NoopFwToFloat8BwStatic, _maybe_initialize_amaxes_scales_for_float8_cast, get_maybe_axiswise_dim, hp_tensor_to_float8_delayed, @@ -32,8 +32,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - e4m3_dtype, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -136,7 +134,7 @@ def forward( else: input_maybe_fp8 = hp_tensor_to_float8_dynamic( input_hp, - e4m3_dtype, + c.cast_config_input.dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input.scaling_granularity, @@ -150,7 +148,7 @@ def forward( else: weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight.dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight.scaling_granularity, @@ -186,7 +184,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output.scaling_granularity, @@ -204,7 +202,7 @@ def backward(ctx, grad_output): # the entire tensor. weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight_for_grad_input.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, @@ -236,7 +234,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output_for_grad_weight.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, @@ -250,7 +248,7 @@ def backward(ctx, grad_output): else: input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, - e4m3_dtype, + c.cast_config_input_for_grad_weight.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, @@ -347,11 +345,9 @@ def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len device = self.weight.device - # TODO(future PR): dtype values below don't have the other float8 - # flavors, fix it - default_input = torch.finfo(torch.float8_e4m3fn).max - default_weight = torch.finfo(torch.float8_e4m3fn).max - default_grad_output = torch.finfo(torch.float8_e5m2).max + default_input = torch.finfo(self.config.cast_config_input.dtype).max + default_weight = torch.finfo(self.config.cast_config_weight.dtype).max + default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max # Note: for now, create all the buffers if any are needed, to postpone # the work to make the scale and amax syncing and history calculation @@ -438,14 +434,14 @@ def cast_input_to_float8( self.fp8_amax_history_input, self.fp8_scale_input, scale_fn_name, - e4m3_dtype, + self.config.cast_config_input.dtype, is_amax_initialized, reduce_amax=True, ) input_fp8 = hp_tensor_to_float8_delayed( input, self.fp8_scale_input, - e4m3_dtype, + self.config.cast_config_input.dtype, self.fp8_amax_input, linear_mm_config=self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, @@ -453,14 +449,17 @@ def cast_input_to_float8( elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( input, - e4m3_dtype, + self.config.cast_config_input.dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC input_fp8 = hp_tensor_to_float8_static( - input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config + input, + self.fp8_static_scale_input, + self.config.cast_config_input.dtype, + self.linear_mm_config, ) return input_fp8 @@ -476,14 +475,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: self.fp8_amax_history_weight, self.fp8_scale_weight, scale_fn_name, - e4m3_dtype, + self.config.cast_config_weight.dtype, self.is_amax_initialized, reduce_amax=True, ) self.fp8_amax_weight.fill_(tensor_to_amax(weight)) return self.fp8_scale_weight elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, e4m3_dtype) + return tensor_to_scale(weight, self.config.cast_config_weight.dtype) else: assert self.scaling_type_weight is ScalingType.STATIC return self.fp8_static_scale_weight @@ -499,7 +498,7 @@ def cast_weight_to_float8_t( weight_fp8 = hp_tensor_and_scale_to_float8( weight, weight_scale, - e4m3_dtype, + self.config.cast_config_weight.dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) @@ -514,7 +513,7 @@ def cast_weight_to_original_t(self, weight: torch.Tensor): def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8E5M2BwDelayed.apply( + output = NoopFwToFloat8BwDelayed.apply( output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, @@ -522,15 +521,21 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: scale_fn_name, self.is_amax_initialized, self.linear_mm_config, + self.config.cast_config_grad_output.dtype, ) elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) + output = NoopFwToFloat8BwDynamic.apply( + output, + self.linear_mm_config, + self.config.cast_config_grad_output.dtype, + ) else: assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8E5M2BwStatic.apply( + output = NoopFwToFloat8BwStatic.apply( output, self.fp8_static_scale_grad_output, self.linear_mm_config, + self.config.cast_config_grad_output.dtype, ) return output @@ -547,19 +552,16 @@ def float8_post_forward(self): return def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: - has_any_axiswise_scaling = ( - self.config.cast_config_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_input_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE + has_any_axiswise_scaling = any( + cc.scaling_granularity is ScalingGranularity.AXISWISE + for cc in [ + self.config.cast_config_input, + self.config.cast_config_weight, + self.config.cast_config_grad_output, + self.config.cast_config_input_for_grad_weight, + self.config.cast_config_weight_for_grad_input, + self.config.cast_config_grad_output_for_grad_weight, + ] ) if not has_any_axiswise_scaling: @@ -682,6 +684,7 @@ def from_float( WeightWithDynamicFloat8CastTensor( new_mod.weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.dtype, ) ) elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: @@ -692,6 +695,7 @@ def from_float( new_mod.fp8_amax_history_weight, new_mod.fp8_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.dtype, new_mod.is_amax_initialized, ) ) @@ -702,6 +706,7 @@ def from_float( new_mod.weight, new_mod.fp8_static_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.dtype, ) ) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index c4fc88eb37..37453d8cfe 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -15,8 +15,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_utils import ( amax_history_to_scale_stack, - e4m3_dtype, - e5m2_dtype, ) log = logging.getLogger(__name__) @@ -227,6 +225,9 @@ def inner_func(): fp8_weight_amax_history_stack = [None] * len(fp8_layers) fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) + input_dtypes = set() + weight_dtypes = set() + grad_output_dtypes = set() scale_fn_recipes = set() for idx, child in enumerate(fp8_layers): @@ -238,8 +239,15 @@ def inner_func(): fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output + input_dtypes.add(child.config.cast_config_input.dtype) + weight_dtypes.add(child.config.cast_config_weight.dtype) + grad_output_dtypes.add(child.config.cast_config_grad_output.dtype) scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) + (input_dtype,) = input_dtypes + (weight_dtype,) = weight_dtypes + (grad_output_dtype,) = grad_output_dtypes + if len(scale_fn_recipes) != 1: raise ValueError( f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" @@ -297,13 +305,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe + fp8_input_amax_history_stack, input_dtype, scale_fn_recipe ) new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe + fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe ) new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe + fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe, ) # Iterate through the layers and update the scales diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index c8fe61c8a4..dec03f1ebb 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -22,7 +22,6 @@ ) from torchao.float8.float8_utils import ( amax_history_to_scale, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -182,7 +181,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): +class NoopFwToFloat8BwDelayed(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with delayed scaling, initialize if needed @@ -198,6 +197,7 @@ def forward( scale_fn_name, is_amax_initialized, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): ctx.save_for_backward( fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output @@ -205,6 +205,7 @@ def forward( ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.linear_mm_config = linear_mm_config + ctx.dtype = dtype return tensor @staticmethod @@ -223,7 +224,7 @@ def backward(ctx, go): fp8_amax_history_grad_output, fp8_scale_grad_output, scale_fn_name, - e5m2_dtype, + ctx.dtype, is_amax_initialized, reduce_amax=True, ) @@ -233,16 +234,16 @@ def backward(ctx, go): res = hp_tensor_and_scale_to_float8( go, fp8_scale_grad_output, - e5m2_dtype, + ctx.dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - empty_grads = None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None return res, *empty_grads @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): +class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with dynamic scaling @@ -253,27 +254,29 @@ def forward( ctx, tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): ctx.linear_mm_config = linear_mm_config + ctx.dtype = dtype return tensor @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None - gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + return gradY, None, None + gradY_scale = tensor_to_scale(gradY, ctx.dtype) fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None + return fp8_tensor, None, None @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwStatic(torch.autograd.Function): +class NoopFwToFloat8BwStatic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with static scaling @@ -285,21 +288,23 @@ def forward( tensor, scale, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): ctx.save_for_backward(scale) ctx.linear_mm_config = linear_mm_config + ctx.dtype = dtype return tensor @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None + return gradY, None, None, None (gradY_scale,) = ctx.saved_tensors fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None, None + return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 20f40330a8..fe2498e2b0 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -10,7 +10,6 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( - e4m3_dtype, to_fp8_saturated, ) @@ -133,7 +132,7 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, @@ -213,7 +212,7 @@ def backward(ctx, g): def hp_tensor_and_scale_to_float8( hp_tensor: torch.Tensor, s: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a3fc4ce7e5..814bd2869c 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -8,13 +8,12 @@ RowwiseParallel, ) -from torchao.float8.config import ScalingType +from torchao.float8.config import ScalingType, e4m3_dtype from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDynamic, + NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole -from torchao.float8.float8_utils import e4m3_dtype # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -49,7 +48,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.config.cast_config_input.dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -70,7 +69,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply( + outputs, mod.linear_mm_config, mod.config.cast_config_grad_output.dtype + ) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -103,7 +104,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.config.cast_config_input.dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -123,7 +124,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply( + outputs, mod.linear_mm_config, mod.config.cast_config_grad_output.dtype + ) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 29319f3814..90927659f8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -10,7 +10,7 @@ import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8TypeConfig, ScalingGranularity +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -28,12 +28,6 @@ } -# User defined type for using the individual F8 type based on config -type_config = Float8TypeConfig() -e4m3_dtype = type_config.e4m3_dtype -e5m2_dtype = type_config.e5m2_dtype - - @torch.no_grad() def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): """Converts the amax value of a tensor to the fp8 scale. @@ -173,7 +167,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def fp8_tensor_statistics( - tensor: torch.Tensor, float8_dtype=e4m3_dtype + tensor: torch.Tensor, float8_dtype: torch.dtype ) -> Tuple[int, ...]: """Calculate FP8 tensor stats diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 8c60995a86..62c0741b8a 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -22,7 +22,7 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_utils import EPS, e4m3_dtype +from torchao.float8.float8_utils import EPS @torch.no_grad() @@ -54,9 +54,14 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + dtypes: Set[torch.dtype] = { + float8_linear.config.cast_config_weight.dtype + for float8_linear in float8_linears + } if not weights: return + (dtype,) = dtypes # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial @@ -69,7 +74,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # upcast to float64 to ensure same numeric between compile and eager origin_dtype = amax_tensor.dtype amax_tensor = amax_tensor.to(torch.float64) - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + scale_tensor = torch.finfo(dtype).max / amax_tensor # Replicate if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) @@ -134,6 +139,7 @@ def __new__( cls, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( @@ -153,10 +159,12 @@ def __init__( self, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._linear_mm_config = linear_mm_config + self._dtype = dtype # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step @@ -166,9 +174,10 @@ def __init__( def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._linear_mm_config + args[0]._tensor, args[0]._linear_mm_config, args[0]._dtype ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal mm_config @@ -176,6 +185,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -185,40 +199,42 @@ def unwrap(t): if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + torch.Tensor, + lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, dtype), + out, ) def __tensor_flatten__(self): + tensors = ["_tensor"] if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._linear_mm_config - else: - return ["_tensor"], self._linear_mm_config + tensors.append("_precomputed_scale") + return tensors, {"mm_config": self._linear_mm_config, "dtype": self._dtype} @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) else: float8_tensor = hp_tensor_to_float8_dynamic( self._tensor, - e4m3_dtype, + self._dtype, self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, @@ -268,6 +284,7 @@ def __new__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): return torch.Tensor._make_wrapper_subclass( @@ -290,6 +307,7 @@ def __init__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): self._tensor = tensor @@ -297,6 +315,7 @@ def __init__( self._amax_history_buffer = amax_history_buffer self._scale_buffer = scale_buffer self._linear_mm_config = linear_mm_config + self._dtype = dtype # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -312,9 +331,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): args[0]._amax_history_buffer, args[0]._scale_buffer, args[0]._linear_mm_config, + args[0]._dtype, args[0].is_amax_initialized, ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None amax_buffer: Optional[torch.Tensor] = None amax_history_buffer: Optional[torch.Tensor] = None scale_buffer: Optional[torch.Tensor] = None @@ -326,6 +347,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer @@ -354,6 +380,7 @@ def unwrap(t): amax_history_buffer, scale_buffer, mm_config, + dtype, is_amax_initialized, ), out, @@ -369,6 +396,7 @@ def __tensor_flatten__(self): ], { "mm_config": self._linear_mm_config, + "dtype": self._dtype, "is_amax_initialized": self.is_amax_initialized, }, ) @@ -381,11 +409,12 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): inner_tensors["_amax_history_buffer"], inner_tensors["_scale_buffer"], metadata["mm_config"], + metadata["dtype"], metadata["is_amax_initialized"], ) def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): # initialize if needed @@ -401,7 +430,7 @@ def fsdp_pre_all_gather(self, mesh): self._amax_history_buffer, self._scale_buffer, "max", # TODO(before land): read this from parent - e4m3_dtype, + self._dtype, self.is_amax_initialized, reduce_amax=True, ) @@ -410,7 +439,7 @@ def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_to_float8_delayed( self._tensor, self._scale_buffer, - e4m3_dtype, + self._dtype, self._amax_buffer, self._linear_mm_config, GemmInputRole.WEIGHT, @@ -447,6 +476,7 @@ def __new__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -466,19 +496,25 @@ def __init__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): self._tensor = tensor self._static_scale = static_scale self._linear_mm_config = linear_mm_config + self._dtype = dtype @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithStaticFloat8CastTensor( - args[0]._tensor, args[0]._static_scale, args[0]._linear_mm_config + args[0]._tensor, + args[0]._static_scale, + args[0]._linear_mm_config, + args[0]._dtype, ) static_scale: Optional[torch.Tensor] = None mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal static_scale @@ -489,6 +525,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -499,30 +540,35 @@ def unwrap(t): return out return pytree.tree_map_only( torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor(x, static_scale, mm_config), + lambda x: WeightWithStaticFloat8CastTensor( + x, static_scale, mm_config, dtype + ), out, ) def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], self._linear_mm_config + return ["_tensor", "_static_scale"], { + "mm_config": self._linear_mm_config, + "dtype": self._dtype, + } @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithStaticFloat8CastTensor( inner_tensors["_tensor"], inner_tensors["_static_scale"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], ) def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._static_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, )