Skip to content

Commit

Permalink
gemm: 3d-tiling for gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Nov 13, 2023
1 parent a00c227 commit 050b903
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 71 deletions.
42 changes: 27 additions & 15 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,34 @@ def emit_header(**kwargs):

# Generate random input matrices
dtype = NUMPY_TYPES[str(kwargs['prec'])]
M, N, K = kwargs['M'], kwargs['N'], kwargs['K']
m_tiles = kwargs['m_tiles']
n_tiles = kwargs['n_tiles']
k_tiles = kwargs['k_tiles']

assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size'

if (kwargs['prec']) == 8:
# sign -1 or 1
sign_a = np.random.randint(0, 2, (kwargs['M'], kwargs['K'])).astype(dtype)
sign_a = np.random.randint(0, 2, (M, K)).astype(dtype)
# esponent < 0b01111
exponent_a = np.random.randint(0, 16, (kwargs['M'], kwargs['K'])).astype(dtype)
exponent_a = np.random.randint(0, 16, (M, K)).astype(dtype)
# mantissa can be arbitrary
mantissa_a = np.random.randint(0, 4, (kwargs['M'], kwargs['K'])).astype(dtype)
mantissa_a = np.random.randint(0, 4, (M, K)).astype(dtype)
# sign -1 or 1
sign_b = np.random.randint(0, 2, (kwargs['K'], kwargs['N'])).astype(dtype)
sign_b = np.random.randint(0, 2, (K, N)).astype(dtype)
# esponent < 0b01111
exponent_b = np.random.randint(0, 16, (kwargs['K'], kwargs['N'])).astype(dtype)
exponent_b = np.random.randint(0, 16, (K, N)).astype(dtype)
# mantissa can be arbitrary
mantissa_b = np.random.randint(0, 4, (kwargs['K'], kwargs['N'])).astype(dtype)
mantissa_b = np.random.randint(0, 4, (K, N)).astype(dtype)
# sign -1 or 1
sign_c = np.random.randint(0, 2, (kwargs['M'], kwargs['N'])).astype(dtype)
sign_c = np.random.randint(0, 2, (M, N)).astype(dtype)
# esponent < 0b01111
exponent_c = np.random.randint(0, 16, (kwargs['M'], kwargs['N'])).astype(dtype)
exponent_c = np.random.randint(0, 16, (M, N)).astype(dtype)
# mantissa can be arbitrary
mantissa_c = np.random.randint(0, 4, (kwargs['M'], kwargs['N'])).astype(dtype)
mantissa_c = np.random.randint(0, 4, (M, N)).astype(dtype)
_a = ((-1.0)**sign_a.astype(np.double))*(2.0**(exponent_a.astype(np.double)-15.0)) \
* (1.0 + mantissa_a.astype(np.double) / (2**2))
_b = ((-1.0)**sign_b.astype(np.double))*(2.0**(exponent_b.astype(np.double)-15.0)) \
Expand All @@ -82,24 +91,27 @@ def emit_header(**kwargs):
b = sign_b << 7 | exponent_b << FP8_FORMATS['fp8']['mant'] | mantissa_b
c = sign_c << 7 | exponent_c << FP8_FORMATS['fp8']['mant'] | mantissa_c
else:
a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype)
b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype)
c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype)
a = np.random.rand(M, K).astype(dtype)
b = np.random.rand(K, N).astype(dtype)
c = np.random.rand(M, N).astype(dtype)
result = golden_model(1, a, b, kwargs['beta'], c)

# Store matrices in transposed form if requested
a = a.T if kwargs['ta'] else a
b = b.T if kwargs['tb'] else b

data_str = [emit_license()]
data_str += [format_scalar_definition('uint32_t', 'M', kwargs['M'])]
data_str += [format_scalar_definition('uint32_t', 'N', kwargs['N'])]
data_str += [format_scalar_definition('uint32_t', 'K', kwargs['K'])]
data_str += [format_scalar_definition('uint32_t', 'M', M)]
data_str += [format_scalar_definition('uint32_t', 'N', N)]
data_str += [format_scalar_definition('uint32_t', 'K', K)]
data_str += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))]
data_str += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))]
data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])]
data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_scalar_definition('uint32_t', 'm_tiles', kwargs['m_tiles'])]
data_str += [format_scalar_definition('uint32_t', 'n_tiles', kwargs['n_tiles'])]
data_str += [format_scalar_definition('uint32_t', 'k_tiles', kwargs['k_tiles'])]
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
Expand Down
9 changes: 6 additions & 3 deletions sw/blas/gemm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
// Parameters for a GEMM

