Skip to content

Commit

Permalink
Fix gemmini kernel, add snakemake, remove Makefile (#316)
Browse files Browse the repository at this point in the history
* Fix gemmini kernel, add snakemake, remove Makefile

* Move gemmini kernel to snakemake CI
  • Loading branch information
JosseVanDelm authored Dec 20, 2024
1 parent 4a1ba86 commit 9d7aabc
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-run-kernel-snake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
working-directory: kernels/${{ matrix.kernel }}
strategy:
matrix:
kernel: [alloc, simple_copy, transform_copy, gemm, rescale]
kernel: [alloc, simple_copy, transform_copy, gemm, rescale, gemmini]
2 changes: 1 addition & 1 deletion .github/workflows/build-run-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
working-directory: kernels/${{ matrix.kernel }}
strategy:
matrix:
kernel: [streamer_alu, tiled_add, streamer_matmul, gemmini]
kernel: [streamer_alu, tiled_add, streamer_matmul]
42 changes: 0 additions & 42 deletions kernels/gemmini/Makefile

This file was deleted.

86 changes: 86 additions & 0 deletions kernels/gemmini/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from util.snake.paths import get_default_paths
from util.snake.flags import get_mlir_preproc_flags


config.update(get_default_paths())
config["snaxoptflags"] = ",".join(
[
"insert-accfg-op{accelerator=gemmini}",
"convert-linalg-to-accfg",
"convert-accfg-to-csr",
]
)

# Rocket-specific RV flags
config["cflags"] = [
"-c",
"-O3",
"-Wno-override-module",
"--target=riscv64-unknown-elf",
"-mcpu=generic-rv64",
"-march=rv64gc",
]

config["mliroptflags"] = [
"--pass-pipeline='builtin.module(transform-interpreter{debug-bind-trailing-args=linalg.quantized_matmul}, test-transform-dialect-erase-schedule, linalg-generalize-named-ops)'"
]
config["mlirpreprocflags"] = get_mlir_preproc_flags()


module default_rules:
snakefile:
"../../util/snake/default_rules.smk"
config:
config


use rule preprocess_mlir from default_rules as default_preprocess_mlir


use rule translate_mlir from default_rules as default_translate_mlir


rule all:
input:
"tiled_matmul.o",


sed_str = "'s/iterator_types =/library_call=\"gemmini\", iterator_types =/gm;t'"


rule prepare_mlir:
input:
"tiled_matmul.transform.mlir",
output:
"tiled_matmul.mlir",
shell:
"{config[mlir-opt]} {config[mliroptflags]} {input} | "
"sed -E {sed_str} > {output}"


rule snax_compile_mlir:
input:
"{file}.preprocfinal.mlir",
output:
"{file}.snax-opt.mlir",
shell:
"{config[snax-opt]} -p {config[snaxoptflags]} -o {output} {input}"


# Use default pass to lower to Rocket (64-bit index type)
rule postprocess_mlir:
input:
"{file}.snax-opt.mlir",
output:
"{file}.ll.mlir",
shell:
"{config[mlir-opt]} --test-lower-to-llvm -o {output} {input}"


rule compile_llvm_module:
input:
"{file}.ll",
output:
"{file}.o",
shell:
"{config[cc]} -x ir {input} {config[cflags]} -o {output}"
29 changes: 15 additions & 14 deletions kernels/gemmini/tiled_matmul.transform.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
func.func @tiled_matmul(%arg0: memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) {
%c0_i32 = arith.constant 0 : i32
linalg.quantized_matmul ins(%arg0, %arg1, %c0_i32, %c0_i32 : memref<128x128xi8>, memref<128x128xi8>, i32, i32) outs(%arg2 : memref<128x128xi32>)
return
}


transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.quantized_matmul">):
// The actual tiling transformation takes tile sizes as attributes.
%loop1, %loop2, %loop3, %tiled = transform.structured.tile %arg1 [80, 96, 128]
: (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
module attributes{transform.with_named_sequence}{
func.func @tiled_matmul(%arg0: memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) {
%c0_i32 = arith.constant 0 : i32
linalg.quantized_matmul ins(%arg0, %arg1, %c0_i32, %c0_i32 : memref<128x128xi8>, memref<128x128xi8>, i32, i32) outs(%arg2 : memref<128x128xi32>)
return
}


transform.named_sequence @__transform_main (%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.quantized_matmul">){
// The actual tiling transformation takes tile sizes as attributes.
%loop1, %loop2, %loop3, %tiled = transform.structured.tile_using_for %arg1 tile_sizes [80, 96, 128]
: (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}

0 comments on commit 9d7aabc

Please sign in to comment.