Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple matmul on gemm accelerator #39

Merged
merged 16 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/build-run-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ jobs:
export PATH=/opt/python3.11/bin:$PATH
make allrun
working-directory: kernels/simple_copy
- name: Build and run kernel simple_matmul
run: |
export PATH=/opt/python3.11/bin:$PATH
make allrun
working-directory: kernels/simple_matmul
33 changes: 33 additions & 0 deletions kernels/simple_matmul/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Courtesy of Federico Ficarelli

.DEFAULT_GOAL := all

include ../../runtime/snax-gemm.rules
include ../../runtime/Makefile.rules

TESTS =
TESTS += cpu.x

CFLAGS += -std=gnu11
CFLAGS += -Wall -Wextra

data.c data.h:
$(PYTHON) gendata.py

%.x: %.o main.o data.o
$(LD) $(LDFLAGS) $^ -o $@

sim_%: %
rm -fr ./logs/
$(VLTSIM) $<

RUN = $(addprefix run_, $(TESTS))
$(RUN): run_%: sim_%
mv logs $(subst sim_,,$<).logs

all: $(TESTS)

allrun: $(RUN)

clean:
rm -fr *.ll12 *.x *.o *.logs/ logs/ data.h data.c
7 changes: 7 additions & 0 deletions kernels/simple_matmul/cpu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
func.func public @simple_matmul(%A: memref<16x16xi8, 1 : i32>,
%B: memref<16x32xi8, 1 : i32>,
%C: memref<16x32xi32, 1 : i32>) -> () {
func.call @simple_matmul_cpu(%A, %B, %C) : (memref<16x16xi8, 1 : i32>, memref<16x32xi8, 1 : i32>, memref<16x32xi32, 1 : i32>) -> ()
return
}
func.func private @simple_matmul_cpu(%A : memref<16x16xi8, 1 : i32>, %B : memref<16x32xi8, 1 : i32>, %C : memref<16x32xi32, 1 : i32>)
87 changes: 87 additions & 0 deletions kernels/simple_matmul/gendata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# simple script to generate inputs and expected outputs for simple_matmult
import numpy as np
from numpy import typing as npt
from typing import Dict


def create_header(
file_name: str, sizes: Dict[str, int], variables: Dict[str, npt.NDArray]
) -> None:
with open(file_name, "w") as f:
includes = ["#include <stdint.h>", "#pragma once", ""]
includes = "\n".join(includes)
variables_string = [""]
for i, j in sizes.items():
variables_string.append(f"#define {i} {j}")
variables_string.append("")
for i, j in variables.items():
variables_string.append(f"extern const {j.dtype}_t {i}[{j.size}];")
variables_string = "\n".join(variables_string)
f.write(includes)
f.write(variables_string)
f.write("\n")


def create_data(file_name: str, variables: Dict[str, npt.NDArray]):
includes = ['#include "data.h"', "", ""]
includes = "\n".join(includes)
variables = {i: np.reshape(j, j.size) for i, j in variables.items()}
with open(file_name, "w") as f:
f.write(includes)
for variable_name, variable_value in variables.items():
f.write(
f"const {variable_value.dtype}_t {variable_name}"
+ f"[{variable_value.size}] = "
+ "{\n"
)
variable_str = ["\t" + str(i) for i in variable_value]
f.write(",\n".join(variable_str))
f.write("\n};\n\n")


if __name__ == "__main__":
# Reset random seed for reproducible behavior
low_bound = -128
high_bound = 127
A_size = [16, 16]
B_size = [16, 16]
np.random.seed(0)

# C = A.B
A = np.random.randint(low_bound, high_bound, size=A_size, dtype=np.dtype("int8"))
# A = np.ones(A_size, dtype=np.dtype("int8"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment

B = np.random.randint(low_bound, high_bound, size=B_size, dtype=np.dtype("int8"))
# B = np.ones(B_size, dtype=np.dtype("int8"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment

C_golden = np.matmul(A.astype(np.dtype("int32")), B.astype(np.dtype("int32")))
C = np.zeros(C_golden.shape, np.dtype("int32"))

