Skip to content

Commit

Permalink
only load A,B if indices change, w/o c2c
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerbarton committed Jan 19, 2024
1 parent 5c2bbc3 commit a0f25dc
Showing 1 changed file with 39 additions and 53 deletions.
92 changes: 39 additions & 53 deletions sw/blas/gemm/src/gemm_occamy_2dpipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ NAMED_DUMP(double, c, 0xc)
const int i##_first = dir ? begin : i##_end_floor; \
const int i##_last = dir ? i##_end_floor : begin; \
i = i##_first; \
i_prev = i; \
for (; dir ? i <= i##_last : i >= i##_last; \
i = dir ? i + stride : i - stride)

Expand Down Expand Up @@ -293,7 +292,9 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {
// For double buffering l1 is a size 2 array
TcdmLayout* l1 = snrt_l1_next();

bool l1Id_AB = false;
// Which buffer is the valid data in for computation
bool l1Id_A = true;
bool l1Id_B = true;
bool l1Id_C = false;

// Initialize indices
Expand All @@ -302,7 +303,7 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {
const uint32_t pj = p[1] % PJ;

int ib, jb, kb;
int ib_prev, jb_prev, kb_prev;
int ib_prev = -1, jb_prev = -1, kb_prev = -1;
bool ib_dir = false, jb_dir = false, kb_dir = false;

bool storeC = false;
Expand All @@ -325,7 +326,7 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {
const uint32_t p_srcA = pi * PJ + ((2 * PJ - pi - pk) % PJ);
const uint32_t p_srcB = pj + PJ * ((2 * PJ - pj - pk) % PJ);

const bool fetch_dram = pk == 0;
const bool fetch_dram = true;//pk == 0;
c2cL1_A = fetch_dram ? NULL : l1Ptr[p_srcA];
c2cL1_B = fetch_dram ? NULL : l1Ptr[p_srcB];

Expand Down Expand Up @@ -355,35 +356,47 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {

FOR_EACH(ib, pi, M / L1_M, PI, ib_dir, ib_prev) {
FOR_EACH(jb, pj, N / L1_N, PJ, jb_dir, jb_prev) {

double* const l1_C = l1[l1Id_C].C;

if (snrt_is_dm_core()) {
dump_ib(ib);
dump_jb(jb);
snrt_dma_load_2d_tile(l1_C, C, ib, jb, L1_M, L1_N, ldc, FP64);
if (ib != ib_first || jb != jb_first) storeC = true;
if (ib_prev >= 0 /* && jb_prev >= 0 */) storeC = true;
}

FOR_EACH(kb, 0, K / L1_K, 1, kb_dir, kb_prev) {
double* const l1_A = l1[l1Id_AB].A;
double* const l1_B = l1[l1Id_AB].B;

// load next A, B
// Only load if the indices have changed, otherwise data is already loaded
const bool loadA = ib != ib_prev || kb != kb_prev;
const bool loadB = kb != kb_prev || jb != jb_prev;

// Switch buffers when the indices have changed
if (loadA) l1Id_A = !l1Id_A;
if (loadB) l1Id_B = !l1Id_B;

double* const l1_A = l1[l1Id_A].A;
double* const l1_B = l1[l1Id_B].B;

if (snrt_is_dm_core()) {
if (c2cL1_A == NULL)
snrt_dma_load_2d_tile(l1_A, A, ib, kb, L1_M, L1_K, lda,
FP64);
else {
double* const c2c_A = c2cL1_A[l1Id_AB].A;
snrt_dma_start_1d(l1_A, c2c_A, L1_M * L1_K * FP64);
dump_kb(kb);
if (loadA) {
if (c2cL1_A == NULL)
snrt_dma_load_2d_tile(l1_A, A, ib, kb, L1_M, L1_K, lda,
FP64);
else {
double* const c2c_A = c2cL1_A[l1Id_A].A;
snrt_dma_start_1d(l1_A, c2c_A, L1_M * L1_K * FP64);
}
}
if (c2cL1_B == NULL)
snrt_dma_load_2d_tile(l1_B, B, kb, jb, L1_K, L1_N, ldb,
FP64);
else {
double* const c2c_B = c2cL1_B[l1Id_AB].B;
snrt_dma_start_1d(l1_B, c2c_B, L1_K * L1_N * FP64);
if (loadB) {
if (c2cL1_B == NULL)
snrt_dma_load_2d_tile(l1_B, B, kb, jb, L1_K, L1_N, ldb,
FP64);
else {
double* const c2c_B = c2cL1_B[l1Id_B].B;
snrt_dma_start_1d(l1_B, c2c_B, L1_K * L1_N * FP64);
}
}

snrt_dma_wait_all();
Expand All @@ -406,7 +419,6 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {
gemm_cluster_kernel(tileInfo, tileArgs);
}

l1Id_AB = !l1Id_AB; // switch buffers
snrt_global_barrier();

if (snrt_is_dm_core()) {
Expand All @@ -429,9 +441,11 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {
snrt_global_barrier(); // DMA core is one index ahead

// store final tile
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N,
ldc, FP64);
snrt_dma_wait_all();
// if (ib_prev >= 0 && jb_prev >= 0) {
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N,
ldc, FP64);
snrt_dma_wait_all();
// }
} else {
gemm_cluster_kernel_deinit(tileInfo);
}
Expand All @@ -441,31 +455,3 @@ void gemm_oc(const GemmInfo info, const GemmArgs args) {
snrt_global_barrier();
}
}

// inline void gemm_oc(precision_t prec, uint32_t expand, uint32_t setup_ssr,
// uint32_t transa, uint32_t transb, uint32_t m, uint32_t n,
// uint32_t k, double alpha, void* a, uint32_t lda, void* b,
// uint32_t ldb, double beta, void* c, uint32_t ldc) {
// // gemm_cluster_kernel(alpha, beta, m, n, k, a, b, c, lda, ldb, ldc);
// // snrt_fpu_fence();
// // snrt_cluster_hw_barrier();

// GemmInfo gemmInfo = {0};
// gemmInfo.M = m;
// gemmInfo.N = n;
// gemmInfo.K = k;
// gemmInfo.lda = lda;
// gemmInfo.ldb = ldb;
// gemmInfo.ldc = ldc;
// gemmInfo.ta = transa;
// gemmInfo.tb = transb;

// GemmArgs gemmArgs = {0};
// gemmArgs.A = a;
// gemmArgs.B = b;
// gemmArgs.C = c;
// gemmArgs.alpha = alpha;
// gemmArgs.beta = beta;

// gemm_oc_opt2d(gemmInfo, gemmArgs);
// }

0 comments on commit a0f25dc

Please sign in to comment.