From 224cfdfdb1ab86184a55d7d641350134f4724da0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Jul 2024 12:49:46 -0700 Subject: [PATCH] Update base for Update on "bring back torch.autograd.Function for float8 matmul" Summary: This is a redo of https://github.com/pytorch-labs/float8_experimental/pull/316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6c14b8e..642529f 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ This is theoretically the most performant recipe as it minimizes memory reads. from float8_experimental import ( convert_to_float8_training, sync_float8_amax_and_scale_history, - TensorScalingType, + ScalingType, ) # create model @@ -95,13 +95,13 @@ m = Model(...) # gated with config.enable_amax_init and # config.enable_pre_and_post_forward are needed for # autocast + compile + FSDP + float8 to work -from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig +from float8_experimental import Float8LinearConfig, ScalingType, CastConfig config = Float8LinearConfig( enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed - cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_grad_output=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling