Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gemmx prameter support #327

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,62 @@ object BlockGemmRescaleSIMD extends App {
Array("--target-dir", "generated/gemmx")
)
}

object BlockGemmRescaleSIMDGen {
def main(args: Array[String]): Unit = {
// Helper function to parse command-line arguments into a Map
def parseArgs(args: Array[String]): Map[String, String] = {
val parsed_args = args
.sliding(2, 2)
.collect {
case Array(key, value) if key.startsWith("--") => key.drop(2) -> value
}
.toMap
if (parsed_args.size != 4) {
throw new Exception(
"Please provide the meshRow, meshCol, tileSize, and withPipeline. Example usage: sbt 'runMain snax_acc.gemmx.BlockGemmRescaleSIMDGen --meshRow 2 --meshCol 2 --tileSize 16 --withPipeline true'"
)
}
parsed_args
}
xiaoling-yi marked this conversation as resolved.
Show resolved Hide resolved

// Parse the arguments
val argMap = parseArgs(args)

// Retrieve the specific values, providing defaults or error handling
val meshRow = argMap("meshRow")
val meshCol = argMap("meshCol")
val tileSize = argMap("tileSize")

// set the parameters for the gemm module
// other parameters are set to default values
val gemmParams = GemmParams(
GemmConstant.dataWidthA,
GemmConstant.dataWidthB,
GemmConstant.dataWidthMul,
GemmConstant.dataWidthC,
GemmConstant.dataWidthAccum,
GemmConstant.subtractionCfgWidth,
meshRow.toInt,
tileSize.toInt,
meshCol.toInt,
GemmConstant.addrWidth,
GemmConstant.sizeConfigWidth
)

val withPipeline = argMap("withPipeline").toBoolean

emitVerilog(
new BlockGemmRescaleSIMD(
BlockGemmRescaleSIMDParams(
gemmParams,
(if (withPipeline == true)
snax_acc.simd.PipelinedConfig.rescaleSIMDConfig
else snax_acc.simd.DefaultConfig.rescaleSIMDConfig),
false
)
),
Array("--target-dir", "generated/gemmx")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@
snax_wide_tcdm_ports: 48,
snax_num_rw_csr: 10,
snax_num_ro_csr: 2,
snax_gemmx_mesh_row: 8,
snax_gemmx_tile_size: 8,
snax_gemmx_mesh_col: 8,
with_pipeline: true,
snax_streamer_cfg: {$ref: "#/snax_streamer_gemmX_streamer_template" }
},
snax_use_custom_ports: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@
snax_wide_tcdm_ports: 48,
snax_num_rw_csr: 10,
snax_num_ro_csr: 2,
snax_gemmx_mesh_row: 8,
snax_gemmx_tile_size: 8,
snax_gemmx_mesh_col: 8,
with_pipeline: false,
snax_streamer_cfg: {$ref: "#/snax_streamer_gemmX_streamer_template" }
},
snax_use_custom_ports: false,
Expand Down
141 changes: 106 additions & 35 deletions target/snitch_cluster/sw/apps/snax-gemmx/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def emit_header_file(**kwargs):
MIN = -128
MAX = 127

bankWidth = 64
input_data_width = 8
output_data_width = 32
quantized_output_data_width = 8


def emit_conv_data(**kwargs):
Cin = kwargs["Cin"]
Expand Down Expand Up @@ -553,6 +558,11 @@ def emit_conv_data(**kwargs):


def emit_matmul_data(**kwargs):

meshRow = kwargs["meshRow"]
tileSize = kwargs["tileSize"]
meshCol = kwargs["meshCol"]

# matmul settings
data_str = []

Expand All @@ -562,13 +572,23 @@ def emit_matmul_data(**kwargs):
data_str += [format_scalar_definition("int", "N", kwargs["N"])]

data_str += [format_scalar_definition("int32_t", "Aslstride0", 1)]
data_str += [format_scalar_definition("int32_t", "Aslstride1", 8)]
data_str += [format_scalar_definition("int32_t", "Aslstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "Atlbound0", kwargs["K"])]
data_str += [format_scalar_definition("int32_t", "Atlstride0", 64)]
data_str += [
format_scalar_definition(
"int32_t", "Atlstride0", input_data_width * tileSize * meshRow / 8
)
]
data_str += [format_scalar_definition("int32_t", "Atlbound1", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "Atlstride1", 0)]
data_str += [format_scalar_definition("int32_t", "Atlbound2", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "Atlstride2", kwargs["K"] * 64)]
data_str += [
format_scalar_definition(
"int32_t",
"Atlstride2",
kwargs["K"] * input_data_width * tileSize * meshRow / 8,
)
]
data_str += [format_scalar_definition("int32_t", "Atlbound3", 1)]
data_str += [format_scalar_definition("int32_t", "Atlstride3", 0)]
data_str += [format_scalar_definition("int32_t", "Atlbound4", 1)]
Expand All @@ -577,70 +597,115 @@ def emit_matmul_data(**kwargs):
data_str += [format_scalar_definition("int32_t", "Atlstride5", 0)]

data_str += [format_scalar_definition("int32_t", "Bslstride0", 1)]
data_str += [format_scalar_definition("int32_t", "Bslstride1", 8)]
data_str += [format_scalar_definition("int32_t", "Bslstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "Btlbound0", kwargs["K"])]
data_str += [format_scalar_definition("int32_t", "Btlstride0", 64)]
data_str += [
format_scalar_definition(
"int32_t", "Btlstride0", input_data_width * tileSize * meshCol / 8
)
]
data_str += [format_scalar_definition("int32_t", "Btlbound1", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "Btlstride1", 64 * kwargs["K"])]
data_str += [
format_scalar_definition(
"int32_t",
"Btlstride1",
kwargs["K"] * input_data_width * tileSize * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "Btlbound2", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "Btlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "Cslstride0", 4)]
data_str += [format_scalar_definition("int32_t", "Cslstride1", 8)]
data_str += [format_scalar_definition("int32_t", "Cslstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "Ctlbound0", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "Ctlstride0", 256)]
data_str += [
format_scalar_definition(
"int32_t", "Ctlstride0", output_data_width * meshRow * meshCol / 8
)
]
data_str += [format_scalar_definition("int32_t", "Ctlbound1", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "Ctlstride1", 256 * kwargs["N"])]
data_str += [
format_scalar_definition(
"int32_t",
"Ctlstride1",
kwargs["N"] * output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "Ctlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "Ctlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "D32slstride0", 4)]
data_str += [format_scalar_definition("int32_t", "D32slstride1", 8)]
data_str += [format_scalar_definition("int32_t", "D32slstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "D32tlbound0", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "D32tlstride0", 256)]
data_str += [
format_scalar_definition(
"int32_t", "D32tlstride0", output_data_width * meshRow * meshCol / 8
)
]
data_str += [format_scalar_definition("int32_t", "D32tlbound1", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "D32tlstride1", 256 * kwargs["N"])]
data_str += [
format_scalar_definition(
"int32_t",
"D32tlstride1",
kwargs["N"] * output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "D32tlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "D32tlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "D8slstride0", 1)]
data_str += [format_scalar_definition("int32_t", "D8slstride1", 8)]
data_str += [format_scalar_definition("int32_t", "D8slstride1", bankWidth / 8)]
data_str += [format_scalar_definition("int32_t", "D8tlbound0", kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "D8tlstride0", 64)]
data_str += [
format_scalar_definition(
"int32_t",
"D8tlstride0",
quantized_output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "D8tlbound1", kwargs["M"])]
data_str += [format_scalar_definition("int32_t", "D8tlstride1", 64 * kwargs["N"])]
data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)]

data_str += [format_scalar_definition("int32_t", "delta_local_a", 0)]
data_str += [
format_scalar_definition(
"int32_t", "delta_local_b", 64 * kwargs["K"] * kwargs["M"]
"int32_t",
"D8tlstride1",
kwargs["N"] * quantized_output_data_width * meshRow * meshCol / 8,
)
]
data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)]
data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)]

delta_local_a = 0
delta_local_b = (
kwargs["K"] * kwargs["M"] * (meshRow * tileSize * input_data_width / 8)
)
delta_local_c = delta_local_b + kwargs["K"] * kwargs["N"] * (
meshCol * tileSize * input_data_width / 8
)
delta_local_d32 = delta_local_c + kwargs["M"] * kwargs["N"] * (
meshRow * meshCol * output_data_width / 8
)
delta_local_d8 = delta_local_d32
data_str += [format_scalar_definition("int32_t", "delta_local_a", delta_local_a)]
data_str += [format_scalar_definition("int32_t", "delta_local_b", delta_local_b)]
data_str += [
format_scalar_definition(
"int32_t",
"delta_local_c",
64 * kwargs["K"] * kwargs["M"] + 64 * kwargs["K"] * kwargs["N"],
delta_local_c,
)
]
data_str += [
format_scalar_definition(
"int32_t",
"delta_local_d32",
64 * kwargs["K"] * kwargs["M"]
+ 64 * kwargs["K"] * kwargs["N"]
+ 256 * kwargs["M"] * kwargs["N"],
delta_local_d32,
)
]
data_str += [
format_scalar_definition(
"int32_t",
"delta_local_d8",
64 * kwargs["K"] * kwargs["M"]
+ 64 * kwargs["K"] * kwargs["N"]
+ 256 * kwargs["M"] * kwargs["N"],
delta_local_d8,
)
]

Expand All @@ -652,18 +717,24 @@ def emit_matmul_data(**kwargs):
data_str += [format_scalar_definition("int8_t", "subtraction_a", subtraction_a)]
data_str += [format_scalar_definition("int8_t", "subtraction_b", subtraction_b)]

A = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["K"], 8, 8)).reshape(-1)
A = np.random.randint(
MIN, MAX, size=(kwargs["M"], kwargs["K"], meshRow, tileSize)
).reshape(-1)
data_str += [format_vector_definition("int8_t", "A", A)]
B = np.random.randint(MIN, MAX, size=(kwargs["K"], kwargs["N"], 8, 8)).reshape(-1)
B = np.random.randint(
MIN, MAX, size=(kwargs["K"], kwargs["N"], tileSize, meshCol)
).reshape(-1)
data_str += [format_vector_definition("int8_t", "B", B)]
C = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["N"], 8, 8)).reshape(-1)
C = np.random.randint(
MIN, MAX, size=(kwargs["M"], kwargs["N"], meshRow, meshCol)
).reshape(-1)
data_str += [format_vector_definition("int32_t", "C", C)]

if kwargs["transposed_A"] == 1:
A = A.reshape(kwargs["M"], kwargs["K"], 8, 8)
A = A.reshape(kwargs["M"], kwargs["K"], meshRow, tileSize)
A = A.transpose(0, 1, 3, 2).reshape(-1)
if kwargs["transposed_B"] == 1:
B = B.reshape(kwargs["K"], kwargs["N"], 8, 8)
B = B.reshape(kwargs["K"], kwargs["N"], tileSize, meshCol)
B = B.transpose(0, 1, 3, 2).reshape(-1)

data_str += [
Expand All @@ -677,9 +748,9 @@ def emit_matmul_data(**kwargs):
kwargs["M"],
kwargs["K"],
kwargs["N"],
8,
8,
8,
meshRow,
tileSize,
meshCol,
A,
B,
subtraction_a,
Expand Down
5 changes: 5 additions & 0 deletions target/snitch_cluster/sw/apps/snax-gemmx/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@
M: 1,
bypassSIMD: 1,
ifTestMatmul: 1,

// hardware parameters
meshRow : 8,
tileSize: 8,
meshCol : 8,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

#pragma once

#define tileSize 8
#define meshRow 8
#define tileSize 8
#define meshCol 8
7 changes: 5 additions & 2 deletions target/snitch_cluster/sw/snax/gemmx/src/snax-gemmx-lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "snax-gemmx-lib.h"
#include <stdbool.h>
#include "snax-gemmx-params.h"
#include "snrt.h"
#include "stdint.h"
#include "streamer_csr_addr_map.h"
Expand Down Expand Up @@ -156,8 +157,10 @@ void set_gemmx_streamer_csr(
csrw_ss(T_STRIDE_READER_WRITER_1_2, D32tlstride2);

// set the transpose
#ifdef TRANSPOSE_EXTENSION_ENABLE
csrw_ss(TRANSPOSE_CSR_READER_0, transpose_A == 0 ? 1 : 0);
csrw_ss(TRANSPOSE_CSR_READER_1, transpose_B == 0 ? 1 : 0);
#endif
}

// Set CSR to start STREAMER
Expand Down Expand Up @@ -234,7 +237,7 @@ uint32_t check_gemmx_result_D8(int8_t* output, int8_t* output_golden,
int32_t Batch, int32_t M, int32_t N) {
uint32_t err = 0;
uint32_t size = 0;
size = Batch * M * N * 8 * 8;
size = Batch * M * N * meshRow * meshCol;

for (int i = 0; i < size; i++) {
if (output[i] != output_golden[i]) {
Expand All @@ -248,7 +251,7 @@ uint32_t check_gemmx_result_D32(int32_t* output, int32_t* output_golden,
int32_t Batch, int32_t M, int32_t N) {
uint32_t err = 0;
uint32_t size = 0;
size = Batch * M * N * 8 * 8;
size = Batch * M * N * meshRow * meshCol;

for (int i = 0; i < size; i++) {
if (output[i] != output_golden[i]) {
Expand Down
Loading
Loading