Skip to content

Commit

Permalink
gemm: Change alpha parameter name to beta
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Sep 28, 2023
1 parent 5d37ea7 commit eee8ee5
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 53 deletions.
10 changes: 5 additions & 5 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
}


def golden_model(a, b, alpha, c):
return np.matmul(a, b) + alpha * c
def golden_model(alpha, a, b, beta, c):
return alpha * np.matmul(a, b) + beta * c


def emit_header(**kwargs):
Expand Down Expand Up @@ -73,15 +73,15 @@ def emit_header(**kwargs):
* (1.0 + mantissa_b.astype(np.double) / (2**2))
_c = ((-1.0)**sign_c.astype(np.double))*(2.0**(exponent_c.astype(np.double)-15.0)) \
* (1.0 + mantissa_c.astype(np.double) / (2**2))
result = np.matmul(_a, _b) + kwargs['alpha'] * _c
result = golden_model(1, _a, _b, kwargs['beta'], _c)
a = sign_a << 7 | exponent_a << FP8_FORMATS['fp8']['mant'] | mantissa_a
b = sign_b << 7 | exponent_b << FP8_FORMATS['fp8']['mant'] | mantissa_b
c = sign_c << 7 | exponent_c << FP8_FORMATS['fp8']['mant'] | mantissa_c
else:
a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype)
b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype)
c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype)
result = golden_model(a, b, kwargs['alpha'], c)
result = golden_model(1, a, b, kwargs['beta'], c)

# Store matrices in transposed form if requested
a = a.T if kwargs['ta'] else a
Expand All @@ -93,7 +93,7 @@ def emit_header(**kwargs):
data_str += [format_scalar_definition('uint32_t', 'K', kwargs['K'])]
data_str += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))]
data_str += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))]
data_str += [format_scalar_definition('uint32_t', 'ALPHA', kwargs['alpha'])]
data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])]
data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten())]
Expand Down
2 changes: 1 addition & 1 deletion sw/blas/gemm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
M: 192,
N: 16,
K: 16,
alpha: 0,
beta: 0,
ta: false,
tb: true, // must be true for SIMD
prec: 64,
Expand Down
84 changes: 42 additions & 42 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ typedef char v8f8 __attribute__((vector_size(8)));

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double ALPHA) {
uint32_t tb, double* C, uint32_t ldC, double BETA) {
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = ALPHA * C[m * ldC + n];
register double c0 = BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k + m * ldA] * B[k * ldB + n];
}
Expand All @@ -37,7 +37,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = ALPHA * C[m * ldC + n];
register double c0 = BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n];
}
Expand All @@ -47,7 +47,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = ALPHA * C[m * ldC + n];
register double c0 = BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k + m * ldA] * B[k + n * ldB];
}
Expand All @@ -57,7 +57,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = ALPHA * C[m * ldC + n];
register double c0 = BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB];
}
Expand All @@ -69,7 +69,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,

void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
uint32_t ta, double* B, uint32_t ldB, uint32_t tb, double* C,
uint32_t ldC, const uint32_t* ALPHA, uint32_t setup_SSR) {
uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) {
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
// for maximum utilization
Expand Down Expand Up @@ -124,7 +124,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,
double c[unroll];

// Load intermediate result
if (*ALPHA) {
if (*BETA) {
c[0] = C[m * ldC + n + 0];
c[1] = C[m * ldC + n + 1];
c[2] = C[m * ldC + n + 2];
Expand Down Expand Up @@ -177,7 +177,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,

for (; n < N; n++) {
double c;
if (*ALPHA) {
if (*BETA) {
c = C[m * ldC + n];
} else {
c = 0.0;
Expand All @@ -196,7 +196,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA,

void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
float* A, const uint32_t ldA, float* B, const uint32_t ldB,
float* C, const uint32_t ldC, const uint32_t* ALPHA,
float* C, const uint32_t ldC, const uint32_t* BETA,
const uint32_t setup_SSR) {
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
Expand Down Expand Up @@ -237,7 +237,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
v2f32 c[unroll], reduce_reg[unroll];

asm volatile(
"lw t0, 0(%[ALPHA]) \n"
"lw t0, 0(%[BETA]) \n"
"beqz t0, 1f \n"
// Load intermediate results
"flw %[reduce_reg0], 0(%[C]) \n"
Expand Down Expand Up @@ -315,7 +315,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep - 1),
[ ALPHA ] "r"(ALPHA)
[ BETA ] "r"(BETA)
: "ft0", "ft1", "ft2");

// Store results
Expand All @@ -332,7 +332,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,
snrt_ssr_disable();

for (; n < N; n++) {
float c = (*ALPHA) ? C[m * ldC + n] : 0.0;
float c = (*BETA) ? C[m * ldC + n] : 0.0;
for (uint32_t k = 0; k < K; k++) {
c += A[k + m * ldA] * B[k + n * ldB];
}
Expand All @@ -347,7 +347,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K,

void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
__fp16* B, uint32_t ldB, __fp16* C, uint32_t ldC,
const uint32_t* ALPHA, uint32_t setup_SSR) {
const uint32_t* BETA, uint32_t setup_SSR) {
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
// for maximum utilization
Expand Down Expand Up @@ -386,11 +386,11 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
const register float zero = 0.0;
v4f16 c[unroll];
v2f32 reduce_reg[unroll];
uint32_t alpha;
uint32_t beta;

asm volatile(
"lw %[alpha], 0(%[ALPHA]) \n"
"beqz %[alpha], 1f \n"
"lw %[beta], 0(%[BETA]) \n"
"beqz %[beta], 1f \n"
// Load intermediate results
"flh %[reduce_reg0], 0(%[C]) \n"
"flh %[reduce_reg1], 2(%[C]) \n"
Expand Down Expand Up @@ -493,7 +493,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
"vfcpkb.h.s %[c1], %[c6], %[c7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
Expand All @@ -503,7 +503,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep),
[ ALPHA ] "r"(ALPHA)
[ BETA ] "r"(BETA)
: "ft0", "ft1", "ft2");

// Store results back
Expand All @@ -516,7 +516,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
// snrt_ssr_disable();

// for (; n < N; n++) {
// __fp16 c = (*ALPHA) ? C[m * ldC + n] : 0.0;
// __fp16 c = (*BETA) ? C[m * ldC + n] : 0.0;
// for (uint32_t k = 0; k < K; k++) {
// c += A[k + m * ldA] * B[k + n * ldB];
// }
Expand All @@ -531,7 +531,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,

void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
uint32_t ldA, __fp16* B, uint32_t ldB, __fp16* C,
uint32_t ldC, const uint32_t* ALPHA, uint32_t setup_SSR) {
uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) {
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
// for maximum utilization
Expand Down Expand Up @@ -570,11 +570,11 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
const register float zero = 0.0;
v4f16 c[unroll];
v2f32 reduce_reg[unroll];
uint32_t alpha;
uint32_t beta;

asm volatile(
"lw %[alpha], 0(%[ALPHA]) \n"
"beqz %[alpha], 1f \n"
"lw %[beta], 0(%[BETA]) \n"
"beqz %[beta], 1f \n"
"flh %[reduce_reg0], 0(%[C]) \n"
"flh %[reduce_reg1], 2(%[C]) \n"
"flh %[reduce_reg2], 4(%[C]) \n"
Expand Down Expand Up @@ -657,7 +657,7 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
"vfcpkb.h.s %[c1], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
Expand All @@ -667,7 +667,7 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep),
[ ALPHA ] "r"(ALPHA)
[ BETA ] "r"(BETA)
: "ft0", "ft1", "ft2");

// Store results back
Expand All @@ -680,7 +680,7 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
// snrt_ssr_disable();

// for (; n < N; n++) {
// __fp16 c = (*ALPHA) ? C[m * ldC + n] : 0.0;
// __fp16 c = (*BETA) ? C[m * ldC + n] : 0.0;
// for (uint32_t k = 0; k < K; k++) {
// c += A[k + m * ldA] * B[k + n * ldB];
// }
Expand All @@ -695,9 +695,9 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,

void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
char* B, uint32_t ldB, char* C, uint32_t ldC,
const uint32_t* ALPHA, uint32_t setup_SSR) {
const uint32_t* BETA, uint32_t setup_SSR) {
// Accumulating currently not implemented
if (*ALPHA != 0) return;
if (*BETA != 0) return;

// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
Expand Down Expand Up @@ -737,11 +737,11 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
const register float zero = 0.0;
v8f8 c[unroll];
v4f16 reduce_reg[unroll];
uint32_t alpha;
uint32_t beta;

asm volatile(
"lw %[alpha], 0(%[ALPHA]) \n"
"beqz %[alpha], 1f \n"
"lw %[beta], 0(%[BETA]) \n"
"beqz %[beta], 1f \n"
"flb %[reduce_reg0], 0(%[C]) \n"
"flb %[reduce_reg1], 1(%[C]) \n"
"flb %[reduce_reg2], 2(%[C]) \n"
Expand Down Expand Up @@ -848,7 +848,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
// "vfcpkd.b.s %[c0], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta),
[ reduce_reg0 ] "+f"(reduce_reg[0]),
[ reduce_reg1 ] "+f"(reduce_reg[1]),
[ reduce_reg2 ] "+f"(reduce_reg[2]),
Expand All @@ -857,7 +857,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
[ reduce_reg5 ] "+f"(reduce_reg[5]),
[ reduce_reg6 ] "+f"(reduce_reg[6]),
[ reduce_reg7 ] "+f"(reduce_reg[7])
: [ C ] "r"(_C), [ n_frep ] "r"(n_frep), [ ALPHA ] "r"(ALPHA),
: [ C ] "r"(_C), [ n_frep ] "r"(n_frep), [ BETA ] "r"(BETA),
[ zero ] "f"(zero)
: "ft0", "ft1", "ft2");

Expand All @@ -870,7 +870,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
// snrt_ssr_disable();

// for (; n < N; n++) {
// char c = (*ALPHA) ? C[m * ldC + n] : 0.0;
// char c = (*BETA) ? C[m * ldC + n] : 0.0;
// for (uint32_t k = 0; k < K; k++) {
// c += A[k + m * ldA] * B[k + n * ldB];
// }
Expand All @@ -886,12 +886,12 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
// BLAS compliant GEMM kernel, with some additional arguments at the beginning
// to specify Snitch implementation details. Matrix sizes and pointers are for
// the whole cluster computation
// TODO: alpha (and beta) should be of floating-point type (same precision as
// TODO: beta (and alpha) should be of floating-point type (same precision as
// operands)
void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
uint32_t transa, uint32_t transb, uint32_t m, uint32_t n, uint32_t k,
uint32_t alpha, void* a, uint32_t lda, void* b, uint32_t ldb,
double beta, void* c, uint32_t ldc) {
double alpha, void* a, uint32_t lda, void* b, uint32_t ldb,
uint32_t beta, void* c, uint32_t ldc) {
const uint32_t compute_num = snrt_cluster_compute_core_num();
const uint32_t compute_id = snrt_cluster_core_idx();

Expand All @@ -910,27 +910,27 @@ void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
case FP64:
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb, (double*)c + offsetC,
ldc_strided, &alpha, setup_ssr);
ldc_strided, &beta, setup_ssr);
break;
case FP32:
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
(float*)b, ldb, (float*)c + offsetC, ldc_strided,
&alpha, setup_ssr);
&beta, setup_ssr);
break;
case FP16:
if (expand) {
gemm_fp16_ex_opt(
frac_m, n, k, (__fp16*)a + offsetA, lda_strided, (__fp16*)b,
ldb, (__fp16*)c + offsetC, ldc_strided, &alpha, setup_ssr);
ldb, (__fp16*)c + offsetC, ldc_strided, &beta, setup_ssr);
} else {
gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA, lda_strided,
(__fp16*)b, ldb, (__fp16*)c + offsetC,
ldc_strided, &alpha, setup_ssr);
ldc_strided, &beta, setup_ssr);
}
break;
case FP8:
gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b,
ldb, (char*)c + offsetC, ldc_strided, &alpha,
ldb, (char*)c + offsetC, ldc_strided, &beta,
setup_ssr);
break;
}
Expand Down
4 changes: 2 additions & 2 deletions sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ int main() {
ldb = K;
}

gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, N, K, ALPHA,
local_a, lda, local_b, ldb, 1, local_c, ldc);
gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, N, K, 1, local_a,
lda, local_b, ldb, BETA, local_c, ldc);

uint32_t end_cycle = snrt_mcycle();
}
Expand Down
4 changes: 2 additions & 2 deletions sw/blas/gemm/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main():
a = np.array(bytes_to_doubles(elf.get_symbol_contents('a')))
b = np.array(bytes_to_doubles(elf.get_symbol_contents('b')))
c = np.array(bytes_to_doubles(elf.get_symbol_contents('c')))
alpha = bytes_to_uint32s(elf.get_symbol_contents('ALPHA'))[0]
beta = bytes_to_uint32s(elf.get_symbol_contents('BETA'))[0]
m = bytes_to_uint32s(elf.get_symbol_contents('M'))[0]
n = bytes_to_uint32s(elf.get_symbol_contents('N'))[0]
k = bytes_to_uint32s(elf.get_symbol_contents('K'))[0]
Expand All @@ -49,7 +49,7 @@ def main():
c = np.reshape(c, (m, n))

# Verify results
c_golden = golden_model(a, b, alpha, c).flatten()
c_golden = golden_model(1, a, b, beta, c).flatten()

absolute_err = np.absolute(c_golden - c_actual)
fail = np.any(absolute_err > ERR_THRESHOLD)
Expand Down
2 changes: 1 addition & 1 deletion util/sim/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def emit_license():
s = (f"// Copyright {datetime.now().year} ETH Zurich and University of Bologna."
s = (f"// Copyright {datetime.now().year} ETH Zurich and University of Bologna.\n"
f"// Licensed under the Apache License, Version 2.0, see LICENSE for details.\n"
f"// SPDX-License-Identifier: Apache-2.0\n\n")
return s
Expand Down

0 comments on commit eee8ee5

Please sign in to comment.