Skip to content

Commit

Permalink
gemm: add 2-dim tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Nov 12, 2023
1 parent 33122c0 commit a00c227
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 101 deletions.
139 changes: 79 additions & 60 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ dump_uint(index, 9);

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
uint32_t tb, double* C, uint32_t ldC, double* BETA) {
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
double c0 = BETA * C[m * ldC + n];
double c0 = *BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
// dump_index(k + m * ldA);
// dump_gemm(A[k + m * ldA]);
c0 += A[k + m * ldA] * B[k * ldB + n];
}
C[m * ldC + n] = c0;
Expand All @@ -42,7 +40,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
register double c0 = *BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n];
}
Expand All @@ -52,17 +50,20 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
register double c0 = *BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
// dump_index(k + m * ldA);
// dump_gemm(A[k + m * ldA]);
c0 += A[k + m * ldA] * B[k + n * ldB];
}
C[m * ldC + n] = c0;
// dump_gemm(C[m * ldC + n]);
}
}
} else {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
register double c0 = *BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB];
}
Expand Down Expand Up @@ -180,8 +181,21 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
// for maximum utilization
const uint32_t unroll = 8;

// const uint32_t unroll = 8;
const uint32_t unroll = 4;

// A is of size MxK, B is of size KxN, C is of size MxN
// for (uint32_t m = 0; m < M; m++) {
// for (uint32_t n = 0; n < N / unroll; n++) {
// double c0 = BETA * C[m * ldC + n];
// for (uint32_t k = 0; k < K; k++) {
// for (uint32_t j = 0; j < unroll; j++) {
// c0 += A[k + m * ldA] * B[k + (n + j) * ldB];
// }
// }
// C[m * ldC + n] = c0;
// }
// }
// SSR strides and bounds only have to be configured
// once in the beginning
if (setup_SSR) {
Expand All @@ -197,6 +211,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M};
const uint32_t ssr0_i[4] = {0, 8, 0, 8 * ldA};

// A[k + unroll * m * ldA]
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
Expand All @@ -214,6 +229,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M};
const uint32_t ssr1_i[4] = {8, 8 * ldB, 8 * unroll, 0};

// B[k + unroll * n * ldB]
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
ssr1_i[3]);
Expand All @@ -236,31 +252,30 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
c[1] = C[m * ldC + n + 1];
c[2] = C[m * ldC + n + 2];
c[3] = C[m * ldC + n + 3];
c[4] = C[m * ldC + n + 4];
c[5] = C[m * ldC + n + 5];
c[6] = C[m * ldC + n + 6];
c[7] = C[m * ldC + n + 7];
// c[4] = C[m * ldC + n + 4];
// c[5] = C[m * ldC + n + 5];
// c[6] = C[m * ldC + n + 6];
// c[7] = C[m * ldC + n + 7];
} else {
c[0] = 0.0;
c[1] = 0.0;
c[2] = 0.0;
c[3] = 0.0;
c[4] = 0.0;
c[5] = 0.0;
c[6] = 0.0;
c[7] = 0.0;
// c[4] = 0.0;
// c[5] = 0.0;
// c[6] = 0.0;
// c[7] = 0.0;
}

