diff --git a/sw/blas/gemm_v2/data/params.json b/sw/blas/gemm_v2/data/params.json index 7543a0eed..b34fb472c 100644 --- a/sw/blas/gemm_v2/data/params.json +++ b/sw/blas/gemm_v2/data/params.json @@ -4,13 +4,14 @@ { setup_ssr: 1, - m_tiles: 2, // number of tiles in M dimension + m_tiles: 1, // number of tiles in M dimension + k_tiles: 8, // number of tiles in K dimension transa: false, transb: true, // must be true for SIMD - M: 96, + M: 48, N: 48, - K: 48, + K: 384, alpha: 1, - beta: 0, - gemm_fp: "gemm_fp64_baseline" + beta: 1, + gemm_fp: "gemm_fp64_opt" } diff --git a/sw/blas/gemm_v2/roi.json b/sw/blas/gemm_v2/roi.json index 1fd37dc36..eaa6278fc 100644 --- a/sw/blas/gemm_v2/roi.json +++ b/sw/blas/gemm_v2/roi.json @@ -1,6 +1,6 @@ [ <% DOUBLE_BUFFER = 1 %> - <% N_TILES = 2 %> + <% N_TILES = 8 %> // Compute cores % for j in range(0, 8): diff --git a/sw/blas/gemm_v2/scripts/datagen.py b/sw/blas/gemm_v2/scripts/datagen.py index 494b3ab0c..89678155b 100755 --- a/sw/blas/gemm_v2/scripts/datagen.py +++ b/sw/blas/gemm_v2/scripts/datagen.py @@ -46,11 +46,11 @@ def infer_implementation(self, gemm_fp): return (int(prec) / 8), impl def validate_config(self, gemm_fp, - m_tiles, transa, + m_tiles, k_tiles, transa, transb, M, N, K, beta, **kwargs): frac_m = M / m_tiles frac_n = N / 1 - frac_k = K / 1 + frac_k = K / k_tiles dtype, impl = self.infer_implementation(gemm_fp) @@ -66,7 +66,7 @@ def validate_config(self, gemm_fp, assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size' assert (N % 1) == 0, 'N is not an integer multiple of tile size' - assert (K % 1) == 0, 'K is not an integer multiple of tile size' + assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size' assert not transa, 'SIMD kernels don\'t support transposed A matrix' assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \ or transb, 'Optimized SIMD kernels only support transposed B matrix' diff --git a/sw/blas/gemm_v2/src/gemm_v2.h b/sw/blas/gemm_v2/src/gemm_v2.h index 02f8ebe4e..329f75869 100644 --- a/sw/blas/gemm_v2/src/gemm_v2.h +++ b/sw/blas/gemm_v2/src/gemm_v2.h @@ -13,13 +13,6 @@ #pragma once -// Guard to avoid conflict with DNN header file -// TODO: move this definition to Snitch math library to solve problem -#ifndef PRECISION_T -#define PRECISION_T -typedef enum { FP64 = 8, FP32 = 4, FP16 = 2, FP8 = 1 } precision_t; -#endif - typedef float v2f32 __attribute__((vector_size(8))); typedef __fp16 v4f16 __attribute__((vector_size(8))); typedef char v8f8 __attribute__((vector_size(8))); @@ -89,7 +82,7 @@ void sc_st_gemm(gemm_args_t* gemm_args, void* a, void* b, uint32_t beta, uint32_t m = gemm_args->M / gemm_args->m_tiles; uint32_t n = gemm_args->N; - uint32_t k = gemm_args->K; + uint32_t k = gemm_args->K / gemm_args->k_tiles; uint32_t lda = k; uint32_t ldb; @@ -151,6 +144,7 @@ int gemm(gemm_args_t* args) { precision_t prec = (precision_t)local_args->prec; uint32_t setup_ssr = local_args->setup_ssr; uint32_t m_tiles = local_args->m_tiles; + uint32_t k_tiles = local_args->k_tiles; uint32_t transa = local_args->transa; uint32_t transb = local_args->transb; double alpha = local_args->alpha; @@ -162,7 +156,7 @@ int gemm(gemm_args_t* args) { // Calculate tile sizes uint32_t frac_m = m / m_tiles; uint32_t frac_n = n; - uint32_t frac_k = k; + uint32_t frac_k = k / k_tiles; uint32_t frac_a = frac_m * frac_k; uint32_t frac_c = frac_m * frac_n; uint32_t size_frac_a = frac_a * prec; @@ -187,10 +181,10 @@ int gemm(gemm_args_t* args) { local_c[1] = heap_ptr; // Calculate number of iterations - int n_tiles = args->m_tiles; + int n_tiles = args->m_tiles * args->k_tiles; int iterations = n_tiles + 2; - int buff_idx; - int i, i_dma_out, i_dma_in, i_compute; + int buff_idx, c_buff_idx, c_move; + int i, i_m, i_k, i_dma_out, i_dma_in, i_compute; // Iterate over all tiles for (i = 0; i < iterations; i++) { @@ -203,11 +197,17 @@ int gemm(gemm_args_t* args) { // Compute tile and buffer indices i_dma_out = i - 2; buff_idx = i_dma_out % 2; + i_m = i_dma_out / k_tiles; + i_k = i_dma_out % k_tiles; + c_buff_idx = i_m % 2; + c_move = i_k == (k_tiles - 1); // Copy job outputs from TCDM - snrt_dma_store_2d_tile(c, local_c[buff_idx], i_dma_out, 0, - frac_m, frac_n, n, prec); - snrt_dma_wait_all(); + if (c_move) { + snrt_dma_store_2d_tile(c, local_c[c_buff_idx], i_m, 0, + frac_m, frac_n, n, prec); + snrt_dma_wait_all(); + } snrt_mcycle(); } @@ -219,12 +219,23 @@ int gemm(gemm_args_t* args) { // Compute tile and buffer indices i_dma_in = i; buff_idx = i_dma_in % 2; + i_m = i_dma_in / k_tiles; + i_k = i_dma_in % k_tiles; + c_buff_idx = i_m % 2; + c_move = i_k == 0; // Copy job operands in TCDM - snrt_dma_load_2d_tile(local_a[buff_idx], a, i_dma_in, - 0, frac_m, frac_k, k, + snrt_dma_load_2d_tile(local_a[buff_idx], a, i_m, + i_k, frac_m, frac_k, k, prec); - snrt_dma_start_1d(local_b[buff_idx], b, frac_k * frac_n * prec); + snrt_dma_load_2d_tile(local_b[buff_idx], b, 0, + i_k, frac_n, frac_k, k, + prec); + if (c_move) { + snrt_dma_load_2d_tile(local_c[c_buff_idx], c, i_m, + 0, frac_m, frac_n, n, + prec); + } snrt_dma_wait_all(); snrt_mcycle(); @@ -239,6 +250,9 @@ int gemm(gemm_args_t* args) { // Compute tile and buffer indices i_compute = i - 1; buff_idx = i_compute % 2; + i_m = i_compute / k_tiles; + i_k = i_compute % k_tiles; + c_buff_idx = i_m % 2; // Perform tile computation volatile uint32_t ldb = frac_n; @@ -247,7 +261,7 @@ int gemm(gemm_args_t* args) { ldb = frac_k; } sc_st_gemm(local_args, local_a[buff_idx], local_b[buff_idx], beta, - local_c[buff_idx]); + local_c[c_buff_idx]); snrt_mcycle(); }