From 90f73c82be37009b02f8a72d06d5e6b9a5e4ecde Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Jul 2024 09:43:02 -0700 Subject: [PATCH] Update base for Update on "bring back torch.autograd.Function for float8 matmul" Summary: This is a redo of https://github.com/pytorch-labs/float8_experimental/pull/316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 7cb7230..6408ac7 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -84,7 +84,7 @@ class Float8LinearConfig: # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() - gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() # # Per-linear configuration