Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
make all 3 gemms in Float8Linear support configurability, not user fa…
Browse files Browse the repository at this point in the history
…cing (#315)

Summary:
Pull Request resolved: #315

This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable:
1. add `LinearMMConfig` to `Float8Tensor` to tie together the three `ScaledMMConfig` objects, one per gemm
2. add `GemmInputRole` to `Float8Tensor` to specify how to pick the right config
3. plumb all of these throughout the codebase

Note that none of this is user facing, and there is no logic change.  Planned follow-ups:
* a future PR will make the per-gemm behavior configurable in a user facing way, which will hook up to the objects introduced in this PR
* a future PR will update the naming from x/w/dL_dY to input/weight/grad_output throughout the codebase

Reviewed By: drisspg

Differential Revision: D59973551

fbshipit-source-id: c667245449628b377e9bb20dda6a76fbf8a5ef3c
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 19, 2024
1 parent 7f0d6bb commit c58fb5d
Show file tree
Hide file tree
Showing 11 changed files with 430 additions and 185 deletions.
9 changes: 7 additions & 2 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
30 changes: 22 additions & 8 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
GemmInputRole,
LinearMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -26,9 +27,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
def forward(
ctx,
tensor,
mm_config: ScaledMMConfig,
linear_mm_config: LinearMMConfig,
):
ctx.mm_config = mm_config
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
Expand All @@ -37,21 +38,34 @@ def backward(ctx, gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
gradY,
gradY_scale,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
)
return fp8_tensor, None


def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.X,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)


def cast_to_float8_e5m2_dynamic_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
gradY: torch.Tensor, linear_mm_config: LinearMMConfig
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config)
50 changes: 33 additions & 17 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
to_fp8_no_autograd,
)
Expand Down Expand Up @@ -85,12 +87,12 @@ def forward(
fp8_scale_dL_dY,
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
linear_mm_config: LinearMMConfig,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
Expand All @@ -113,7 +115,11 @@ def backward(ctx, go):
fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs):

self.create_buffers()

# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
self.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
# TODO(future): user level configuration of gemms
self.linear_mm_config = LinearMMConfig(
# x
ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
),
# w
ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
),
# dL_dY
ScaledMMConfig(emulate, False, False, config.pad_inner_dim),
)

# Note: is_amax_initialized is not a buffer to avoid data dependent
Expand Down Expand Up @@ -308,11 +320,12 @@ def cast_x_to_float8(
self.fp8_scale_x,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
return x_fp8

def cast_w_to_float8(
Expand All @@ -339,14 +352,17 @@ def cast_w_to_float8(
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
)
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.linear_mm_config,
)
else:
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
return y

def float8_pre_forward(self, x):
Expand Down Expand Up @@ -457,7 +473,7 @@ def from_float(
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.forward_config,
new_mod.linear_mm_config,
)
)
else:
Expand All @@ -468,7 +484,7 @@ def from_float(
new_mod.fp8_amax_w,
new_mod.fp8_amax_history_w,
new_mod.fp8_scale_w,
new_mod.forward_config,
new_mod.linear_mm_config,
new_mod.is_amax_initialized,
)
)
Expand Down
Loading

0 comments on commit c58fb5d

Please sign in to comment.