Skip to content

Commit

Permalink
[software] Fix shape of arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Oct 28, 2024
1 parent b21ea0c commit 9752fea
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions software/PHY_emulator/MMSE_BER.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import sys
import pyflexfloat as ff
import matplotlib.pyplot as plt
from scipy.linalg import solve_triangular


# __ __ ___ __ __ ___ _______ __
# | \/ |_ _| \/ |/ _ \ __|_ _\ \/ /
Expand Down Expand Up @@ -115,12 +117,10 @@ def mmse(x, H, y, N, my_type):
# MMSE estimator
H_h = H.conj().T
G = np.matmul(H_h, H) + N * np.eye(H.shape[1])
# G = np.matmul(H_h, H)
xhat = np.matmul(np.linalg.inv(G), np.dot(H_h, y))

# Type cast
xhat = xhat.real.astype(my_type) + 1j * xhat.imag.astype(my_type)
H = H.flatten()
H = H.flatten(order='C')
return N, H, y, x, xhat


Expand All @@ -129,26 +129,23 @@ def generate_mimo_transmission_f16(

# Create input vector
idx = np.random.randint(0, len(symbols), size=N_tx)
Es = np.mean(abs(symbols)**2)
x = symbols[idx]

# Generate channel and noise
if channel_type == 'rayleigh':
# Generate Rayleigh fading channel
scale = 1.0
H = np.sqrt(np.random.chisquare(2, [N_rx, N_tx])) + 1j * \
np.sqrt(np.random.chisquare(2, [N_rx, N_tx]))
Es = np.mean(abs(symbols)**2)
Eh = (np.linalg.norm(H, 'fro')**2) / N_rx
elif channel_type == 'random':
H = np.sqrt(0.5) * \
(np.random.normal(0, 1, [N_rx, N_tx]) + 1j *
np.random.normal(0, 1, [N_rx, N_tx]))
Es = np.mean(abs(symbols)**2)
Eh = (np.linalg.norm(H, 'fro')**2) / N_rx
else:
# Generate AWGN channel
H = np.eye(N_rx) + 1.j * np.zeros([N_rx, N_tx])
Es = np.mean(abs(symbols)**2)
Eh = 1

# Noise variance
Expand All @@ -172,6 +169,7 @@ def generate_mimo_transmission_f16(
# Golden model
x64 = np.column_stack((x64.real, x64.imag)).flatten()
xhat64 = np.column_stack((xhat64.real, xhat64.imag)).flatten()

output = {
"N16": N16,
"y16": y16,
Expand Down Expand Up @@ -222,8 +220,8 @@ def gen_data_header_file(outdir, my_type, **kwargs):
string += "#define N_ITR ({})\n".format(kwargs['N_itr'])
string += stringify_array(kwargs['H'].flatten(order='F'),
my_type, "l2_H")
string += stringify_array(kwargs['y'].flatten(), my_type, "l2_y")
string += stringify_array(kwargs['N'].flatten(), my_type, "l2_S")
string += stringify_array(kwargs['y'].flatten(order='F'), my_type, "l2_y")
string += stringify_array(kwargs['N'].flatten(order='F'), my_type, "l2_S")
f.write(string)


Expand All @@ -236,10 +234,10 @@ def banshee_call(banshee_dir: pathlib.Path.cwd(),
file_dir = os.path.dirname(os.path.realpath(__file__))
compile_app = "DEFINES=" + compiler_flag + " "
compile_app += "l1_bank_size=16384 config=terapool "
compile_app += "make COMPILER=llvm BIN_DIR={}/bin ".format(file_dir)
compile_app += "make COMPILER=llvm ".format(file_dir)
compile_app += "{} -C {}/apps/baremetal".format(app, software_dir)
banshee_arg = " --num-cores 1 --num-clusters 1 --configuration config/terapool.yaml"
banshee_app = " {}/bin/{}".format(file_dir, app)
banshee_app = " {}/bin/apps/baremetal/{}".format(software_dir, app)
run_banshee = "SNITCH_LOG=info cargo run --" + banshee_arg + banshee_app

# Compile code
Expand Down Expand Up @@ -360,7 +358,6 @@ def main():
# Arithmetic precisions + compiler flags
if run_banshee & (channel_type == "rayleigh"):
precisions = [['64b', ""],
['16b-MP', "\"-DSINGLE -DBANSHEE\""],
['16b-MP wDotp', "\"-DSINGLE -DBANSHEE -DVEC\""]]
vSNRdB = range(0, 40, 4)
vITR = np.concatenate([np.full(9, 1), np.full(1, 2)])
Expand Down Expand Up @@ -459,8 +456,9 @@ def main():
gen_data_header_file(DATA_DIR, '__fp16', **kwargs)
result = banshee_call(
BANSHEE_DIR, SOFTWARE_DIR, flag, "mimo_mmse_f16")
result = banshee_cast_output(result.stderr)
vxhat[iPrec + 1, :, :] = result.reshape(2 * N_tx, N_batch)
result_casted = banshee_cast_output(result.stderr)
vxhat[iPrec + 1, :, :] = result_casted.reshape(2 * N_tx, N_batch, order='F')

# ----------------------------------------------------------------

# ----------------------------------------------------------------
Expand Down Expand Up @@ -510,10 +508,10 @@ def main():
timestr = time.strftime("%Y%m%d_%H%M%S", current_local_time)
col_names = [precision[0] for precision in precisions]
row_names = [f"{value} dB" for value in vSNRdB]
df_ber = pd.DataFrame(vBER.reshape(-1, 1), columns=col_names, index=row_names)
df_evm = pd.DataFrame(vEVM.reshape(-1, 1), columns=col_names, index=row_names)
df_ber.to_excel(f"BER_{timestr}.odf", index=True, header=True, engine='odf')
df_evm.to_excel(f"EVM_{timestr}.odf", index=True, header=True, engine='odf')
df_ber = pd.DataFrame(np.transpose(vBER), columns=col_names, index=row_names)
df_evm = pd.DataFrame(np.transpose(vEVM), columns=col_names, index=row_names)
df_ber.to_excel(f"BER_{timestr}.ods", index=True, header=True, engine='odf')
df_evm.to_excel(f"EVM_{timestr}.ods", index=True, header=True, engine='odf')


# Plot output
Expand Down

0 comments on commit 9752fea

Please sign in to comment.