Skip to content

Commit

Permalink
sw: Add GEMM baselines without SSR & frep extensions (#83)
Browse files Browse the repository at this point in the history
* gemm: Add fp32 SIMD baseline kernel without SSR and frep extensions

* gemm: Remove clean up of columns

This is not needed for not-unrolled kernels

* gemm: Implement fp16 baseline

* gemm: Implement fp8 baseline

* gemm: Try to verify fp8

* lint: Python sources

* gemm: Parametrize type of kernels as baseline/optimized + align gemm signatures

* gemm: Rename naive baseline kernels to `naive`

* gemm: Clean up comments

* gemm: Format code

* gemm: Fix function arguments

* gemm: Formatting

* gemm: Format back with `clang-fomat-10`

* gemm: Remove deprecated verify function

* gemm: Fix error threshold

* verify: Remove unused constant numpy types
  • Loading branch information
fischeti authored Feb 15, 2024
1 parent b15db73 commit 2f532e9
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 67 deletions.
253 changes: 194 additions & 59 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ static inline double multiply_opt(double multiplicand, double multiplier) {
return 0;
}

void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
uint32_t ldA, uint32_t ta, float* B, uint32_t ldB,
uint32_t tb, float* C, uint32_t ldC, float BETA) {
void gemm_fp32_naive(uint32_t M, uint32_t N, uint32_t K, float* A, uint32_t ldA,
uint32_t ta, float* B, uint32_t ldB, uint32_t tb, float* C,
uint32_t ldC, float BETA) {
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
Expand Down Expand Up @@ -81,9 +81,9 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
}
}

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
void gemm_fp64_naive(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
Expand Down Expand Up @@ -127,35 +127,16 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
}
}

/* params:
* M: number of rows of A and C
* N: number of columns of B and C
* K: number of columns of A and rows of B
* A: pointer to matrix A
* ldA: row stride of A
* ta: transpose A
* B: pointer to matrix B
* ldB: row stride of B
* tb: transpose B
* C: pointer to matrix C
* ldC: row stride of C
* BETA: scalar beta
* A is MxK, B is KxN, C is MxN
*/
void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A,
uint32_t ldA, uint32_t ta, float* B,
uint32_t ldB, uint32_t tb, float* C,
uint32_t ldC, float BETA) {
// float c0, c1, c2, c3 = 0;
void gemm_fp32_naive_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A,
uint32_t ldA, uint32_t ta, float* B, uint32_t ldB,
uint32_t tb, float* C, uint32_t ldC, float BETA) {
float c0 = 0.0f;
float c1 = 0.0f;
float c2 = 0.0f;
float c3 = 0.0f;
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
// register float c0 = BETA * C[m * ldC + n];
// c0, c1, c2, c3 = 0;
if (BETA == 0.0f) {
c0 = 0.0f;
} else {
Expand All @@ -176,7 +157,6 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A,
} else if (ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
// register float c0 = BETA * C[m * ldC + n];
if (BETA == 0.0f) {
c0 = 0.0f;
} else {
Expand All @@ -197,7 +177,6 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A,
} else if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
// register float c0 = BETA * C[m * ldC + n];
if (BETA == 0.0f) {
c0 = 0.0f;
} else {
Expand All @@ -207,13 +186,11 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A,
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k += 4) {
// c0 += A[k + m * ldA] * B[k + n * ldB];
c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB];
c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB];
c2 += A[(k + 2) + m * ldA] * B[(k + 2) + n * ldB];
c3 += A[(k + 3) + m * ldA] * B[(k + 3) + n * ldB];
}
// C[m * ldC + n] = c0;
C[m * ldC + n] = c0 + c1 + c2 + c3;
}
}
Expand All @@ -238,18 +215,6 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
// for maximum utilization
const uint32_t unroll = 8;

// A is of size MxK, B is of size KxN, C is of size MxN
// for (uint32_t m = 0; m < M; m++) {
// for (uint32_t n = 0; n < N / unroll; n++) {
// double c0 = BETA * C[m * ldC + n];
// for (uint32_t k = 0; k < K; k++) {
// for (uint32_t j = 0; j < unroll; j++) {
// c0 += A[k + m * ldA] * B[k + (n + j) * ldB];
// }
// }
// C[m * ldC + n] = c0;
// }
// }
// SSR strides and bounds only have to be configured
// once in the beginning
if (setup_SSR) {
Expand All @@ -265,7 +230,6 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M};
const uint32_t ssr0_i[4] = {0, 8, 0, 8 * ldA};

