Skip to content

Commit

Permalink
delet gemm rt and update hw cfg and sw (#51)
Browse files Browse the repository at this point in the history
* delet gemm rt and update hw cfg and sw

* update sw

* fix gemmx src
  • Loading branch information
xiaoling-yi authored Oct 8, 2024
1 parent 37db678 commit 8b9eb25
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 439 deletions.
3 changes: 3 additions & 0 deletions target/rtl/cfg/cluster_cfg/snax_KUL_xdma_cluster.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,11 @@
temporal_dim: [3, 3],
num_channel: [32, 32],
fifo_depth: [2, 2],
configurable_channel: [1, 0],
},

has_transpose: true,

snax_library_name: "gemmx",
}
}
95 changes: 77 additions & 18 deletions target/sim/sw/device/apps/snax/snax-gemmx/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,13 +732,26 @@ def emit_matmul_data(**kwargs):
MIN, MAX, size=(kwargs["M"], kwargs["K"], meshRow, tileSize)
).reshape(-1)
data_str += [format_vector_definition("int8_t", "A", A)]

B = np.random.randint(
MIN, MAX, size=(kwargs["K"], kwargs["N"], tileSize, meshCol)
).reshape(-1)
data_str += [format_vector_definition("int8_t", "B", B)]
C = np.random.randint(
MIN, MAX, size=(kwargs["M"], kwargs["N"], meshRow, meshCol)
).reshape(-1)

if kwargs["channel_en_C"] == 1:
C = np.random.randint(
MIN, MAX, size=(kwargs["M"], kwargs["N"], meshRow, meshCol)
).reshape(-1)
else:
C = np.random.randint(
0, 1, size=(kwargs["M"], kwargs["N"], meshRow, meshCol)
).reshape(-1)
if kwargs["channel_en_C"] == 1:
data_str += [
format_scalar_definition("int32_t", "channel_en_C", ((1 << 32) - 1))
]
else:
data_str += [format_scalar_definition("int32_t", "channel_en_C", 0)]
data_str += [format_vector_definition("int32_t", "C", C)]

if kwargs["transposed_A"] == 1:
Expand Down Expand Up @@ -792,35 +805,81 @@ def emit_gemmx_data(**kwargs):
data_str += [format_scalar_definition("int32_t", "bypassSIMD", bypassSIMD)]

# Generating random constant values
group_num = 8
input_zp_i = np.random.randint(MIN, MAX)
output_zp_i = np.random.randint(MIN, MAX)
shift_i = np.random.randint(0, 63) # values between 0-63
max_int_i = MAX
min_int_i = MIN
double_round_i = np.random.randint(0, 1)
multiplier_i = np.random.randint(-(2**31), 2**31 - 1)

shift_i = np.random.randint(0, 63, size=group_num) # values between 0-63
multiplier_i = np.random.randint(-(2**31), 2**31 - 1, size=group_num)

# Writing the constant values to data.h
data_str += [
format_scalar_definition("int8_t", "input_zp_i", input_zp_i),
format_scalar_definition("int8_t", "output_zp_i", output_zp_i),
format_scalar_definition("int8_t", "shift_i", shift_i),
format_scalar_definition("int8_t", "max_int_i", max_int_i),
format_scalar_definition("int8_t", "min_int_i", min_int_i),
format_scalar_definition("int8_t", "double_round_i", double_round_i),
format_scalar_definition("int32_t", "multiplier_i", multiplier_i),
]

D8 = postprocessing_simd_golden_model(
D32,
input_zp_i,
output_zp_i,
shift_i,
max_int_i,
min_int_i,
double_round_i,
multiplier_i,
]

shared_bitpacked_shift0 = (
(shift_i[3] << 24) | (shift_i[2] << 16) | (shift_i[1] << 8) | shift_i[0]
)
shared_bitpacked_shift1 = (
(shift_i[7] << 24) | (shift_i[6] << 16) | (shift_i[5] << 8) | shift_i[4]
)
data_str += [
format_scalar_definition(
"int32_t", "shared_bitpacked_shift0", shared_bitpacked_shift0
)
]
data_str += [
format_scalar_definition(
"int32_t", "shared_bitpacked_shift1", shared_bitpacked_shift1
)
]

data_str += [
format_scalar_definition("int32_t", "shared_multiplier0", multiplier_i[0])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier1", multiplier_i[1])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier2", multiplier_i[2])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier3", multiplier_i[3])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier4", multiplier_i[4])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier5", multiplier_i[5])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier6", multiplier_i[6])
]
data_str += [
format_scalar_definition("int32_t", "shared_multiplier7", multiplier_i[7])
]

D8 = np.zeros_like(D32, dtype=np.uint8)
# output channel (innermost dim) has a different scale factor
for i in range(group_num):
D8[i::group_num] = postprocessing_simd_golden_model(
D32[i::group_num],
input_zp_i,
output_zp_i,
shift_i[i],
max_int_i,
min_int_i,
double_round_i,
multiplier_i[i],
)

data_str += [format_vector_definition("int8_t", "D8", D8)]

data_str = "\n\n".join(data_str)
Expand Down
43 changes: 32 additions & 11 deletions target/sim/sw/device/apps/snax/snax-gemmx/data/params.hjson
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
// Copyright 2024 KU Leuven.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

