From 6a65815041e83d61c10423aad180dbe7d76d5542 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Fri, 28 Jul 2023 17:13:33 -0400 Subject: [PATCH] Remove bank offsets from the GEMM kernel --- Sources/Attention.metal | 1 - Sources/GEMM.metal | 16 +++--------- Tests/Operations/GEMM.swift | 41 +++--------------------------- Tests/Test Cases/MFATestCase.swift | 2 +- 4 files changed, 7 insertions(+), 53 deletions(-) diff --git a/Sources/Attention.metal b/Sources/Attention.metal index 4cc118c..217b7cf 100644 --- a/Sources/Attention.metal +++ b/Sources/Attention.metal @@ -999,4 +999,3 @@ kernel void attention(device void *Q [[buffer(0)]], } } } - diff --git a/Sources/GEMM.metal b/Sources/GEMM.metal index e9e36df..ba0a28d 100644 --- a/Sources/GEMM.metal +++ b/Sources/GEMM.metal @@ -50,20 +50,10 @@ constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd; constant ushort M_splits [[function_constant(210)]]; constant ushort N_splits [[function_constant(211)]]; -constant ushort M_bank_offset [[function_constant(50002)]]; // 220 -constant ushort N_bank_offset [[function_constant(50003)]]; // 221 -constant ushort K_bank_offset [[function_constant(50004)]]; // 222 -constant bool M_bank_offset_defined = is_function_constant_defined(M_bank_offset); -constant bool N_bank_offset_defined = is_function_constant_defined(N_bank_offset); -constant bool K_bank_offset_defined = is_function_constant_defined(K_bank_offset); - constant ushort M_group = M_simd * M_splits; constant ushort N_group = N_simd * N_splits; -constant ushort M_block_dim = M_group + (M_bank_offset_defined ? M_bank_offset : 0); -constant ushort N_block_dim = N_group + (N_bank_offset_defined ? N_bank_offset : 0); -constant ushort K_block_dim = K_simd + (K_bank_offset_defined ? K_bank_offset : 0); -constant ushort A_block_leading_dim = (A_trans ? M_block_dim : K_block_dim); -constant ushort B_block_leading_dim = (B_trans ? K_block_dim : N_block_dim); +constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd); +constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group); // There is no padding for M reads/writes. // There is no padding for N reads/writes. @@ -72,7 +62,7 @@ constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8; constant ushort A_sram_length = (M_simd / 8) * 1; constant ushort B_sram_length = 1 * (N_simd / 8); -constant ushort A_block_length = (A_trans) ? K_simd * M_block_dim : M_group * K_block_dim; +constant ushort A_block_length = M_group * K_simd; // Threadgroup block must fit entire C accumulator and partial sums. constant ushort A_sram_offset = 0; diff --git a/Tests/Operations/GEMM.swift b/Tests/Operations/GEMM.swift index 6f7eb03..3fe06e5 100644 --- a/Tests/Operations/GEMM.swift +++ b/Tests/Operations/GEMM.swift @@ -98,31 +98,7 @@ struct MFA_GEMM: GEMM, MFA_Operation { let M_group = M_simd * M_splits let N_group = N_simd * N_splits - var M_block_dim = M_group - var N_block_dim = N_group - var K_block_dim = K_simd - func setBankOffset(_ dim: inout UInt16, index: Int) { - precondition(dim % 8 == 0, "Original dimension must be divisible by 8.") - let dimBytes = dim * UInt16(dataType.size) - - let dimBytesModulo = dimBytes % 64 - if dimBytesModulo == 16 || dimBytesModulo == 48 { - return - } else if dimBytesModulo == 0 || dimBytesModulo == 32 { - let bankOffsetBytes: UInt16 = 16 - var bankOffset = bankOffsetBytes / UInt16(dataType.size) - dim += bankOffset - constants.setConstantValue(&bankOffset, type: .ushort, index: index) - } else { - fatalError("This should never happen.") - } - } - - // The bank conflict fix makes GEMM slower, unlike attention. -// setBankOffset(&M_block_dim, index: 50002) -// setBankOffset(&N_block_dim, index: 50003) -// setBankOffset(&K_block_dim, index: 50004) - + // Satisfy Metal API validation. #if DEBUG do { @@ -145,19 +121,8 @@ struct MFA_GEMM: GEMM, MFA_Operation { let function = try! library.makeFunction( name: name, constantValues: constants) - var A_block_length: UInt16 - var B_block_length: UInt16 - - if parameters.A_trans { - A_block_length = K_simd * M_block_dim - } else { - A_block_length = M_group * K_block_dim - } - if parameters.B_trans { - B_block_length = N_group * K_block_dim - } else { - B_block_length = K_simd * N_block_dim - } + let A_block_length = M_group * K_simd + let B_block_length = K_simd * N_group var blockElements = A_block_length + B_block_length; if (pcopy.M % 8 != 0) && (pcopy.N % 8 != 0) { diff --git a/Tests/Test Cases/MFATestCase.swift b/Tests/Test Cases/MFATestCase.swift index 010bfdd..6334eca 100644 --- a/Tests/Test Cases/MFATestCase.swift +++ b/Tests/Test Cases/MFATestCase.swift @@ -10,7 +10,7 @@ import Foundation class MFATestCase { // Global setting for the precision used in tests. #if arch(arm64) - typealias Real = Float16 + typealias Real = Float32 #else typealias Real = Float #endif