Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "bring back torch.autograd.Function for float8 matmul"
Browse files Browse the repository at this point in the history
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 25, 2024
2 parents 7118c16 + 90f73c8 commit 1c3e320
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Float8LinearConfig:
#
gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig()

#
# Per-linear configuration
Expand Down

0 comments on commit 1c3e320

Please sign in to comment.