Skip to content

Commit

Permalink
add gemmx prameter support (#327)
Browse files Browse the repository at this point in the history
* add gemmx prameter support

* c fmt

* add cfg checker

* py fmt

* py fmt
  • Loading branch information
xiaoling-yi authored Sep 17, 2024
1 parent 1aea283 commit 4fae969
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,62 @@ object BlockGemmRescaleSIMD extends App {
Array("--target-dir", "generated/gemmx")
)
}

object BlockGemmRescaleSIMDGen {
def main(args: Array[String]): Unit = {
// Helper function to parse command-line arguments into a Map
def parseArgs(args: Array[String]): Map[String, String] = {
val parsed_args = args
.sliding(2, 2)
.collect {
case Array(key, value) if key.startsWith("--") => key.drop(2) -> value
}
.toMap
if (parsed_args.size != 4) {
throw new Exception(
"Please provide the meshRow, meshCol, tileSize, and withPipeline. Example usage: sbt 'runMain snax_acc.gemmx.BlockGemmRescaleSIMDGen --meshRow 2 --meshCol 2 --tileSize 16 --withPipeline true'"
)
}
parsed_args
}

// Parse the arguments
val argMap = parseArgs(args)

// Retrieve the specific values, providing defaults or error handling
val meshRow = argMap("meshRow")
val meshCol = argMap("meshCol")
val tileSize = argMap("tileSize")

// set the parameters for the gemm module
// other parameters are set to default values
val gemmParams = GemmParams(
GemmConstant.dataWidthA,
GemmConstant.dataWidthB,
GemmConstant.dataWidthMul,
GemmConstant.dataWidthC,
GemmConstant.dataWidthAccum,
GemmConstant.subtractionCfgWidth,
meshRow.toInt,
tileSize.toInt,
meshCol.toInt,
GemmConstant.addrWidth,
GemmConstant.sizeConfigWidth
)

val withPipeline = argMap("withPipeline").toBoolean

emitVerilog(
new BlockGemmRescaleSIMD(
BlockGemmRescaleSIMDParams(
gemmParams,
(if (withPipeline == true)
snax_acc.simd.PipelinedConfig.rescaleSIMDConfig
else snax_acc.simd.DefaultConfig.rescaleSIMDConfig),
false
)
),
Array("--target-dir", "generated/gemmx")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@
snax_wide_tcdm_ports: 48,
snax_num_rw_csr: 10,
snax_num_ro_csr: 2,
snax_gemmx_mesh_row: 8,
snax_gemmx_tile_size: 8,
snax_gemmx_mesh_col: 8,
with_pipeline: true,
snax_streamer_cfg: {$ref: "#/snax_streamer_gemmX_streamer_template" }
},
snax_use_custom_ports: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@
snax_wide_tcdm_ports: 48,
snax_num_rw_csr: 10,
snax_num_ro_csr: 2,
snax_gemmx_mesh_row: 8,
snax_gemmx_tile_size: 8,
snax_gemmx_mesh_col: 8,
with_pipeline: false,
snax_streamer_cfg: {$ref: "#/snax_streamer_gemmX_streamer_template" }
},
snax_use_custom_ports: false,
Expand Down
141 changes: 106 additions & 35 deletions target/snitch_cluster/sw/apps/snax-gemmx/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def emit_header_file(**kwargs):
MIN = -128
MAX = 127

bankWidth = 64
input_data_width = 8
output_data_width = 32
quantized_output_data_width = 8


def emit_conv_data(**kwargs):
Cin = kwargs["Cin"]
Expand Down Expand Up @@ -553,6 +558,11 @@ def emit_conv_data(**kwargs):


def emit_matmul_data(**kwargs):

meshRow = kwargs["meshRow"]
tileSize = kwargs["tileSize"]
meshCol = kwargs["meshCol"]

# matmul settings
data_str = []

Expand All @@ -562,13 +572,23 @@ def emit_matmul_data(**kwargs):
data_str += [format_scalar_definition("int", "N", kwargs["N"])]

data_str += [format_scalar_definition("int32_t", "Aslstride0", 1)]
data_str += [format_scalar_definition("int32_t", "Aslstride1", 8)]
data_str += [format_scalar_definition("int32_t", "Aslstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "Atlbound0", kwargs["K"])]
data_str += [format_scalar_definition("int32_t", "Atlstride0", 64)]
data_str += [
format_scalar_definition(
"int32_t", "Atlstride0", input_data_width * tileSize * meshRow / 8
)
]
data_str += [format_scalar_definition("int32_t", "Atlbound1", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "Atlstride1", 0)]
data_str += [format_scalar_definition("int32_t", "Atlbound2", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "Atlstride2", kwargs["K"] * 64)]
data_str += [
format_scalar_definition(
"int32_t",
"Atlstride2",
kwargs["K"] * input_data_width * tileSize * meshRow / 8,
)
]
data_str += [format_scalar_definition("int32_t", "Atlbound3", 1)]
data_str += [format_scalar_definition("int32_t", "Atlstride3", 0)]
data_str += [format_scalar_definition("int32_t", "Atlbound4", 1)]
Expand All @@ -577,70 +597,115 @@ def emit_matmul_data(**kwargs):
data_str += [format_scalar_definition("int32_t", "Atlstride5", 0)]

data_str += [format_scalar_definition("int32_t", "Bslstride0", 1)]
data_str += [format_scalar_definition("int32_t", "Bslstride1", 8)]
data_str += [format_scalar_definition("int32_t", "Bslstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "Btlbound0", kwargs["K"])]
data_str += [format_scalar_definition("int32_t", "Btlstride0", 64)]
data_str += [
format_scalar_definition(
"int32_t", "Btlstride0", input_data_width * tileSize * meshCol / 8
)
]
data_str += [format_scalar_definition("int32_t", "Btlbound1", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "Btlstride1", 64 * kwargs["K"])]
data_str += [
format_scalar_definition(
"int32_t",
"Btlstride1",
kwargs["K"] * input_data_width * tileSize * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "Btlbound2", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "Btlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "Cslstride0", 4)]
data_str += [format_scalar_definition("int32_t", "Cslstride1", 8)]
data_str += [format_scalar_definition("int32_t", "Cslstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "Ctlbound0", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "Ctlstride0", 256)]
data_str += [
format_scalar_definition(
"int32_t", "Ctlstride0", output_data_width * meshRow * meshCol / 8
)
]
data_str += [format_scalar_definition("int32_t", "Ctlbound1", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "Ctlstride1", 256 * kwargs["N"])]
data_str += [
format_scalar_definition(
"int32_t",
"Ctlstride1",
kwargs["N"] * output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "Ctlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "Ctlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "D32slstride0", 4)]
data_str += [format_scalar_definition("int32_t", "D32slstride1", 8)]
data_str += [format_scalar_definition("int32_t", "D32slstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "D32tlbound0", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "D32tlstride0", 256)]
data_str += [
format_scalar_definition(
"int32_t", "D32tlstride0", output_data_width * meshRow * meshCol / 8
)
]
data_str += [format_scalar_definition("int32_t", "D32tlbound1", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "D32tlstride1", 256 * kwargs["N"])]
data_str += [
format_scalar_definition(
"int32_t",
"D32tlstride1",
kwargs["N"] * output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "D32tlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "D32tlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "D8slstride0", 1)]
data_str += [format_scalar_definition("int32_t", "D8slstride1", 8)]
data_str += [format_scalar_definition("int32_t", "D8slstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "D8tlbound0", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "D8tlstride0", 64)]
data_str += [
format_scalar_definition(
"int32_t",
"D8tlstride0",
quantized_output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "D8tlbound1", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "D8tlstride1", 64 * kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "delta_local_a", 0)]
data_str += [
format_scalar_definition(
"int32_t", "delta_local_b", 64 * kwargs["K"] * kwargs["M"]
"int32_t",
"D8tlstride1",
kwargs["N"] * quantized_output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)]

delta_local_a = 0
delta_local_b = (
kwargs["K"] * kwargs["M"] * (meshRow * tileSize * input_data_width / 8)
)
delta_local_c = delta_local_b + kwargs["K"] * kwargs["N"] * (
meshCol * tileSize * input_data_width / 8
)
delta_local_d32 = delta_local_c + kwargs["M"] * kwargs["N"] * (
meshRow * meshCol * output_data_width / 8
)
delta_local_d8 = delta_local_d32
data_str += [format_scalar_definition("int32_t", "delta_local_a", delta_local_a)]
data_str += [format_scalar_definition("int32_t", "delta_local_b", delta_local_b)]
data_str += [
format_scalar_definition(
"int32_t",
"delta_local_c",
64 * kwargs["K"] * kwargs["M"] + 64 * kwargs["K"] * kwargs["N"],
delta_local_c,
)
]
data_str += [
format_scalar_definition(
"int32_t",
"delta_local_d32",
64 * kwargs["K"] * kwargs["M"]
+ 64 * kwargs["K"] * kwargs["N"]
+ 256 * kwargs["M"] * kwargs["N"],
delta_local_d32,
)
]
data_str += [
format_scalar_definition(
"int32_t",
"delta_local_d8",
64 * kwargs["K"] * kwargs["M"]
+ 64 * kwargs["K"] * kwargs["N"]
+ 256 * kwargs["M"] * kwargs["N"],
delta_local_d8,
)
]

Expand All @@ -652,18 +717,24 @@ def emit_matmul_data(**kwargs):
data_str += [format_scalar_definition("int8_t", "subtraction_a", subtraction_a)]
data_str += [format_scalar_definition("int8_t", "subtraction_b", subtraction_b)]

A = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["K"], 8, 8)).reshape(-1)
A = np.random.randint(
MIN, MAX, size=(kwargs["M"], kwargs["K"], meshRow, tileSize)
).reshape(-1)
data_str += [format_vector_definition("int8_t", "A", A)]
B = np.random.randint(MIN, MAX, size=(kwargs["K"], kwargs["N"], 8, 8)).reshape(-1)
B = np.random.randint(
MIN, MAX, size=(kwargs["K"], kwargs["N"], tileSize, meshCol)
).reshape(-1)
data_str += [format_vector_definition("int8_t", "B", B)]
C = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["N"], 8, 8)).reshape(-1)
C = np.random.randint(
MIN, MAX, size=(kwargs["M"], kwargs["N"], meshRow, meshCol)
).reshape(-1)
data_str += [format_vector_definition("int32_t", "C", C)]

if kwargs["transposed_A"] == 1:
A = A.reshape(kwargs["M"], kwargs["K"], 8, 8)
A = A.reshape(kwargs["M"], kwargs["K"], meshRow, tileSize)
A = A.transpose(0, 1, 3, 2).reshape(-1)
if kwargs["transposed_B"] == 1:
B = B.reshape(kwargs["K"], kwargs["N"], 8, 8)
B = B.reshape(kwargs["K"], kwargs["N"], tileSize, meshCol)
B = B.transpose(0, 1, 3, 2).reshape(-1)

data_str += [
Expand All @@ -677,9 +748,9 @@ def emit_matmul_data(**kwargs):
kwargs["M"],
kwargs["K"],
kwargs["N"],
8,
8,
8,
meshRow,
tileSize,
meshCol,
A,
B,
subtraction_a,
Expand Down
5 changes: 5 additions & 0 deletions target/snitch_cluster/sw/apps/snax-gemmx/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@
M: 1,
bypassSIMD: 1,
ifTestMatmul: 1,

// hardware parameters
meshRow : 8,
tileSize: 8,
meshCol : 8,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

#pragma once

#define tileSize 8
#define meshRow 8
#define tileSize 8
#define meshCol 8
7 changes: 5 additions & 2 deletions target/snitch_cluster/sw/snax/gemmx/src/snax-gemmx-lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "snax-gemmx-lib.h"
#include <stdbool.h>
#include "snax-gemmx-params.h"
#include "snrt.h"
#include "stdint.h"
#include "streamer_csr_addr_map.h"
Expand Down Expand Up @@ -156,8 +157,10 @@ void set_gemmx_streamer_csr(
csrw_ss(T_STRIDE_READER_WRITER_1_2, D32tlstride2);

// set the transpose
#ifdef TRANSPOSE_EXTENSION_ENABLE
csrw_ss(TRANSPOSE_CSR_READER_0, transpose_A == 0 ? 1 : 0);
csrw_ss(TRANSPOSE_CSR_READER_1, transpose_B == 0 ? 1 : 0);
#endif
}

// Set CSR to start STREAMER
Expand Down Expand Up @@ -234,7 +237,7 @@ uint32_t check_gemmx_result_D8(int8_t* output, int8_t* output_golden,
int32_t Batch, int32_t M, int32_t N) {
uint32_t err = 0;
uint32_t size = 0;
size = Batch * M * N * 8 * 8;
size = Batch * M * N * meshRow * meshCol;

for (int i = 0; i < size; i++) {
if (output[i] != output_golden[i]) {
Expand All @@ -248,7 +251,7 @@ uint32_t check_gemmx_result_D32(int32_t* output, int32_t* output_golden,
int32_t Batch, int32_t M, int32_t N) {
uint32_t err = 0;
uint32_t size = 0;
size = Batch * M * N * 8 * 8;
size = Batch * M * N * meshRow * meshCol;

for (int i = 0; i < size; i++) {
if (output[i] != output_golden[i]) {
Expand Down
Loading

0 comments on commit 4fae969

Please sign in to comment.