// A[k + unroll * m * ldA]
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
Expand All @@ -283,7 +247,6 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M};
const uint32_t ssr1_i[4] = {8, 8 * ldB, 8 * unroll, 0};

// B[k + unroll * n * ldB]
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
ssr1_i[3]);
Expand Down Expand Up @@ -370,6 +333,63 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
snrt_ssr_disable();
}

void gemm_fp32_baseline(const uint32_t M, const uint32_t N, const uint32_t K,
float* A, const uint32_t ldA, float* B,
const uint32_t ldB, float* C, const uint32_t ldC,
const uint32_t BETA) {
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (; n < N; n++) {
volatile register v2f32 *a_ptr, *b_ptr;
register v2f32 a, b;
volatile float* c_ptr;
const register float zero = 0.0;
double c = 0.0;
v2f32 reduce_reg;

a_ptr = (v2f32*)(&A[m * ldA]);
b_ptr = (v2f32*)(&B[n * ldB]);
c_ptr = &C[m * ldC + n];
// Don't accumulate in first iteration
asm volatile(
"beqz %[BETA], 1f \n"
// Load intermediate results
"flw ft2, 0(%[C]) \n"
"vfcpka.s.s ft2, ft2, %[zero]\n"
// or initialize with zero
"j 2f \n"
"1: \n"
"vfcpka.s.s ft2, %[zero], %[zero]\n"
// Don't accumulate in first iteration
"2: \n"
"fld ft0, 0(%[a_ptr]) \n"
"fld ft1, 0(%[b_ptr]) \n"
"add %[a_ptr], %[a_ptr], 8 \n"
"add %[b_ptr], %[b_ptr], 8 \n"
"vfmul.s ft3, ft0, ft1 \n"
// loop over the MACs
"li t0, 2 \n"
"3: \n"
"fld ft0, 0(%[a_ptr]) \n"
"fld ft1, 0(%[b_ptr]) \n"
"vfmac.s ft3, ft0, ft1 \n"
"add %[a_ptr], %[a_ptr], 8 \n"
"add %[b_ptr], %[b_ptr], 8 \n"
"addi t0, t0, 2 \n"
"blt t0, %[K], 3b \n"
// Sum reduce vector
"vfsum.s ft2, ft3 \n"
// Store results
"fsw ft2, 0(%[C]) \n"
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr)
: [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg),
[ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K),
[ zero ] "f"(zero)
: "ft0", "ft1", "ft2", "ft3", "ft4", "t0");
}
}
}

void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
float* A, const uint32_t ldA, float* B, const uint32_t ldB,
float* C, const uint32_t ldC, const uint32_t* BETA,
Expand Down Expand Up @@ -521,6 +541,61 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
snrt_ssr_disable();
}

void gemm_fp16_baseline(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
uint32_t ldA, __fp16* B, uint32_t ldB, __fp16* C,
uint32_t ldC, const uint32_t* BETA,
uint32_t setup_SSR) {
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (; n < N; n++) {
volatile register v4f16 *a_ptr, *b_ptr;
register v4f16 a, b;
volatile __fp16* c_ptr;
const register float zero = 0.0;
double c = 0.0;
v4f16 reduce_reg;

a_ptr = (v4f16*)(&A[m * ldA]);
b_ptr = (v4f16*)(&B[n * ldB]);
c_ptr = &C[m * ldC + n];
// Don't accumulate in first iteration
asm volatile(
"lw t0, 0(%[BETA]) \n"
"beqz t0, 1f \n"
// Load intermediate results
"flh ft2, 0(%[C]) \n"
"vfcvt.s.h ft2, ft2\n"
"vfcpka.s.s ft2, ft2, %[zero]\n"
// or initialize with zero
"j 2f \n"
"1: \n"
"vfcpka.s.s ft2, %[zero], %[zero]\n"
"2: \n"
// loop over the MACs
"li t0, 0 \n"
"3: \n"
"fld ft0, 0(%[a_ptr]) \n"
"fld ft1, 0(%[b_ptr]) \n"
"add %[a_ptr], %[a_ptr], 8 \n"
"add %[b_ptr], %[b_ptr], 8 \n"
"vfdotpex.s.h ft2, ft0, ft1 \n"
"addi t0, t0, 4 \n"
"blt t0, %[K], 3b \n"
// Sum reduce vector
"vfcpka.s.s ft3, %[zero], %[zero]\n"
"vfsum.s ft3, ft2 \n"
"vfcvt.h.s ft3, ft3\n"
// Store results
"fsh ft3, 0(%[C]) \n"
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr)
: [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg),
[ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K),
[ zero ] "f"(zero)
: "ft0", "ft1", "ft2", "ft3", "ft4", "t0");
}
}
}

void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
__fp16* B, uint32_t ldB, __fp16* C, uint32_t ldC,
const uint32_t* BETA, uint32_t setup_SSR) {
Expand Down Expand Up @@ -869,6 +944,61 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
snrt_ssr_disable();
}

void gemm_fp8_baseline(uint32_t M, uint32_t N, uint32_t K, char* A,
uint32_t ldA, char* B, uint32_t ldB, char* C,
uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) {
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (; n < N; n++) {
volatile register v8f8 *a_ptr, *b_ptr;
register v8f8 a, b;
volatile char* c_ptr;
const register float zero = 0.0;
double c = 0.0;
v8f8 reduce_reg;

a_ptr = (v8f8*)(&A[m * ldA]);
b_ptr = (v8f8*)(&B[n * ldB]);
c_ptr = &C[m * ldC + n];
asm volatile(
"lw t0, 0(%[BETA]) \n"
"beqz t0, 1f \n"
// Load intermediate results
"flb ft2, 0(%[C]) \n"
"vfcvt.s.b ft2, ft2\n"
"vfcpka.h.s ft2, ft2, %[zero]\n"
// or initialize with zero
"j 2f \n"
"1: \n"
"vfcpka.s.s ft2, %[zero], %[zero]\n"
"2: \n"
// loop over the MACs
"li t0, 0 \n"
"3: \n"
"fld ft0, 0(%[a_ptr]) \n"
"fld ft1, 0(%[b_ptr]) \n"
"add %[a_ptr], %[a_ptr], 8 \n"
"add %[b_ptr], %[b_ptr], 8 \n"
"vfdotpex.h.b ft2, ft0, ft1 \n"
"addi t0, t0, 8 \n"
"blt t0, %[K], 3b \n"
// Sum reduce vector
"vfcpka.s.s ft3, %[zero], %[zero]\n"
"vfsumex.s.h ft3, ft2 \n"
"vfcpka.s.s ft2, %[zero], %[zero]\n"
"vfsum.s ft2, ft3 \n"
"vfcvt.b.s ft2, ft2\n"
// Store results
"fsb ft2, 0(%[C]) \n"
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr)
: [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg),
[ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K),
[ zero ] "f"(zero)
: "ft0", "ft1", "ft2", "ft3", "ft4", "t0");
}
}
}

void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
char* B, uint32_t ldB, char* C, uint32_t ldC,
const uint32_t* BETA, uint32_t setup_SSR) {
Expand Down Expand Up @@ -1088,11 +1218,11 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,

switch (prec) {
case FP64:
if (baseline == 1) {
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb,
transb, (double*)c + offsetC,
ldc_strided, (double)beta);
if (baseline) {
gemm_fp64_naive(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb,
transb, (double*)c + offsetC, ldc_strided,
(double)beta);
} else {
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb, transb,
Expand All @@ -1101,11 +1231,10 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
}
break;
case FP32:
if (baseline == 1) {
if (baseline) {
gemm_fp32_baseline(frac_m, n, k, (float*)a + offsetA,
lda_strided, transa, (float*)b, ldb,
transb, (float*)c + offsetC, ldc_strided,
(float)beta);
lda_strided, (float*)b, ldb,
(float*)c + offsetC, ldc_strided, beta);
} else {
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA,
lda_strided, (float*)b, ldb,
Expand All @@ -1128,9 +1257,15 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
}
break;
case FP8:
gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b,
ldb, (char*)c + offsetC, ldc_strided, &beta,
setup_ssr);
if (baseline) {
gemm_fp8_baseline(frac_m, n, k, (char*)a + offsetA, lda,
(char*)b, ldb, (char*)c + offsetC,
ldc_strided, &beta, setup_ssr);
} else {
gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda,
(char*)b, ldb, (char*)c + offsetC,
ldc_strided, &beta, setup_ssr);
}
break;
}
}
Expand Down Expand Up @@ -1292,4 +1427,4 @@ int gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
}

return 0;
}
}
Loading

0 comments on commit 2f532e9

Please sign in to comment.