Skip to content

Commit

Permalink
sw: add baseline opt as struct flag for GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Feb 7, 2024
1 parent 8af1e87 commit 7a9b0a0
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 166 deletions.
2 changes: 2 additions & 0 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def emit_header(**kwargs):
k_tiles = kwargs['k_tiles']
parallelize_m = kwargs['parallelize_m']
parallelize_k = kwargs['parallelize_k']
baseline = kwargs['baseline']

assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
Expand Down Expand Up @@ -120,6 +121,7 @@ def emit_header(**kwargs):
data_str += [format_scalar_definition('uint32_t', 'k_tiles', kwargs['k_tiles'])]
data_str += [format_scalar_definition('uint32_t', 'parallelize_m', kwargs['parallelize_m'])]
data_str += [format_scalar_definition('uint32_t', 'parallelize_k', kwargs['parallelize_k'])]
data_str += [format_scalar_definition('uint32_t', 'baseline', baseline)]
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
Expand Down
3 changes: 2 additions & 1 deletion sw/blas/gemm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
K: 4,
beta: 0,
ta: false,
tb: false, // must be true for SIMD
tb: true, // must be true for SIMD
prec: 64,
expand: 0,
m_tiles: 2 // number of tiles in M dimension
k_tiles: 2 // number of tiles in K dimension
n_tiles: 2 // number of tiles in N dimension
parallelize_k: 0
parallelize_m: 1
baseline: 0
}
124 changes: 62 additions & 62 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#include "snrt.h"

#define BASELINE 1
// Guard to avoid conflict with DNN header file
// TODO: move this definition to Snitch math library to solve problem
#ifndef PRECISION_T
Expand Down Expand Up @@ -332,10 +331,10 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
// "fmadd.d %[c5], ft0, ft1, %[c5] \n"
// "fmadd.d %[c6], ft0, ft1, %[c6] \n"
// "fmadd.d %[c7], ft0, ft1, %[c7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7])
: [ n_frep ] "r"(K - 1), [ unroll ] "i"(unroll)
: [c0] "+f"(c[0]), [c1] "+f"(c[1]), [c2] "+f"(c[2]),
[c3] "+f"(c[3]), [c4] "+f"(c[4]), [c5] "+f"(c[5]),
[c6] "+f"(c[6]), [c7] "+f"(c[7])
: [n_frep] "r"(K - 1), [unroll] "i"(unroll)
: "ft0", "ft1", "ft2");

// Store results back
Expand Down Expand Up @@ -481,19 +480,19 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
"vfcpka.s.s %[c1], %[reduce_reg2], %[reduce_reg3] \n"
"vfcpka.s.s %[c2], %[reduce_reg4], %[reduce_reg5] \n"
"vfcpka.s.s %[c3], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
[ reduce_reg3 ] "+f"(reduce_reg[3]),
[ reduce_reg4 ] "+f"(reduce_reg[4]),
[ reduce_reg5 ] "+f"(reduce_reg[5]),
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep - 1),
[ unroll ] "i"(unroll), [ BETA ] "r"(BETA)
: [c0] "+f"(c[0]), [c1] "+f"(c[1]), [c2] "+f"(c[2]),
[c3] "+f"(c[3]), [c4] "+f"(c[4]), [c5] "+f"(c[5]),
[c6] "+f"(c[6]), [c7] "+f"(c[7]),
[reduce_reg0] "+f"(reduce_reg[0]),
[reduce_reg1] "+f"(reduce_reg[1]),
[reduce_reg2] "+f"(reduce_reg[2]),
[reduce_reg3] "+f"(reduce_reg[3]),
[reduce_reg4] "+f"(reduce_reg[4]),
[reduce_reg5] "+f"(reduce_reg[5]),
[reduce_reg6] "+f"(reduce_reg[6]),
[reduce_reg7] "+f"(reduce_reg[7])
: [C] "r"(_C), [zero] "f"(zero), [n_frep] "r"(n_frep - 1),
[unroll] "i"(unroll), [BETA] "r"(BETA)
: "ft0", "ft1", "ft2");

