diff --git a/software/apps/baremetal/mimo_mmse_f16/main.c b/software/apps/baremetal/mimo_mmse_f16/main.c index c5cb20155..36247fbbd 100644 --- a/software/apps/baremetal/mimo_mmse_f16/main.c +++ b/software/apps/baremetal/mimo_mmse_f16/main.c @@ -19,7 +19,7 @@ #include "data_mimo_mmse_f16.h" #define NUM_BANKS (BANKING_FACTOR * NUM_CORES) -#define DOUBLE_BUFFERING +//#define DOUBLE_BUFFERING /********************************************************** ********************************************************** @@ -33,7 +33,6 @@ ***********************************************************/ #ifndef DOUBLE_BUFFERING -#define PARALLEL __fp16 l1_H[2 * N_TX * N_RX * N_ITR] __attribute__((aligned(NUM_BANKS * sizeof(int32_t)), section(".l1_prio"))); @@ -61,6 +60,22 @@ int main() { uint32_t core_id = mempool_get_core_id(); uint32_t num_cores = mempool_get_core_count(); mempool_barrier_init(core_id); // Initialize barrier and synchronize + +#ifdef BANSHEE + /* Initialize matrices */ + if (core_id == 0) { + for (uint32_t i = 0; i < N_RX * N_TX * N_ITR; i++) { + (*(uint32_t *)&l1_H[2 * i]) = *(uint32_t *)&l2_H[2 * i]; + } + for (uint32_t i = 0; i < N_RX * N_ITR; i++) { + (*(uint32_t *)&l1_y[2 * i]) = *(uint32_t *)&l2_y[2 * i]; + } + for (uint32_t i = 0; i < N_TX * N_ITR; i++) { + (*(uint32_t *)&l1_S[2 * i]) = *(uint32_t *)&l2_Sigma[2 * i]; + } + } + mempool_barrier(num_cores); +#else /* Initialize matrices */ if (core_id == 0) { dma_memcpy_blocking(l1_beamgroups, l2_beamgroups, N_ITR * sizeof(int32_t)); @@ -69,6 +84,38 @@ int main() { dma_memcpy_blocking(l1_S, l2_Sigma, N_TX * N_ITR * sizeof(int32_t)); } mempool_barrier(num_cores); +#endif + +#ifdef BANSHEE + /* Benchmark */ + if (core_id == 0) { + mempool_start_benchmark(); + for (uint32_t itr = 0; itr < N_ITR; itr++) { + __fp16 *PtrH = l1_H + itr * (2 * N_TX * N_RX); + __fp16 *Ptry = l1_y + itr * (2 * N_RX); + __fp16 *PtrSigma = l1_S + itr * N_TX; + __fp16 *PtrG = l1_G + itr * (2 * N_TX * N_TX); + __fp16 *PtrL = l1_L + itr * (2 * N_TX * N_TX); + __fp16 *Ptry2 = y2 + itr * (2 * N_TX); + __fp16 *Ptry3 = y3 + itr * (2 * N_TX); + __fp16 *Ptrx = l1_x + itr * (2 * N_TX); + +#ifdef VEC + mempool_hermitian_f16vecs(PtrH, PtrG, PtrSigma, N_RX, N_TX); + mempool_MVP_conjtransp_f16vecs(PtrH, Ptry, Ptry2, N_RX, N_TX); + mempool_cholesky_f16vecs(PtrG, PtrL, N_TX); +#else + mempool_hermitian_f16s(PtrH, PtrG, PtrSigma, N_RX, N_TX, 0, 0); + mempool_MVP_conjtransp_f16s(PtrH, Ptry, Ptry2, N_RX, N_TX, 0); + mempool_cholesky_f16s(PtrG, PtrL, N_TX); +#endif + mempool_Ltrisol_f16s(PtrL, Ptry2, Ptry3, N_TX); + mempool_Lttrisol_f16s(PtrL, Ptry3, Ptrx, N_TX); + } + mempool_stop_benchmark(); + } + mempool_barrier(num_cores); +#endif #ifdef SINGLE if (core_id == 0) { @@ -116,8 +163,18 @@ int main() { #endif // Check the result +#ifdef BANSHEE + if (core_id == 0) { + for (uint32_t i = 0; i < 2 * N_TX * N_ITR; i++) { + uint32_t x = (*(uint32_t *)&l1_x[i]) & 0x0000FFFF; + printf("RES=%04x\n", x); + } + } + mempool_barrier(num_cores); +#else mempool_check_f16(l1_x, l2_x, 2 * N_TX, 0.01f, 0); mempool_barrier(num_cores); +#endif return 0; } diff --git a/software/kernels/baremetal/mempool_cholesky_f16s.h b/software/kernels/baremetal/mempool_cholesky_f16s.h index bb6143ed7..ada17eb2a 100644 --- a/software/kernels/baremetal/mempool_cholesky_f16s.h +++ b/software/kernels/baremetal/mempool_cholesky_f16s.h @@ -20,7 +20,7 @@ */ void mempool_cholesky_f16s(__fp16 *pSrc, __fp16 *pL, const uint32_t n) { - register __fp16 sum; + __fp16 sum; __fp16 a, b; __fp16 c, d; __fp16 diag; // Diagonal element @@ -173,21 +173,22 @@ void mempool_cholesky_folded_f16s(__fp16 *pSrc, __fp16 *pL, const uint32_t n) { */ void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) { - float sum; // register float sum - __fp16 diag; // Diagonal element - __fp16 ap; + float sum; // register float sum + v2h abp, ab, cd; + __fp16 diag, ap; +#ifndef __CDOTP float as, bs; - v2h abp, ab, cd, ndc; - v2h vec_sum; - v2h vec_diag; + v2h ndc, vec_sum; +#endif + v2h vec_diag; uint32_t i, j, k; for (j = 0; j < n; j++) { // Elements on diagonal (input matrix is positive-definite) ap = pSrc[2U * (j * n + j)]; - sum = (float)0.0f; + sum = 0.0f; for (k = 0; k < j; k++) { ab = (*(v2h *)&pL[2U * (j * n + k)]); asm volatile("vfdotpex.s.h %[sum], %[ab], %[ab];" @@ -209,6 +210,22 @@ void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) { abp = (*(v2h *)&pSrc[2U * (i * n + j)]); // Diag diag = pL[2U * (j * n + j)]; + +#ifdef __CDOTP + for (k = 0; k < j; k++) { + ab = (*(v2h *)&pL[2U * (i * n + k)]); + cd = (*(v2h *)&pL[2U * (j * n + k)]); + asm volatile("fcndotpex.s.h %[abp],%[ab],%[cd];" + : [abp] "=&r"(abp) + : [ab] "r"(ab), [cd] "r"(cd) + :); + } + asm volatile("pv.pack %[vec_diag], %[diag], %[diag];" + "vfdiv.h %[abp], %[abp], %[vec_diag];" + : [abp] "+&r"(abp), [vec_diag] "=&r"(vec_diag) + : [diag] "r"(diag) + :); +#else // Sum -> s = s + (ac + bd) + j*(bc - ad) as = (float)0.0f; bs = (float)0.0f; @@ -221,7 +238,7 @@ void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) { // s = s + (ac + bd) + j(bc - ad) "vfdotpex.s.h %[as], %[ab], %[cd];" "pv.shuffle2.h %[ndc], %[cd], %[shuffle_mask];" - "xor %[ndc], %[neg_mask], %[ndc];" + "xor %[ndc], %[neg_mask], %[ndc];" "vfdotpex.s.h %[bs], %[ab], %[ndc];" : [as] "+&r"(as), [bs] "+&r"(bs), [ndc] "+r"(ndc) : [ab] "r"(ab), [cd] "r"(cd), [neg_mask] "r"(neg_mask), @@ -229,13 +246,14 @@ void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) { :); } asm volatile("vfcpka.h.s %[vec_sum], %[as], %[bs];" - "pv.pack.h %[vec_diag], %[diag], %[diag];" + "pv.pack %[vec_diag], %[diag], %[diag];" "vfsub.h %[abp], %[abp], %[vec_sum];" "vfdiv.h %[abp], %[abp], %[vec_diag];" : [abp] "+&r"(abp), [vec_sum] "+&r"(vec_sum), [vec_diag] "+&r"(vec_diag) : [as] "r"(as), [bs] "r"(bs), [diag] "r"(diag) :); +#endif (*(v2h *)&pL[2U * (i * n + j)]) = abp; } } diff --git a/software/kernels/baremetal/mempool_mimo_mmse_f16s.h b/software/kernels/baremetal/mempool_mimo_mmse_f16s.h index fef02faaf..9f8327703 100644 --- a/software/kernels/baremetal/mempool_mimo_mmse_f16s.h +++ b/software/kernels/baremetal/mempool_mimo_mmse_f16s.h @@ -282,10 +282,15 @@ void mempool_hermitian_f16vecs(__fp16 *pH, __fp16 *pG, __fp16 *pS, uint32_t i, j, k; v2h ab; v2h cd0, cd1, cd2, cd3; + const uint32_t shuffle_mask = 0x00020003; +#ifndef __CDOTP float as0, as1, as2, as3; float bs0, bs1, bs2, bs3; const uint32_t neg_mask = 0x80000000; - const uint32_t shuffle_mask = 0x00020003; +#endif + +#ifndef __CDOTP + for (i = 0; i < n_tx; i++) { if (n_tx == 1) { @@ -433,6 +438,64 @@ void mempool_hermitian_f16vecs(__fp16 *pH, __fp16 *pG, __fp16 *pS, } } } + +#else + + for (i = 0; i < n_tx; i++) { + if (n_tx >= 4) { + // UNROLL_4 + for (j = 0; j < n_tx; j += 4) { + v2h res0 = (v2h)0.0f; + v2h res1 = (v2h)0.0f; + v2h res2 = (v2h)0.0f; + v2h res3 = (v2h)0.0f; + for (k = 0; k < n_rx; k++) { + ab = (*(v2h *)&pH[2U * (k * n_tx + i)]); + cd0 = (*(v2h *)&pH[2U * (k * n_tx + j)]); + cd1 = (*(v2h *)&pH[2U * (k * n_tx + j + 1U)]); + cd2 = (*(v2h *)&pH[2U * (k * n_tx + j + 2U)]); + cd3 = (*(v2h *)&pH[2U * (k * n_tx + j + 3U)]); + asm volatile("fccdotpex.s.h %[res0], %[ab], %[cd0];" + "fccdotpex.s.h %[res1], %[ab], %[cd1];" + "fccdotpex.s.h %[res2], %[ab], %[cd2];" + "fccdotpex.s.h %[res3], %[ab], %[cd3];" + : [res0] "+&r"(res0), [res1] "+&r"(res1), + [res2] "+&r"(res2), [res3] "+&r"(res3) + : [cd0] "r"(cd0), [cd1] "r"(cd1), [cd2] "r"(cd2), + [cd3] "r"(cd3), [ab] "r"(ab) + :); + } + asm volatile("pv.shuffle2.h %[res0], %[res0], %[shuffle_mask];" + "pv.shuffle2.h %[res1], %[res1], %[shuffle_mask];" + "pv.shuffle2.h %[res2], %[res2], %[shuffle_mask];" + "pv.shuffle2.h %[res3], %[res3], %[shuffle_mask];" + : [res0] "+&r"(res0), [res1] "+&r"(res1), + [res2] "+&r"(res2), [res3] "+&r"(res3) + : [shuffle_mask] "r"(shuffle_mask) + :); + __fp16 sigma = pS[2 * i]; + if (i == j) { + asm volatile("and %0, %0, %1;" : "+&r"(res0) : "r"(0x0000FFFF)); + asm volatile("fadd.h %0, %0, %1;" : "+&r"(res0) : "r"(sigma)); + } else if (i == (j + 1U)) { + asm volatile("and %0, %0, %1;" : "+&r"(res1) : "r"(0x0000FFFF)); + asm volatile("fadd.h %0, %0, %1;" : "+&r"(res1) : "r"(sigma)); + } else if (i == (j + 2U)) { + asm volatile("and %0, %0, %1;" : "+&r"(res2) : "r"(0x0000FFFF)); + asm volatile("fadd.h %0, %0, %1;" : "+&r"(res2) : "r"(sigma)); + } else if (i == (j + 3U)) { + asm volatile("and %0, %0, %1;" : "+&r"(res3) : "r"(0x0000FFFF)); + asm volatile("fadd.h %0, %0, %1;" : "+&r"(res3) : "r"(sigma)); + } + (*(v2h *)&pG[2 * (i * n_tx + j)]) = res0; + (*(v2h *)&pG[2 * (i * n_tx + j + 1U)]) = res1; + (*(v2h *)&pG[2 * (i * n_tx + j + 2U)]) = res2; + (*(v2h *)&pG[2 * (i * n_tx + j + 3U)]) = res3; + } + } + } + +#endif return; } @@ -450,14 +513,19 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py, const uint32_t n_rx, const uint32_t n_tx) { uint32_t i, j; - float as0, as1, as2, as3; - float bs0, bs1, bs2, bs3; + v2h res0, res1, res2, res3; v2h ab0, ab1, ab2, ab3; v2h cd; - uint32_t ndc; - const uint32_t neg_mask = 0x80000000; const uint32_t shuffle_mask = 0x00020003; +#ifndef __CDOTP + float as0, as1, as2, as3; + float bs0, bs1, bs2, bs3; + const uint32_t neg_mask = 0x80000000; + uint32_t ndc; +#endif + +#ifndef __CDOTP if (n_tx < 4) { for (i = 0; i < n_tx; i++) { as0 = 0.0f; // Initialize the real part of sums @@ -478,7 +546,6 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py, [shuffle_mask] "r"(shuffle_mask), [ab0] "r"(ab0) :); } - v2h res0; asm volatile("vfcpka.h.s %0, %1, %2;" : "=&r"(res0) : "r"(as0), "r"(bs0) @@ -521,7 +588,6 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py, [ab1] "r"(ab1), [ab2] "r"(ab2), [ab3] "r"(ab3) :); } - v2h res0, res1, res2, res3; asm volatile( "vfcpka.h.s %[res0], %[as0], %[bs0];" "vfcpka.h.s %[res1], %[as1], %[bs1];" @@ -538,5 +604,44 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py, *(v2h *)&py[2U * (i + 3U)] = res3; } } +#else + if (n_tx >= 4) { + // UNROLL_4 + for (i = 0; i < n_tx; i += 4) { + res0 = (v2h)0.0f; + res1 = (v2h)0.0f; + res2 = (v2h)0.0f; + res3 = (v2h)0.0f; + for (j = 0; j < n_rx; j++) { + ab0 = *(v2h *)&pH[2U * (j * n_tx + i)]; + ab1 = *(v2h *)&pH[2U * (j * n_tx + i + 1U)]; + ab2 = *(v2h *)&pH[2U * (j * n_tx + i + 2U)]; + ab3 = *(v2h *)&pH[2U * (j * n_tx + i + 3U)]; + cd = *(v2h *)&px[2U * j]; + asm volatile("fccdotpex.s.h %[res0], %[ab0], %[cd];" + "fccdotpex.s.h %[res1], %[ab1], %[cd];" + "fccdotpex.s.h %[res2], %[ab2], %[cd];" + "fccdotpex.s.h %[res3], %[ab3], %[cd];" + : [res0] "+&r"(res0), [res1] "+&r"(res1), + [res2] "+&r"(res2), [res3] "+&r"(res3) + : [ab0] "r"(ab0), [ab1] "r"(ab1), [ab2] "r"(ab2), + [ab3] "r"(ab3), [cd] "r"(cd) + :); + } + asm volatile("pv.shuffle2.h %[res0], %[res0], %[shuffle_mask];" + "pv.shuffle2.h %[res1], %[res1], %[shuffle_mask];" + "pv.shuffle2.h %[res2], %[res2], %[shuffle_mask];" + "pv.shuffle2.h %[res3], %[res3], %[shuffle_mask];" + : [res0] "+&r"(res0), [res1] "+&r"(res1), [res2] "+&r"(res2), + [res3] "+&r"(res3) + : [shuffle_mask] "r"(shuffle_mask) + :); + *(v2h *)&py[2U * i] = res0; + *(v2h *)&py[2U * (i + 1U)] = res1; + *(v2h *)&py[2U * (i + 2U)] = res2; + *(v2h *)&py[2U * (i + 3U)] = res3; + } + } +#endif return; }