// Cluster configuration for a simple testbench system.
{
K: 2
N: 2
M: 3
bypassSIMD: 1
ifTestMatmul: 1
transposed_A: 0
transposed_B: 0
meshRow: 8
meshCol: 8
tileSize: 8
}
ifC8HW8datalayout: true,
// ifC8HW8datalayout: false,
Nbatch: 1,
W: 8,
H: 2,
Cin: 16,
Cout: 16,
Kh: 3,
Kw: 3,
pad_h: 1,
pad_w: 1,
stride_h: 1,
stride_w: 1,
transposed_A: 1,
transposed_B: 0,
K: 18,
N: 2,
M: 1,
bypassSIMD: 0,
channel_en_C: 0,
ifTestMatmul: 1,

// hardware parameters
meshRow : 8,
tileSize: 8,
meshCol : 8,
}
65 changes: 40 additions & 25 deletions target/sim/sw/device/apps/snax/snax-gemmx/src/snax-gemmx.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,29 @@ int main() {
// Set err value for checking
int err = 0;

// Declaration of Variables
int32_t program_start;
int32_t cycle_start, cycle_end;

// Prepare addresses in TCDM
int8_t *local_a, *local_b;
int32_t *local_c, *local_d32;
int8_t *local_d8;

// Set GEMMX configuration CSR
uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);

uint32_t csr0 =
gen_csr0_config(input_zp_i, output_zp_i, max_int_i, min_int_i);
uint32_t csr1 = gen_csr1_config(double_round_i);

// Set the size configuration
int32_t gemmx_cycles;
int32_t gemmx_streamer_cycles;

if (snrt_cluster_idx() == 1){
printf("SNAX GEMM Conv2d: Start\n");
// Prepare addresses in TCDM
int8_t *local_a, *local_b;
int32_t *local_c, *local_d32;
int8_t *local_d8;

// Allocate space in TCDM
local_a = (int8_t *)(snrt_l1_next() + delta_local_a);
local_b = (int8_t *)(snrt_l1_next() + delta_local_b);
Expand All @@ -33,9 +49,9 @@ int main() {

// Transfer data from L3 to L1
// Using DMA only
int32_t cycle_start = snrt_mcycle();
int32_t program_start = snrt_mcycle();
program_start = snrt_mcycle();
if (snrt_is_dm_core()) {
cycle_start = snrt_mcycle();
#ifdef TEST_MATMUL
snrt_dma_start_1d(local_a, A,
M * K * meshRow * tileSize * sizeof(int8_t));
Expand All @@ -50,13 +66,14 @@ int main() {
snrt_dma_start_1d(local_c, C,
M * N * meshRow * meshCol * sizeof(int32_t));
snrt_dma_wait_all();
cycle_end = snrt_mcycle();
printf("DMA cycles %d \n", cycle_end - cycle_start);
}

snrt_cluster_hw_barrier();
int32_t cycle_end = snrt_mcycle();
printf("DMA cycles %d \n", cycle_end - cycle_start);

if (snrt_global_core_idx() == 0) {
if (snrt_cluster_core_idx() == 0) {
printf("SNAX GEMM Conv2d: Start\n");
cycle_start = snrt_mcycle();
// Set Streamer configuration CSR for conv2d
set_gemmx_streamer_csr(
Expand All @@ -77,29 +94,27 @@ int main() {
D32tlstride1, D32tlbound2, D32tlstride2,

delta_local_a, delta_local_b, delta_local_d8, delta_local_c,
delta_local_d32, bypassSIMD, transposed_A, transposed_B);
delta_local_d32, bypassSIMD, transposed_A, transposed_B,
channel_en_C);

set_gemmx_csr(
K, N, M, subtraction_setting, csr0, csr1, shared_bitpacked_shift0,
shared_bitpacked_shift1, shared_multiplier0, shared_multiplier1,
shared_multiplier2, shared_multiplier3, shared_multiplier4,
shared_multiplier5, shared_multiplier6, shared_multiplier7, M * N,
bypassSIMD);

// Set CSR to start Streamer for conv2d
set_gemmx_streamer_start();

// Set GEMMX configuration CSR
uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);

uint32_t csr0 =
gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i);
uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i);
uint32_t csr2 = gen_csr2_config(multiplier_i);

set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N,
bypassSIMD);

// Set CSR to start GEMM
set_gemmx_start();

// Poll until Streamer and GEMM accelerator finish
wait_gemmx_and_streamer();

printf("SNAX GEMM Conv2d: Finish\n");

cycle_end = snrt_mcycle();

// check the result of the implicit im2col convolution
Expand All @@ -115,8 +130,8 @@ int main() {
printf("SNAX GEMM Conv2d: %s, Error: %d . bypassSIMD = %d .\n",
err ? "FAIL" : "PASS", err, bypassSIMD);
#endif
int32_t gemmx_cycles = read_gemmx_perf_counter();
int32_t gemmx_streamer_cycles = read_gemmx_streamer_perf_counter();
gemmx_cycles = read_gemmx_perf_counter();
gemmx_streamer_cycles = read_gemmx_streamer_perf_counter();
printf("Workload size: M = %d x N = %d x K = %d\n", M, N, K);
printf("SNAX GEMM cycles: %d\n", gemmx_cycles);
printf("SNAX GEMM Streamer cycles: %d\n", gemmx_streamer_cycles);
Expand Down
32 changes: 0 additions & 32 deletions target/sim/sw/device/snax/gemm/Makefile

This file was deleted.

59 changes: 0 additions & 59 deletions target/sim/sw/device/snax/gemm/include/snax-gemm-lib.h

This file was deleted.

Loading

0 comments on commit 8b9eb25

Please sign in to comment.