Skip to content

Commit

Permalink
simple_mult end to end (#35)
Browse files Browse the repository at this point in the history
* add makefile passes

* move memory management to linalg

* clear memory space pass

* snax renaming

* snax renaming

* add new passes to makefile

* fix regex

* add snax runtime library

* update renaming in tests

* fix simple copy kernel

* fix alloc kernel

* use regular ints

* better comment
  • Loading branch information
jorendumoulin authored Dec 5, 2023
1 parent 8ff15dd commit 82a4394
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 73 deletions.
2 changes: 2 additions & 0 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from compiler.transforms.dispatch_regions import DispatchRegions
from compiler.transforms.snax_copy_to_dma import SNAXCopyToDMA
from compiler.transforms.snax_to_func import SNAXToFunc
from compiler.transforms.clear_memory_space import ClearMemorySpace
from collections.abc import Sequence


Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
super().register_pass(DispatchRegions)
super().register_pass(SNAXCopyToDMA)
super().register_pass(SNAXToFunc)
super().register_pass(ClearMemorySpace)

# arg handling
arg_parser = argparse.ArgumentParser(description=description)
Expand Down
36 changes: 36 additions & 0 deletions compiler/transforms/clear_memory_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from xdsl.dialects import builtin, memref, func
from xdsl.ir import MLContext
from xdsl.passes import ModulePass


class ClearMemorySpace(ModulePass):
name = "clear-memory-space"

def apply(self, ctx: MLContext, module: builtin.ModuleOp) -> None:
# helper function to clear the memory space of a memref
def clear_memory_space(t):
if isinstance(t, memref.MemRefType):
if not isinstance(t.memory_space, builtin.NoneAttr):
return memref.MemRefType.from_element_type_and_shape(
t.element_type,
t.get_shape(),
t.layout,
builtin.NoneAttr(),
)
return t

for op_in_module in module.walk():
for operand in op_in_module.operands:
operand.type = clear_memory_space(operand.type)

if isinstance(op_in_module, func.FuncOp):
# special case for func ops because func ops do not have
# operands, they have function_types which have ins & outs
# Define new function type with updated inputs and outputs
# mapped to a default memory space
new_function_type = builtin.FunctionType.from_lists(
map(clear_memory_space, op_in_module.function_type.inputs),
map(clear_memory_space, op_in_module.function_type.outputs),
)

op_in_module.function_type = new_function_type
16 changes: 8 additions & 8 deletions compiler/transforms/dispatch_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable):

## dispatch dm core ops, insert function call
# in dominator block if changes made
func_call_dm = func.Call("snrt_is_dm_core", [], [builtin.i1])
func_call_dm = func.Call("snax_is_dm_core", [], [builtin.i1])
if any(
dispatcher(block, func_call_dm.res[0], dispatch_to_dm)
for block in func_op.body.blocks
Expand All @@ -70,7 +70,7 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable):

## dispatch compute core ops, insert function call
# in dominator block if changes made
func_call_compute = func.Call("snrt_is_compute_core", [], [builtin.i1])
func_call_compute = func.Call("snax_is_compute_core", [], [builtin.i1])
if any(
dispatcher(block, func_call_compute.res[0], dispatch_to_compute)
for block in func_op.body.blocks
Expand All @@ -80,21 +80,21 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable):


class InsertFunctionDeclaration(RewritePattern):
"""Insert external function declarations of snrt_is_compute core
and snrt_is_dm_core if they are used in the module"""
"""Insert external function declarations of snax_is_compute core
and snax_is_dm_core if they are used in the module"""

@op_type_rewrite_pattern
def match_and_rewrite(self, module_op: builtin.ModuleOp, rewriter: PatternRewriter):
for op in module_op.walk():
if isinstance(op, func.Call):
if op.callee.string_value() == "snrt_is_compute_core":
if op.callee.string_value() == "snax_is_compute_core":
func_op_compute = func.FuncOp.external(
"snrt_is_compute_core", [], [builtin.i1]
"snax_is_compute_core", [], [builtin.i1]
)
SymbolTable.insert_or_update(module_op, func_op_compute)
if op.callee.string_value() == "snrt_is_dm_core":
if op.callee.string_value() == "snax_is_dm_core":
func_op_dm = func.FuncOp.external(
"snrt_is_dm_core", [], [builtin.i1]
"snax_is_dm_core", [], [builtin.i1]
)
SymbolTable.insert_or_update(module_op, func_op_dm)