{
M: 192,
M: 32,
N: 16,
K: 16,
K: 768,
beta: 0,
ta: false,
tb: true, // must be true for SIMD
prec: 64,
expand: 0
expand: 0,
m_tiles: 1 // number of tiles in M dimension per cluster
k_tiles: 2 // number of tiles in K dimension over all clusters
n_tiles: 1 // number of tiles in N dimension per cluster
}
1 change: 1 addition & 0 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
for (uint32_t k = 0; k < K; k++) {
// dump_index(k + m * ldA);
// dump_gemm(A[k + m * ldA]);
// dump_gemm(B[k + n * ldB]);
c0 += A[k + m * ldA] * B[k + n * ldB];
}
C[m * ldC + n] = c0;
Expand Down
177 changes: 124 additions & 53 deletions sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ int main() {
uint32_t m_tile_size = M / m_tiles;
uint32_t n_tile_size = N / n_tiles;
uint32_t k_tile_size = K / k_tiles;
uint32_t beta_k = 0;
// Calculate size and pointers for each cluster
uint32_t frac_m = M / (snrt_cluster_num() * m_tiles);
uint32_t frac_n = n_tile_size;
Expand All @@ -41,67 +42,137 @@ int main() {
dump_value(tcdm_alloc_size / 1024);
// Determine the tile offset for each cluster
for (uint32_t m_tile = 0; m_tile < m_tiles; m_tile++) {
dump_val(m_tile);
// Copy in the Row Block of A
uint32_t offset_a = m_tile * m_tile_size * frac_k + snrt_cluster_idx() * frac_a;
remote_a = a + offset_a;
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_a, remote_a, size_frac_a);
snrt_dma_wait_all();
// dump the tile
// for (uint32_t i = 0; i < frac_a; i++) {
// dump_value(((double *)local_a)[i]);
// }
}
snrt_cluster_hw_barrier();
for (uint32_t n_tile = 0; n_tile < n_tiles; n_tile++) {
dump_val(n_tile);
// for (uint32_t n_tile = 0; n_tile < 2; n_tile++) {
// Copy in the Column Block of B and the Tile of C
uint32_t offset_b = n_tile * frac_n * k_tile_size;
dump_val(offset_b);
remote_b = b + offset_b;
uint32_t offset_c = frac_m * N * (m_tile + snrt_cluster_idx() * m_tiles) + n_tile * frac_n;
dump_val(offset_c);
remote_c = c + offset_c;
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_b, remote_b, size_b);
snrt_dma_start_1d(local_c, remote_c, size_frac_c);
snrt_dma_wait_all();
// for (uint32_t i = 0; i < size_b / dtype_size; i++) {
// dump_value(((double *)local_b)[i]);
// }
}
snrt_cluster_hw_barrier();
// Compute the Tile
if (!snrt_is_dm_core()) {
const uint32_t setup_ssr = 1;
uint32_t start_cycle = snrt_mcycle();

volatile uint32_t lda = K;
volatile uint32_t ldb = n_tile_size;
volatile uint32_t ldc = n_tile_size;

// Transpose of A unsopported
if (TA) return -1;
if (TB) {
// Transpose of B supported only in FP64
if (dtype_size != FP64) return -1;
ldb = K;
for (uint32_t k_tile = 0; k_tile < k_tiles; k_tile++) {
// Define pointer offsets for each cluster
uint32_t offset_a = (snrt_cluster_idx() * m_tiles + m_tile) * m_tile_size * K + k_tile * frac_k;
remote_a = a + offset_a;
uint32_t offset_b = n_tile * frac_n + k_tile * N * frac_k;
remote_b = b + offset_b;
uint32_t offset_c = n_tile * frac_n + (snrt_cluster_idx() * m_tiles + m_tile) * N * frac_m;
// dump_val(offset_c);
remote_c = c + offset_c;
// Copy data in TCDM
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_a, remote_a, size_frac_a);
// dump the tile
// for (uint32_t i = 0; i < frac_a; i++) {
// dump_value(((double *)local_a)[i]);
// }
snrt_dma_start_1d(local_b, remote_b, size_b);
// for (uint32_t i = 0; i < size_b / dtype_size; i++) {
// dump_value(((double *)local_b)[i]);
// }
snrt_dma_start_1d(local_c, remote_c, size_frac_c);
// for (uint32_t i = 0; i < size_frac_c / dtype_size; i++) {
// dump_value(((double *)local_c)[i]);
// }
snrt_dma_wait_all();
}

gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, n_tile_size, K, 1, local_a,
lda, local_b, ldb, BETA, local_c, ldc);
snrt_cluster_hw_barrier();