// Store results
Expand Down Expand Up @@ -669,19 +668,19 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
"vfcpkb.h.s %[c0], %[c2], %[c3] \n"
"vfcpka.h.s %[c1], %[c4], %[c5] \n"
"vfcpkb.h.s %[c1], %[c6], %[c7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
[ reduce_reg3 ] "+f"(reduce_reg[3]),
[ reduce_reg4 ] "+f"(reduce_reg[4]),
[ reduce_reg5 ] "+f"(reduce_reg[5]),
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep),
[ BETA ] "r"(BETA)
: [c0] "+f"(c[0]), [c1] "+f"(c[1]), [c2] "+f"(c[2]),
[c3] "+f"(c[3]), [c4] "+f"(c[4]), [c5] "+f"(c[5]),
[c6] "+f"(c[6]), [c7] "+f"(c[7]), [beta] "=r"(beta),
[reduce_reg0] "+f"(reduce_reg[0]),
[reduce_reg1] "+f"(reduce_reg[1]),
[reduce_reg2] "+f"(reduce_reg[2]),
[reduce_reg3] "+f"(reduce_reg[3]),
[reduce_reg4] "+f"(reduce_reg[4]),
[reduce_reg5] "+f"(reduce_reg[5]),
[reduce_reg6] "+f"(reduce_reg[6]),
[reduce_reg7] "+f"(reduce_reg[7])
: [C] "r"(_C), [zero] "f"(zero), [n_frep] "r"(n_frep),
[BETA] "r"(BETA)
: "ft0", "ft1", "ft2");

// Store results back
Expand Down Expand Up @@ -833,19 +832,19 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
"vfcpkb.h.s %[c0], %[reduce_reg2], %[reduce_reg3] \n"
"vfcpka.h.s %[c1], %[reduce_reg4], %[reduce_reg5] \n"
"vfcpkb.h.s %[c1], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
[ reduce_reg3 ] "+f"(reduce_reg[3]),
[ reduce_reg4 ] "+f"(reduce_reg[4]),
[ reduce_reg5 ] "+f"(reduce_reg[5]),
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep),
[ unroll ] "i"(unroll), [ BETA ] "r"(BETA)
: [c0] "+f"(c[0]), [c1] "+f"(c[1]), [c2] "+f"(c[2]),
[c3] "+f"(c[3]), [c4] "+f"(c[4]), [c5] "+f"(c[5]),
[c6] "+f"(c[6]), [c7] "+f"(c[7]), [beta] "=r"(beta),
[reduce_reg0] "+f"(reduce_reg[0]),
[reduce_reg1] "+f"(reduce_reg[1]),
[reduce_reg2] "+f"(reduce_reg[2]),
[reduce_reg3] "+f"(reduce_reg[3]),
[reduce_reg4] "+f"(reduce_reg[4]),
[reduce_reg5] "+f"(reduce_reg[5]),
[reduce_reg6] "+f"(reduce_reg[6]),
[reduce_reg7] "+f"(reduce_reg[7])
: [C] "r"(_C), [zero] "f"(zero), [n_frep] "r"(n_frep),
[unroll] "i"(unroll), [BETA] "r"(BETA)
: "ft0", "ft1", "ft2");

// Store results back
Expand Down Expand Up @@ -1024,19 +1023,19 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
// "vfcpkb.b.s %[c0], %[reduce_reg2], %[reduce_reg3] \n"
// "vfcpkc.b.s %[c0], %[reduce_reg4], %[reduce_reg5] \n"
// "vfcpkd.b.s %[c0], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
[ reduce_reg3 ] "+f"(reduce_reg[3]),
[ reduce_reg4 ] "+f"(reduce_reg[4]),
[ reduce_reg5 ] "+f"(reduce_reg[5]),
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ n_frep ] "r"(n_frep), [ BETA ] "r"(BETA),
[ unroll ] "i"(unroll), [ zero ] "f"(zero)
: [c0] "+f"(c[0]), [c1] "+f"(c[1]), [c2] "+f"(c[2]),
[c3] "+f"(c[3]), [c4] "+f"(c[4]), [c5] "+f"(c[5]),
[c6] "+f"(c[6]), [c7] "+f"(c[7]), [beta] "=r"(beta),
[reduce_reg0] "+f"(reduce_reg[0]),
[reduce_reg1] "+f"(reduce_reg[1]),
[reduce_reg2] "+f"(reduce_reg[2]),
[reduce_reg3] "+f"(reduce_reg[3]),
[reduce_reg4] "+f"(reduce_reg[4]),
[reduce_reg5] "+f"(reduce_reg[5]),
[reduce_reg6] "+f"(reduce_reg[6]),
[reduce_reg7] "+f"(reduce_reg[7])
: [C] "r"(_C), [n_frep] "r"(n_frep), [BETA] "r"(BETA),
[unroll] "i"(unroll), [zero] "f"(zero)
: "ft0", "ft1", "ft2");

