From 5b266987a306249f926e9b31d65ce69d201d48da Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Fri, 28 Jul 2023 17:10:00 -0400 Subject: [PATCH] The bank conflict optimization does not help performance. --- README.md | 3 +- Sources/GEMM.metal | 25 +++++++++----- Tests/Operations/GEMM.swift | 44 ++++++++++++++++++++++--- Tests/Test Cases/CorrectnessTests.swift | 15 +++++++++ Tests/Test Cases/GEMMPerfTests.swift | 6 ++-- Tests/Test Cases/GEMMTest.swift | 33 +++++++++++++------ Tests/Test Cases/MFATestCase.swift | 2 +- Tests/main.swift | 2 -- 8 files changed, 100 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index f28a8bb..110c41e 100644 --- a/README.md +++ b/README.md @@ -64,10 +64,9 @@ Scaling by square size: | ------ | --------- | | `M_splits` | 2 | | `N_splits` | 2 | -| `K_splits` | 1 | | `M_simd` | Block M / `M_splits` | | `N_simd` | Block N / `N_splits` | -| `K_simd` | Block K / `K_splits` | +| `K_simd` | Block K | | Precision | Block M | Block N | Block K | | - | - | - | - | diff --git a/Sources/GEMM.metal b/Sources/GEMM.metal index 253e397..e9e36df 100644 --- a/Sources/GEMM.metal +++ b/Sources/GEMM.metal @@ -50,10 +50,20 @@ constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd; constant ushort M_splits [[function_constant(210)]]; constant ushort N_splits [[function_constant(211)]]; +constant ushort M_bank_offset [[function_constant(50002)]]; // 220 +constant ushort N_bank_offset [[function_constant(50003)]]; // 221 +constant ushort K_bank_offset [[function_constant(50004)]]; // 222 +constant bool M_bank_offset_defined = is_function_constant_defined(M_bank_offset); +constant bool N_bank_offset_defined = is_function_constant_defined(N_bank_offset); +constant bool K_bank_offset_defined = is_function_constant_defined(K_bank_offset); + constant ushort M_group = M_simd * M_splits; constant ushort N_group = N_simd * N_splits; -constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd); -constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group); +constant ushort M_block_dim = M_group + (M_bank_offset_defined ? M_bank_offset : 0); +constant ushort N_block_dim = N_group + (N_bank_offset_defined ? N_bank_offset : 0); +constant ushort K_block_dim = K_simd + (K_bank_offset_defined ? K_bank_offset : 0); +constant ushort A_block_leading_dim = (A_trans ? M_block_dim : K_block_dim); +constant ushort B_block_leading_dim = (B_trans ? K_block_dim : N_block_dim); // There is no padding for M reads/writes. // There is no padding for N reads/writes. @@ -62,8 +72,7 @@ constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8; constant ushort A_sram_length = (M_simd / 8) * 1; constant ushort B_sram_length = 1 * (N_simd / 8); -constant ushort A_block_length = M_group * K_simd; -//constant ushort B_block_length = K_group * N_group; +constant ushort A_block_length = (A_trans) ? K_simd * M_block_dim : M_group * K_block_dim; // Threadgroup block must fit entire C accumulator and partial sums. constant ushort A_sram_offset = 0; @@ -369,16 +378,16 @@ void _gemm_impl(device T *A [[buffer(0)]], for (ushort k = 0; k < K_simd_padded; k += 8) { bool accumulate = use_bias || !(K <= K_simd && k == 0); multiply_accumulate(sram, A_block_src, B_block_src, accumulate); - A_block_src += A_trans ? 8 * M_group : 8; - B_block_src += B_trans ? 8 : 8 * N_group; + A_block_src += A_trans ? 8 * A_block_leading_dim : 8; + B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; } if (K_floor + K_simd < K) { #pragma clang loop unroll(full) for (ushort k = K_simd_padded; k < K_simd; k += 8) { multiply_accumulate(sram, A_block_src, B_block_src); - A_block_src += A_trans ? 8 * M_group : 8; - B_block_src += B_trans ? 8 : 8 * N_group; + A_block_src += A_trans ? 8 * A_block_leading_dim : 8; + B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; } threadgroup_barrier(mem_flags::mem_threadgroup); diff --git a/Tests/Operations/GEMM.swift b/Tests/Operations/GEMM.swift index 7216a51..6f7eb03 100644 --- a/Tests/Operations/GEMM.swift +++ b/Tests/Operations/GEMM.swift @@ -96,6 +96,33 @@ struct MFA_GEMM: GEMM, MFA_Operation { constants.setConstantValue(&M_splits, type: .ushort, index: 210) constants.setConstantValue(&N_splits, type: .ushort, index: 211) + let M_group = M_simd * M_splits + let N_group = N_simd * N_splits + var M_block_dim = M_group + var N_block_dim = N_group + var K_block_dim = K_simd + func setBankOffset(_ dim: inout UInt16, index: Int) { + precondition(dim % 8 == 0, "Original dimension must be divisible by 8.") + let dimBytes = dim * UInt16(dataType.size) + + let dimBytesModulo = dimBytes % 64 + if dimBytesModulo == 16 || dimBytesModulo == 48 { + return + } else if dimBytesModulo == 0 || dimBytesModulo == 32 { + let bankOffsetBytes: UInt16 = 16 + var bankOffset = bankOffsetBytes / UInt16(dataType.size) + dim += bankOffset + constants.setConstantValue(&bankOffset, type: .ushort, index: index) + } else { + fatalError("This should never happen.") + } + } + + // The bank conflict fix makes GEMM slower, unlike attention. +// setBankOffset(&M_block_dim, index: 50002) +// setBankOffset(&N_block_dim, index: 50003) +// setBankOffset(&K_block_dim, index: 50004) + // Satisfy Metal API validation. #if DEBUG do { @@ -118,10 +145,19 @@ struct MFA_GEMM: GEMM, MFA_Operation { let function = try! library.makeFunction( name: name, constantValues: constants) - let M_group = M_simd * M_splits - let N_group = N_simd * N_splits - let A_block_length = M_group * K_simd - let B_block_length = K_simd * N_group + var A_block_length: UInt16 + var B_block_length: UInt16 + + if parameters.A_trans { + A_block_length = K_simd * M_block_dim + } else { + A_block_length = M_group * K_block_dim + } + if parameters.B_trans { + B_block_length = N_group * K_block_dim + } else { + B_block_length = K_simd * N_block_dim + } var blockElements = A_block_length + B_block_length; if (pcopy.M % 8 != 0) && (pcopy.N % 8 != 0) { diff --git a/Tests/Test Cases/CorrectnessTests.swift b/Tests/Test Cases/CorrectnessTests.swift index 4ede88f..01c0378 100644 --- a/Tests/Test Cases/CorrectnessTests.swift +++ b/Tests/Test Cases/CorrectnessTests.swift @@ -178,6 +178,21 @@ class CorrectnessTests: MFATestCase { let params = EuclideanDistanceParameters( matrixK: K, batchSize: batchSize) if !mfa_C.isApproximatelyEqual(to: mps_C, parameters: params) { + do { + var shapeRepr: String + if let batchSize { + shapeRepr = "\(batchSize)x\(M)x\(N)x\(K)x\(DTypeRepr)" + } else { + shapeRepr = "\(M)x\(N)x\(K)x\(DTypeRepr)" + } + if let extraDim { + shapeRepr = "\(extraDim)x\(shapeRepr)" + } + let dist = mfa_C.euclideanDistance(to: mps_C) + let distRepr = "- \(String(format: "%.3f", dist))" + print("Failed test: \(shapeRepr) (\(transRepr)) \(distRepr)") + } + MPL_showComparison( actual: mfa_C, actualName: "MFA", expected: mps_C, expectedName: "MPS", parameters: params) diff --git a/Tests/Test Cases/GEMMPerfTests.swift b/Tests/Test Cases/GEMMPerfTests.swift index 6bb82ab..4b4bf48 100644 --- a/Tests/Test Cases/GEMMPerfTests.swift +++ b/Tests/Test Cases/GEMMPerfTests.swift @@ -17,10 +17,10 @@ class GEMMPerfTests: MFATestCase { // Tests the precision you set as the global testing precision. For a quick // smoke test, you can set a larger granularity. testGEMMSpeed( - granularity: 512, trialsExtension: 2, - B_trans: true, D_trans: false, + granularity: 8, trialsExtension: 2, + B_trans: false, D_trans: false, batchSize: nil, useBias: false, - large: true) + large: false) } // Covers the entire range of square matrix sizes, as well as differences diff --git a/Tests/Test Cases/GEMMTest.swift b/Tests/Test Cases/GEMMTest.swift index 83e0438..48bfeab 100644 --- a/Tests/Test Cases/GEMMTest.swift +++ b/Tests/Test Cases/GEMMTest.swift @@ -71,10 +71,15 @@ func showMatrixTransposeTest() { } func showMatrixBiasTest() { - let M = 57 - let N = 42 - let K = 3 - let batchSize: Int? = 2 +#if arch(arm64) + // 708x25x23xf32 (TTT, bias) + // Failed test: 15x1x124x (TT) + // Failed test: 144x927x28xf32 (TT) - nan + + let M = 48 // 708, 57 + let N = 25 // 25, 42 + let K = 23 // 23, 3 + let batchSize: Int? = nil // 2 let transposeD: Bool = Bool.random() ? true : true var shapeA: [Int] @@ -82,7 +87,7 @@ func showMatrixBiasTest() { var shapeC: [Int] var shapeD: [Int] if let batchSize { - shapeA = [batchSize, M, K] + shapeA = [batchSize, K, M] shapeB = [batchSize, K, N] shapeC = [batchSize, M, N] if transposeD { @@ -91,7 +96,7 @@ func showMatrixBiasTest() { shapeD = [batchSize, N] } } else { - shapeA = [M, K] + shapeA = [K, M] shapeB = [K, N] shapeC = [M, N] if transposeD { @@ -113,7 +118,9 @@ func showMatrixBiasTest() { let py_D = Tensor(shape: shapeD, randomUniform: 0..<1, backend: .numpy) _ExecutionContext.withDefaultBackend(.numpy) { _ExecutionContext.profileCommands { - py_C.matmul(py_A, py_B, py_D, transposeD: transposeD, fusedBias: true) + py_C.matmul( + py_A, py_B, py_D, + transposeA: true, transposeD: transposeD, fusedBias: true) } } @@ -123,7 +130,9 @@ func showMatrixBiasTest() { let mps_D = Tensor(copying: py_D, backend: .mps) _ExecutionContext.withDefaultBackend(.mps) { _ExecutionContext.profileCommands { - mps_C.matmul(mps_A, mps_B, mps_D, transposeD: transposeD, fusedBias: true) + mps_C.matmul( + mps_A, mps_B, mps_D, + transposeA: true, transposeD: transposeD, fusedBias: true) } } @@ -133,7 +142,9 @@ func showMatrixBiasTest() { let mfa_D = Tensor(copying: py_D, backend: .mfa) _ExecutionContext.withDefaultBackend(.mfa) { _ExecutionContext.profileCommands { - mfa_C.matmul(mfa_A, mfa_B, mfa_D, transposeD: transposeD, fusedBias: true) + mfa_C.matmul( + mfa_A, mfa_B, mfa_D, + transposeA: true, transposeD: transposeD, fusedBias: true) } } @@ -151,5 +162,7 @@ func showMatrixBiasTest() { mfa: mfa_C, mps: mps_C, numpy: py_C, parameters: .init(matrixK: K, batchSize: nil)) } - +#endif } + + diff --git a/Tests/Test Cases/MFATestCase.swift b/Tests/Test Cases/MFATestCase.swift index 6334eca..010bfdd 100644 --- a/Tests/Test Cases/MFATestCase.swift +++ b/Tests/Test Cases/MFATestCase.swift @@ -10,7 +10,7 @@ import Foundation class MFATestCase { // Global setting for the precision used in tests. #if arch(arm64) - typealias Real = Float32 + typealias Real = Float16 #else typealias Real = Float #endif diff --git a/Tests/main.swift b/Tests/main.swift index 8018bda..53798dc 100644 --- a/Tests/main.swift +++ b/Tests/main.swift @@ -13,6 +13,4 @@ import PythonKit _ = MetalContext.global _ = PythonContext.global -//showMatrixBiasTest() - MFATestCase.runTests(speed: .veryLong)