Skip to content

ivarflakstad/metal-flash-attention

 
 

Repository files navigation

Metal FlashAttention

A faster alternative to Metal Performance Shaders, a reference implementation of modern GPU algorithms, and a step toward defragmenting the AI ecosystem.

Algorithms:

  • Attention
    • Dense (90.5% ALU)
    • Block-Sparse
  • GEMM
    • FP16 (93.3% ALU)
    • FP32 (87.2% ALU)
    • Fused Biases

Usage

Progamming Language MFA Supports MPSGraph Supports PyTorch Supports
CPU C++ (metal-cpp)
GPU C++ (Indirect Command Buffers)
Swift (iPadOS, Playgrounds)
Swift (macOS, Xcode)
Predecessor to Swift not tested

Usage:

  • Download Xcode 14.2 from the Apple developer tools archive
    • Copy into /Applications/Xcode 14.2.app, side by side with the existing Xcode installation /Applications/Xcode.app
  • Run the Swift script to compile libMetalFlashAttention.metallib
    • Enter this repository from Terminal and type swift build.swift
  • Read the API specification
  • Generate Metal shader variants at runtime

Alternatively:

  • Download the newest version of Xcode
  • Fetch the Metal library from GitHub releases
  • Run the unit tests from this repository

Performance

SGEMM, every square matrix from 1–1536:

Max GFLOPS achieved

HGEMM, every square matrix from 1–2048:

Max GFLOPS achieved

GEMM

Scaling by square size:

  • Matrix M: every even integer
  • Matrix N: every even integer
  • Matrix K: every even integer
  • For 2x batched, every multiple of 4
  • For very large square matrices, granularity varies
Function Constant Value
M_splits 2
N_splits 2
M_simd Block M / M_splits
N_simd Block N / N_splits
K_simd Block K
Precision Block M Block N Block K
Float32 32 32 32
Float32 48 48 24
Float16 32 32 32
Float16 48 48 32
Size Start Size End Duplicate Commands/Encoder Trials
1 190 256 16
192 254 128 16
256 382 64 16
384 510 32 16
512 766 16 16
768 1022 8 16
1024 1534 4 16
1536 2048 2 16

Float32 Utilization (NN)

Float32 Utilization (NN)

Float32 Utilization (NT)

Float32 Utilization (NT)

Float32 Utilization (NT, Large)

Float32 Utilization (NT)

Float16 Utilization (NN)

Float16 Utilization (NN)

Float16 Utilization (NT, 2x Batched)

Float16 Utilization (NT, 2x Batched)

Float16 Utilization (NTN, 2x Batched, Bias)

Float16 Utilization (NTN, 2x Batched, Bias)

Attention

Setup:

  • Sequence dimension:
    • R = rows (output sequence length)
    • C = columns (input sequence length)
    • R = C
  • Masking:
    • Only MFA supports block-sparse masks.
    • For "scaling by sparsity", sparse block size equals GEMM block size.

Scaling by sequence length:

  • Masking:
    • No mask
    • Dense Mask: triangular mask
    • Sparse Mask: triangular mask, summarized by block-sparse mask
  • Sequence length:
    • Small sequences: every multiple of 4
    • Large sequences: every multiple of 64
    • Causal mask: every even integer
  • Head size: 64
  • Head count:
    • Small sequences: 10
    • Large sequences: 5
    • Causal mask: 10

Scaling by head size:

  • Masking: dense, no mask
  • Sequence length 4096
  • Head size: every integer
    • ≤64: every integer
    • >64: every roundUpToPowerOf2(D/64) integers
  • Head count: 8
Function Constant Value
Q_trans
K_trans
V_trans
O_trans
R_splits TBD
R_simd Block R / R_splits
C_simd Block C
D_simd $$8 \times \left \lceil{ \frac{D}{8} }\right \rceil $$

Float32 Sequence Scaling (Small)

FlashAttention (F32, H=10, D=64)

Float16 Sequence Scaling (Small)

Dense: Stable Diffusion XL outermost attention layer @ 512x512 (sequence length = 1024)

FlashAttention (F16, H=10, D=64)

Float16 Sequence Scaling (Large)

Dense: Stable Diffusion 2 outermost attention layer @ 512x512 (sequence length = 4096)

FlashAttention (F16, H=5, D=64)

Float32 Sequence Scaling (Causal Mask)

FlashAttention (F32, H=10, D=64)

Float16 Sequence Scaling (Causal Mask)

FlashAttention (F16, H=10, D=64)

FlashAttention (F16, H=10, D=64)

Float16 Head Scaling

Dense: Stable Diffusion 1 outermost attention layer @ 512x512 (head size = 40)

FlashAttention (F16, R=C=4096, H=8)

Roadmap

Releases:

  • v0.1.0-alpha
    • Initial release, only non-batched GEMM without fused transposes
  • v0.2.0-alpha
    • Fused transposes for A and B
    • Batched GEMM
  • v1.0.0
    • Attention: dense and block-sparse
  • v1.0.1
    • GEMM: fused biases

Prospective Future Goals:

  • Tune the existing GEMM and Attention kernels for new A17/M3 hardware
  • Kahan block-summation with double-single accumulate, in a manner portable to other vendors

About

Faster alternative to Metal Performance Shaders

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Swift 72.1%
  • Metal 27.9%