diff --git a/README.md b/README.md index 6c14b8e..642529f 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ This is theoretically the most performant recipe as it minimizes memory reads. from float8_experimental import ( convert_to_float8_training, sync_float8_amax_and_scale_history, - TensorScalingType, + ScalingType, ) # create model @@ -95,13 +95,13 @@ m = Model(...) # gated with config.enable_amax_init and # config.enable_pre_and_post_forward are needed for # autocast + compile + FSDP + float8 to work -from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig +from float8_experimental import Float8LinearConfig, ScalingType, CastConfig config = Float8LinearConfig( enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed - cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), - cast_config_grad_output=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED), + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling