Skip to content

Commit

Permalink
Add software test and golden model for subtraction (pulp-platform#53)
Browse files Browse the repository at this point in the history
* sw: add test and golden model for subtraction

* lint: formatting datagen.py

* Apply suggestions from code review

Co-authored-by: Josse Van Delm <[email protected]>

* sw: format snax-gemm-lib.c

---------

Co-authored-by: Josse Van Delm <[email protected]>
  • Loading branch information
xiaoling-yi and JosseVanDelm committed Dec 11, 2023
1 parent f3a4423 commit 27757d9
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Bender.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:
tech_cells_generic: { git: https://github.com/pulp-platform/tech_cells_generic, version: 0.2.11 }
riscv-dbg: { git: https://github.com/pulp-platform/riscv-dbg, version: 0.8.0 }
hwpe-mac-engine: { git: https://github.com/KULeuven-MICAS/hwpe-mac-engine.git, rev: 5d3b4525b665169fc8321c8a811f3c83ad3c72e8 }
snax-gemm: { git: https://github.com/KULeuven-MICAS/snax-gemm.git, rev: db87157ad619bdbfe317b48c9a2790bce156db6f }
snax-gemm: { git: https://github.com/KULeuven-MICAS/snax-gemm.git, rev: 7e04484633da28513bdbe96a73a86a8476f69da4 }

vendor_package:
- name: musl
Expand Down
25 changes: 23 additions & 2 deletions target/snitch_cluster/sw/apps/snax-block-gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


# Golden model in python
def block_gemm_golden_model(m, k, n, row, size, col, a, b):
def block_gemm_golden_model(m, k, n, row, size, col, a, b,
subtraction_a, subtraction_b):
c = np.zeros(m * row * n * col, dtype=(np.int32))
for mm in range(m):
for nn in range(n):
Expand All @@ -48,7 +49,9 @@ def block_gemm_golden_model(m, k, n, row, size, col, a, b):
+ cc * size
+ ss
)
c[c_index] = c[c_index] + a[a_index] * b[b_index]
c[c_index] = c[c_index] + \
(a[a_index] - subtraction_a) * \
(b[b_index] - subtraction_b)
return c


Expand Down Expand Up @@ -115,6 +118,22 @@ def emit_gemm_data(**kwargs):
)
]

# Generating random 8 integer a and b for subtraction
subtraction_a = np.random.randint(MIN, MAX)
subtraction_b = np.random.randint(MIN, MAX)

# Writing the subtraction value to data.h
data_str += [
format_scalar_definition(
"int8_t", "subtraction_a", subtraction_a
)
]
data_str += [
format_scalar_definition(
"int8_t", "subtraction_b", subtraction_b
)
]

# Generate random input matrices
length_a = (
kwargs["M"] * kwargs["K"] * kwargs["meshRow"] * kwargs["tileSize"]
Expand All @@ -135,6 +154,8 @@ def emit_gemm_data(**kwargs):
kwargs["meshCol"],
a,
b,
subtraction_a,
subtraction_b
)
c_init = np.zeros(c_golden.shape)
c_cpu = np.zeros(c_golden.shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,16 @@ int main() {
// Pack matrix size setting to one CSR
uint32_t size_setting = gen_size_config(Batch, M, K, N);

uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);

uint32_t gemm_start = snrt_mcycle();

// Set GEMM configuration CSR
set_batch_gemm(size_setting, local_a, local_b, local_c,
strideInnermostA, strideInnermostB, strideInnermostC,
ldA, ldB, ldC, strideA, strideB, strideC);
set_batch_gemm(size_setting, local_a, local_b, subtraction_setting,
local_c, strideInnermostA, strideInnermostB,
strideInnermostC, ldA, ldB, ldC, strideA, strideB,
strideC);

// Set CSR to start GEMM and poll until GEMM accelerator finishes
start_batch_gemm();
Expand All @@ -77,9 +81,10 @@ int main() {
// region for benchmarking)
uint32_t start_cycle = snrt_mcycle();

batch_gemm_cpu(Batch, M, K, N, local_a, local_b, C_cpu,
strideInnermostA, strideInnermostB, strideInnermostC,
ldA, ldB, ldC, strideA, strideB, strideC);
batch_gemm_cpu(Batch, M, K, N, local_a, local_b, subtraction_a,
subtraction_b, C_cpu, strideInnermostA, strideInnermostB,
strideInnermostC, ldA, ldB, ldC, strideA, strideB,
strideC);

// Read the mcycle CSR
uint32_t end_cycle = snrt_mcycle();
Expand Down
21 changes: 15 additions & 6 deletions target/snitch_cluster/sw/snax/gemm/include/snax-gemm-lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@
// Pack matrix size setting to one CSR
int32_t gen_size_config(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N);

// Pack two subtraction values to one CSR
int32_t gen_subtraction_config(int8_t subtraction_a, int8_t subtraction_b);

// Performance counter for GEMM busy cycles
uint32_t read_performance_counter();

// Golden model for base gemm
void base_gemm(uint8_t m, uint8_t k, uint8_t n, int8_t* A, int8_t* B,
int32_t* C_cpu, bool new_batch);
int8_t subtraction_a, int8_t subtraction_b, int32_t* C_cpu,
bool new_batch);

// Golden model for batch gemm
void batch_gemm_cpu(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, int8_t* A,
int8_t* B, int32_t* C, uint32_t strideInnermostA,
int8_t* B, int8_t subtraction_a, int8_t subtraction_b,
int32_t* C, uint32_t strideInnermostA,
uint32_t strideInnermostB, uint32_t strideInnermostC,
uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA,
uint32_t strideB, uint32_t strideC);
Expand All @@ -33,10 +41,11 @@ void load_input_data(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N,

// Set GEMM configuration CSR
void set_batch_gemm(uint32_t size_setting, int8_t* local_a, int8_t* local_b,
int32_t* local_c, uint32_t strideInnermostA,
uint32_t strideInnermostB, uint32_t strideInnermostC,
uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA,
uint32_t strideB, uint32_t strideC);
int32_t subtractions, int32_t* local_c,
uint32_t strideInnermostA, uint32_t strideInnermostB,
uint32_t strideInnermostC, uint32_t ldA, uint32_t ldB,
uint32_t ldC, uint32_t strideA, uint32_t strideB,
uint32_t strideC);

// Set CSR to start GEMM
void start_batch_gemm();
Expand Down
40 changes: 29 additions & 11 deletions target/snitch_cluster/sw/snax/gemm/src/snax-gemm-lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@ int32_t gen_size_config(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N) {
(int32_t)N;
}

int32_t gen_subtraction_config(int8_t subtraction_a, int8_t subtraction_b) {
return ((uint8_t)subtraction_b << 8) | (uint8_t)subtraction_a;
}

uint32_t read_performance_counter() {
uint32_t performance_counter;
performance_counter = read_csr(0x3cd);
return performance_counter;
};

void base_gemm(uint8_t m, uint8_t k, uint8_t n, int8_t* A, int8_t* B,
int32_t* C_cpu, bool clear) {
int8_t subtraction_a, int8_t subtraction_b, int32_t* C_cpu,
bool clear) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
// clear memory first before start matrix multiplication
Expand All @@ -23,15 +34,18 @@ void base_gemm(uint8_t m, uint8_t k, uint8_t n, int8_t* A, int8_t* B,
C_cpu[i * n + j] = 0;
}
for (int s = 0; s < k; s++) {
C_cpu[i * n + j] = C_cpu[i * n + j] + (int32_t)A[i * k + s] *
(int32_t)B[s + j * k];
C_cpu[i * n + j] =
C_cpu[i * n + j] +
((int32_t)A[i * k + s] - (int32_t)subtraction_a) *
((int32_t)B[s + j * k] - (int32_t)subtraction_b);
}
}
}
};

void batch_gemm_cpu(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, int8_t* A,
int8_t* B, int32_t* C, uint32_t strideInnermostA,
int8_t* B, int8_t subtraction_a, int8_t subtraction_b,
int32_t* C, uint32_t strideInnermostA,
uint32_t strideInnermostB, uint32_t strideInnermostC,
uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA,
uint32_t strideB, uint32_t strideC) {
Expand Down Expand Up @@ -62,7 +76,7 @@ void batch_gemm_cpu(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, int8_t* A,
// when k == 0, clear the memory
clear = k == 0;
base_gemm(meshRow, tileSize, meshCol, addr_a, addr_b,
addr_c, clear);
subtraction_a, subtraction_b, addr_c, clear);
}
}
}
Expand Down Expand Up @@ -120,10 +134,11 @@ void load_input_data(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N,
}

