2:4 sparsity & `torch.compile`-ing memory_efficient_attention
Pre-built binary wheels require PyTorch 2.3.0
Added
- [2:4 sparsity] Added support for Straight-Through Estimator for
sparsify24
gradient (GRADIENT_STE
) - [2:4 sparsity]
sparsify24_like
now supports the cuSparseLt backend, and the STE gradient - Basic support for
torch.compile
for thememory_efficient_attention
operator. Currently only supports Flash-Attention, and without any bias provided. We want to expand this coverage progressively.
Improved
- merge_attentions no longer needs inputs to be stacked.
- fMHA: triton_splitk now supports additive bias
- fMHA: benchmark cleanup