Skip to content

Commit

Permalink
Remove bank offsets from the GEMM kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
philipturner committed Jul 28, 2023
1 parent 5b26698 commit 6a65815
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 53 deletions.
1 change: 0 additions & 1 deletion Sources/Attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -999,4 +999,3 @@ kernel void attention(device void *Q [[buffer(0)]],
}
}
}

16 changes: 3 additions & 13 deletions Sources/GEMM.metal
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,10 @@ constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
constant ushort M_splits [[function_constant(210)]];
constant ushort N_splits [[function_constant(211)]];

constant ushort M_bank_offset [[function_constant(50002)]]; // 220
constant ushort N_bank_offset [[function_constant(50003)]]; // 221
constant ushort K_bank_offset [[function_constant(50004)]]; // 222
constant bool M_bank_offset_defined = is_function_constant_defined(M_bank_offset);
constant bool N_bank_offset_defined = is_function_constant_defined(N_bank_offset);
constant bool K_bank_offset_defined = is_function_constant_defined(K_bank_offset);

constant ushort M_group = M_simd * M_splits;
constant ushort N_group = N_simd * N_splits;
constant ushort M_block_dim = M_group + (M_bank_offset_defined ? M_bank_offset : 0);
constant ushort N_block_dim = N_group + (N_bank_offset_defined ? N_bank_offset : 0);
constant ushort K_block_dim = K_simd + (K_bank_offset_defined ? K_bank_offset : 0);
constant ushort A_block_leading_dim = (A_trans ? M_block_dim : K_block_dim);
constant ushort B_block_leading_dim = (B_trans ? K_block_dim : N_block_dim);
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);

// There is no padding for M reads/writes.
// There is no padding for N reads/writes.
Expand All @@ -72,7 +62,7 @@ constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;

constant ushort A_sram_length = (M_simd / 8) * 1;
constant ushort B_sram_length = 1 * (N_simd / 8);
constant ushort A_block_length = (A_trans) ? K_simd * M_block_dim : M_group * K_block_dim;
constant ushort A_block_length = M_group * K_simd;

// Threadgroup block must fit entire C accumulator and partial sums.
constant ushort A_sram_offset = 0;
Expand Down
41 changes: 3 additions & 38 deletions Tests/Operations/GEMM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,31 +98,7 @@ struct MFA_GEMM: GEMM, MFA_Operation {

let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
var M_block_dim = M_group
var N_block_dim = N_group
var K_block_dim = K_simd
func setBankOffset(_ dim: inout UInt16, index: Int) {
precondition(dim % 8 == 0, "Original dimension must be divisible by 8.")
let dimBytes = dim * UInt16(dataType.size)

let dimBytesModulo = dimBytes % 64
if dimBytesModulo == 16 || dimBytesModulo == 48 {
return
} else if dimBytesModulo == 0 || dimBytesModulo == 32 {
let bankOffsetBytes: UInt16 = 16
var bankOffset = bankOffsetBytes / UInt16(dataType.size)
dim += bankOffset
constants.setConstantValue(&bankOffset, type: .ushort, index: index)
} else {
fatalError("This should never happen.")
}
}

// The bank conflict fix makes GEMM slower, unlike attention.
// setBankOffset(&M_block_dim, index: 50002)
// setBankOffset(&N_block_dim, index: 50003)
// setBankOffset(&K_block_dim, index: 50004)


// Satisfy Metal API validation.
#if DEBUG
do {
Expand All @@ -145,19 +121,8 @@ struct MFA_GEMM: GEMM, MFA_Operation {
let function = try! library.makeFunction(
name: name, constantValues: constants)

var A_block_length: UInt16
var B_block_length: UInt16

if parameters.A_trans {
A_block_length = K_simd * M_block_dim
} else {
A_block_length = M_group * K_block_dim
}
if parameters.B_trans {
B_block_length = N_group * K_block_dim
} else {
B_block_length = K_simd * N_block_dim
}
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group

var blockElements = A_block_length + B_block_length;
if (pcopy.M % 8 != 0) && (pcopy.N % 8 != 0) {
Expand Down
2 changes: 1 addition & 1 deletion Tests/Test Cases/MFATestCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation
class MFATestCase {
// Global setting for the precision used in tests.
#if arch(arm64)
typealias Real = Float16
typealias Real = Float32
#else
typealias Real = Float
#endif
Expand Down

0 comments on commit 6a65815

Please sign in to comment.