Skip to content

Commit

Permalink
gemm_v2: Optimized FP64 baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Aug 23, 2024
1 parent 8bdc141 commit 46233f1
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sw/blas/gemm_v2/data/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
K: 48,
alpha: 1,
beta: 0,
gemm_fp: "gemm_fp64_opt"
gemm_fp: "gemm_fp64_baseline"
}
2 changes: 0 additions & 2 deletions sw/blas/gemm_v2/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def validate_config(self, gemm_fp,
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
'when using optimized kernels'
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta'
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \
' for FP64 (switch to NAIVE)'
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \
'Expanding GEMM kernels' \
' not supported for FP64 and FP32'
Expand Down
75 changes: 75 additions & 0 deletions sw/blas/gemm_v2/src/gemm_fp64.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,81 @@ void gemm_fp64_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p,
}
}

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, void* A_p,
uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB,
uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA,
uint32_t setup_SSR) {
double* A = (double*)A_p;
double* B = (double*)B_p;
double* C = (double*)C_p;

// Unrolling factors
// Note: changes must be reflected in the inline assembly code
// and datagen script
const uint32_t unroll1 = 4;
const uint32_t unroll0 = 4;

if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n += unroll1) {

double acc[4];
acc[0] = multiply_opt(C[m * ldC + n + 0], BETA);
acc[1] = multiply_opt(C[m * ldC + n + 1], BETA);
acc[2] = multiply_opt(C[m * ldC + n + 2], BETA);
acc[3] = multiply_opt(C[m * ldC + n + 3], BETA);

for (uint32_t k = 0; k < K; k += unroll0) {
asm volatile(
"fmadd.d %[acc0], %[A0], %[B0], %[acc0] \n"
"fmadd.d %[acc1], %[A0], %[B1], %[acc1] \n"
"fmadd.d %[acc2], %[A0], %[B2], %[acc2] \n"
"fmadd.d %[acc3], %[A0], %[B3], %[acc3] \n"
"fmadd.d %[acc0], %[A1], %[B4], %[acc0] \n"
"fmadd.d %[acc1], %[A1], %[B5], %[acc1] \n"
"fmadd.d %[acc2], %[A1], %[B6], %[acc2] \n"
"fmadd.d %[acc3], %[A1], %[B7], %[acc3] \n"
"fmadd.d %[acc0], %[A2], %[B8], %[acc0] \n"
"fmadd.d %[acc1], %[A2], %[B9], %[acc1] \n"
"fmadd.d %[acc2], %[A2], %[B10], %[acc2] \n"
"fmadd.d %[acc3], %[A2], %[B11], %[acc3] \n"
"fmadd.d %[acc0], %[A3], %[B12], %[acc0] \n"
"fmadd.d %[acc1], %[A3], %[B13], %[acc1] \n"
"fmadd.d %[acc2], %[A3], %[B14], %[acc2] \n"
"fmadd.d %[acc3], %[A3], %[B15], %[acc3] \n"
: [ acc0 ] "+f"(acc[0]), [ acc1 ] "+f"(acc[1]),
[ acc2 ] "+f"(acc[2]), [ acc3 ] "+f"(acc[3])
:
[ A0 ] "f"(A[m * ldA + k + 0]), [ A1 ] "f"(A[m * ldA + k + 1]),
[ A2 ] "f"(A[m * ldA + k + 2]), [ A3 ] "f"(A[m * ldA + k + 3]),
[ B0 ] "f"(B[(n + 0) * ldB + k]),
[ B1 ] "f"(B[(n + 1) * ldB + k]),
[ B2 ] "f"(B[(n + 2) * ldB + k]),
[ B3 ] "f"(B[(n + 3) * ldB + k]),
[ B4 ] "f"(B[(n + 0) * ldB + k + 1]),
[ B5 ] "f"(B[(n + 1) * ldB + k + 1]),
[ B6 ] "f"(B[(n + 2) * ldB + k + 1]),
[ B7 ] "f"(B[(n + 3) * ldB + k + 1]),
[ B8 ] "f"(B[(n + 0) * ldB + k + 2]),
[ B9 ] "f"(B[(n + 1) * ldB + k + 2]),
[ B10 ] "f"(B[(n + 2) * ldB + k + 2]),
[ B11 ] "f"(B[(n + 3) * ldB + k + 2]),
[ B12 ] "f"(B[(n + 0) * ldB + k + 3]),
[ B13 ] "f"(B[(n + 1) * ldB + k + 3]),
[ B14 ] "f"(B[(n + 2) * ldB + k + 3]),
[ B15 ] "f"(B[(n + 3) * ldB + k + 3])
:);
}

C[m * ldC + n + 0] = acc[0];
C[m * ldC + n + 1] = acc[1];
C[m * ldC + n + 2] = acc[2];
C[m * ldC + n + 3] = acc[3];
}
}
}
}

void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA,
uint32_t ta, void* B_p, uint32_t ldB, uint32_t tb, void* C_p,
uint32_t ldC, uint32_t BETA, uint32_t setup_SSR) {
Expand Down

0 comments on commit 46233f1

Please sign in to comment.