Skip to content

Commit

Permalink
The bank conflict optimization does not help performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
philipturner committed Jul 28, 2023
1 parent f6710a0 commit 5b26698
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 30 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ Scaling by square size:
| ------ | --------- |
| `M_splits` | 2 |
| `N_splits` | 2 |
| `K_splits` | 1 |
| `M_simd` | Block M / `M_splits` |
| `N_simd` | Block N / `N_splits` |
| `K_simd` | Block K / `K_splits` |
| `K_simd` | Block K |

| Precision | Block M | Block N | Block K |
| - | - | - | - |
Expand Down
25 changes: 17 additions & 8 deletions Sources/GEMM.metal
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,20 @@ constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
constant ushort M_splits [[function_constant(210)]];
constant ushort N_splits [[function_constant(211)]];

constant ushort M_bank_offset [[function_constant(50002)]]; // 220
constant ushort N_bank_offset [[function_constant(50003)]]; // 221
constant ushort K_bank_offset [[function_constant(50004)]]; // 222
constant bool M_bank_offset_defined = is_function_constant_defined(M_bank_offset);
constant bool N_bank_offset_defined = is_function_constant_defined(N_bank_offset);
constant bool K_bank_offset_defined = is_function_constant_defined(K_bank_offset);

constant ushort M_group = M_simd * M_splits;
constant ushort N_group = N_simd * N_splits;
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
constant ushort M_block_dim = M_group + (M_bank_offset_defined ? M_bank_offset : 0);
constant ushort N_block_dim = N_group + (N_bank_offset_defined ? N_bank_offset : 0);
constant ushort K_block_dim = K_simd + (K_bank_offset_defined ? K_bank_offset : 0);
constant ushort A_block_leading_dim = (A_trans ? M_block_dim : K_block_dim);
constant ushort B_block_leading_dim = (B_trans ? K_block_dim : N_block_dim);

// There is no padding for M reads/writes.
// There is no padding for N reads/writes.
Expand All @@ -62,8 +72,7 @@ constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;

constant ushort A_sram_length = (M_simd / 8) * 1;
constant ushort B_sram_length = 1 * (N_simd / 8);
constant ushort A_block_length = M_group * K_simd;
//constant ushort B_block_length = K_group * N_group;
constant ushort A_block_length = (A_trans) ? K_simd * M_block_dim : M_group * K_block_dim;

