From 82a439440e829507d56a1271581fd4d51efc1b7a Mon Sep 17 00:00:00 2001 From: jorendumoulin <47864363+jorendumoulin@users.noreply.github.com> Date: Tue, 5 Dec 2023 08:38:28 +0100 Subject: [PATCH] simple_mult end to end (#35) * add makefile passes * move memory management to linalg * clear memory space pass * snax renaming * snax renaming * add new passes to makefile * fix regex * add snax runtime library * update renaming in tests * fix simple copy kernel * fix alloc kernel * use regular ints * better comment --- compiler/tools/snax_opt_main.py | 2 ++ compiler/transforms/clear_memory_space.py | 36 +++++++++++++++++++ compiler/transforms/dispatch_regions.py | 16 ++++----- compiler/transforms/set_memory_space.py | 2 +- compiler/transforms/snax_to_func.py | 6 ++-- kernels/alloc/func.mlir | 6 ++-- kernels/alloc/main.c | 16 +-------- kernels/simple_copy/main.c | 7 +--- kernels/simple_mult/main.c | 36 ++++++------------- runtime/Makefile.rules | 2 +- runtime/include/snax_rt.h | 34 ++++++++++++++++++ runtime/tomlir16.py | 3 +- .../transforms/dispatch_regions.mlir | 16 ++++----- tests/filecheck/transforms/snax_to_func.mlir | 4 +-- 14 files changed, 113 insertions(+), 73 deletions(-) create mode 100644 compiler/transforms/clear_memory_space.py create mode 100644 runtime/include/snax_rt.h diff --git a/compiler/tools/snax_opt_main.py b/compiler/tools/snax_opt_main.py index 5a9814f9..6328b479 100644 --- a/compiler/tools/snax_opt_main.py +++ b/compiler/tools/snax_opt_main.py @@ -9,6 +9,7 @@ from compiler.transforms.dispatch_regions import DispatchRegions from compiler.transforms.snax_copy_to_dma import SNAXCopyToDMA from compiler.transforms.snax_to_func import SNAXToFunc +from compiler.transforms.clear_memory_space import ClearMemorySpace from collections.abc import Sequence @@ -37,6 +38,7 @@ def __init__( super().register_pass(DispatchRegions) super().register_pass(SNAXCopyToDMA) super().register_pass(SNAXToFunc) + super().register_pass(ClearMemorySpace) # arg handling arg_parser = argparse.ArgumentParser(description=description) diff --git a/compiler/transforms/clear_memory_space.py b/compiler/transforms/clear_memory_space.py new file mode 100644 index 00000000..4084a48d --- /dev/null +++ b/compiler/transforms/clear_memory_space.py @@ -0,0 +1,36 @@ +from xdsl.dialects import builtin, memref, func +from xdsl.ir import MLContext +from xdsl.passes import ModulePass + + +class ClearMemorySpace(ModulePass): + name = "clear-memory-space" + + def apply(self, ctx: MLContext, module: builtin.ModuleOp) -> None: + # helper function to clear the memory space of a memref + def clear_memory_space(t): + if isinstance(t, memref.MemRefType): + if not isinstance(t.memory_space, builtin.NoneAttr): + return memref.MemRefType.from_element_type_and_shape( + t.element_type, + t.get_shape(), + t.layout, + builtin.NoneAttr(), + ) + return t + + for op_in_module in module.walk(): + for operand in op_in_module.operands: + operand.type = clear_memory_space(operand.type) + + if isinstance(op_in_module, func.FuncOp): + # special case for func ops because func ops do not have + # operands, they have function_types which have ins & outs + # Define new function type with updated inputs and outputs + # mapped to a default memory space + new_function_type = builtin.FunctionType.from_lists( + map(clear_memory_space, op_in_module.function_type.inputs), + map(clear_memory_space, op_in_module.function_type.outputs), + ) + + op_in_module.function_type = new_function_type diff --git a/compiler/transforms/dispatch_regions.py b/compiler/transforms/dispatch_regions.py index 82db3147..35fb3611 100644 --- a/compiler/transforms/dispatch_regions.py +++ b/compiler/transforms/dispatch_regions.py @@ -61,7 +61,7 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable): ## dispatch dm core ops, insert function call # in dominator block if changes made - func_call_dm = func.Call("snrt_is_dm_core", [], [builtin.i1]) + func_call_dm = func.Call("snax_is_dm_core", [], [builtin.i1]) if any( dispatcher(block, func_call_dm.res[0], dispatch_to_dm) for block in func_op.body.blocks @@ -70,7 +70,7 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable): ## dispatch compute core ops, insert function call # in dominator block if changes made - func_call_compute = func.Call("snrt_is_compute_core", [], [builtin.i1]) + func_call_compute = func.Call("snax_is_compute_core", [], [builtin.i1]) if any( dispatcher(block, func_call_compute.res[0], dispatch_to_compute) for block in func_op.body.blocks @@ -80,21 +80,21 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable): class InsertFunctionDeclaration(RewritePattern): - """Insert external function declarations of snrt_is_compute core - and snrt_is_dm_core if they are used in the module""" + """Insert external function declarations of snax_is_compute core + and snax_is_dm_core if they are used in the module""" @op_type_rewrite_pattern def match_and_rewrite(self, module_op: builtin.ModuleOp, rewriter: PatternRewriter): for op in module_op.walk(): if isinstance(op, func.Call): - if op.callee.string_value() == "snrt_is_compute_core": + if op.callee.string_value() == "snax_is_compute_core": func_op_compute = func.FuncOp.external( - "snrt_is_compute_core", [], [builtin.i1] + "snax_is_compute_core", [], [builtin.i1] ) SymbolTable.insert_or_update(module_op, func_op_compute) - if op.callee.string_value() == "snrt_is_dm_core": + if op.callee.string_value() == "snax_is_dm_core": func_op_dm = func.FuncOp.external( - "snrt_is_dm_core", [], [builtin.i1] + "snax_is_dm_core", [], [builtin.i1] ) SymbolTable.insert_or_update(module_op, func_op_dm) diff --git a/compiler/transforms/set_memory_space.py b/compiler/transforms/set_memory_space.py index 15015e0e..3ff4b053 100644 --- a/compiler/transforms/set_memory_space.py +++ b/compiler/transforms/set_memory_space.py @@ -46,7 +46,7 @@ def change_to_memory_space(t): # mapped to a default memory space new_function_type = builtin.FunctionType.from_lists( map(change_to_memory_space, op.function_type.inputs), - map(change_to_memory_space, op.function_type.outputs), + (op.function_type.outputs), ) # Change region of function to use new argument types diff --git a/compiler/transforms/snax_to_func.py b/compiler/transforms/snax_to_func.py index d55fe415..5243897c 100644 --- a/compiler/transforms/snax_to_func.py +++ b/compiler/transforms/snax_to_func.py @@ -15,16 +15,16 @@ class InsertFunctionCall(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, func_op: snax.ClusterSyncOp, rewriter: PatternRewriter): """Swap cluster sync op with function call""" - func_call = func.Call("snrt_cluster_hw_barrier", [], []) + func_call = func.Call("snax_cluster_hw_barrier", [], []) rewriter.replace_matched_op(func_call) class InsertFunctionDeclaration(RewritePattern): - """Insert external function declarations of snrt_cluster_hw_barrier""" + """Insert external function declarations of snax_cluster_hw_barrier""" @op_type_rewrite_pattern def match_and_rewrite(self, module_op: builtin.ModuleOp, rewriter: PatternRewriter): - func_op = func.FuncOp.external("snrt_cluster_hw_barrier", [], []) + func_op = func.FuncOp.external("snax_cluster_hw_barrier", [], []) SymbolTable.insert_or_update(module_op, func_op) diff --git a/kernels/alloc/func.mlir b/kernels/alloc/func.mlir index e5f5bfa9..c5a83b84 100644 --- a/kernels/alloc/func.mlir +++ b/kernels/alloc/func.mlir @@ -1,4 +1,4 @@ -func.func public @simple_alloc() -> (memref<10xi32>) { - %alloc = "memref.alloc"() {"alignment" = 64 : i64, operand_segment_sizes = array} : () -> memref<10xi32> - return %alloc : memref<10xi32> +func.func public @simple_alloc() -> (memref<10xi32, 1 : i32>) { + %alloc = "memref.alloc"() {"alignment" = 64 : i64, operand_segment_sizes = array} : () -> memref<10xi32, 1 : i32> + return %alloc : memref<10xi32, 1 : i32> } diff --git a/kernels/alloc/main.c b/kernels/alloc/main.c index 93a98087..8e631d61 100644 --- a/kernels/alloc/main.c +++ b/kernels/alloc/main.c @@ -1,4 +1,5 @@ #include "memref.h" +#include "snax_rt.h" #include "stdint.h" #include @@ -6,21 +7,6 @@ // The C interface converts the pass by value to a pass by refernece! void _mlir_ciface_simple_alloc(OneDMemrefI32_t *returned_alloc); -int8_t *allocated_pointer; - -int8_t *_mlir_memref_to_llvm_alloc(uint32_t size) { - /* This calls malloc on the DMA core - * --> requires mlir opt to compile with: - * --convert-memref-to-llvm="use-generic-functions index-bitwidth=32" - * To ensure that all cores in the cluster come up with the correct - */ - if (snrt_is_dm_core()) { - allocated_pointer = (int8_t *)snrt_l1alloc(size); - } - snrt_cluster_hw_barrier(); - return allocated_pointer; -}; - int main() { // Allocate memory for the fields // static is required to have a global heap-allocated value for all cores diff --git a/kernels/simple_copy/main.c b/kernels/simple_copy/main.c index 2a8e713a..eeed3fc3 100644 --- a/kernels/simple_copy/main.c +++ b/kernels/simple_copy/main.c @@ -1,15 +1,10 @@ #include "data.h" #include "memref.h" +#include "snax_rt.h" #include "stdint.h" #include #include -void _mlir_ciface_snax_dma_1d_transfer(size_t *source, size_t *destination, - size_t size) { - snrt_dma_start_1d((void *)destination, (void *)source, size * sizeof(size_t)); - return; -} - int main() { // create memref object for A diff --git a/kernels/simple_mult/main.c b/kernels/simple_mult/main.c index ae8ee30f..ffbd75de 100644 --- a/kernels/simple_mult/main.c +++ b/kernels/simple_mult/main.c @@ -1,9 +1,11 @@ #include "data.h" #include "mac.h" #include "memref.h" +#include "snax_rt.h" #include "stdint.h" #include +#include // Kernel provided via external definition void _mlir_ciface_simple_mult(OneDMemrefI32_t *a, OneDMemrefI32_t *b, @@ -18,55 +20,39 @@ void _mlir_ciface_snax_hwpe_mult(OneDMemrefI32_t *a, OneDMemrefI32_t *b, } int main() { - // Allocate shared local memory - // By avoiding allocators and bumping by a known offset a base pointer - // (snrt_l1_next()) that is the same for all the cores in the cluster, we are - // essentially providing the same memory regions to all the cores in this - // cluster. - - // Allocate memory for the fields + // Create memref objects for data stored in L1 OneDMemrefI32_t memrefA; - memrefA.data = (int32_t *)snrt_l1_next(); + memrefA.data = &A; memrefA.aligned_data = memrefA.data; memrefA.offset = 0; memrefA.shape[0] = N; memrefA.stride[0] = sizeof(int32_t); OneDMemrefI32_t memrefB; - memrefB.data = (int32_t *)memrefA.data + N; + memrefB.data = &B; memrefB.aligned_data = memrefB.data; memrefB.offset = 0; memrefB.shape[0] = N; memrefB.stride[0] = sizeof(int32_t); OneDMemrefI32_t memrefD; - memrefD.data = (int32_t *)memrefB.data + N; + memrefD.data = &G; memrefD.aligned_data = memrefD.data; memrefD.offset = 0; memrefD.shape[0] = N; memrefD.stride[0] = sizeof(int32_t); - // Copy data in shared local memory - if (snrt_is_dm_core()) { - snrt_dma_start_1d(memrefA.aligned_data, A, - (memrefA.shape[0]) * sizeof(int32_t)); - snrt_dma_start_1d(memrefB.aligned_data, B, - (memrefB.shape[0]) * sizeof(int32_t)); - } - - snrt_cluster_hw_barrier(); + (void)snrt_mcycle(); + _mlir_ciface_simple_mult(&memrefA, &memrefB, &memrefD); + (void)snrt_mcycle(); - // Launch kernel: from this point on only core 0 is required to be alive. + // Correctness check - + // from this point on only core 0 is required to be alive. int thiscore = snrt_cluster_core_idx(); if (thiscore != 0) return 0; - (void)snrt_mcycle(); - _mlir_ciface_simple_mult(&memrefA, &memrefB, &memrefD); - (void)snrt_mcycle(); - - // Correctness check int nerr = 0; for (int i = 0; i < N; i++) { int32_t error = memrefD.aligned_data[i] - G[i]; diff --git a/runtime/Makefile.rules b/runtime/Makefile.rules index f77964a4..395e5324 100644 --- a/runtime/Makefile.rules +++ b/runtime/Makefile.rules @@ -88,7 +88,7 @@ MLIRPREPROCFLAGS += --mlir-print-local-scope # SNAX opt -SNAXOPTFLAGS = -p dispatch-elementwise-mult,linalg-to-library-call,snax-copy-to-dma +SNAXOPTFLAGS = -p set-memory-space,insert-sync-barrier,dispatch-regions,dispatch-elementwise-mult,linalg-to-library-call,snax-copy-to-dma,snax-to-func,clear-memory-space %.preproc18.mlir: %.preproc.mlir $(MAKEFILE_RULES_DIRNAME)/tomlir18.py < $< > $@ diff --git a/runtime/include/snax_rt.h b/runtime/include/snax_rt.h new file mode 100644 index 00000000..f518bf93 --- /dev/null +++ b/runtime/include/snax_rt.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +int8_t *allocated_pointer; + +int8_t *_mlir_memref_to_llvm_alloc(uint32_t size) { + /* This calls malloc on the DMA core + * --> requires mlir opt to compile with: + * --convert-memref-to-llvm="use-generic-functions index-bitwidth=32" + * To ensure that all cores in the cluster come up with the correct + */ + if (snrt_is_dm_core()) { + allocated_pointer = (int8_t *)snrt_l1alloc(size); + } + snrt_cluster_hw_barrier(); + return allocated_pointer; +}; + +void _mlir_ciface_snax_cluster_hw_barrier() { + snrt_cluster_hw_barrier(); + return; +} + +void _mlir_ciface_snax_dma_1d_transfer(size_t *source, size_t *destination, + size_t size) { + snrt_dma_start_1d((void *)destination, (void *)source, size * sizeof(size_t)); + return; +} + +int _mlir_ciface_snax_is_dm_core() { return snrt_is_dm_core(); } + +int _mlir_ciface_snax_is_compute_core() { return snrt_is_compute_core(); } diff --git a/runtime/tomlir16.py b/runtime/tomlir16.py index a7d6059c..278793e4 100755 --- a/runtime/tomlir16.py +++ b/runtime/tomlir16.py @@ -22,8 +22,9 @@ lambda x: f"{x.group(2)} {x.group(1)}", ir, ) - ir = re.sub(r"<({.*})>", lambda x: f"{x.group(1)}", ir) + # join 2 separate attribute regions generated by the previous regexes + ir = re.sub(r"}[ \t]*{", ", ", ir) # remove quotes ir = re.sub('"indexing_maps"', "indexing_maps", ir) # remove quotes diff --git a/tests/filecheck/transforms/dispatch_regions.mlir b/tests/filecheck/transforms/dispatch_regions.mlir index 387d6400..e47dac0c 100644 --- a/tests/filecheck/transforms/dispatch_regions.mlir +++ b/tests/filecheck/transforms/dispatch_regions.mlir @@ -31,7 +31,7 @@ //CHECK: "builtin.module"() ({ //CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ //CHECK-NEXT: ^0(%0 : memref<64xi32>, %1 : memref<64xi32>, %2 : memref<64xi32>): -//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snrt_is_compute_core}> : () -> i1 +//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snax_is_compute_core}> : () -> i1 //CHECK-NEXT: "scf.if"(%3) ({ //CHECK-NEXT: "linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type], "operandSegmentSizes" = array}> ({ //CHECK-NEXT: ^1(%arg0 : i32, %arg1 : i32, %arg2 : i32): @@ -43,7 +43,7 @@ //CHECK-NEXT: }) : (i1) -> () //CHECK-NEXT: "func.return"() : () -> () //CHECK-NEXT: }) : () -> () -//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ +//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ //CHECK-NEXT: }) : () -> () //CHECK-NEXT: }) : () -> () // ----- @@ -58,7 +58,7 @@ //CHECK: "builtin.module"() ({ //CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ //CHECK-NEXT: ^0(%0 : memref<64xi32>, %1 : memref<64xi32>): -//CHECK-NEXT: %2 = "func.call"() <{"callee" = @snrt_is_dm_core}> : () -> i1 +//CHECK-NEXT: %2 = "func.call"() <{"callee" = @snax_is_dm_core}> : () -> i1 //CHECK-NEXT: "scf.if"(%2) ({ //CHECK-NEXT: "memref.copy"(%0, %1) : (memref<64xi32>, memref<64xi32>) -> () //CHECK-NEXT: "scf.yield"() : () -> () @@ -66,7 +66,7 @@ //CHECK-NEXT: }) : (i1) -> () //CHECK-NEXT: "func.return"() : () -> () //CHECK-NEXT: }) : () -> () -//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ +//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ //CHECK-NEXT: }) : () -> () //CHECK-NEXT: }) : () -> () // ----- @@ -88,8 +88,8 @@ //CHECK: "builtin.module"() ({ //CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ //CHECK-NEXT: ^0(%0 : memref<64xi32>, %1 : memref<64xi32>, %2 : memref<64xi32>): -//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snrt_is_compute_core}> : () -> i1 -//CHECK-NEXT: %4 = "func.call"() <{"callee" = @snrt_is_dm_core}> : () -> i1 +//CHECK-NEXT: %3 = "func.call"() <{"callee" = @snax_is_compute_core}> : () -> i1 +//CHECK-NEXT: %4 = "func.call"() <{"callee" = @snax_is_dm_core}> : () -> i1 //CHECK-NEXT: %alloc = "memref.alloc"() <{"operandSegmentSizes" = array}> {"alignment" = 64 : i64} : () -> memref<64xi32> //CHECK-NEXT: "scf.if"(%4) ({ //CHECK-NEXT: "memref.copy"(%0, %1) : (memref<64xi32>, memref<64xi32>) -> () @@ -108,9 +108,9 @@ //CHECK-NEXT: }) : (i1) -> () //CHECK-NEXT: "func.return"() : () -> () //CHECK-NEXT: }) : () -> () -//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ +//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_compute_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ //CHECK-NEXT: }) : () -> () -//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ +//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_is_dm_core", "function_type" = () -> i1, "sym_visibility" = "private"}> ({ //CHECK-NEXT: }) : () -> () //CHECK-NEXT: }) : () -> () diff --git a/tests/filecheck/transforms/snax_to_func.mlir b/tests/filecheck/transforms/snax_to_func.mlir index cbc6e9c1..dcbe4d68 100644 --- a/tests/filecheck/transforms/snax_to_func.mlir +++ b/tests/filecheck/transforms/snax_to_func.mlir @@ -5,7 +5,7 @@ }) : () -> () //CHECK: "builtin.module"() ({ -//CHECK-NEXT: "func.call"() <{"callee" = @snrt_cluster_hw_barrier}> : () -> () -//CHECK-NEXT: "func.func"() <{"sym_name" = "snrt_cluster_hw_barrier", "function_type" = () -> (), "sym_visibility" = "private"}> ({ +//CHECK-NEXT: "func.call"() <{"callee" = @snax_cluster_hw_barrier}> : () -> () +//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_cluster_hw_barrier", "function_type" = () -> (), "sym_visibility" = "private"}> ({ //CHECK-NEXT: }) : () -> () //CHECK-NEXT: }) : () -> ()