Skip to content

Commit

Permalink
[cleanup][2/x] split float8 mm by delayed vs dynamic
Browse files Browse the repository at this point in the history
Summary:

Before this PR, the float8 mm logic was split by axiswise vs tensorwise.
After this PR, the float8 mm logic is split by dynamic vs non-dynamic scaling.

Motivation: there is more and more evidence that dynamic scaling will
be common to the most important lowp recipes. This PR is a step on the way
to making the dynamic scaling logic be simpler and easier to understand
in `torchao.float8`.

There are a lot of other simplifications to do, but stopping here to
keep the PR small.  This is a pure refactor without any logic changes.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: aa6d0ce2723c46983865909082482bba764c7084
ghstack-comment-id: 2564049757
Pull Request resolved: #1461
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent a2dcbf2 commit 31515ab
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 141 deletions.
270 changes: 132 additions & 138 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,77 +29,86 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


@torch._dynamo.allow_in_graph
class manual_float8_matmul_with_args_in_float8(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in float8
Note: this function requires all arguments to already be Float8Tensor objects,
which only supports tensorwise scaling granularity. The reason we didn't just make this
function support axiswise scaling granularity is because that would need very
careful testing of delayed scaling, as delayed scaling modifies buffers inplace.
In the future we'll probably have to unify, just postponing that until a future PR.
"""

@staticmethod
def forward(
ctx,
input_fp8,
weight_fp8_t,
):
ctx.save_for_backward(input_fp8, weight_fp8_t)
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, grad_output_fp8):
input_fp8, weight_fp8_t = ctx.saved_tensors

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
grad_output_fp8_orig_shape = grad_output_fp8.shape
grad_output_fp8_reshaped = grad_output_fp8.reshape(
-1, grad_output_fp8_orig_shape[-1]
)

# calculate grad_input
grad_input = torch.mm(
grad_output_fp8_reshaped,
weight_fp8_t.t(),
)
grad_input = grad_input.reshape(
*grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
def _cast_input_to_float8(
input: torch.Tensor,
scaling_type_input: ScalingType,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
config.cast_config_input.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)

input_fp8_orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])

# calculate grad_weight
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
grad_weight = torch.mm(
grad_output_fp8_reshaped.t(),
input_fp8_reshaped,
)

return grad_input, grad_weight.t()
return input_fp8


def _get_weight_scale(
weight: torch.Tensor,
scaling_type_weight: ScalingType,
config: Float8LinearConfig,
) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
assert scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()


def _cast_output_to_float8_in_bw(
output: torch.Tensor,
scaling_type_grad_output,
linear_mm_config: LinearMMConfig,
config: Float8LinearConfig,
) -> torch.Tensor:
assert scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
output,
linear_mm_config,
config.cast_config_grad_output.target_dtype,
)
return output


@torch._dynamo.allow_in_graph
class manual_float8_matmul_with_args_in_hp(torch.autograd.Function):
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in high precision and the cast to float8
defined inside of this function.
Like torch.matmul, but with the arguments in either high precision or float8.
* if the arguments are in high precision, they are cast to float8 according
to the specified config
* if the arguments are in float8, we assume the cast honored the config
Note: this function currently only supports dynamic scaling type and
axiswise granularity. We will have to unify this with other scaling types
and other granularities in a separate PR.
Only supports dynamic scaling, does not support delayed/static scaling.
"""

@staticmethod
Expand All @@ -116,7 +125,9 @@ def forward(

c = config

if c.cast_config_input.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(input_hp):
input_maybe_fp8 = input_hp
elif c.cast_config_input.scaling_type is ScalingType.DISABLED:
input_maybe_fp8 = input_hp
else:
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
Expand All @@ -130,7 +141,9 @@ def forward(
),
)

if c.cast_config_weight.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(weight_hp_t):
weight_maybe_fp8_t = weight_hp_t
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
Expand Down Expand Up @@ -166,7 +179,10 @@ def backward(ctx, grad_output):
# calculate grad_input
#

if c.cast_config_grad_output.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(grad_output_reshaped):
# TODO(future PR): this var name is axiswise-specific, fix it
grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped
elif c.cast_config_grad_output.scaling_type is ScalingType.DISABLED:
grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped
else:
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
Expand All @@ -180,7 +196,10 @@ def backward(ctx, grad_output):
),
)

if c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(weight_hp_t):
# TODO(future PR): var name is axiswise specific, fix it
weight_t_maybe_fp8_dim0 = weight_hp_t
elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
weight_t_maybe_fp8_dim0 = weight_hp_t
else:
# Note: we need https://github.com/pytorch/pytorch/issues/136267
Expand Down Expand Up @@ -213,7 +232,10 @@ def backward(ctx, grad_output):
# calculate grad_weight
#

if (
if tensor_already_casted_to_fp8(grad_output_reshaped):
# TODO(future PR): var name is axiswise specific, fix it
grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped
elif (
c.cast_config_grad_output_for_grad_weight.scaling_type
is ScalingType.DISABLED
):
Expand All @@ -230,7 +252,10 @@ def backward(ctx, grad_output):
),
)

if c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED:
if tensor_already_casted_to_fp8(input_hp_reshaped):
# TODO(future PR): var name is axiswise specific, fix it
input_reshaped_maybe_fp8_dim1 = input_hp_reshaped
elif c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED:
input_reshaped_maybe_fp8_dim1 = input_hp_reshaped
else:
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
Expand Down Expand Up @@ -303,58 +328,6 @@ def __init__(self, *args, **kwargs):
),
)

def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
assert self.scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype)

def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
self.config.cast_config_weight.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.target_dtype,
)
return output

def forward(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
Expand All @@ -368,34 +341,55 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
]
)

input_maybe_fp8 = input
weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
input_fp8 = self.cast_input_to_float8(input)
input_fp8 = _cast_input_to_float8(
input,
self.scaling_type_input,
self.config,
self.linear_mm_config,
)
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = self.get_weight_scale(self.weight)
weight_scale = _get_weight_scale(
self.weight, self.scaling_type_weight, self.config
)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
self.cast_weight_to_float8_t,
_cast_weight_to_float8_t,
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)
else:
weight_fp8_t = self.cast_weight_to_float8_t(self.weight, weight_scale)
weight_fp8_t = _cast_weight_to_float8_t(
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)

output = manual_float8_matmul_with_args_in_float8.apply(
input_fp8, weight_fp8_t
)
input_maybe_fp8 = input_fp8
weight_maybe_fp8_t = weight_fp8_t

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
output = matmul_with_hp_or_float8_args.apply(
input_maybe_fp8,
weight_maybe_fp8_t,
self.linear_mm_config,
self.config,
)

else:
# for now, axiswise path is separate
# TODO(future PR): unify to support mix and match
output = manual_float8_matmul_with_args_in_hp.apply(
input,
self.weight.t(),
if not has_any_axiswise_scaling:
# Cast grad_output to float8_e5m2 during backward
output = _cast_output_to_float8_in_bw(
output,
self.scaling_type_grad_output,
self.linear_mm_config,
self.config,
)
Expand Down
Loading

0 comments on commit 31515ab

Please sign in to comment.