Skip to content

Commit

Permalink
Add memref calling convention (#18)
Browse files Browse the repository at this point in the history
* Add memref wrappers

* Use common memref header

* Clean up use of memrefs in main

* Remove old comments
  • Loading branch information
JosseVanDelm authored Nov 15, 2023
1 parent e4655f6 commit 9a02735
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
6 changes: 4 additions & 2 deletions kernels/simple_mult/baseline.c
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "data.h"
#include "memref.h"

#include <snrt.h>

#include <stdint.h>

void simple_mult(const int32_t *A, const int32_t *B, int32_t *D) {
void _mlir_ciface_simple_mult(OneDMemrefI32_t *A, OneDMemrefI32_t *B,
OneDMemrefI32_t *D) {
const uint32_t n = N;
for (uint32_t i = 0; i < n; ++i) {
D[i] = A[i] * B[i];
D->aligned_data[i] = A->aligned_data[i] * B->aligned_data[i];
}
}
52 changes: 41 additions & 11 deletions kernels/simple_mult/main.c
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
#include "data.h"
#include "mac.h"
#include "memref.h"
#include "stdint.h"

#include <snrt.h>

// Kernel provided via external definition
void simple_mult(int32_t *a, int32_t *b, int32_t *d);
void _mlir_ciface_simple_mult(OneDMemrefI32_t *a, OneDMemrefI32_t *b,
OneDMemrefI32_t *d);

void snax_hwpe_mult(int32_t *a, int32_t *b, int32_t *d) {
// shape of data is statically defined in data.h
snax_mac_setup_simple_mult(a, b, d, N);
void _mlir_ciface_snax_hwpe_mult(OneDMemrefI32_t *a, OneDMemrefI32_t *b,
OneDMemrefI32_t *d) {
snax_mac_setup_simple_mult(a->aligned_data, b->aligned_data, d->aligned_data,
a->shape[0]);
snax_mac_launch();
snax_mac_sw_barrier();
}
Expand All @@ -20,14 +23,41 @@ int main() {
// (snrt_l1_next()) that is the same for all the cores in the cluster, we are
// essentially providing the same memory regions to all the cores in this
// cluster.
int32_t *local_A = (int32_t *)snrt_l1_next();
int32_t *local_B = local_A + N;
int32_t *local_D = local_B + N;

uint32_t constant_zero = 0;
uint32_t constant_size = N;
// Allocate memory for the fields

OneDMemrefI32_t memrefA = {
.data = (int32_t *)snrt_l1_next(),
.aligned_data = memrefA.data,
.offset = &constant_zero,
.shape[0] = &constant_size,
.stride[0] = &constant_zero,
};

OneDMemrefI32_t memrefB = {
.data = (int32_t *)memrefA.data + N,
.aligned_data = memrefB.data,
.offset = &constant_zero,
.shape[0] = &constant_size,
.stride[0] = &constant_zero,
};

OneDMemrefI32_t memrefD = {
.data = (int32_t *)memrefB.data + N,
.aligned_data = memrefD.data,
.offset = &constant_zero,
.shape[0] = &constant_size,
.stride[0] = &constant_zero,
};

// Copy data in shared local memory
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_A, A, N * sizeof(float));
snrt_dma_start_1d(local_B, B, N * sizeof(float));
snrt_dma_start_1d(memrefA.aligned_data, A,
*(memrefA.shape[0]) * sizeof(int32_t));
snrt_dma_start_1d(memrefB.aligned_data, B,
*(memrefB.shape[0]) * sizeof(int32_t));
}

snrt_cluster_hw_barrier();
Expand All @@ -38,13 +68,13 @@ int main() {
return 0;

(void)snrt_mcycle();
simple_mult(local_A, local_B, local_D);
_mlir_ciface_simple_mult(&memrefA, &memrefB, &memrefD);
(void)snrt_mcycle();

// Correctness check
int nerr = 0;
for (int i = 0; i < N; i++) {
int32_t error = local_D[i] - G[i];
int32_t error = memrefD.aligned_data[i] - G[i];
if (error != 0)
nerr += 1;
}
Expand Down
4 changes: 3 additions & 1 deletion runtime/Makefile.rules
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ CFLAGS += -I$(SNITCH_RTL_PATH)/sw/math/src/internal
CFLAGS += -I$(SNITCH_RTL_PATH)/sw/math/include/bits
CFLAGS += -I$(SNITCH_RTL_PATH)/sw/math/include
CFLAGS += -I$(SNITCH_RTL_PATH)/target/snitch_cluster/sw/snax/mac/include
CFLAGS += -I$(MAKEFILE_RULES_DIRNAME)include
CFLAGS += -D__DEFINED_uint64_t
CFLAGS += -menable-experimental-extensions
CFLAGS += -mcpu=snitch
Expand Down Expand Up @@ -103,8 +104,9 @@ MLIROPTFLAGS += --convert-scf-to-cf
MLIROPTFLAGS += --canonicalize
MLIROPTFLAGS += --cse
MLIROPTFLAGS += --convert-math-to-llvm
MLIROPTFLAGS += --llvm-request-c-wrappers
MLIROPTFLAGS += --convert-memref-to-llvm='use-generic-functions index-bitwidth=32'
MLIROPTFLAGS += --convert-func-to-llvm='use-bare-ptr-memref-call-conv index-bitwidth=32'
MLIROPTFLAGS += --convert-func-to-llvm='index-bitwidth=32'
MLIROPTFLAGS += --convert-index-to-llvm=index-bitwidth=32
MLIROPTFLAGS += --convert-cf-to-llvm=index-bitwidth=32
MLIROPTFLAGS += --convert-arith-to-llvm=index-bitwidth=32
Expand Down
15 changes: 15 additions & 0 deletions runtime/include/memref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <stdint.h>

struct OneDMemrefI32 {
int32_t *data; // allocated pointer: Pointer to data buffer as allocated,
// only used for deallocating the memref
int32_t *aligned_data; // aligned pointer: Pointer to properly aligned data
// that memref indexes
uint32_t *offset;
uint32_t *shape[1];
uint32_t *stride[1];
};

typedef struct OneDMemrefI32 OneDMemrefI32_t;

0 comments on commit 9a02735

Please sign in to comment.