From e511579ccec36961b585c929627a90f518a5652e Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 22 Nov 2024 09:49:51 +0000 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 4 +- torchao/float8/config.py | 40 +++++++++--- torchao/float8/float8_linear.py | 78 ++++++++++++------------ torchao/float8/float8_linear_utils.py | 18 ++++-- torchao/float8/float8_scaling_utils.py | 20 +++--- torchao/float8/float8_tensor.py | 5 +- torchao/float8/float8_tensor_parallel.py | 10 +-- torchao/float8/float8_utils.py | 10 +-- torchao/float8/fsdp_utils.py | 77 ++++++++++++++++------- 9 files changed, 161 insertions(+), 101 deletions(-) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 70d6673fca..bb5086703b 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -28,7 +28,7 @@ from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.config import CastConfig, ScalingType -from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic +from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -197,7 +197,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) + out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward() diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 0a6ebda658..2524bbf35a 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -62,6 +62,7 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + dtype: torch.dtype = torch.uint8 # dummy dtype to satisfy typing def short_str(self): return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" @@ -75,6 +76,10 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" + if self.scaling_type is not ScalingType.DISABLED: + assert ( + self.dtype.is_floating_point and self.dtype.itemsize == 1 + ), "must specify a 8-bit floating-point dtype" @dataclass(frozen=True) @@ -124,6 +129,12 @@ def __post_init__(self): self.e5m2_dtype = torch.float8_e5m2fnuz +# User defined type for using the individual F8 type based on config +type_config = Float8TypeConfig() +e4m3_dtype = type_config.e4m3_dtype +e5m2_dtype = type_config.e5m2_dtype + + @dataclass(frozen=True) class Float8GemmConfig: """ @@ -158,13 +169,13 @@ class Float8LinearConfig: # 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`. # # `input` - cast_config_input: CastConfig = CastConfig() + cast_config_input: CastConfig = CastConfig(dtype=e4m3_dtype) cast_config_input_for_grad_weight: Optional[CastConfig] = None # `weight` - cast_config_weight: CastConfig = CastConfig() + cast_config_weight: CastConfig = CastConfig(dtype=e4m3_dtype) cast_config_weight_for_grad_input: Optional[CastConfig] = None # `grad_output` - cast_config_grad_output: CastConfig = CastConfig() + cast_config_grad_output: CastConfig = CastConfig(dtype=e5m2_dtype) cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None # @@ -279,6 +290,15 @@ def __post_init__(self): is_disabled_1 == is_disabled_2 ), f"incompatible operand precision for {gemm_name}" + for cc1, cc2, operand_name in [ + (cc_i, cc_i_gw, "input"), + (cc_w, cc_w_gi, "weight"), + (cc_go, cc_go_gw, "grad_output"), + ]: + assert ( + cc1.dtype == cc2.dtype + ), f"{operand_name} must be cast to the same dtype in both the matmuls it's used in" + if self.use_fp8_all_gather_only: assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" @@ -315,9 +335,9 @@ def recipe_name_to_linear_config( elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e5m2_dtype) return Float8LinearConfig( cast_config_input=cc_i, @@ -339,12 +359,12 @@ def recipe_name_to_linear_config( # which is more amenable to fast kernels # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE, dtype=e4m3_dtype) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 299d007416..ae86d8d01a 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -14,9 +14,9 @@ from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDelayed, - NoopFwToFloat8E5M2BwDynamic, - NoopFwToFloat8E5M2BwStatic, + NoopFwToFloat8BwDelayed, + NoopFwToFloat8BwDynamic, + NoopFwToFloat8BwStatic, _maybe_initialize_amaxes_scales_for_float8_cast, get_maybe_axiswise_dim, hp_tensor_to_float8_delayed, @@ -31,8 +31,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - e4m3_dtype, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -135,7 +133,7 @@ def forward( else: input_maybe_fp8 = hp_tensor_to_float8_dynamic( input_hp, - e4m3_dtype, + c.cast_config_input.dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input.scaling_granularity, @@ -149,7 +147,7 @@ def forward( else: weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight.dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight.scaling_granularity, @@ -185,7 +183,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output.scaling_granularity, @@ -203,7 +201,7 @@ def backward(ctx, grad_output): # the entire tensor. weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight_for_grad_input.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, @@ -235,7 +233,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output_for_grad_weight.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, @@ -249,7 +247,7 @@ def backward(ctx, grad_output): else: input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, - e4m3_dtype, + c.cast_config_input_for_grad_weight.dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, @@ -354,11 +352,9 @@ def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len device = self.weight.device - # TODO(future PR): dtype values below don't have the other float8 - # flavors, fix it - default_input = torch.finfo(torch.float8_e4m3fn).max - default_weight = torch.finfo(torch.float8_e4m3fn).max - default_grad_output = torch.finfo(torch.float8_e5m2).max + default_input = torch.finfo(config.cast_config_input.dtype).max + default_weight = torch.finfo(config.cast_config_weight.dtype).max + default_grad_output = torch.finfo(config.cast_config_grad_output.dtype).max # Note: for now, create all the buffers if any are needed, to postpone # the work to make the scale and amax syncing and history calculation @@ -445,14 +441,14 @@ def cast_input_to_float8( self.fp8_amax_history_input, self.fp8_scale_input, scale_fn_name, - e4m3_dtype, + self.config.cast_config_input.dtype, is_amax_initialized, reduce_amax=True, ) input_fp8 = hp_tensor_to_float8_delayed( input, self.fp8_scale_input, - e4m3_dtype, + self.config.cast_config_input.dtype, self.fp8_amax_input, linear_mm_config=self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, @@ -460,14 +456,17 @@ def cast_input_to_float8( elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( input, - e4m3_dtype, + self.config.cast_config_input.dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC input_fp8 = hp_tensor_to_float8_static( - input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config + input, + self.fp8_static_scale_input, + self.config.cast_config_input.dtype, + self.linear_mm_config, ) return input_fp8 @@ -483,14 +482,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: self.fp8_amax_history_weight, self.fp8_scale_weight, scale_fn_name, - e4m3_dtype, + self.config.cast_config_weight.dtype, self.is_amax_initialized, reduce_amax=True, ) self.fp8_amax_weight.fill_(tensor_to_amax(weight)) return self.fp8_scale_weight elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, e4m3_dtype) + return tensor_to_scale(weight, self.config.cast_config_weight.dtype) else: assert self.scaling_type_weight is ScalingType.STATIC return self.fp8_static_scale_weight @@ -506,7 +505,7 @@ def cast_weight_to_float8_t( weight_fp8 = hp_tensor_and_scale_to_float8( weight, weight_scale, - e4m3_dtype, + self.config.cast_config_weight.dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) @@ -521,7 +520,7 @@ def cast_weight_to_original_t(self, weight: torch.Tensor): def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8E5M2BwDelayed.apply( + output = NoopFwToFloat8BwDelayed.apply( output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, @@ -529,15 +528,17 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: scale_fn_name, self.is_amax_initialized, self.linear_mm_config, + self.config.cast_config_grad_output.dtype, ) elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) + output = NoopFwToFloat8BwDynamic.apply(output, self.linear_mm_config, self.config.cast_config_grad_output.dtype) else: assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8E5M2BwStatic.apply( + output = NoopFwToFloat8BwStatic.apply( output, self.fp8_static_scale_grad_output, self.linear_mm_config, + self.config.cast_config_grad_output.dtype, ) return output @@ -563,19 +564,15 @@ def float8_post_forward(self): self.amax_and_scale_synced = False def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: - has_any_axiswise_scaling = ( - self.config.cast_config_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_input_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE + has_any_axiswise_scaling = any( + cc.scaling_granularity is ScalingGranularity.AXISWISE for cc in [ + self.config.cast_config_input, + self.config.cast_config_weight, + self.config.cast_config_grad_output, + self.config.cast_config_input_for_grad_weight, + self.config.cast_config_weight_for_grad_input, + self.config.cast_config_grad_output_for_grad_weight, + ] ) if not has_any_axiswise_scaling: @@ -698,6 +695,7 @@ def from_float( WeightWithDynamicFloat8CastTensor( new_mod.weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.dtype, ) ) elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: @@ -708,6 +706,7 @@ def from_float( new_mod.fp8_amax_history_weight, new_mod.fp8_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.dtype, new_mod.is_amax_initialized, ) ) @@ -718,6 +717,7 @@ def from_float( new_mod.weight, new_mod.fp8_static_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.dtype, ) ) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index d99e2a73e6..c157e669de 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -15,8 +15,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_utils import ( amax_history_to_scale_stack, - e4m3_dtype, - e5m2_dtype, ) log = logging.getLogger(__name__) @@ -224,6 +222,9 @@ def inner_func(): fp8_weight_amax_history_stack = [None] * len(fp8_layers) fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) + input_dtypes = set() + weight_dtypes = set() + grad_output_dtypes = set() x_dtypes = set() scale_fn_recipes = set() @@ -236,9 +237,16 @@ def inner_func(): fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output + input_dtypes.add(child.config.cast_config_input.dtype) + weight_dtypes.add(child.config.cast_config_weight.dtype) + grad_output_dtypes.add(child.config.cast_config_grad_output.dtype) x_dtypes.add(child.last_seen_input_dtype) scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) + input_dtype, = input_dtypes + weight_dtype, = weight_dtypes + grad_output_dtype, = grad_output_dtypes + # TODO This way to get the activation dtype is not ideal if len(x_dtypes) != 1: raise ValueError( @@ -303,13 +311,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe ) new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + fp8_weight_amax_history_stack, weight_dtype, x_dtype, scale_fn_recipe ) new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe + fp8_grad_output_amax_history_stack, grad_output_dtype, x_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index fa5eff733f..b4acce40ac 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -184,7 +184,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): +class NoopFwToFloat8BwDelayed(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with delayed scaling, initialize if needed @@ -200,6 +200,7 @@ def forward( scale_fn_name, is_amax_initialized, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): ctx.save_for_backward( fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output @@ -207,6 +208,7 @@ def forward( ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.linear_mm_config = linear_mm_config + ctx.dtype return tensor @staticmethod @@ -225,7 +227,7 @@ def backward(ctx, go): fp8_amax_history_grad_output, fp8_scale_grad_output, scale_fn_name, - e5m2_dtype, + ctx.dtype, is_amax_initialized, reduce_amax=True, ) @@ -235,7 +237,7 @@ def backward(ctx, go): res = hp_tensor_and_scale_to_float8( go, fp8_scale_grad_output, - e5m2_dtype, + ctx.dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) @@ -244,7 +246,7 @@ def backward(ctx, go): @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): +class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with dynamic scaling @@ -255,8 +257,10 @@ def forward( ctx, tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): ctx.linear_mm_config = linear_mm_config + ctx.dtype = dtype return tensor @staticmethod @@ -267,7 +271,7 @@ def backward(ctx, gradY): fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) @@ -275,7 +279,7 @@ def backward(ctx, gradY): @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwStatic(torch.autograd.Function): +class NoopFwToFloat8BwStatic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with static scaling @@ -287,9 +291,11 @@ def forward( tensor, scale, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): ctx.save_for_backward(scale) ctx.linear_mm_config = linear_mm_config + ctx.dtype = dtype return tensor @staticmethod @@ -300,7 +306,7 @@ def backward(ctx, gradY): fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 21de057fd5..1aed6cebdc 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -11,7 +11,6 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( - e4m3_dtype, to_fp8_saturated, ) @@ -149,7 +148,7 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, @@ -229,7 +228,7 @@ def backward(ctx, g): def hp_tensor_and_scale_to_float8( hp_tensor: torch.Tensor, s: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a3fc4ce7e5..6d42781608 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -10,7 +10,7 @@ from torchao.float8.config import ScalingType from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDynamic, + NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole @@ -49,7 +49,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.cast_config_input.dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -70,7 +70,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply(outputs, mod.linear_mm_config, mod.cast_config_grad_output.dtype) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -103,7 +103,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.cast_config_input.dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -123,7 +123,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply(outputs, mod.linear_mm_config, mod.cast_config_grad_output.dtype) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 72cf5ad971..fc64c427f5 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist -from torchao.float8.config import Float8TypeConfig, ScalingGranularity +from torchao.float8.config import ScalingGranularity, type_config, e4m3_dtype, e5m2_dtype # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -27,12 +27,6 @@ } -# User defined type for using the individual F8 type based on config -type_config = Float8TypeConfig() -e4m3_dtype = type_config.e4m3_dtype -e5m2_dtype = type_config.e5m2_dtype - - @torch.no_grad() def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype @@ -180,7 +174,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def fp8_tensor_statistics( - tensor: torch.Tensor, float8_dtype=e4m3_dtype + tensor: torch.Tensor, float8_dtype: torch.dtype ) -> Tuple[int, ...]: """Calculate FP8 tensor stats diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 8c60995a86..84d1a8bcb6 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -22,7 +22,7 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_utils import EPS, e4m3_dtype +from torchao.float8.float8_utils import EPS @torch.no_grad() @@ -54,9 +54,11 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + dtypes: Set[torch.dtype] = {float8_linear.config.cast_config_weight.dtype for float8_linear in float8_linears} if not weights: return + dtype, = dtypes # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial @@ -69,7 +71,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # upcast to float64 to ensure same numeric between compile and eager origin_dtype = amax_tensor.dtype amax_tensor = amax_tensor.to(torch.float64) - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + scale_tensor = torch.finfo(dtype).max / amax_tensor # Replicate if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) @@ -134,6 +136,7 @@ def __new__( cls, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( @@ -153,10 +156,12 @@ def __init__( self, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._linear_mm_config = linear_mm_config + self._dtype = dtype # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step @@ -166,9 +171,10 @@ def __init__( def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._linear_mm_config + args[0]._tensor, args[0]._linear_mm_config, args[0]._dtype ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal mm_config @@ -176,6 +182,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if mm_config is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -185,40 +196,40 @@ def unwrap(t): if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, dtype), out ) def __tensor_flatten__(self): + tensors = ["_tensor"] if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._linear_mm_config - else: - return ["_tensor"], self._linear_mm_config + tensors.append("_precomputed_scale") + return tensors, {"mm_config": self._linear_mm_config, "dtype": self._dtype} @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) else: float8_tensor = hp_tensor_to_float8_dynamic( self._tensor, - e4m3_dtype, + self._dtype, self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, @@ -268,6 +279,7 @@ def __new__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): return torch.Tensor._make_wrapper_subclass( @@ -290,6 +302,7 @@ def __init__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): self._tensor = tensor @@ -297,6 +310,7 @@ def __init__( self._amax_history_buffer = amax_history_buffer self._scale_buffer = scale_buffer self._linear_mm_config = linear_mm_config + self._dtype = dtype # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -312,9 +326,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): args[0]._amax_history_buffer, args[0]._scale_buffer, args[0]._linear_mm_config, + args[0]._dtype, args[0].is_amax_initialized, ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None amax_buffer: Optional[torch.Tensor] = None amax_history_buffer: Optional[torch.Tensor] = None scale_buffer: Optional[torch.Tensor] = None @@ -326,6 +342,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer @@ -354,6 +375,7 @@ def unwrap(t): amax_history_buffer, scale_buffer, mm_config, + dtype, is_amax_initialized, ), out, @@ -369,6 +391,7 @@ def __tensor_flatten__(self): ], { "mm_config": self._linear_mm_config, + "dtype": self._dtype, "is_amax_initialized": self.is_amax_initialized, }, ) @@ -381,11 +404,12 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): inner_tensors["_amax_history_buffer"], inner_tensors["_scale_buffer"], metadata["mm_config"], + metadata["dtype"], metadata["is_amax_initialized"], ) def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): # initialize if needed @@ -401,7 +425,7 @@ def fsdp_pre_all_gather(self, mesh): self._amax_history_buffer, self._scale_buffer, "max", # TODO(before land): read this from parent - e4m3_dtype, + self._dtype, self.is_amax_initialized, reduce_amax=True, ) @@ -410,7 +434,7 @@ def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_to_float8_delayed( self._tensor, self._scale_buffer, - e4m3_dtype, + self._dtype, self._amax_buffer, self._linear_mm_config, GemmInputRole.WEIGHT, @@ -447,6 +471,7 @@ def __new__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -466,19 +491,22 @@ def __init__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): self._tensor = tensor self._static_scale = static_scale self._linear_mm_config = linear_mm_config + self._dtype = dtype @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithStaticFloat8CastTensor( - args[0]._tensor, args[0]._static_scale, args[0]._linear_mm_config + args[0]._tensor, args[0]._static_scale, args[0]._linear_mm_config, args[0]._dtype ) static_scale: Optional[torch.Tensor] = None mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal static_scale @@ -489,6 +517,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -499,30 +532,30 @@ def unwrap(t): return out return pytree.tree_map_only( torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor(x, static_scale, mm_config), + lambda x: WeightWithStaticFloat8CastTensor(x, static_scale, mm_config, dtype), out, ) def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], self._linear_mm_config + return ["_tensor", "_static_scale"], {"mm_config": self._linear_mm_config, "dtype": self._dtype} @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithStaticFloat8CastTensor( inner_tensors["_tensor"], inner_tensors["_static_scale"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], ) def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._static_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) From 51acb5b1ef6f64f743e71a6038fe07a6d3022c54 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 22 Nov 2024 09:59:03 +0000 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- torchao/float8/config.py | 2 +- torchao/float8/float8_linear.py | 6 +++--- torchao/float8/float8_utils.py | 2 +- torchao/float8/fsdp_utils.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 2524bbf35a..cfefc9d01e 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -297,7 +297,7 @@ def __post_init__(self): ]: assert ( cc1.dtype == cc2.dtype - ), f"{operand_name} must be cast to the same dtype in both the matmuls it's used in" + ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" if self.use_fp8_all_gather_only: assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index ae86d8d01a..fa75910f87 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -352,9 +352,9 @@ def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len device = self.weight.device - default_input = torch.finfo(config.cast_config_input.dtype).max - default_weight = torch.finfo(config.cast_config_weight.dtype).max - default_grad_output = torch.finfo(config.cast_config_grad_output.dtype).max + default_input = torch.finfo(self.config.cast_config_input.dtype).max + default_weight = torch.finfo(self.config.cast_config_weight.dtype).max + default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max # Note: for now, create all the buffers if any are needed, to postpone # the work to make the scale and amax syncing and history calculation diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index fc64c427f5..06735c30d4 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist -from torchao.float8.config import ScalingGranularity, type_config, e4m3_dtype, e5m2_dtype +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 84d1a8bcb6..f881067394 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple import torch import torch.nn as nn From 97b9cf83d5ab38394c4736f42d10dbf954a7ff92 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 22 Nov 2024 10:22:04 +0000 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- test/float8/test_base.py | 4 ++-- test/float8/test_compile.py | 2 +- test/float8/test_dtensor.py | 4 ++-- torchao/float8/float8_scaling_utils.py | 3 +-- torchao/float8/float8_tensor_parallel.py | 3 +-- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index a0ea96baae..0c55c9c38a 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -25,6 +25,8 @@ from torchao.float8.config import ( CastConfig, + e4m3_dtype, + e5m2_dtype, Float8LinearConfig, Float8LinearRecipeName, recipe_name_to_linear_config, @@ -51,8 +53,6 @@ ) from torchao.float8.float8_utils import ( compute_error, - e4m3_dtype, - e5m2_dtype, fp8_tensor_statistics, FP8_TYPES, tensor_to_scale, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 6fecffa8f7..ce9935ca79 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -21,6 +21,7 @@ import torch.nn as nn from torchao.float8.config import ( CastConfig, + e4m3_dtype, Float8LinearConfig, ScalingType, Float8LinearRecipeName, @@ -41,7 +42,6 @@ GemmInputRole, ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config from torch._dynamo.test_case import TestCase as DynamoTestCase diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index bb5086703b..92143e62b3 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -27,7 +27,7 @@ from torchao.float8 import Float8LinearConfig from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.config import CastConfig, ScalingType +from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, @@ -40,7 +40,7 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) -from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale +from torchao.float8.float8_utils import tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index b4acce40ac..8635f0d174 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -22,7 +22,6 @@ ) from torchao.float8.float8_utils import ( amax_history_to_scale, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -267,7 +266,7 @@ def forward( def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): return gradY, None - gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + gradY_scale = tensor_to_scale(gradY, ctx.dtype) fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 81c4d51da0..145fe3e5bc 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -8,13 +8,12 @@ RowwiseParallel, ) -from torchao.float8.config import ScalingType +from torchao.float8.config import e4m3_dtype, ScalingType from torchao.float8.float8_scaling_utils import ( NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole -from torchao.float8.float8_utils import e4m3_dtype # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support From 810ad91e834032cc3bf5b0f4b119c9bd5d14450c Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 22 Nov 2024 10:37:22 +0000 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- torchao/float8/config.py | 51 ++++++++++-------------- torchao/float8/float8_tensor_parallel.py | 2 +- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b411b7c618..eab4c6d187 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -62,7 +62,7 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None - dtype: torch.dtype = torch.uint8 # dummy dtype to satisfy typing + dtype: Optional[torch.dtype] = None def short_str(self): return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" @@ -76,10 +76,9 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" - if self.scaling_type is not ScalingType.DISABLED: - assert ( - self.dtype.is_floating_point and self.dtype.itemsize == 1 - ), "must specify a 8-bit floating-point dtype" + assert ( + self.dtype is None or (self.dtype.is_floating_point and self.dtype.itemsize == 1) + ), "must specify a 8-bit floating-point dtype" @dataclass(frozen=True) @@ -169,13 +168,13 @@ class Float8LinearConfig: # 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`. # # `input` - cast_config_input: CastConfig = CastConfig(dtype=e4m3_dtype) + cast_config_input: CastConfig = CastConfig() cast_config_input_for_grad_weight: Optional[CastConfig] = None # `weight` - cast_config_weight: CastConfig = CastConfig(dtype=e4m3_dtype) + cast_config_weight: CastConfig = CastConfig() cast_config_weight_for_grad_input: Optional[CastConfig] = None # `grad_output` - cast_config_grad_output: CastConfig = CastConfig(dtype=e5m2_dtype) + cast_config_grad_output: CastConfig = CastConfig() cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None # @@ -290,11 +289,15 @@ def __post_init__(self): is_disabled_1 == is_disabled_2 ), f"incompatible operand precision for {gemm_name}" - for cc1, cc2, operand_name in [ - (cc_i, cc_i_gw, "input"), - (cc_w, cc_w_gi, "weight"), - (cc_go, cc_go_gw, "grad_output"), + for cc1, cc2, operand_name, default_dtype in [ + (cc_i, cc_i_gw, "input", e4m3_dtype), + (cc_w, cc_w_gi, "weight", e4m3_dtype), + (cc_go, cc_go_gw, "grad_output", e5m2_dtype), ]: + if cc1.dtype is None: + cc1.dtype = default_dtype + if cc2.dtype is None: + cc2.dtype = default_dtype assert ( cc1.dtype == cc2.dtype ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" @@ -335,15 +338,9 @@ def recipe_name_to_linear_config( elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype - ) - cc_w = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype - ) - cc_go = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, dtype=e5m2_dtype - ) + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) return Float8LinearConfig( cast_config_input=cc_i, @@ -365,20 +362,14 @@ def recipe_name_to_linear_config( # which is more amenable to fast kernels # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 - cc_i = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype - ) - cc_w = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype - ) + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise cc_go = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype ) - cc_w_gi = CastConfig( - scaling_granularity=ScalingGranularity.TENSORWISE, dtype=e4m3_dtype - ) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 145fe3e5bc..a46641f348 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -8,7 +8,7 @@ RowwiseParallel, ) -from torchao.float8.config import e4m3_dtype, ScalingType +from torchao.float8.config import ScalingType, e4m3_dtype from torchao.float8.float8_scaling_utils import ( NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, From b9672f55eedf0ad65d55289690ee9ade4381fe7f Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 22 Nov 2024 13:19:31 +0000 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- torchao/float8/config.py | 11 ++++++----- torchao/float8/float8_scaling_utils.py | 12 ++++++------ torchao/float8/float8_tensor_parallel.py | 8 ++++---- torchao/float8/fsdp_utils.py | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index eab4c6d187..1011e93524 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -76,8 +76,8 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" - assert ( - self.dtype is None or (self.dtype.is_floating_point and self.dtype.itemsize == 1) + assert self.dtype is None or ( + self.dtype.is_floating_point and self.dtype.itemsize == 1 ), "must specify a 8-bit floating-point dtype" @@ -294,10 +294,11 @@ def __post_init__(self): (cc_w, cc_w_gi, "weight", e4m3_dtype), (cc_go, cc_go_gw, "grad_output", e5m2_dtype), ]: + # Override the dataclass being frozen if cc1.dtype is None: - cc1.dtype = default_dtype + object.__setattr__(cc1, "dtype", default_dtype) if cc2.dtype is None: - cc2.dtype = default_dtype + object.__setattr__(cc2, "dtype", default_dtype) assert ( cc1.dtype == cc2.dtype ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" @@ -373,7 +374,7 @@ def recipe_name_to_linear_config( # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) - cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype) return Float8LinearConfig( cast_config_input=cc_i, diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 8635f0d174..2da7c6028b 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -207,7 +207,7 @@ def forward( ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.linear_mm_config = linear_mm_config - ctx.dtype + ctx.dtype = dtype return tensor @staticmethod @@ -240,7 +240,7 @@ def backward(ctx, go): ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - empty_grads = None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None return res, *empty_grads @@ -265,7 +265,7 @@ def forward( @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None + return gradY, None, None gradY_scale = tensor_to_scale(gradY, ctx.dtype) fp8_tensor = hp_tensor_and_scale_to_float8( gradY, @@ -274,7 +274,7 @@ def backward(ctx, gradY): ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None + return fp8_tensor, None, None @torch._dynamo.allow_in_graph @@ -300,7 +300,7 @@ def forward( @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None + return gradY, None, None, None (gradY_scale,) = ctx.saved_tensors fp8_tensor = hp_tensor_and_scale_to_float8( gradY, @@ -309,4 +309,4 @@ def backward(ctx, gradY): ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None, None + return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a46641f348..814bd2869c 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -48,7 +48,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - mod.cast_config_input.dtype, + mod.config.cast_config_input.dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -70,7 +70,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # fwd noop bwd cast to DTensor(Float8Tensor) outputs = NoopFwToFloat8BwDynamic.apply( - outputs, mod.linear_mm_config, mod.cast_config_grad_output.dtype + outputs, mod.linear_mm_config, mod.config.cast_config_grad_output.dtype ) # back to local tensor @@ -104,7 +104,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - mod.cast_config_input.dtype, + mod.config.cast_config_input.dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -125,7 +125,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # fwd noop bwd cast to DTensor(Float8Tensor) outputs = NoopFwToFloat8BwDynamic.apply( - outputs, mod.linear_mm_config, mod.cast_config_grad_output.dtype + outputs, mod.linear_mm_config, mod.config.cast_config_grad_output.dtype ) # back to local tensor if use_local_output is True diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 6d353bca5c..62c0741b8a 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -186,7 +186,7 @@ def unwrap(t): else: assert t._linear_mm_config == mm_config nonlocal dtype - if mm_config is None: + if dtype is None: dtype = t._dtype else: assert t._dtype == dtype