Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] Allow specifying arbitrary dtype for each tensor #1326

Draft
wants to merge 7 commits into
base: gh/lw/2/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
ScalingGranularity,
Expand Down Expand Up @@ -53,8 +55,6 @@
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
tensor_to_scale,
)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
ScalingType,
Expand All @@ -47,7 +48,6 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config


Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -45,7 +45,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.float8_utils import tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel

Expand Down Expand Up @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()
Expand Down
30 changes: 28 additions & 2 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  1. can we add a comment on what this is used for, and that None means the default e4m3|e5m2 value will be used?
  2. optional - thoughts about naming this in a more specific way such as target_dtype, lowp_dtype, etc? dtype is a bit ambiguous across torchao unfortunately :(


def short_str(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add the dtype here, so it appears when we print an instance of Float8Linear? Float8Linear.__extra_repr__ calls this method.

return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
Expand All @@ -75,6 +76,9 @@ def __post_init__(self):
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"
assert self.dtype is None or (
self.dtype.is_floating_point and self.dtype.itemsize == 1
), "must specify a 8-bit floating-point dtype"


@dataclass(frozen=True)
Expand Down Expand Up @@ -124,6 +128,12 @@ def __post_init__(self):
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand Down Expand Up @@ -276,6 +286,20 @@ def __post_init__(self):
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

for cc1, cc2, operand_name, default_dtype in [
(cc_i, cc_i_gw, "input", e4m3_dtype),
(cc_w, cc_w_gi, "weight", e4m3_dtype),
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
]:
# Override the dataclass being frozen
if cc1.dtype is None:
object.__setattr__(cc1, "dtype", default_dtype)
if cc2.dtype is None:
object.__setattr__(cc2, "dtype", default_dtype)
assert (
cc1.dtype == cc2.dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

Expand Down Expand Up @@ -340,12 +364,14 @@ def recipe_name_to_linear_config(
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads?

)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype)

return Float8LinearConfig(
cast_config_input=cc_i,
Expand Down
83 changes: 44 additions & 39 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
NoopFwToFloat8E5M2BwStatic,
NoopFwToFloat8BwDelayed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating these!

NoopFwToFloat8BwDynamic,
NoopFwToFloat8BwStatic,
_maybe_initialize_amaxes_scales_for_float8_cast,
get_maybe_axiswise_dim,
hp_tensor_to_float8_delayed,
Expand All @@ -32,8 +32,6 @@
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
)
Expand Down Expand Up @@ -136,7 +134,7 @@ def forward(
else:
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
input_hp,
e4m3_dtype,
c.cast_config_input.dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=c.cast_config_input.scaling_granularity,
Expand All @@ -150,7 +148,7 @@ def forward(
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
e4m3_dtype,
c.cast_config_weight.dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
Expand Down Expand Up @@ -186,7 +184,7 @@ def backward(ctx, grad_output):
else:
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
grad_output_reshaped,
e5m2_dtype,
c.cast_config_grad_output.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=c.cast_config_grad_output.scaling_granularity,
Expand All @@ -204,7 +202,7 @@ def backward(ctx, grad_output):
# the entire tensor.
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
weight_hp_t,
e4m3_dtype,
c.cast_config_weight_for_grad_input.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity,
Expand Down Expand Up @@ -236,7 +234,7 @@ def backward(ctx, grad_output):
else:
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
grad_output_reshaped,
e5m2_dtype,
c.cast_config_grad_output_for_grad_weight.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity,
Expand All @@ -250,7 +248,7 @@ def backward(ctx, grad_output):
else:
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
input_hp_reshaped,
e4m3_dtype,
c.cast_config_input_for_grad_weight.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity,
Expand Down Expand Up @@ -347,11 +345,9 @@ def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
device = self.weight.device
# TODO(future PR): dtype values below don't have the other float8
# flavors, fix it
default_input = torch.finfo(torch.float8_e4m3fn).max
default_weight = torch.finfo(torch.float8_e4m3fn).max
default_grad_output = torch.finfo(torch.float8_e5m2).max
default_input = torch.finfo(self.config.cast_config_input.dtype).max
default_weight = torch.finfo(self.config.cast_config_weight.dtype).max
default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max

# Note: for now, create all the buffers if any are needed, to postpone
# the work to make the scale and amax syncing and history calculation
Expand Down Expand Up @@ -438,29 +434,32 @@ def cast_input_to_float8(
self.fp8_amax_history_input,
self.fp8_scale_input,
scale_fn_name,
e4m3_dtype,
self.config.cast_config_input.dtype,
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = hp_tensor_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
self.config.cast_config_input.dtype,
self.fp8_amax_input,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
elif self.scaling_type_input is ScalingType.DYNAMIC:
input_fp8 = hp_tensor_to_float8_dynamic(
input,
e4m3_dtype,
self.config.cast_config_input.dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is ScalingType.STATIC
input_fp8 = hp_tensor_to_float8_static(
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
input,
self.fp8_static_scale_input,
self.config.cast_config_input.dtype,
self.linear_mm_config,
)

return input_fp8
Expand All @@ -476,14 +475,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
self.config.cast_config_weight.dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
return tensor_to_scale(weight, e4m3_dtype)
return tensor_to_scale(weight, self.config.cast_config_weight.dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
return self.fp8_static_scale_weight
Expand All @@ -499,7 +498,7 @@ def cast_weight_to_float8_t(
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
e4m3_dtype,
self.config.cast_config_weight.dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand All @@ -514,23 +513,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2BwDelayed.apply(
output = NoopFwToFloat8BwDelayed.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
self.fp8_scale_grad_output,
scale_fn_name,
self.is_amax_initialized,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
output = NoopFwToFloat8BwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8E5M2BwStatic.apply(
output = NoopFwToFloat8BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
return output

Expand All @@ -547,19 +552,16 @@ def float8_post_forward(self):
return

def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = (
self.config.cast_config_input.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_weight.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_grad_output.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_input_for_grad_weight.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_weight_for_grad_input.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity
is ScalingGranularity.AXISWISE
has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
self.config.cast_config_input,
self.config.cast_config_weight,
self.config.cast_config_grad_output,
self.config.cast_config_input_for_grad_weight,
self.config.cast_config_weight_for_grad_input,
self.config.cast_config_grad_output_for_grad_weight,
]
)

if not has_any_axiswise_scaling:
Expand Down Expand Up @@ -682,6 +684,7 @@ def from_float(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
)
)
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
Expand All @@ -692,6 +695,7 @@ def from_float(
new_mod.fp8_amax_history_weight,
new_mod.fp8_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
new_mod.is_amax_initialized,
)
)
Expand All @@ -702,6 +706,7 @@ def from_float(
new_mod.weight,
new_mod.fp8_static_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
)
)

Expand Down
Loading
Loading