void set_batch_gemm(uint32_t size_setting, int8_t* local_a, int8_t* local_b,
int32_t* local_c, uint32_t strideInnermostA,
uint32_t strideInnermostB, uint32_t strideInnermostC,
uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA,
uint32_t strideB, uint32_t strideC) {
int32_t subtractions, int32_t* local_c,
uint32_t strideInnermostA, uint32_t strideInnermostB,
uint32_t strideInnermostC, uint32_t ldA, uint32_t ldB,
uint32_t ldC, uint32_t strideA, uint32_t strideB,
uint32_t strideC) {
// Set matrix size
write_csr(0x3c0, size_setting);

Expand All @@ -144,21 +159,24 @@ void set_batch_gemm(uint32_t size_setting, int8_t* local_a, int8_t* local_b,
write_csr(0x3ca, strideA);
write_csr(0x3cb, strideB);
write_csr(0x3cc, strideC);

// Set subtraction values
write_csr(0x3ce, subtractions);
}

void start_batch_gemm() {
// 0x3ce is the CSR address for accelerator status
// set the lowest bit of state CSR (state CSR[0]) to set start signal in
// GEMM
write_csr(0x3ce, 1);
write_csr(0x3cf, 1);
}

void wait_batch_gemm() {
uint32_t break_poll;

while (1) {
// poll the state CSR[1] to see if GEMM is still busy
break_poll = read_csr(0x3ce);
break_poll = read_csr(0x3cf);
if ((break_poll >> 1) == 1) {
break;
};
Expand Down

0 comments on commit 27757d9

Please sign in to comment.