asm volatile(
"frep.o %[n_frep], 8, 0, 0 \n"
"frep.o %[n_frep], 4, 0, 0 \n"
"fmadd.d %[c0], ft0, ft1, %[c0] \n"
"fmadd.d %[c1], ft0, ft1, %[c1] \n"
"fmadd.d %[c2], ft0, ft1, %[c2] \n"
"fmadd.d %[c3], ft0, ft1, %[c3] \n"
"fmadd.d %[c4], ft0, ft1, %[c4] \n"
"fmadd.d %[c5], ft0, ft1, %[c5] \n"
"fmadd.d %[c6], ft0, ft1, %[c6] \n"
"fmadd.d %[c7], ft0, ft1, %[c7] \n"
// "fmadd.d %[c4], ft0, ft1, %[c4] \n"
// "fmadd.d %[c5], ft0, ft1, %[c5] \n"
// "fmadd.d %[c6], ft0, ft1, %[c6] \n"
// "fmadd.d %[c7], ft0, ft1, %[c7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7])
Expand All @@ -272,10 +287,10 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
C[m * ldC + n + 1] = c[1];
C[m * ldC + n + 2] = c[2];
C[m * ldC + n + 3] = c[3];
C[m * ldC + n + 4] = c[4];
C[m * ldC + n + 5] = c[5];
C[m * ldC + n + 6] = c[6];
C[m * ldC + n + 7] = c[7];
// C[m * ldC + n + 4] = c[4];
// C[m * ldC + n + 5] = c[5];
// C[m * ldC + n + 6] = c[6];
// C[m * ldC + n + 7] = c[7];
n += unroll;
}

Expand Down Expand Up @@ -308,7 +323,8 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
// for maximum utilization
const uint32_t unroll = 8;
// const uint32_t unroll = 8;
const uint32_t unroll = 4;

// SSR strides and bounds only have to be configured
// once in the beginning
Expand Down Expand Up @@ -351,65 +367,65 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
"flw %[reduce_reg1], 4(%[C]) \n"
"flw %[reduce_reg2], 8(%[C]) \n"
"flw %[reduce_reg3], 12(%[C]) \n"
"flw %[reduce_reg4], 16(%[C]) \n"
"flw %[reduce_reg5], 20(%[C]) \n"
"flw %[reduce_reg6], 24(%[C]) \n"
"flw %[reduce_reg7], 28(%[C]) \n"
// "flw %[reduce_reg4], 16(%[C]) \n"
// "flw %[reduce_reg5], 20(%[C]) \n"
// "flw %[reduce_reg6], 24(%[C]) \n"
// "flw %[reduce_reg7], 28(%[C]) \n"
// Pack intermediate results into SIMD vector
"vfcpka.s.s %[reduce_reg0], %[reduce_reg0], %[zero]\n"
"vfcpka.s.s %[reduce_reg1], %[reduce_reg1], %[zero]\n"
"vfcpka.s.s %[reduce_reg2], %[reduce_reg2], %[zero]\n"
"vfcpka.s.s %[reduce_reg3], %[reduce_reg3], %[zero]\n"
"vfcpka.s.s %[reduce_reg4], %[reduce_reg4], %[zero]\n"
"vfcpka.s.s %[reduce_reg5], %[reduce_reg5], %[zero]\n"
"vfcpka.s.s %[reduce_reg6], %[reduce_reg6], %[zero]\n"
"vfcpka.s.s %[reduce_reg7], %[reduce_reg7], %[zero]\n"
// "vfcpka.s.s %[reduce_reg4], %[reduce_reg4], %[zero]\n"
// "vfcpka.s.s %[reduce_reg5], %[reduce_reg5], %[zero]\n"
// "vfcpka.s.s %[reduce_reg6], %[reduce_reg6], %[zero]\n"
// "vfcpka.s.s %[reduce_reg7], %[reduce_reg7], %[zero]\n"
"j 2f \n"
"1: \n"
// Initialize SIMD vector with zeros
"vfcpka.s.s %[reduce_reg0], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg1], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg2], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg3], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg4], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg5], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg6], %[zero], %[zero]\n"
"vfcpka.s.s %[reduce_reg7], %[zero], %[zero]\n"
// "vfcpka.s.s %[reduce_reg4], %[zero], %[zero]\n"
// "vfcpka.s.s %[reduce_reg5], %[zero], %[zero]\n"
// "vfcpka.s.s %[reduce_reg6], %[zero], %[zero]\n"
// "vfcpka.s.s %[reduce_reg7], %[zero], %[zero]\n"