// Threadgroup block must fit entire C accumulator and partial sums.
constant ushort A_sram_offset = 0;
Expand Down Expand Up @@ -369,16 +378,16 @@ void _gemm_impl(device T *A [[buffer(0)]],
for (ushort k = 0; k < K_simd_padded; k += 8) {
bool accumulate = use_bias || !(K <= K_simd && k == 0);
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
A_block_src += A_trans ? 8 * M_group : 8;
B_block_src += B_trans ? 8 : 8 * N_group;
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}

if (K_floor + K_simd < K) {
#pragma clang loop unroll(full)
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
multiply_accumulate(sram, A_block_src, B_block_src);
A_block_src += A_trans ? 8 * M_group : 8;
B_block_src += B_trans ? 8 : 8 * N_group;
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);

Expand Down
44 changes: 40 additions & 4 deletions Tests/Operations/GEMM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,33 @@ struct MFA_GEMM: GEMM, MFA_Operation {
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)

let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
var M_block_dim = M_group
var N_block_dim = N_group
var K_block_dim = K_simd
func setBankOffset(_ dim: inout UInt16, index: Int) {
precondition(dim % 8 == 0, "Original dimension must be divisible by 8.")
let dimBytes = dim * UInt16(dataType.size)

let dimBytesModulo = dimBytes % 64
if dimBytesModulo == 16 || dimBytesModulo == 48 {
return
} else if dimBytesModulo == 0 || dimBytesModulo == 32 {
let bankOffsetBytes: UInt16 = 16
var bankOffset = bankOffsetBytes / UInt16(dataType.size)
dim += bankOffset
constants.setConstantValue(&bankOffset, type: .ushort, index: index)
} else {
fatalError("This should never happen.")
}
}

// The bank conflict fix makes GEMM slower, unlike attention.
// setBankOffset(&M_block_dim, index: 50002)
// setBankOffset(&N_block_dim, index: 50003)
// setBankOffset(&K_block_dim, index: 50004)

// Satisfy Metal API validation.
#if DEBUG
do {
Expand All @@ -118,10 +145,19 @@ struct MFA_GEMM: GEMM, MFA_Operation {
let function = try! library.makeFunction(
name: name, constantValues: constants)

let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var A_block_length: UInt16
var B_block_length: UInt16

if parameters.A_trans {
A_block_length = K_simd * M_block_dim
} else {
A_block_length = M_group * K_block_dim
}
if parameters.B_trans {
B_block_length = N_group * K_block_dim
} else {
B_block_length = K_simd * N_block_dim
}

var blockElements = A_block_length + B_block_length;
if (pcopy.M % 8 != 0) && (pcopy.N % 8 != 0) {
Expand Down
15 changes: 15 additions & 0 deletions Tests/Test Cases/CorrectnessTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ class CorrectnessTests: MFATestCase {
let params = EuclideanDistanceParameters(
matrixK: K, batchSize: batchSize)
if !mfa_C.isApproximatelyEqual(to: mps_C, parameters: params) {
do {
var shapeRepr: String
if let batchSize {
shapeRepr = "\(batchSize)x\(M)x\(N)x\(K)x\(DTypeRepr)"
} else {
shapeRepr = "\(M)x\(N)x\(K)x\(DTypeRepr)"
}
if let extraDim {
shapeRepr = "\(extraDim)x\(shapeRepr)"
}
let dist = mfa_C.euclideanDistance(to: mps_C)
let distRepr = "- \(String(format: "%.3f", dist))"
print("Failed test: \(shapeRepr) (\(transRepr)) \(distRepr)")
}

MPL_showComparison(
actual: mfa_C, actualName: "MFA",
expected: mps_C, expectedName: "MPS", parameters: params)
Expand Down
6 changes: 3 additions & 3 deletions Tests/Test Cases/GEMMPerfTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ class GEMMPerfTests: MFATestCase {
// Tests the precision you set as the global testing precision. For a quick
// smoke test, you can set a larger granularity.
testGEMMSpeed(
granularity: 512, trialsExtension: 2,
B_trans: true, D_trans: false,
granularity: 8, trialsExtension: 2,
B_trans: false, D_trans: false,
batchSize: nil, useBias: false,
large: true)
large: false)
}

// Covers the entire range of square matrix sizes, as well as differences
Expand Down
33 changes: 23 additions & 10 deletions Tests/Test Cases/GEMMTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,23 @@ func showMatrixTransposeTest() {
}

func showMatrixBiasTest() {
let M = 57
let N = 42
let K = 3
let batchSize: Int? = 2
#if arch(arm64)
// 708x25x23xf32 (TTT, bias)
// Failed test: 15x1x124x (TT)
// Failed test: 144x927x28xf32 (TT) - nan

let M = 48 // 708, 57
let N = 25 // 25, 42
let K = 23 // 23, 3
let batchSize: Int? = nil // 2
let transposeD: Bool = Bool.random() ? true : true

var shapeA: [Int]
var shapeB: [Int]
var shapeC: [Int]
var shapeD: [Int]
if let batchSize {
shapeA = [batchSize, M, K]
shapeA = [batchSize, K, M]
shapeB = [batchSize, K, N]
shapeC = [batchSize, M, N]
if transposeD {
Expand All @@ -91,7 +96,7 @@ func showMatrixBiasTest() {
shapeD = [batchSize, N]
}
} else {
shapeA = [M, K]
shapeA = [K, M]
shapeB = [K, N]
shapeC = [M, N]
if transposeD {
Expand All @@ -113,7 +118,9 @@ func showMatrixBiasTest() {
let py_D = Tensor<Real>(shape: shapeD, randomUniform: 0..<1, backend: .numpy)
_ExecutionContext.withDefaultBackend(.numpy) {
_ExecutionContext.profileCommands {
py_C.matmul(py_A, py_B, py_D, transposeD: transposeD, fusedBias: true)
py_C.matmul(
py_A, py_B, py_D,
transposeA: true, transposeD: transposeD, fusedBias: true)
}
}

Expand All @@ -123,7 +130,9 @@ func showMatrixBiasTest() {
let mps_D = Tensor(copying: py_D, backend: .mps)
_ExecutionContext.withDefaultBackend(.mps) {
_ExecutionContext.profileCommands {
mps_C.matmul(mps_A, mps_B, mps_D, transposeD: transposeD, fusedBias: true)
mps_C.matmul(
mps_A, mps_B, mps_D,
transposeA: true, transposeD: transposeD, fusedBias: true)
}
}

Expand All @@ -133,7 +142,9 @@ func showMatrixBiasTest() {
let mfa_D = Tensor(copying: py_D, backend: .mfa)
_ExecutionContext.withDefaultBackend(.mfa) {
_ExecutionContext.profileCommands {
mfa_C.matmul(mfa_A, mfa_B, mfa_D, transposeD: transposeD, fusedBias: true)
mfa_C.matmul(
mfa_A, mfa_B, mfa_D,
transposeA: true, transposeD: transposeD, fusedBias: true)
}
}

Expand All @@ -151,5 +162,7 @@ func showMatrixBiasTest() {
mfa: mfa_C, mps: mps_C, numpy: py_C,
parameters: .init(matrixK: K, batchSize: nil))
}
#endif
}


2 changes: 1 addition & 1 deletion Tests/Test Cases/MFATestCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation
class MFATestCase {
// Global setting for the precision used in tests.
#if arch(arm64)
typealias Real = Float32
typealias Real = Float16
#else
typealias Real = Float
#endif
Expand Down
2 changes: 0 additions & 2 deletions Tests/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,4 @@ import PythonKit
_ = MetalContext.global
_ = PythonContext.global

//showMatrixBiasTest()

MFATestCase.runTests(speed: .veryLong)

0 comments on commit 5b26698

Please sign in to comment.