Skip to content

2:4 sparsity & `torch.compile`-ing memory_efficient_attention

Compare
Choose a tag to compare
@danthe3rd danthe3rd released this 29 Apr 14:40
· 149 commits to main since this release

Pre-built binary wheels require PyTorch 2.3.0

Added

  • [2:4 sparsity] Added support for Straight-Through Estimator for sparsify24 gradient (GRADIENT_STE)
  • [2:4 sparsity] sparsify24_like now supports the cuSparseLt backend, and the STE gradient
  • Basic support for torch.compile for the memory_efficient_attention operator. Currently only supports Flash-Attention, and without any bias provided. We want to expand this coverage progressively.

Improved

  • merge_attentions no longer needs inputs to be stacked.
  • fMHA: triton_splitk now supports additive bias
  • fMHA: benchmark cleanup