// Store results back
Expand Down Expand Up @@ -1071,7 +1070,8 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
uint32_t transa, uint32_t transb, uint32_t m, uint32_t n,
uint32_t k, double alpha, void* a, uint32_t lda, void* b,
uint32_t ldb, uint32_t beta, void* c, uint32_t ldc) {
uint32_t ldb, uint32_t beta, void* c, uint32_t ldc,
uint32_t baseline) {
if (snrt_is_compute_core()) {
const uint32_t compute_num = snrt_cluster_compute_core_num();
const uint32_t compute_id = snrt_cluster_core_idx();
Expand All @@ -1089,7 +1089,7 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,

switch (prec) {
case FP64:
if (BASELINE == 1) {
if (baseline == 1) {
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb,
transb, (double*)c + offsetC,
Expand All @@ -1102,7 +1102,7 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
}
break;
case FP32:
if (BASELINE == 1) {
if (baseline == 1) {
gemm_fp32_baseline(frac_m, n, k, (float*)a + offsetA,
lda_strided, transa, (float*)b, ldb,
transb, (float*)c + offsetC, ldc_strided,
Expand Down Expand Up @@ -1152,7 +1152,7 @@ int gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
uint32_t n_tiles, uint32_t k_tiles, uint32_t load_a, uint32_t load_b,
uint32_t load_c, uint32_t transa, uint32_t transb, uint32_t m,
uint32_t n, uint32_t k, double alpha, void* a, void* b, uint32_t beta,
void* c) {
void* c, uint32_t baseline) {
// Calculate tile sizes
uint32_t frac_m = m / m_tiles;
uint32_t frac_n = n / n_tiles;
Expand Down Expand Up @@ -1264,7 +1264,7 @@ int gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,

sc_st_gemm(prec, expand, setup_ssr, transa, transb, frac_m,
frac_n, frac_k, 1, local_a, lda, local_b, ldb,
beta_k, local_c_partial, ldc);
beta_k, local_c_partial, ldc, baseline);

uint32_t end_cycle = snrt_mcycle();
}
Expand Down
19 changes: 16 additions & 3 deletions sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
#include "snrt.h"

int main() {
int retcode =
gemm(dtype_size, expand, 1, parallelize_m, parallelize_k, m_tiles,
n_tiles, k_tiles, 1, 1, 1, TA, TB, M, N, K, 1, a, b, BETA, c);
int retcode = gemm(dtype_size, expand, 1, parallelize_m, parallelize_k,
m_tiles, n_tiles, k_tiles, 1, 1, 1, TA, TB, M, N, K, 1,
a, b, BETA, c, baseline);

snrt_cluster_hw_barrier();

Expand All @@ -26,6 +26,19 @@ int main() {
void *local_a, *local_b, *local_c;
void *remote_a, *remote_b, *remote_c;

// Calculate size and pointers for each cluster
uint32_t frac_m = M / snrt_cluster_num();
uint32_t frac_a = frac_m * K;
uint32_t frac_c = frac_m * N;
uint32_t size_frac_a = frac_a * dtype_size;
uint32_t size_b = K * N * dtype_size;
uint32_t size_frac_c = frac_c * dtype_size;
uint32_t offset_a = frac_a * snrt_cluster_idx();
uint32_t offset_c = frac_c * snrt_cluster_idx();
remote_a = a + offset_a;
remote_b = b;
remote_c = c + offset_c;

// Allocate space in TCDM
local_a = (void *)snrt_l1_next();
local_b = local_a + size_frac_a;
Expand Down
1 change: 1 addition & 0 deletions sw/dnn/flashattention_2/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
B_r: 16
B_c: 8
dtype: FP64
baseline: 0
}
Loading

0 comments on commit 7a9b0a0

Please sign in to comment.