"2: \n"
// Don't accumulate in first iteration
"vfmul.s %[c0], ft1, ft0 \n"
"vfmul.s %[c1], ft1, ft0 \n"
"vfmul.s %[c2], ft1, ft0 \n"
"vfmul.s %[c3], ft1, ft0 \n"
"vfmul.s %[c4], ft1, ft0 \n"
"vfmul.s %[c5], ft1, ft0 \n"
"vfmul.s %[c6], ft1, ft0 \n"
"vfmul.s %[c7], ft1, ft0 \n"
// "vfmul.s %[c4], ft1, ft0 \n"
// "vfmul.s %[c5], ft1, ft0 \n"
// "vfmul.s %[c6], ft1, ft0 \n"
// "vfmul.s %[c7], ft1, ft0 \n"
// frep over MACs
"frep.o %[n_frep], 8, 0, 0 \n"
"frep.o %[n_frep], 4, 0, 0 \n"
"vfmac.s %[c0], ft1, ft0 \n"
"vfmac.s %[c1], ft1, ft0 \n"
"vfmac.s %[c2], ft1, ft0 \n"
"vfmac.s %[c3], ft1, ft0 \n"
"vfmac.s %[c4], ft1, ft0 \n"
"vfmac.s %[c5], ft1, ft0 \n"
"vfmac.s %[c6], ft1, ft0 \n"
"vfmac.s %[c7], ft1, ft0 \n"
// "vfmac.s %[c4], ft1, ft0 \n"
// "vfmac.s %[c5], ft1, ft0 \n"
// "vfmac.s %[c6], ft1, ft0 \n"
// "vfmac.s %[c7], ft1, ft0 \n"
// Sum-reduce vector
"vfsum.s %[reduce_reg0], %[c0] \n"
"vfsum.s %[reduce_reg1], %[c1] \n"
"vfsum.s %[reduce_reg2], %[c2] \n"
"vfsum.s %[reduce_reg3], %[c3] \n"
"vfsum.s %[reduce_reg4], %[c4] \n"
"vfsum.s %[reduce_reg5], %[c5] \n"
"vfsum.s %[reduce_reg6], %[c6] \n"
"vfsum.s %[reduce_reg7], %[c7] \n"
// "vfsum.s %[reduce_reg4], %[c4] \n"
// "vfsum.s %[reduce_reg5], %[c5] \n"
// "vfsum.s %[reduce_reg6], %[c6] \n"
// "vfsum.s %[reduce_reg7], %[c7] \n"
// Pack results together again into vectors
"vfcpka.s.s %[c0], %[reduce_reg0], %[reduce_reg1] \n"
"vfcpka.s.s %[c1], %[reduce_reg2], %[reduce_reg3] \n"
"vfcpka.s.s %[c2], %[reduce_reg4], %[reduce_reg5] \n"
"vfcpka.s.s %[c3], %[reduce_reg6], %[reduce_reg7] \n"
// "vfcpka.s.s %[c2], %[reduce_reg4], %[reduce_reg5] \n"
// "vfcpka.s.s %[c3], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]),
Expand All @@ -428,8 +444,8 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
// Store results
((v2f32*)_C)[0] = c[0];
((v2f32*)_C)[1] = c[1];
((v2f32*)_C)[2] = c[2];
((v2f32*)_C)[3] = c[3];
// ((v2f32*)_C)[2] = c[2];
// ((v2f32*)_C)[3] = c[3];

// progress by 2 columns each iteration of the loop
n += unroll * 2;
Expand Down Expand Up @@ -1015,9 +1031,12 @@ void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,

switch (prec) {
case FP64:
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb, (double*)c + offsetC,
ldc_strided, &beta, setup_ssr);
// gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
// transa, (double*)b, ldb, transb, (double*)c + offsetC,
// ldc_strided, &beta, setup_ssr);
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb,
(double*)c + offsetC, ldc_strided, (double*)&beta);
break;
case FP32:
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
Expand Down
Loading

0 comments on commit a00c227

Please sign in to comment.