Expand Down
2 changes: 1 addition & 1 deletion compiler/transforms/set_memory_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def change_to_memory_space(t):
# mapped to a default memory space
new_function_type = builtin.FunctionType.from_lists(
map(change_to_memory_space, op.function_type.inputs),
map(change_to_memory_space, op.function_type.outputs),
(op.function_type.outputs),
)

# Change region of function to use new argument types
Expand Down
6 changes: 3 additions & 3 deletions compiler/transforms/snax_to_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ class InsertFunctionCall(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, func_op: snax.ClusterSyncOp, rewriter: PatternRewriter):
"""Swap cluster sync op with function call"""
func_call = func.Call("snrt_cluster_hw_barrier", [], [])
func_call = func.Call("snax_cluster_hw_barrier", [], [])
rewriter.replace_matched_op(func_call)


class InsertFunctionDeclaration(RewritePattern):
"""Insert external function declarations of snrt_cluster_hw_barrier"""
"""Insert external function declarations of snax_cluster_hw_barrier"""

@op_type_rewrite_pattern
def match_and_rewrite(self, module_op: builtin.ModuleOp, rewriter: PatternRewriter):
func_op = func.FuncOp.external("snrt_cluster_hw_barrier", [], [])
func_op = func.FuncOp.external("snax_cluster_hw_barrier", [], [])
SymbolTable.insert_or_update(module_op, func_op)


