diff --git a/.github/workflows/build-run-kernel-snake.yml b/.github/workflows/build-run-kernel-snake.yml index ca40c5e2..6f89304f 100644 --- a/.github/workflows/build-run-kernel-snake.yml +++ b/.github/workflows/build-run-kernel-snake.yml @@ -20,4 +20,4 @@ jobs: working-directory: kernels/${{ matrix.kernel }} strategy: matrix: - kernel: [alloc, simple_copy, transform_copy, gemm, rescale] + kernel: [alloc, simple_copy, transform_copy, gemm, rescale, gemmini] diff --git a/.github/workflows/build-run-kernel.yml b/.github/workflows/build-run-kernel.yml index f12a7bde..491f8643 100644 --- a/.github/workflows/build-run-kernel.yml +++ b/.github/workflows/build-run-kernel.yml @@ -20,4 +20,4 @@ jobs: working-directory: kernels/${{ matrix.kernel }} strategy: matrix: - kernel: [streamer_alu, tiled_add, streamer_matmul, gemmini] + kernel: [streamer_alu, tiled_add, streamer_matmul] diff --git a/kernels/gemmini/Makefile b/kernels/gemmini/Makefile deleted file mode 100644 index a8e0f602..00000000 --- a/kernels/gemmini/Makefile +++ /dev/null @@ -1,42 +0,0 @@ -.DEFAULT_GOAL := all - -include ../../runtime/Makefile.rules - -TESTS = -TESTS += tiled_matmul.acc-dialect.x - -CFLAGS += -std=gnu11 -CFLAGS += -Wall -Wextra - - -# Override snax-opt rules to avoid linalg-to-library-call pass -SNAXOPTACCFLAGS = -p insert-accfg-op{accelerator=gemmini},convert-linalg-to-accfg,convert-accfg-to-csr - -tiled_matmul.mlir: tiled_matmul.transform.mlir - ${MLIROPT} --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=linalg.quantized_matmul}, test-transform-dialect-erase-schedule, linalg-generalize-named-ops)' $^ | sed -E 's/iterator_types =/library_call="gemmini", iterator_types =/gm;t' > $@ - -%.acc-dialect.snax-opt.mlir: %.preprocfinal.mlir - $(SNAXOPT) $(SNAXOPTACCFLAGS) -o $@ $< - -%.postproc.mlir: %.snax-opt.mlir - cat $< | sed 's/arith.maximumf/arith.maxf/g' | sed 's/arith.minimumf/arith.minf/g' > $@ - -%.ll.mlir: %.postproc.mlir - ${MLIROPT} --test-lower-to-llvm -o $@ $< - -%.ll: %.postproc.mlir - ${MLIRTRANSLATE} --mlir-to-llvmir -o $@ $< - -%.x: %.ll - ${CC} -x ir $< -c -O3 -o $@ -Wno-override-module --target=riscv64-unknown-elf -mcpu=generic-rv64 -march=rv64gc - -allrun: all - -all: $(TESTS) - -dump: $(TESTS) - ${OBJDUMP} -d $< - - -clean: - rm -fr *.ll12 *.x *.o *.logs/ logs/ data.h data.c diff --git a/kernels/gemmini/Snakefile b/kernels/gemmini/Snakefile new file mode 100644 index 00000000..3757689f --- /dev/null +++ b/kernels/gemmini/Snakefile @@ -0,0 +1,86 @@ +from util.snake.paths import get_default_paths +from util.snake.flags import get_mlir_preproc_flags + + +config.update(get_default_paths()) +config["snaxoptflags"] = ",".join( + [ + "insert-accfg-op{accelerator=gemmini}", + "convert-linalg-to-accfg", + "convert-accfg-to-csr", + ] +) + +# Rocket-specific RV flags +config["cflags"] = [ + "-c", + "-O3", + "-Wno-override-module", + "--target=riscv64-unknown-elf", + "-mcpu=generic-rv64", + "-march=rv64gc", +] + +config["mliroptflags"] = [ + "--pass-pipeline='builtin.module(transform-interpreter{debug-bind-trailing-args=linalg.quantized_matmul}, test-transform-dialect-erase-schedule, linalg-generalize-named-ops)'" +] +config["mlirpreprocflags"] = get_mlir_preproc_flags() + + +module default_rules: + snakefile: + "../../util/snake/default_rules.smk" + config: + config + + +use rule preprocess_mlir from default_rules as default_preprocess_mlir + + +use rule translate_mlir from default_rules as default_translate_mlir + + +rule all: + input: + "tiled_matmul.o", + + +sed_str = "'s/iterator_types =/library_call=\"gemmini\", iterator_types =/gm;t'" + + +rule prepare_mlir: + input: + "tiled_matmul.transform.mlir", + output: + "tiled_matmul.mlir", + shell: + "{config[mlir-opt]} {config[mliroptflags]} {input} | " + "sed -E {sed_str} > {output}" + + +rule snax_compile_mlir: + input: + "{file}.preprocfinal.mlir", + output: + "{file}.snax-opt.mlir", + shell: + "{config[snax-opt]} -p {config[snaxoptflags]} -o {output} {input}" + + +# Use default pass to lower to Rocket (64-bit index type) +rule postprocess_mlir: + input: + "{file}.snax-opt.mlir", + output: + "{file}.ll.mlir", + shell: + "{config[mlir-opt]} --test-lower-to-llvm -o {output} {input}" + + +rule compile_llvm_module: + input: + "{file}.ll", + output: + "{file}.o", + shell: + "{config[cc]} -x ir {input} {config[cflags]} -o {output}" diff --git a/kernels/gemmini/tiled_matmul.transform.mlir b/kernels/gemmini/tiled_matmul.transform.mlir index 660ae02b..1bb48a40 100644 --- a/kernels/gemmini/tiled_matmul.transform.mlir +++ b/kernels/gemmini/tiled_matmul.transform.mlir @@ -1,15 +1,16 @@ -func.func @tiled_matmul(%arg0: memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) { - %c0_i32 = arith.constant 0 : i32 - linalg.quantized_matmul ins(%arg0, %arg1, %c0_i32, %c0_i32 : memref<128x128xi8>, memref<128x128xi8>, i32, i32) outs(%arg2 : memref<128x128xi32>) - return -} - - -transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op, - %arg1: !transform.op<"linalg.quantized_matmul">): - // The actual tiling transformation takes tile sizes as attributes. - %loop1, %loop2, %loop3, %tiled = transform.structured.tile %arg1 [80, 96, 128] - : (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - transform.yield +module attributes{transform.with_named_sequence}{ + func.func @tiled_matmul(%arg0: memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) { + %c0_i32 = arith.constant 0 : i32 + linalg.quantized_matmul ins(%arg0, %arg1, %c0_i32, %c0_i32 : memref<128x128xi8>, memref<128x128xi8>, i32, i32) outs(%arg2 : memref<128x128xi32>) + return + } + + + transform.named_sequence @__transform_main (%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.quantized_matmul">){ + // The actual tiling transformation takes tile sizes as attributes. + %loop1, %loop2, %loop3, %tiled = transform.structured.tile_using_for %arg1 tile_sizes [80, 96, 128] + : (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } }