A faster alternative to Metal Performance Shaders, a reference implementation of modern GPU algorithms, and a step toward defragmenting the AI ecosystem.
Algorithms:
- Attention
- Dense (90.5% ALU)
- Block-Sparse
- GEMM
- FP16 (93.3% ALU)
- FP32 (87.2% ALU)
- Fused Biases
Progamming Language | MFA Supports | MPSGraph Supports | PyTorch Supports |
---|---|---|---|
CPU C++ (metal-cpp) | ✅ | ❌ | ✅ |
GPU C++ (Indirect Command Buffers) | ✅ | ❌ | ❌ |
Swift (iPadOS, Playgrounds) | ✅ | ✅ | ❌ |
Swift (macOS, Xcode) | ✅ | ✅ | ✅ |
Predecessor to Swift | not tested | ✅ | ✅ |
Usage:
- Download Xcode 14.2 from the Apple developer tools archive
- Copy into
/Applications/Xcode 14.2.app
, side by side with the existing Xcode installation/Applications/Xcode.app
- Copy into
- Run the Swift script to compile
libMetalFlashAttention.metallib
- Enter this repository from Terminal and type
swift build.swift
- Enter this repository from Terminal and type
- Read the API specification
- Generate Metal shader variants at runtime
Alternatively:
- Download the newest version of Xcode
- Fetch the Metal library from GitHub releases
- Run the unit tests from this repository
SGEMM, every square matrix from 1–1536:
HGEMM, every square matrix from 1–2048:
Scaling by square size:
- Matrix M: every even integer
- Matrix N: every even integer
- Matrix K: every even integer
- For 2x batched, every multiple of 4
- For very large square matrices, granularity varies
Function Constant | Value |
---|---|
M_splits |
2 |
N_splits |
2 |
M_simd |
Block M / M_splits |
N_simd |
Block N / N_splits |
K_simd |
Block K |
Precision | Block M | Block N | Block K |
---|---|---|---|
Float32 | 32 | 32 | 32 |
Float32 | 48 | 48 | 24 |
Float16 | 32 | 32 | 32 |
Float16 | 48 | 48 | 32 |
Size Start | Size End | Duplicate Commands/Encoder | Trials |
---|---|---|---|
1 | 190 | 256 | 16 |
192 | 254 | 128 | 16 |
256 | 382 | 64 | 16 |
384 | 510 | 32 | 16 |
512 | 766 | 16 | 16 |
768 | 1022 | 8 | 16 |
1024 | 1534 | 4 | 16 |
1536 | 2048 | 2 | 16 |
Setup:
- Sequence dimension:
- R = rows (output sequence length)
- C = columns (input sequence length)
- R = C
- Masking:
- Only MFA supports block-sparse masks.
- For "scaling by sparsity", sparse block size equals GEMM block size.
Scaling by sequence length:
- Masking:
- No mask
- Dense Mask: triangular mask
- Sparse Mask: triangular mask, summarized by block-sparse mask
- Sequence length:
- Small sequences: every multiple of 4
- Large sequences: every multiple of 64
- Causal mask: every even integer
- Head size: 64
- Head count:
- Small sequences: 10
- Large sequences: 5
- Causal mask: 10
Scaling by head size:
- Masking: dense, no mask
- Sequence length 4096
- Head size: every integer
- ≤64: every integer
- >64: every
roundUpToPowerOf2(D/64)
integers
- Head count: 8
Function Constant | Value |
---|---|
Q_trans |
❌ |
K_trans |
✅ |
V_trans |
❌ |
O_trans |
❌ |
R_splits |
TBD |
R_simd |
Block R / R_splits
|
C_simd |
Block C |
D_simd |
Dense: Stable Diffusion XL outermost attention layer @ 512x512 (sequence length = 1024)
Dense: Stable Diffusion 2 outermost attention layer @ 512x512 (sequence length = 4096)
Dense: Stable Diffusion 1 outermost attention layer @ 512x512 (head size = 40)
Releases:
- v0.1.0-alpha
- Initial release, only non-batched GEMM without fused transposes
- v0.2.0-alpha
- Fused transposes for A and B
- Batched GEMM
- v1.0.0
- Attention: dense and block-sparse
- v1.0.1
- GEMM: fused biases
Prospective Future Goals:
- Tune the existing GEMM and Attention kernels for new A17/M3 hardware
- Kahan block-summation with double-single accumulate, in a manner portable to other vendors