Expand Down
6 changes: 3 additions & 3 deletions kernels/alloc/func.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
func.func public @simple_alloc() -> (memref<10xi32>) {
%alloc = "memref.alloc"() {"alignment" = 64 : i64, operand_segment_sizes = array<i32: 0, 0>} : () -> memref<10xi32>
return %alloc : memref<10xi32>
func.func public @simple_alloc() -> (memref<10xi32, 1 : i32>) {
%alloc = "memref.alloc"() {"alignment" = 64 : i64, operand_segment_sizes = array<i32: 0, 0>} : () -> memref<10xi32, 1 : i32>
return %alloc : memref<10xi32, 1 : i32>
}
16 changes: 1 addition & 15 deletions kernels/alloc/main.c
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
#include "memref.h"
#include "snax_rt.h"
#include "stdint.h"
#include <snrt.h>

// Kernel provided via external definition
// The C interface converts the pass by value to a pass by refernece!
void _mlir_ciface_simple_alloc(OneDMemrefI32_t *returned_alloc);

int8_t *allocated_pointer;

int8_t *_mlir_memref_to_llvm_alloc(uint32_t size) {
/* This calls malloc on the DMA core
* --> requires mlir opt to compile with:
* --convert-memref-to-llvm="use-generic-functions index-bitwidth=32"
* To ensure that all cores in the cluster come up with the correct
*/
if (snrt_is_dm_core()) {
allocated_pointer = (int8_t *)snrt_l1alloc(size);
}
snrt_cluster_hw_barrier();
return allocated_pointer;
};

int main() {
// Allocate memory for the fields
// static is required to have a global heap-allocated value for all cores
Expand Down
7 changes: 1 addition & 6 deletions kernels/simple_copy/main.c
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#include "data.h"
#include "memref.h"
#include "snax_rt.h"
#include "stdint.h"
#include <snrt.h>
#include <stdint.h>

void _mlir_ciface_snax_dma_1d_transfer(size_t *source, size_t *destination,
size_t size) {
snrt_dma_start_1d((void *)destination, (void *)source, size * sizeof(size_t));
return;
}

int main() {

// create memref object for A
Expand Down
36 changes: 11 additions & 25 deletions kernels/simple_mult/main.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "data.h"
#include "mac.h"
#include "memref.h"
#include "snax_rt.h"
#include "stdint.h"

#include <snrt.h>
#include <stdint.h>

// Kernel provided via external definition
void _mlir_ciface_simple_mult(OneDMemrefI32_t *a, OneDMemrefI32_t *b,
Expand All @@ -18,55 +20,39 @@ void _mlir_ciface_snax_hwpe_mult(OneDMemrefI32_t *a, OneDMemrefI32_t *b,
}

int main() {
// Allocate shared local memory
// By avoiding allocators and bumping by a known offset a base pointer
// (snrt_l1_next()) that is the same for all the cores in the cluster, we are
// essentially providing the same memory regions to all the cores in this
// cluster.

// Allocate memory for the fields

// Create memref objects for data stored in L1
OneDMemrefI32_t memrefA;
memrefA.data = (int32_t *)snrt_l1_next();
memrefA.data = &A;
memrefA.aligned_data = memrefA.data;
memrefA.offset = 0;
memrefA.shape[0] = N;
memrefA.stride[0] = sizeof(int32_t);

OneDMemrefI32_t memrefB;
memrefB.data = (int32_t *)memrefA.data + N;
memrefB.data = &B;
memrefB.aligned_data = memrefB.data;
memrefB.offset = 0;
memrefB.shape[0] = N;
memrefB.stride[0] = sizeof(int32_t);

OneDMemrefI32_t memrefD;
memrefD.data = (int32_t *)memrefB.data + N;
memrefD.data = &G;
memrefD.aligned_data = memrefD.data;
memrefD.offset = 0;
memrefD.shape[0] = N;
memrefD.stride[0] = sizeof(int32_t);

// Copy data in shared local memory
if (snrt_is_dm_core()) {
snrt_dma_start_1d(memrefA.aligned_data, A,
(memrefA.shape[0]) * sizeof(int32_t));
snrt_dma_start_1d(memrefB.aligned_data, B,
(memrefB.shape[0]) * sizeof(int32_t));
}

snrt_cluster_hw_barrier();
(void)snrt_mcycle();
_mlir_ciface_simple_mult(&memrefA, &memrefB, &memrefD);
(void)snrt_mcycle();

// Launch kernel: from this point on only core 0 is required to be alive.
// Correctness check -
// from this point on only core 0 is required to be alive.
int thiscore = snrt_cluster_core_idx();
if (thiscore != 0)
return 0;

(void)snrt_mcycle();
_mlir_ciface_simple_mult(&memrefA, &memrefB, &memrefD);
(void)snrt_mcycle();

// Correctness check
int nerr = 0;
for (int i = 0; i < N; i++) {
int32_t error = memrefD.aligned_data[i] - G[i];
Expand Down
2 changes: 1 addition & 1 deletion runtime/Makefile.rules
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ MLIRPREPROCFLAGS += --mlir-print-local-scope

# SNAX opt

SNAXOPTFLAGS = -p dispatch-elementwise-mult,linalg-to-library-call,snax-copy-to-dma
SNAXOPTFLAGS = -p set-memory-space,insert-sync-barrier,dispatch-regions,dispatch-elementwise-mult,linalg-to-library-call,snax-copy-to-dma,snax-to-func,clear-memory-space

%.preproc18.mlir: %.preproc.mlir
$(MAKEFILE_RULES_DIRNAME)/tomlir18.py < $< > $@
Expand Down
34 changes: 34 additions & 0 deletions runtime/include/snax_rt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <snrt.h>
#include <stdint.h>

int8_t *allocated_pointer;

int8_t *_mlir_memref_to_llvm_alloc(uint32_t size) {
/* This calls malloc on the DMA core
* --> requires mlir opt to compile with:
* --convert-memref-to-llvm="use-generic-functions index-bitwidth=32"
* To ensure that all cores in the cluster come up with the correct
*/
if (snrt_is_dm_core()) {
allocated_pointer = (int8_t *)snrt_l1alloc(size);
}
snrt_cluster_hw_barrier();
return allocated_pointer;
};

void _mlir_ciface_snax_cluster_hw_barrier() {
snrt_cluster_hw_barrier();
return;
}

void _mlir_ciface_snax_dma_1d_transfer(size_t *source, size_t *destination,
size_t size) {
snrt_dma_start_1d((void *)destination, (void *)source, size * sizeof(size_t));
return;
}

int _mlir_ciface_snax_is_dm_core() { return snrt_is_dm_core(); }

int _mlir_ciface_snax_is_compute_core() { return snrt_is_compute_core(); }
3 changes: 2 additions & 1 deletion runtime/tomlir16.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
lambda x: f"{x.group(2)} {x.group(1)}",
ir,
)

ir = re.sub(r"<({.*})>", lambda x: f"{x.group(1)}", ir)
# join 2 separate attribute regions generated by the previous regexes
ir = re.sub(r"}[ \t]*{", ", ", ir)
# remove quotes
ir = re.sub('"indexing_maps"', "indexing_maps", ir)
# remove quotes
Expand Down
16 changes: 8 additions & 8 deletions tests/filecheck/transforms/dispatch_regions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
//CHECK-NEXT: ^0(%0 : memref<64xi32>, %1 : memref<64xi32>, %2 : memref<64xi32>):
//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snrt_is_compute_core}> : () -> i1
//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snax_is_compute_core}> : () -> i1
//CHECK-NEXT: "scf.if"(%3) ({
//CHECK-NEXT: "linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({
//CHECK-NEXT: ^1(%arg0 : i32, %arg1 : i32, %arg2 : i32):
Expand All @@ -43,7 +43,7 @@
//CHECK-NEXT: }) : (i1) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: }) : () -> ()
// -----
Expand All @@ -58,15 +58,15 @@
//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
//CHECK-NEXT: ^0(%0 : memref<64xi32>, %1 : memref<64xi32>):
//CHECK-NEXT: %2 = "func.call"() <{"callee" = @snrt_is_dm_core}> : () -> i1
//CHECK-NEXT: %2 = "func.call"() <{"callee" = @snax_is_dm_core}> : () -> i1
//CHECK-NEXT: "scf.if"(%2) ({
//CHECK-NEXT: "memref.copy"(%0, %1) : (memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: "scf.yield"() : () -> ()
//CHECK-NEXT: }, {
//CHECK-NEXT: }) : (i1) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: }) : () -> ()
// -----
Expand All @@ -88,8 +88,8 @@
//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
//CHECK-NEXT: ^0(%0 : memref<64xi32>, %1 : memref<64xi32>, %2 : memref<64xi32>):
//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snrt_is_compute_core}> : () -> i1
//CHECK-NEXT: %4 = "func.call"() <{"callee" = @snrt_is_dm_core}> : () -> i1
//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snax_is_compute_core}> : () -> i1
//CHECK-NEXT: %4 = "func.call"() <{"callee" = @snax_is_dm_core}> : () -> i1
//CHECK-NEXT: %alloc = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> {"alignment" = 64 : i64} : () -> memref<64xi32>
//CHECK-NEXT: "scf.if"(%4) ({
//CHECK-NEXT: "memref.copy"(%0, %1) : (memref<64xi32>, memref<64xi32>) -> ()
Expand All @@ -108,9 +108,9 @@
//CHECK-NEXT: }) : (i1) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: }) : () -> ()

4 changes: 2 additions & 2 deletions tests/filecheck/transforms/snax_to_func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
}) : () -> ()

//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.call"() <{"callee" = @snrt_cluster_hw_barrier}> : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_cluster_hw_barrier", "function_type" = () -> (), "sym_visibility" = "private"}> ({
//CHECK-NEXT: "func.call"() <{"callee" = @snax_cluster_hw_barrier}> : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_cluster_hw_barrier", "function_type" = () -> (), "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: }) : () -> ()

0 comments on commit 82a4394

Please sign in to comment.