Skip to content

Commit

Permalink
[software] Add MMSE with complex dotp
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Sep 20, 2024
1 parent a1bc514 commit d05ce74
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 19 deletions.
61 changes: 59 additions & 2 deletions software/apps/baremetal/mimo_mmse_f16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#include "data_mimo_mmse_f16.h"
#define NUM_BANKS (BANKING_FACTOR * NUM_CORES)
#define DOUBLE_BUFFERING
//#define DOUBLE_BUFFERING

/**********************************************************
**********************************************************
Expand All @@ -33,7 +33,6 @@
***********************************************************/

#ifndef DOUBLE_BUFFERING
#define PARALLEL

__fp16 l1_H[2 * N_TX * N_RX * N_ITR]
__attribute__((aligned(NUM_BANKS * sizeof(int32_t)), section(".l1_prio")));
Expand Down Expand Up @@ -61,6 +60,22 @@ int main() {
uint32_t core_id = mempool_get_core_id();
uint32_t num_cores = mempool_get_core_count();
mempool_barrier_init(core_id); // Initialize barrier and synchronize

#ifdef BANSHEE
/* Initialize matrices */
if (core_id == 0) {
for (uint32_t i = 0; i < N_RX * N_TX * N_ITR; i++) {
(*(uint32_t *)&l1_H[2 * i]) = *(uint32_t *)&l2_H[2 * i];
}
for (uint32_t i = 0; i < N_RX * N_ITR; i++) {
(*(uint32_t *)&l1_y[2 * i]) = *(uint32_t *)&l2_y[2 * i];
}
for (uint32_t i = 0; i < N_TX * N_ITR; i++) {
(*(uint32_t *)&l1_S[2 * i]) = *(uint32_t *)&l2_Sigma[2 * i];
}
}
mempool_barrier(num_cores);
#else
/* Initialize matrices */
if (core_id == 0) {
dma_memcpy_blocking(l1_beamgroups, l2_beamgroups, N_ITR * sizeof(int32_t));
Expand All @@ -69,6 +84,38 @@ int main() {
dma_memcpy_blocking(l1_S, l2_Sigma, N_TX * N_ITR * sizeof(int32_t));
}
mempool_barrier(num_cores);
#endif

#ifdef BANSHEE
/* Benchmark */
if (core_id == 0) {
mempool_start_benchmark();
for (uint32_t itr = 0; itr < N_ITR; itr++) {
__fp16 *PtrH = l1_H + itr * (2 * N_TX * N_RX);
__fp16 *Ptry = l1_y + itr * (2 * N_RX);
__fp16 *PtrSigma = l1_S + itr * N_TX;
__fp16 *PtrG = l1_G + itr * (2 * N_TX * N_TX);
__fp16 *PtrL = l1_L + itr * (2 * N_TX * N_TX);
__fp16 *Ptry2 = y2 + itr * (2 * N_TX);
__fp16 *Ptry3 = y3 + itr * (2 * N_TX);
__fp16 *Ptrx = l1_x + itr * (2 * N_TX);

#ifdef VEC
mempool_hermitian_f16vecs(PtrH, PtrG, PtrSigma, N_RX, N_TX);
mempool_MVP_conjtransp_f16vecs(PtrH, Ptry, Ptry2, N_RX, N_TX);
mempool_cholesky_f16vecs(PtrG, PtrL, N_TX);
#else
mempool_hermitian_f16s(PtrH, PtrG, PtrSigma, N_RX, N_TX, 0, 0);
mempool_MVP_conjtransp_f16s(PtrH, Ptry, Ptry2, N_RX, N_TX, 0);
mempool_cholesky_f16s(PtrG, PtrL, N_TX);
#endif
mempool_Ltrisol_f16s(PtrL, Ptry2, Ptry3, N_TX);
mempool_Lttrisol_f16s(PtrL, Ptry3, Ptrx, N_TX);
}
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
#endif

#ifdef SINGLE
if (core_id == 0) {
Expand Down Expand Up @@ -116,8 +163,18 @@ int main() {
#endif

// Check the result
#ifdef BANSHEE
if (core_id == 0) {
for (uint32_t i = 0; i < 2 * N_TX * N_ITR; i++) {
uint32_t x = (*(uint32_t *)&l1_x[i]) & 0x0000FFFF;
printf("RES=%04x\n", x);
}
}
mempool_barrier(num_cores);
#else
mempool_check_f16(l1_x, l2_x, 2 * N_TX, 0.01f, 0);
mempool_barrier(num_cores);
#endif
return 0;
}

Expand Down
38 changes: 28 additions & 10 deletions software/kernels/baremetal/mempool_cholesky_f16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*/
void mempool_cholesky_f16s(__fp16 *pSrc, __fp16 *pL, const uint32_t n) {

register __fp16 sum;
__fp16 sum;
__fp16 a, b;
__fp16 c, d;
__fp16 diag; // Diagonal element
Expand Down Expand Up @@ -173,21 +173,22 @@ void mempool_cholesky_folded_f16s(__fp16 *pSrc, __fp16 *pL, const uint32_t n) {
*/
void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) {

float sum; // register float sum
__fp16 diag; // Diagonal element
__fp16 ap;
float sum; // register float sum
v2h abp, ab, cd;
__fp16 diag, ap;

#ifndef __CDOTP
float as, bs;
v2h abp, ab, cd, ndc;
v2h vec_sum;
v2h vec_diag;
v2h ndc, vec_sum;
#endif

v2h vec_diag;
uint32_t i, j, k;

for (j = 0; j < n; j++) {
// Elements on diagonal (input matrix is positive-definite)
ap = pSrc[2U * (j * n + j)];
sum = (float)0.0f;
sum = 0.0f;
for (k = 0; k < j; k++) {
ab = (*(v2h *)&pL[2U * (j * n + k)]);
asm volatile("vfdotpex.s.h %[sum], %[ab], %[ab];"
Expand All @@ -209,6 +210,22 @@ void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) {
abp = (*(v2h *)&pSrc[2U * (i * n + j)]);
// Diag
diag = pL[2U * (j * n + j)];

#ifdef __CDOTP
for (k = 0; k < j; k++) {
ab = (*(v2h *)&pL[2U * (i * n + k)]);
cd = (*(v2h *)&pL[2U * (j * n + k)]);
asm volatile("fcndotpex.s.h %[abp],%[ab],%[cd];"
: [abp] "=&r"(abp)
: [ab] "r"(ab), [cd] "r"(cd)
:);
}
asm volatile("pv.pack %[vec_diag], %[diag], %[diag];"
"vfdiv.h %[abp], %[abp], %[vec_diag];"
: [abp] "+&r"(abp), [vec_diag] "=&r"(vec_diag)
: [diag] "r"(diag)
:);
#else
// Sum -> s = s + (ac + bd) + j*(bc - ad)
as = (float)0.0f;
bs = (float)0.0f;
Expand All @@ -221,21 +238,22 @@ void mempool_cholesky_f16vecs(__fp16 *pSrc, __fp16 *pL, const uint32_t n) {
// s = s + (ac + bd) + j(bc - ad)
"vfdotpex.s.h %[as], %[ab], %[cd];"
"pv.shuffle2.h %[ndc], %[cd], %[shuffle_mask];"
"xor %[ndc], %[neg_mask], %[ndc];"
"xor %[ndc], %[neg_mask], %[ndc];"
"vfdotpex.s.h %[bs], %[ab], %[ndc];"
: [as] "+&r"(as), [bs] "+&r"(bs), [ndc] "+r"(ndc)
: [ab] "r"(ab), [cd] "r"(cd), [neg_mask] "r"(neg_mask),
[shuffle_mask] "r"(shuffle_mask)
:);
}
asm volatile("vfcpka.h.s %[vec_sum], %[as], %[bs];"
"pv.pack.h %[vec_diag], %[diag], %[diag];"
"pv.pack %[vec_diag], %[diag], %[diag];"
"vfsub.h %[abp], %[abp], %[vec_sum];"
"vfdiv.h %[abp], %[abp], %[vec_diag];"
: [abp] "+&r"(abp), [vec_sum] "+&r"(vec_sum),
[vec_diag] "+&r"(vec_diag)
: [as] "r"(as), [bs] "r"(bs), [diag] "r"(diag)
:);
#endif
(*(v2h *)&pL[2U * (i * n + j)]) = abp;
}
}
Expand Down
119 changes: 112 additions & 7 deletions software/kernels/baremetal/mempool_mimo_mmse_f16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,15 @@ void mempool_hermitian_f16vecs(__fp16 *pH, __fp16 *pG, __fp16 *pS,
uint32_t i, j, k;
v2h ab;
v2h cd0, cd1, cd2, cd3;
const uint32_t shuffle_mask = 0x00020003;
#ifndef __CDOTP
float as0, as1, as2, as3;
float bs0, bs1, bs2, bs3;
const uint32_t neg_mask = 0x80000000;
const uint32_t shuffle_mask = 0x00020003;
#endif

#ifndef __CDOTP

for (i = 0; i < n_tx; i++) {

if (n_tx == 1) {
Expand Down Expand Up @@ -433,6 +438,64 @@ void mempool_hermitian_f16vecs(__fp16 *pH, __fp16 *pG, __fp16 *pS,
}
}
}

#else

for (i = 0; i < n_tx; i++) {
if (n_tx >= 4) {
// UNROLL_4
for (j = 0; j < n_tx; j += 4) {
v2h res0 = (v2h)0.0f;
v2h res1 = (v2h)0.0f;
v2h res2 = (v2h)0.0f;
v2h res3 = (v2h)0.0f;
for (k = 0; k < n_rx; k++) {
ab = (*(v2h *)&pH[2U * (k * n_tx + i)]);
cd0 = (*(v2h *)&pH[2U * (k * n_tx + j)]);
cd1 = (*(v2h *)&pH[2U * (k * n_tx + j + 1U)]);
cd2 = (*(v2h *)&pH[2U * (k * n_tx + j + 2U)]);
cd3 = (*(v2h *)&pH[2U * (k * n_tx + j + 3U)]);
asm volatile("fccdotpex.s.h %[res0], %[ab], %[cd0];"
"fccdotpex.s.h %[res1], %[ab], %[cd1];"
"fccdotpex.s.h %[res2], %[ab], %[cd2];"
"fccdotpex.s.h %[res3], %[ab], %[cd3];"
: [res0] "+&r"(res0), [res1] "+&r"(res1),
[res2] "+&r"(res2), [res3] "+&r"(res3)
: [cd0] "r"(cd0), [cd1] "r"(cd1), [cd2] "r"(cd2),
[cd3] "r"(cd3), [ab] "r"(ab)
:);
}
asm volatile("pv.shuffle2.h %[res0], %[res0], %[shuffle_mask];"
"pv.shuffle2.h %[res1], %[res1], %[shuffle_mask];"
"pv.shuffle2.h %[res2], %[res2], %[shuffle_mask];"
"pv.shuffle2.h %[res3], %[res3], %[shuffle_mask];"
: [res0] "+&r"(res0), [res1] "+&r"(res1),
[res2] "+&r"(res2), [res3] "+&r"(res3)
: [shuffle_mask] "r"(shuffle_mask)
:);
__fp16 sigma = pS[2 * i];
if (i == j) {
asm volatile("and %0, %0, %1;" : "+&r"(res0) : "r"(0x0000FFFF));
asm volatile("fadd.h %0, %0, %1;" : "+&r"(res0) : "r"(sigma));
} else if (i == (j + 1U)) {
asm volatile("and %0, %0, %1;" : "+&r"(res1) : "r"(0x0000FFFF));
asm volatile("fadd.h %0, %0, %1;" : "+&r"(res1) : "r"(sigma));
} else if (i == (j + 2U)) {
asm volatile("and %0, %0, %1;" : "+&r"(res2) : "r"(0x0000FFFF));
asm volatile("fadd.h %0, %0, %1;" : "+&r"(res2) : "r"(sigma));
} else if (i == (j + 3U)) {
asm volatile("and %0, %0, %1;" : "+&r"(res3) : "r"(0x0000FFFF));
asm volatile("fadd.h %0, %0, %1;" : "+&r"(res3) : "r"(sigma));
}
(*(v2h *)&pG[2 * (i * n_tx + j)]) = res0;
(*(v2h *)&pG[2 * (i * n_tx + j + 1U)]) = res1;
(*(v2h *)&pG[2 * (i * n_tx + j + 2U)]) = res2;
(*(v2h *)&pG[2 * (i * n_tx + j + 3U)]) = res3;
}
}
}

#endif
return;
}

Expand All @@ -450,14 +513,19 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py,
const uint32_t n_rx, const uint32_t n_tx) {

uint32_t i, j;
float as0, as1, as2, as3;
float bs0, bs1, bs2, bs3;
v2h res0, res1, res2, res3;
v2h ab0, ab1, ab2, ab3;
v2h cd;
uint32_t ndc;
const uint32_t neg_mask = 0x80000000;
const uint32_t shuffle_mask = 0x00020003;

#ifndef __CDOTP
float as0, as1, as2, as3;
float bs0, bs1, bs2, bs3;
const uint32_t neg_mask = 0x80000000;
uint32_t ndc;
#endif

#ifndef __CDOTP
if (n_tx < 4) {
for (i = 0; i < n_tx; i++) {
as0 = 0.0f; // Initialize the real part of sums
Expand All @@ -478,7 +546,6 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py,
[shuffle_mask] "r"(shuffle_mask), [ab0] "r"(ab0)
:);
}
v2h res0;
asm volatile("vfcpka.h.s %0, %1, %2;"
: "=&r"(res0)
: "r"(as0), "r"(bs0)
Expand Down Expand Up @@ -521,7 +588,6 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py,
[ab1] "r"(ab1), [ab2] "r"(ab2), [ab3] "r"(ab3)
:);
}
v2h res0, res1, res2, res3;
asm volatile(
"vfcpka.h.s %[res0], %[as0], %[bs0];"
"vfcpka.h.s %[res1], %[as1], %[bs1];"
Expand All @@ -538,5 +604,44 @@ void mempool_MVP_conjtransp_f16vecs(__fp16 *pH, __fp16 *px, __fp16 *py,
*(v2h *)&py[2U * (i + 3U)] = res3;
}
}
#else
if (n_tx >= 4) {
// UNROLL_4
for (i = 0; i < n_tx; i += 4) {
res0 = (v2h)0.0f;
res1 = (v2h)0.0f;
res2 = (v2h)0.0f;
res3 = (v2h)0.0f;
for (j = 0; j < n_rx; j++) {
ab0 = *(v2h *)&pH[2U * (j * n_tx + i)];
ab1 = *(v2h *)&pH[2U * (j * n_tx + i + 1U)];
ab2 = *(v2h *)&pH[2U * (j * n_tx + i + 2U)];
ab3 = *(v2h *)&pH[2U * (j * n_tx + i + 3U)];
cd = *(v2h *)&px[2U * j];
asm volatile("fccdotpex.s.h %[res0], %[ab0], %[cd];"
"fccdotpex.s.h %[res1], %[ab1], %[cd];"
"fccdotpex.s.h %[res2], %[ab2], %[cd];"
"fccdotpex.s.h %[res3], %[ab3], %[cd];"
: [res0] "+&r"(res0), [res1] "+&r"(res1),
[res2] "+&r"(res2), [res3] "+&r"(res3)
: [ab0] "r"(ab0), [ab1] "r"(ab1), [ab2] "r"(ab2),
[ab3] "r"(ab3), [cd] "r"(cd)
:);
}
asm volatile("pv.shuffle2.h %[res0], %[res0], %[shuffle_mask];"
"pv.shuffle2.h %[res1], %[res1], %[shuffle_mask];"
"pv.shuffle2.h %[res2], %[res2], %[shuffle_mask];"
"pv.shuffle2.h %[res3], %[res3], %[shuffle_mask];"
: [res0] "+&r"(res0), [res1] "+&r"(res1), [res2] "+&r"(res2),
[res3] "+&r"(res3)
: [shuffle_mask] "r"(shuffle_mask)
:);
*(v2h *)&py[2U * i] = res0;
*(v2h *)&py[2U * (i + 1U)] = res1;
*(v2h *)&py[2U * (i + 2U)] = res2;
*(v2h *)&py[2U * (i + 3U)] = res3;
}
}
#endif
return;
}

0 comments on commit d05ce74

Please sign in to comment.