diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 7cb7230..6408ac7 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -84,7 +84,7 @@ class Float8LinearConfig: # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() - gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() # # Per-linear configuration