assert A.shape[1] == B.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert must come before the np.matmul computation. If the shapes do not correspond, the matmul would fail anyway, making the assert useless

sizes = {"N_size": A.shape[0], "K_size": A.shape[1], "M_size": B.shape[1]}

# Perform layout transformations before writing to memory

# convert from row-major to block row-major
A_new_layout = np.reshape(A, [2, 8, 2, 8])
# convert to [2,2,8,8]
A_new_layout = np.swapaxes(A_new_layout, 1, 2)

B_new_layout = np.transpose(B)
# convert from column-major to block column-major
B_new_layout = np.reshape(B_new_layout, [2, 8, 2, 8])
# convert to [2,2,8,8]
B_new_layout = np.swapaxes(B_new_layout, 1, 2)
# convert from row-major to block row-major
C_golden_new_layout = np.reshape(C_golden, [2, 8, 2, 8])
# convert to [2,2,8,8]
C_golden_new_layout = np.swapaxes(C_golden_new_layout, 1, 2)

# C are just all zeros, so layout not important
variables = {
"A": A_new_layout,
"B": B_new_layout,
"C_golden": C_golden_new_layout,
"C": C,
}

create_header("data.h", sizes, variables)
create_data("data.c", variables)
132 changes: 132 additions & 0 deletions kernels/simple_matmul/main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "data.h"
#include "memref.h"
#include "snax-gemm-lib.h"
#include "snax-gemm-params.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it feels kind of weird to include header files from another repository without specifying where they come from 😕 . I guess it also doesn't make sense to have them duplicate... Is there a way to share this in a better way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not, maybe make it clear with a comment where these files come from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! Clear like this?

#include "snax_rt.h"
#include "stdint.h"

#include <snrt.h>
#include <stdint.h>

uint8_t Batch = 1;
// meshRow, tileSize and meshCol are defined in snax-gemm-params.h
uint8_t M_param = M_size / meshRow;
uint8_t K_param = K_size / tileSize;
uint8_t N_param = N_size / meshCol;
// Extracted from datagen.py in snitch_cluster repo
uint32_t strideInnermostA = 256;
uint32_t strideInnermostB = 256;
uint32_t strideInnermostC = 256;
uint32_t ldA = 512;
uint32_t ldB = 512;
uint32_t ldC = 512;
uint32_t strideA = 0;
uint32_t strideB = 0;
uint32_t strideC = 0;

// Kernel provided via external definition
void _mlir_ciface_simple_matmul(TwoDMemrefI8_t *a, TwoDMemrefI8_t *b,
TwoDMemrefI32_t *c);

void _mlir_ciface_simple_matmul_cpu(TwoDMemrefI8_t *a, TwoDMemrefI8_t *b,
TwoDMemrefI32_t *c) {
int8_t *a_ptr = a->aligned_data;
int8_t *b_ptr = b->aligned_data;
int32_t *c_ptr = c->aligned_data;
batch_gemm_cpu(Batch, M_param, K_param, N_param, a_ptr, b_ptr, c_ptr,
strideInnermostA, strideInnermostB, strideInnermostC, ldA, ldB,
ldC, strideA, strideB, strideC);
}

