Skip to content

Commit

Permalink
sw: Update AXPY and GEMM for (trivial) multi-cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Sep 11, 2023
1 parent 85e9214 commit 9e7a4a9
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 179 deletions.
31 changes: 23 additions & 8 deletions sw/blas/axpy/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,58 @@

int main() {
double *local_x, *local_y, *local_z;
double *remote_x, *remote_y, *remote_z;

// Calculate size and pointers for each cluster
uint32_t frac = l / snrt_cluster_num();
uint32_t offset = frac * snrt_cluster_idx();
remote_x = x + offset;
remote_y = y + offset;
remote_z = z + offset;

// Allocate space in TCDM
local_x = (double *)snrt_l1_next();
local_y = local_x + l;
local_z = local_y + l;
local_y = local_x + frac;
local_z = local_y + frac;

// Copy data in TCDM
if (snrt_is_dm_core()) {
size_t size = l * sizeof(double);
snrt_dma_start_1d(local_x, x, size);
snrt_dma_start_1d(local_y, y, size);
size_t size = frac * sizeof(double);
snrt_dma_start_1d(local_x, remote_x, size);
snrt_dma_start_1d(local_y, remote_y, size);
snrt_dma_wait_all();
}

snrt_cluster_hw_barrier();

// Compute
if (!snrt_is_dm_core()) {
uint32_t start_cycle = snrt_mcycle();
axpy(l, a, local_x, local_y, local_z);
axpy(frac, a, local_x, local_y, local_z);
uint32_t end_cycle = snrt_mcycle();
}

snrt_cluster_hw_barrier();

// Copy data out of TCDM
if (snrt_is_dm_core()) {
size_t size = l * sizeof(double);
snrt_dma_start_1d(z, local_z, size);
size_t size = frac * sizeof(double);
snrt_dma_start_1d(remote_z, local_z, size);
snrt_dma_wait_all();
}

snrt_cluster_hw_barrier();

// TODO: currently only works for single cluster otherwise need to
// synchronize all cores here
#ifdef BIST
uint32_t nerr = l;

// Check computation is correct
if (snrt_global_core_idx() == 0) {
for (int i = 0; i < l; i++) {
if (local_z[i] == g[i]) nerr--;
printf("%d %d\n", local_z[i], g[i]);
}
}

Expand Down
80 changes: 29 additions & 51 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
import argparse
import pathlib
import hjson
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import emit_license, format_scalar_definition, \
format_vector_definition, format_ifdef_wrapper # noqa: E402


np.random.seed(42)

Expand All @@ -33,41 +40,11 @@
}


def format_vector_definition(id, vector, typ):
s = f'{typ} {id}[{len(vector)}] = ' + '{\n'
for i, el in enumerate(vector):
if typ != 'char':
s += f'\t{el},'
else:
if type(el) == float:
print(el)
s += f'0x{el:02x},'
if i % 8 == 7:
s += '\n'
s += '};'
return s


def format_vector_declaration(id, vector, typ):
s = f'{typ} {id}[{len(vector)}];'
return s


def format_scalar_definition(id, scalar, typ):
s = f'{typ} {id} = {scalar};'
return s


def emit_header_file(**kwargs):

emit_str = "// Copyright 2023 ETH Zurich and University of Bologna.\n" + \
"// Licensed under the Apache License, Version 2.0, see LICENSE for details.\n" + \
"// SPDX-License-Identifier: Apache-2.0\n\n"
emit_str += emit_gemm_data(**kwargs)
return emit_str
def golden_model(a, b, alpha, c):
return np.matmul(a, b) + alpha * c


def emit_gemm_data(**kwargs):
def emit_header(**kwargs):

# Generate random input matrices
dtype = NUMPY_TYPES[str(kwargs['prec'])]
Expand Down Expand Up @@ -104,30 +81,31 @@ def emit_gemm_data(**kwargs):
a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype)
b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype)
c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype)
result = np.matmul(a, b) + kwargs['alpha'] * c
result = golden_model(a, b, kwargs['alpha'], c)

# Store matrices in transposed form if requested
a = a.T if kwargs['ta'] else a
b = b.T if kwargs['tb'] else b

data_str = []
data_str += [format_scalar_definition('M', kwargs['M'], 'uint32_t')]
data_str += [format_scalar_definition('N', kwargs['N'], 'uint32_t')]
data_str += [format_scalar_definition('K', kwargs['K'], 'uint32_t')]
data_str += [format_scalar_definition('TA', int(kwargs['ta']), 'uint32_t')]
data_str += [format_scalar_definition('TB', int(kwargs['tb']), 'uint32_t')]
data_str += [format_scalar_definition('ALPHA', kwargs['alpha'], 'uint32_t')]
data_str += [format_scalar_definition('dtype_size', kwargs['prec']//8, 'uint32_t')]
data_str += [format_scalar_definition('expand', kwargs['expand'], 'uint32_t')]
data_str += [format_vector_definition('a', a.flatten(), C_TYPES[str(kwargs['prec'])])]
data_str += [format_vector_definition('b', b.flatten(), C_TYPES[str(kwargs['prec'])])]
data_str += [format_vector_definition('c', c.flatten(), C_TYPES[str(kwargs['prec'])])]
data_str = [emit_license()]
data_str += [format_scalar_definition('uint32_t', 'M', kwargs['M'])]
data_str += [format_scalar_definition('uint32_t', 'N', kwargs['N'])]
data_str += [format_scalar_definition('uint32_t', 'K', kwargs['K'])]
data_str += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))]
data_str += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))]
data_str += [format_scalar_definition('uint32_t', 'ALPHA', kwargs['alpha'])]
data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten())]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten())]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten())]
if kwargs['prec'] == 8:
data_str += [format_vector_definition('result', result.flatten(), C_TYPES['64'])]
result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten())
else:
data_str += [format_vector_definition('result',
result.flatten(),
C_TYPES[str(kwargs['prec'])])]
result_def = format_vector_definition(C_TYPES[str(kwargs['prec'])],
'result',
result.flatten())
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

return data_str
Expand All @@ -149,7 +127,7 @@ def main():
param = hjson.loads(f.read())

# Emit header file
print(emit_header_file(**param))
print(emit_header(**param))


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions sw/blas/gemm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
// Parameters for a GEMM

{
M: 16,
M: 192,
N: 16,
K: 16,
alpha: 0,
ta: false,
tb: true, // must be true for SIMD
prec: 32,
prec: 64,
expand: 0
}
61 changes: 61 additions & 0 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
//
// Author: Tim Fischer <[email protected]>
// Luca Bertaccini <[email protected]>
// Luca Colagrande <[email protected]>

#include <stdint.h>

#include "snrt.h"

// Guard to avoid conflict with DNN header file
// TODO: move this definition to Snitch math library to solve problem
#ifndef PRECISION_T
#define PRECISION_T
typedef enum { FP64 = 8, FP32 = 4, FP16 = 2, FP8 = 1 } precision_t;

typedef float v2f32 __attribute__((vector_size(8)));
typedef __fp16 v4f16 __attribute__((vector_size(8)));
typedef char v8f8 __attribute__((vector_size(8)));
#endif

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
Expand Down Expand Up @@ -874,3 +882,56 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,

snrt_ssr_disable();
}

// BLAS compliant GEMM kernel, with some additional arguments at the beginning
// to specify Snitch implementation details. Matrix sizes and pointers are for
// the whole cluster computation
// TODO: alpha (and beta) should be of floating-point type (same precision as
// operands)
void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
uint32_t transa, uint32_t transb, uint32_t m, uint32_t n, uint32_t k,
uint32_t alpha, void* a, uint32_t lda, void* b, uint32_t ldb,
double beta, void* c, uint32_t ldc) {
const uint32_t compute_num = snrt_cluster_compute_core_num();
const uint32_t compute_id = snrt_cluster_core_idx();

// Compute cores work not on contiguous blocks but on strided rows
uint32_t lda_strided = compute_num * lda;
uint32_t ldc_strided = compute_num * ldc;

// Compute cores access A and C at offsets of one row from each other
uint32_t offsetA = compute_id * lda;
uint32_t offsetC = compute_id * ldc;

// Compute fraction of C rows every core computes
uint32_t frac_m = m / compute_num;

switch (prec) {
case FP64:
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb, (double*)c + offsetC,
ldc_strided, &alpha, setup_ssr);
break;
case FP32:
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
(float*)b, ldb, (float*)c + offsetC, ldc_strided,
&alpha, setup_ssr);
break;
case FP16:
if (expand) {
gemm_fp16_ex_opt(
frac_m, n, k, (__fp16*)a + offsetA, lda_strided, (__fp16*)b,
ldb, (__fp16*)c + offsetC, ldc_strided, &alpha, setup_ssr);
} else {
gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA, lda_strided,
(__fp16*)b, ldb, (__fp16*)c + offsetC,
ldc_strided, &alpha, setup_ssr);
}
break;
case FP8:
gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b,
ldb, (char*)c + offsetC, ldc_strided, &alpha,
setup_ssr);
break;
}
}
Loading

0 comments on commit 9e7a4a9

Please sign in to comment.