uint32_t end_cycle = snrt_mcycle();
// Compute
if (!snrt_is_dm_core()) {
const uint32_t setup_ssr = 1;
uint32_t start_cycle = snrt_mcycle();

// Copy out the Tile of C
if (snrt_is_dm_core()) {
snrt_dma_start_1d(remote_c, local_c, size_frac_c);
snrt_dma_wait_all();
volatile uint32_t lda = frac_k;
volatile uint32_t ldb = frac_n;
volatile uint32_t ldc = frac_n;

// Transpose of A unsopported
if (TA) return -1;
if (TB) {
// Transpose of B supported only in FP64
if (dtype_size != FP64) return -1;
ldb = frac_k;
}

if (k_tile != 0) {
beta_k = 1;
} else {
beta_k = BETA;
}

gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, frac_n, frac_k, 1, local_a,
lda, local_b, ldb, beta_k, local_c, ldc);

uint32_t end_cycle = snrt_mcycle();
}

snrt_cluster_hw_barrier();

}

// Copy data out of TCDM
if (snrt_is_dm_core()) {
snrt_dma_start_1d(remote_c, local_c, size_frac_c);
snrt_dma_wait_all();
}

}
// dump_val(m_tile);
// // Copy in the Row Block of A
// uint32_t offset_a = m_tile * m_tile_size * frac_k + snrt_cluster_idx() * frac_a;
// remote_a = a + offset_a;
// if (snrt_is_dm_core()) {
// snrt_dma_start_1d(local_a, remote_a, size_frac_a);
// snrt_dma_wait_all();
// // dump the tile
// // for (uint32_t i = 0; i < frac_a; i++) {
// // dump_value(((double *)local_a)[i]);
// // }
// }
// snrt_cluster_hw_barrier();
// for (uint32_t n_tile = 0; n_tile < n_tiles; n_tile++) {
// dump_val(n_tile);
// // for (uint32_t n_tile = 0; n_tile < 2; n_tile++) {
// // Copy in the Column Block of B and the Tile of C
// uint32_t offset_b = n_tile * frac_n * k_tile_size;
// dump_val(offset_b);
// remote_b = b + offset_b;
// uint32_t offset_c = frac_m * N * (m_tile + snrt_cluster_idx() * m_tiles) + n_tile * frac_n;
// dump_val(offset_c);
// remote_c = c + offset_c;
// if (snrt_is_dm_core()) {
// snrt_dma_start_1d(local_b, remote_b, size_b);
// snrt_dma_start_1d(local_c, remote_c, size_frac_c);
// snrt_dma_wait_all();
// // for (uint32_t i = 0; i < size_b / dtype_size; i++) {
// // dump_value(((double *)local_b)[i]);
// // }
// }
// snrt_cluster_hw_barrier();
// // Compute the Tile
// if (!snrt_is_dm_core()) {
// const uint32_t setup_ssr = 1;
// uint32_t start_cycle = snrt_mcycle();

// volatile uint32_t lda = K;
// volatile uint32_t ldb = n_tile_size;
// volatile uint32_t ldc = n_tile_size;

// // Transpose of A unsopported
// if (TA) return -1;
// if (TB) {
// // Transpose of B supported only in FP64
// if (dtype_size != FP64) return -1;
// ldb = K;
// }

// gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, n_tile_size, K, 1, local_a,
// lda, local_b, ldb, BETA, local_c, ldc);

// uint32_t end_cycle = snrt_mcycle();

// // Copy out the Tile of C
// if (snrt_is_dm_core()) {
// snrt_dma_start_1d(remote_c, local_c, size_frac_c);
// snrt_dma_wait_all();
// }
// }
// }
}


Expand Down

0 comments on commit 050b903

Please sign in to comment.