int main() {
// Allocate space in TCDM
// We put the data in different banks, but we don't interleave the data for
// now.
//
// | A | x | x | x | --> A in banks 0 - 7 --> (8/32 banks used)*
// (int8 --> 8 elements/bank)
// 1 row --> 64 elements
// | x | B | x | x | --> B in banks 7 - 15 --> (8/32 banks used)*
// (8 elements/bank)*32 banks
// 1 row --> 64 elements
// | C | C | C | C | --> C in banks 0 - 31 --> (32/32 banks used)*
// (2 elements/bank)* 32 bank
// 1 row --> 64 elements
// | x | x | x | x |
//
// 32 banks --> 1 row = 32 banks * 8 bytes --> 256 adresses further

static int8_t *allocated_a;
static int8_t *allocated_b;
static int32_t *allocated_c;

// Transfer data from L3 to L1
// Using DMA only
if (snrt_is_dm_core()) {
// calculation in bytes directly
allocated_a = (int8_t *)snrt_l1alloc(256 * M_size * K_size / 64);
allocated_b = (int8_t *)snrt_l1alloc(256 * K_size * N_size / 64);
allocated_c = (int32_t *)snrt_l1alloc(256 * M_size * N_size / 64);
}
snrt_cluster_hw_barrier();

// Create memref descriptors for data stored in L1
TwoDMemrefI8_t memrefA;
memrefA.data = allocated_a;
memrefA.aligned_data = memrefA.data;
memrefA.shape[0] = M_size;
memrefA.shape[1] = K_size;
// These are not considered correctly right now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment not very clear.

Suggested change
// These are not considered correctly right now
// Strides are not used due to the tiled-block layout.
// Instead we use the variables strideInnermostA, ldA and strideA

memrefA.offset = 0;
memrefA.stride[0] = sizeof(int8_t);
memrefA.stride[1] = sizeof(int8_t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were to use a standard layout, this would also not be correct, but rather:

Suggested change
memrefA.stride[0] = sizeof(int8_t);
memrefA.stride[1] = sizeof(int8_t);
memrefA.stride[0] = sizeof(int8_t);
memrefA.stride[1] = sizeof(int8_t) * M_size;

Maybe just set them to 0 to make it very clear we are not using them.


TwoDMemrefI8_t memrefB;
memrefB.data = allocated_b;
// Data is stored in banks 8 - 15, so increment by 8banks*8bytes = 64
memrefB.aligned_data = memrefB.data + 64;
memrefB.shape[0] = K_size;
memrefB.shape[1] = N_size;
// These are not considered correctly right now
JosseVanDelm marked this conversation as resolved.
Show resolved Hide resolved
memrefB.offset = 0;
memrefB.stride[0] = sizeof(int8_t);
memrefB.stride[1] = sizeof(int8_t);

TwoDMemrefI32_t memrefC;
memrefC.data = allocated_c;
memrefC.aligned_data = memrefC.data;
memrefC.shape[0] = M_size;
memrefC.shape[1] = N_size;
// These are not considered correctly right now
JosseVanDelm marked this conversation as resolved.
Show resolved Hide resolved
memrefC.offset = 0;
memrefC.stride[0] = sizeof(int32_t);
memrefC.stride[1] = sizeof(int32_t);
if (snrt_is_dm_core()) {
load_input_data(Batch, M_size / meshRow, K_size / tileSize,
N_size / meshCol, memrefA.aligned_data,
memrefB.aligned_data, A, B, strideInnermostA,
strideInnermostB, ldA, ldB, strideA, strideB);
}
snrt_cluster_hw_barrier();
(void)snrt_mcycle();
if (snrt_is_compute_core()) {
_mlir_ciface_simple_matmul(&memrefA, &memrefB, &memrefC);
}
(void)snrt_mcycle();
snrt_cluster_hw_barrier();

// Correctness check -
// from this point on only core 0 is required to be alive.
int thiscore = snrt_cluster_core_idx();
if (thiscore != 0)
return 0;

int nerr = 0;
for (int i = 0; i < M_size * N_size; i++) {
// printf("%d , golden : %d\n", memrefC.aligned_data[i],C_golden[i]);
int32_t error = memrefC.aligned_data[i] - C_golden[i];
if (error != 0)
nerr += 1;
}
return nerr;
}
1 change: 1 addition & 0 deletions kernels/simple_mult/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ int main() {

int nerr = 0;
for (int i = 0; i < N; i++) {
printf("result: %d golden: %d\n", memrefD.aligned_data[i], G[i]);
JosseVanDelm marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To delete then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make a seperate PR for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait no

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's gone now haha

int32_t error = memrefD.aligned_data[i] - G[i];
if (error != 0)
nerr += 1;
Expand Down
4 changes: 0 additions & 4 deletions runtime/Makefile.rules
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ MLIRTRANSLATE = mlir-translate-16
SNAXOPT = $(MAKEFILE_RULES_DIRNAME)/../compiler/snax-opt
PYTHON = /opt/python3.11/bin/python3

CFLAGS =
# Mixing .c and .ll files makes some flags, useful for the former,
# unused for the latter (e.g. -I)
CFLAGS += -Wno-unused-command-line-argument
Expand All @@ -32,7 +31,6 @@ CFLAGS += -I$(SNITCH_SW_PATH)/sw/math/src/include
CFLAGS += -I$(SNITCH_SW_PATH)/sw/math/src/internal
CFLAGS += -I$(SNITCH_SW_PATH)/sw/math/include/bits
CFLAGS += -I$(SNITCH_SW_PATH)/sw/math/include
CFLAGS += -I$(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/mac/include
CFLAGS += -I$(MAKEFILE_RULES_DIRNAME)include
CFLAGS += -D__DEFINED_uint64_t
CFLAGS += -menable-experimental-extensions
Expand All @@ -45,7 +43,6 @@ CFLAGS += -fno-builtin-printf
CFLAGS += -fno-common
CFLAGS += -O3

LDFLAGS =
LDFLAGS += -fuse-ld=$(SNITCH_LLVM_PATH)/bin/ld.lld
LDFLAGS += -L$(SNITCH_LLVM_PATH)/lib/clang/12.0.1/lib/
LDFLAGS += -T$(SNITCH_SW_PATH)/sw/snRuntime/base.ld
Expand All @@ -56,7 +53,6 @@ LDFLAGS += -nostdlib
LDFLAGS += -lclang_rt.builtins-riscv32
LDFLAGS += -lc
LDFLAGS += -lsnRuntime
LDFLAGS += $(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/mac/build/mac.o

# useful for debugging at llvm level:
%.ll: %.c
Expand Down
22 changes: 22 additions & 0 deletions runtime/include/memref.h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MLIR docs show quite a nice way to have only one memref struct specification for all types, in C++

template<typename T, size_t N>
struct MemRefDescriptor {
  T *allocated;
  T *aligned;
  intptr_t offset;
  intptr_t sizes[N];
  intptr_t strides[N];
};

do you know if something similar is possible in C? if not, I guess this is fine for now

Copy link
Contributor Author

@JosseVanDelm JosseVanDelm Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible, but it looks a bit complicated https://isocpp.org/wiki/faq/mixing-c-and-cpp , let's take this for a future PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good!

Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,26 @@ struct OneDMemrefI32 {
uint32_t stride[1];
};

struct TwoDMemrefI32 {
int32_t *data; // allocated pointer: Pointer to data buffer as allocated,
// only used for deallocating the memref
int32_t *aligned_data; // aligned pointer: Pointer to properly aligned data
// that memref indexes
uint32_t offset;
uint32_t shape[2];
uint32_t stride[2];
};

struct TwoDMemrefI8 {
int8_t *data; // allocated pointer: Pointer to data buffer as allocated,
// only used for deallocating the memref
int8_t *aligned_data; // aligned pointer: Pointer to properly aligned data
// that memref indexes
uint32_t offset;
uint32_t shape[2];
uint32_t stride[2];
};

typedef struct OneDMemrefI32 OneDMemrefI32_t;
typedef struct TwoDMemrefI8 TwoDMemrefI8_t;
typedef struct TwoDMemrefI32 TwoDMemrefI32_t;
5 changes: 5 additions & 0 deletions runtime/snax-gemm.rules
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Specific settings for snax-mac RTL
SNITCH_SW_PATH = /opt/snax-gemm
VLTSIM = /opt/snax-gemm-rtl/bin/snitch_cluster.vlt
CFLAGS += -I$(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/gemm/include
LDFLAGS += $(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/gemm/build/snax-gemm-lib.o
2 changes: 2 additions & 0 deletions runtime/snax-mac.rules
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Specific settings for snax-mac RTL
SNITCH_SW_PATH = /opt/snax-mac
VLTSIM = /opt/snax-mac-rtl/bin/snitch_cluster.vlt
CFLAGS += -I$(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/mac/include
LDFLAGS += $(SNITCH_SW_PATH)/target/snitch_cluster/sw/snax/mac/build/mac.o