-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simple matmul on gemm accelerator #39
Changes from 13 commits
b689b01
27735a4
09876df
9ca258c
3c5a06a
c6c5142
360c183
ae7481c
9c0dfb9
d8c31be
98bfa2b
a28c6dd
658f204
7b8be56
4a50c06
6334e89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Courtesy of Federico Ficarelli | ||
|
||
.DEFAULT_GOAL := all | ||
|
||
include ../../runtime/snax-gemm.rules | ||
include ../../runtime/Makefile.rules | ||
|
||
TESTS = | ||
TESTS += cpu.x | ||
|
||
CFLAGS += -std=gnu11 | ||
CFLAGS += -Wall -Wextra | ||
|
||
data.c data.h: | ||
$(PYTHON) gendata.py | ||
|
||
%.x: %.o main.o data.o | ||
$(LD) $(LDFLAGS) $^ -o $@ | ||
|
||
sim_%: % | ||
rm -fr ./logs/ | ||
$(VLTSIM) $< | ||
|
||
RUN = $(addprefix run_, $(TESTS)) | ||
$(RUN): run_%: sim_% | ||
mv logs $(subst sim_,,$<).logs | ||
|
||
all: $(TESTS) | ||
|
||
allrun: $(RUN) | ||
|
||
clean: | ||
rm -fr *.ll12 *.x *.o *.logs/ logs/ data.h data.c |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
func.func public @simple_matmul(%A: memref<16x16xi8, 1 : i32>, | ||
%B: memref<16x32xi8, 1 : i32>, | ||
%C: memref<16x32xi32, 1 : i32>) -> () { | ||
func.call @simple_matmul_cpu(%A, %B, %C) : (memref<16x16xi8, 1 : i32>, memref<16x32xi8, 1 : i32>, memref<16x32xi32, 1 : i32>) -> () | ||
return | ||
} | ||
func.func private @simple_matmul_cpu(%A : memref<16x16xi8, 1 : i32>, %B : memref<16x32xi8, 1 : i32>, %C : memref<16x32xi32, 1 : i32>) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# simple script to generate inputs and expected outputs for simple_matmult | ||
import numpy as np | ||
from numpy import typing as npt | ||
from typing import Dict | ||
|
||
|
||
def create_header( | ||
file_name: str, sizes: Dict[str, int], variables: Dict[str, npt.NDArray] | ||
) -> None: | ||
with open(file_name, "w") as f: | ||
includes = ["#include <stdint.h>", "#pragma once", ""] | ||
includes = "\n".join(includes) | ||
variables_string = [""] | ||
for i, j in sizes.items(): | ||
variables_string.append(f"#define {i} {j}") | ||
variables_string.append("") | ||
for i, j in variables.items(): | ||
variables_string.append(f"extern const {j.dtype}_t {i}[{j.size}];") | ||
variables_string = "\n".join(variables_string) | ||
f.write(includes) | ||
f.write(variables_string) | ||
f.write("\n") | ||
|
||
|
||
def create_data(file_name: str, variables: Dict[str, npt.NDArray]): | ||
includes = ['#include "data.h"', "", ""] | ||
includes = "\n".join(includes) | ||
variables = {i: np.reshape(j, j.size) for i, j in variables.items()} | ||
with open(file_name, "w") as f: | ||
f.write(includes) | ||
for variable_name, variable_value in variables.items(): | ||
f.write( | ||
f"const {variable_value.dtype}_t {variable_name}" | ||
+ f"[{variable_value.size}] = " | ||
+ "{\n" | ||
) | ||
variable_str = ["\t" + str(i) for i in variable_value] | ||
f.write(",\n".join(variable_str)) | ||
f.write("\n};\n\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
# Reset random seed for reproducible behavior | ||
low_bound = -128 | ||
high_bound = 127 | ||
A_size = [16, 16] | ||
B_size = [16, 16] | ||
np.random.seed(0) | ||
|
||
# C = A.B | ||
A = np.random.randint(low_bound, high_bound, size=A_size, dtype=np.dtype("int8")) | ||
# A = np.ones(A_size, dtype=np.dtype("int8")) | ||
B = np.random.randint(low_bound, high_bound, size=B_size, dtype=np.dtype("int8")) | ||
# B = np.ones(B_size, dtype=np.dtype("int8")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove comment |
||
C_golden = np.matmul(A.astype(np.dtype("int32")), B.astype(np.dtype("int32"))) | ||
C = np.zeros(C_golden.shape, np.dtype("int32")) | ||
|
||
assert A.shape[1] == B.shape[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assert must come before the np.matmul computation. If the shapes do not correspond, the matmul would fail anyway, making the assert useless |
||
sizes = {"N_size": A.shape[0], "K_size": A.shape[1], "M_size": B.shape[1]} | ||
|
||
# Perform layout transformations before writing to memory | ||
|
||
# convert from row-major to block row-major | ||
A_new_layout = np.reshape(A, [2, 8, 2, 8]) | ||
# convert to [2,2,8,8] | ||
A_new_layout = np.swapaxes(A_new_layout, 1, 2) | ||
|
||
B_new_layout = np.transpose(B) | ||
# convert from column-major to block column-major | ||
B_new_layout = np.reshape(B_new_layout, [2, 8, 2, 8]) | ||
# convert to [2,2,8,8] | ||
B_new_layout = np.swapaxes(B_new_layout, 1, 2) | ||
# convert from row-major to block row-major | ||
C_golden_new_layout = np.reshape(C_golden, [2, 8, 2, 8]) | ||
# convert to [2,2,8,8] | ||
C_golden_new_layout = np.swapaxes(C_golden_new_layout, 1, 2) | ||
|
||
# C are just all zeros, so layout not important | ||
variables = { | ||
"A": A_new_layout, | ||
"B": B_new_layout, | ||
"C_golden": C_golden_new_layout, | ||
"C": C, | ||
} | ||
|
||
create_header("data.h", sizes, variables) | ||
create_data("data.c", variables) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,132 @@ | ||||||||||
#include "data.h" | ||||||||||
#include "memref.h" | ||||||||||
#include "snax-gemm-lib.h" | ||||||||||
#include "snax-gemm-params.h" | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me it feels kind of weird to include header files from another repository without specifying where they come from 😕 . I guess it also doesn't make sense to have them duplicate... Is there a way to share this in a better way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if not, maybe make it clear with a comment where these files come from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed! Clear like this? |
||||||||||
#include "snax_rt.h" | ||||||||||
#include "stdint.h" | ||||||||||
|
||||||||||
#include <snrt.h> | ||||||||||
#include <stdint.h> | ||||||||||
|
||||||||||
uint8_t Batch = 1; | ||||||||||
// meshRow, tileSize and meshCol are defined in snax-gemm-params.h | ||||||||||
uint8_t M_param = M_size / meshRow; | ||||||||||
uint8_t K_param = K_size / tileSize; | ||||||||||
uint8_t N_param = N_size / meshCol; | ||||||||||
// Extracted from datagen.py in snitch_cluster repo | ||||||||||
uint32_t strideInnermostA = 256; | ||||||||||
uint32_t strideInnermostB = 256; | ||||||||||
uint32_t strideInnermostC = 256; | ||||||||||
uint32_t ldA = 512; | ||||||||||
uint32_t ldB = 512; | ||||||||||
uint32_t ldC = 512; | ||||||||||
uint32_t strideA = 0; | ||||||||||
uint32_t strideB = 0; | ||||||||||
uint32_t strideC = 0; | ||||||||||
|
||||||||||
// Kernel provided via external definition | ||||||||||
void _mlir_ciface_simple_matmul(TwoDMemrefI8_t *a, TwoDMemrefI8_t *b, | ||||||||||
TwoDMemrefI32_t *c); | ||||||||||
|
||||||||||
void _mlir_ciface_simple_matmul_cpu(TwoDMemrefI8_t *a, TwoDMemrefI8_t *b, | ||||||||||
TwoDMemrefI32_t *c) { | ||||||||||
int8_t *a_ptr = a->aligned_data; | ||||||||||
int8_t *b_ptr = b->aligned_data; | ||||||||||
int32_t *c_ptr = c->aligned_data; | ||||||||||
batch_gemm_cpu(Batch, M_param, K_param, N_param, a_ptr, b_ptr, c_ptr, | ||||||||||
strideInnermostA, strideInnermostB, strideInnermostC, ldA, ldB, | ||||||||||
ldC, strideA, strideB, strideC); | ||||||||||
} | ||||||||||
|
||||||||||
int main() { | ||||||||||
// Allocate space in TCDM | ||||||||||
// We put the data in different banks, but we don't interleave the data for | ||||||||||
// now. | ||||||||||
// | ||||||||||
// | A | x | x | x | --> A in banks 0 - 7 --> (8/32 banks used)* | ||||||||||
// (int8 --> 8 elements/bank) | ||||||||||
// 1 row --> 64 elements | ||||||||||
// | x | B | x | x | --> B in banks 7 - 15 --> (8/32 banks used)* | ||||||||||
// (8 elements/bank)*32 banks | ||||||||||
// 1 row --> 64 elements | ||||||||||
// | C | C | C | C | --> C in banks 0 - 31 --> (32/32 banks used)* | ||||||||||
// (2 elements/bank)* 32 bank | ||||||||||
// 1 row --> 64 elements | ||||||||||
// | x | x | x | x | | ||||||||||
// | ||||||||||
// 32 banks --> 1 row = 32 banks * 8 bytes --> 256 adresses further | ||||||||||
|
||||||||||
static int8_t *allocated_a; | ||||||||||
static int8_t *allocated_b; | ||||||||||
static int32_t *allocated_c; | ||||||||||
|
||||||||||
// Transfer data from L3 to L1 | ||||||||||
// Using DMA only | ||||||||||
if (snrt_is_dm_core()) { | ||||||||||
// calculation in bytes directly | ||||||||||
allocated_a = (int8_t *)snrt_l1alloc(256 * M_size * K_size / 64); | ||||||||||
allocated_b = (int8_t *)snrt_l1alloc(256 * K_size * N_size / 64); | ||||||||||
allocated_c = (int32_t *)snrt_l1alloc(256 * M_size * N_size / 64); | ||||||||||
} | ||||||||||
snrt_cluster_hw_barrier(); | ||||||||||
|
||||||||||
// Create memref descriptors for data stored in L1 | ||||||||||
TwoDMemrefI8_t memrefA; | ||||||||||
memrefA.data = allocated_a; | ||||||||||
memrefA.aligned_data = memrefA.data; | ||||||||||
memrefA.shape[0] = M_size; | ||||||||||
memrefA.shape[1] = K_size; | ||||||||||
// These are not considered correctly right now | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment not very clear.
Suggested change
|
||||||||||
memrefA.offset = 0; | ||||||||||
memrefA.stride[0] = sizeof(int8_t); | ||||||||||
memrefA.stride[1] = sizeof(int8_t); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we were to use a standard layout, this would also not be correct, but rather:
Suggested change
Maybe just set them to 0 to make it very clear we are not using them. |
||||||||||
|
||||||||||
TwoDMemrefI8_t memrefB; | ||||||||||
memrefB.data = allocated_b; | ||||||||||
// Data is stored in banks 8 - 15, so increment by 8banks*8bytes = 64 | ||||||||||
memrefB.aligned_data = memrefB.data + 64; | ||||||||||
memrefB.shape[0] = K_size; | ||||||||||
memrefB.shape[1] = N_size; | ||||||||||
// These are not considered correctly right now | ||||||||||
JosseVanDelm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
memrefB.offset = 0; | ||||||||||
memrefB.stride[0] = sizeof(int8_t); | ||||||||||
memrefB.stride[1] = sizeof(int8_t); | ||||||||||
|
||||||||||
TwoDMemrefI32_t memrefC; | ||||||||||
memrefC.data = allocated_c; | ||||||||||
memrefC.aligned_data = memrefC.data; | ||||||||||
memrefC.shape[0] = M_size; | ||||||||||
memrefC.shape[1] = N_size; | ||||||||||
// These are not considered correctly right now | ||||||||||
JosseVanDelm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
memrefC.offset = 0; | ||||||||||
memrefC.stride[0] = sizeof(int32_t); | ||||||||||
memrefC.stride[1] = sizeof(int32_t); | ||||||||||
if (snrt_is_dm_core()) { | ||||||||||
load_input_data(Batch, M_size / meshRow, K_size / tileSize, | ||||||||||
N_size / meshCol, memrefA.aligned_data, | ||||||||||
memrefB.aligned_data, A, B, strideInnermostA, | ||||||||||
strideInnermostB, ldA, ldB, strideA, strideB); | ||||||||||
} | ||||||||||
snrt_cluster_hw_barrier(); | ||||||||||
(void)snrt_mcycle(); | ||||||||||
if (snrt_is_compute_core()) { | ||||||||||
_mlir_ciface_simple_matmul(&memrefA, &memrefB, &memrefC); | ||||||||||
} | ||||||||||
(void)snrt_mcycle(); | ||||||||||
snrt_cluster_hw_barrier(); | ||||||||||
|
||||||||||
// Correctness check - | ||||||||||
// from this point on only core 0 is required to be alive. | ||||||||||
int thiscore = snrt_cluster_core_idx(); | ||||||||||
if (thiscore != 0) | ||||||||||
return 0; | ||||||||||
|
||||||||||
int nerr = 0; | ||||||||||
for (int i = 0; i < M_size * N_size; i++) { | ||||||||||
// printf("%d , golden : %d\n", memrefC.aligned_data[i],C_golden[i]); | ||||||||||
int32_t error = memrefC.aligned_data[i] - C_golden[i]; | ||||||||||
if (error != 0) | ||||||||||
nerr += 1; | ||||||||||
} | ||||||||||
return nerr; | ||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,7 @@ int main() { | |
|
||
int nerr = 0; | ||
for (int i = 0; i < N; i++) { | ||
printf("result: %d golden: %d\n", memrefD.aligned_data[i], G[i]); | ||
JosseVanDelm marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To delete then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will make a seperate PR for this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's gone now haha |
||
int32_t error = memrefD.aligned_data[i] - G[i]; | ||
if (error != 0) | ||
nerr += 1; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The MLIR docs show quite a nice way to have only one memref struct specification for all types, in C++
do you know if something similar is possible in C? if not, I guess this is fine for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is possible, but it looks a bit complicated https://isocpp.org/wiki/faq/mixing-c-and-cpp , let's take this for a future PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, good! |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Specific settings for snax-mac RTL | ||
SNITCH_SW_PATH = /opt/snax-gemm | ||
VLTSIM = /opt/snax-gemm-rtl/bin/snitch_cluster.vlt | ||
CFLAGS += -I$(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/gemm/include | ||
LDFLAGS += $(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/gemm/build/snax-gemm-lib.o |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# Specific settings for snax-mac RTL | ||
SNITCH_SW_PATH = /opt/snax-mac | ||
VLTSIM = /opt/snax-mac-rtl/bin/snitch_cluster.vlt | ||
CFLAGS += -I$(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/mac/include | ||
LDFLAGS += $(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/mac/build/mac.o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment