Skip to content

Commit

Permalink
Add alloc snakemake kernel (#310)
Browse files Browse the repository at this point in the history
* Add snakefmt to pre-commit

* Add default snakemake rules

* Remove makefile, clean up snakefile

* Add separate snakemake and makefile CI flows

* Actually add snakemake utils

* Adapt linker flag to point to conda environment

* Adapt other linker flag
  • Loading branch information
JosseVanDelm authored Dec 19, 2024
1 parent 4740a09 commit db2ea35
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 31 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/build-run-kernel-snake.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Build and run kernels

on:
push:
branches:
- main
pull_request:

jobs:
build-and-run-kernels:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v3
- uses: prefix-dev/[email protected]
with:
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
- name: Build and run kernels
run: pixi run snakemake -k -j `nproc`
working-directory: kernels/${{ matrix.kernel }}
strategy:
matrix:
kernel: [alloc]
2 changes: 1 addition & 1 deletion .github/workflows/build-run-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
working-directory: kernels/${{ matrix.kernel }}
strategy:
matrix:
kernel: [alloc, simple_copy, transform_copy, streamer_alu, tiled_add, streamer_matmul, gemmini, rescale, gemm]
kernel: [simple_copy, transform_copy, streamer_alu, tiled_add, streamer_matmul, gemmini, rescale, gemm]
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ repos:
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/snakemake/snakefmt
rev: v0.10.2
hooks:
- id: snakefmt
30 changes: 0 additions & 30 deletions kernels/alloc/Makefile

This file was deleted.

47 changes: 47 additions & 0 deletions kernels/alloc/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from util.snake.configs import get_snax_mac_config

config = get_snax_mac_config()

config["snaxoptflags"] = ",".join(
[
"dispatch-kernels",
"set-memory-space",
"set-memory-layout",
"realize-memref-casts",
"reuse-memref-allocs",
"insert-sync-barrier",
"dispatch-regions",
"linalg-to-library-call",
"snax-copy-to-dma",
"memref-to-snax",
"snax-to-func",
"clear-memory-space",
]
)


module default_rules:
snakefile:
"../../util/snake/default_rules.smk"
config:
config


rule all:
input:
"func.x",
shell:
"{config[vltsim]} {input[0]}"


use rule * from default_rules exclude compile_simple_main as default_*


rule compile_snax_binary:
input:
"func.o",
"main.o",
output:
"func.x",
shell:
"{config[ld]} {config[ldflags]} {input} -o {output}"
Empty file added util/snake/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions util/snake/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

from util.snake.flags import get_default_flags
from util.snake.paths import get_default_paths


def get_snax_mac_config():
# use CONDA_PREFIX to access pixi env
snax_utils_path = os.environ["CONDA_PREFIX"] + "/snax-utils"
snitch_sw_path = snax_utils_path + "/snax-mac"
config = {}
config.update(get_default_paths())
config.update(get_default_flags(snitch_sw_path))
config["vltsim"] = f"{snax_utils_path}/snax-mac-rtl/bin/snitch_cluster.vlt"
config["cflags"].append(
f"-I{snitch_sw_path}/target/snitch_cluster/sw/snax/mac/include"
)
config["ldflags"].append(
f"-I{snitch_sw_path}/target/snitch_cluster/sw/snax/mac/build/mac.o"
)
return config
105 changes: 105 additions & 0 deletions util/snake/default_rules.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
rule preprocess_mlir:
"""
Apply various preprocessing transformations to mlir files with upstream mlir.
Options controlled by `mlirpreprocflags` defined in config.
"""
input:
"{file}.mlir",
output:
temp("{file}.preproc1.mlir"),
temp("{file}.preproc2.mlir"),
temp("{file}.preprocfinal.mlir"),
run:
shell(
"{config[mlir-opt]} {config[mlirpreprocflags][0]} -o {wildcards.file}.preproc1.mlir {input}"
)
shell(
"{config[mlir-opt]} {config[mlirpreprocflags][1]} -o {wildcards.file}.preproc2.mlir {wildcards.file}.preproc1.mlir"
)
shell(
"{config[mlir-opt]} {config[mlirpreprocflags][2]} -o {output[2]} {wildcards.file}.preproc2.mlir"
)


rule snax_opt_mlir:
"""
Apply various transformations snax-opt on mlir files.
Options controlled with `snaxoptflags` defined in config.
"""
input:
"{file}.preprocfinal.mlir",
output:
temp("{file}.snax-opt.mlir"),
shell:
"{config[snax-opt]} -p {config[snaxoptflags]} -o {output} {input}"


rule postprocess_mlir:
"""
Apply various postprocessing transformations to mlir files with upstream mlir.
Goal is to lower everything to LLVM dialect after this step.
Options controlled with `mlirpostprocflags` defined in config.
"""
input:
"{file}.snax-opt.mlir",
output:
temp("{file}.ll.mlir"),
shell:
"{config[mlir-opt]} {config[mlirpostprocflags]} -o {output} {input}"


rule translate_mlir:
"""
Translate MLIR LLVM dialect to actual LLVM.
"""
input:
"{file}.ll.mlir",
output:
temp("{file}.ll"),
shell:
"{config[mlir-translate]} --mlir-to-llvmir -o {output} {input}"


rule compile_c:
"""
Generic rule to compile c files with default compilation options.
"""
input:
"{file}.c",
output:
temp("{file}.o"),
shell:
"{config[cc]} {config[cflags]} -c {input} -o {output}"


rule postprocess_llvm_module:
"""
Add extra metadata to LLVM module required for snitch-based systems
"""
input:
"{file}.ll",
output:
temp("{file}.ll12"),
shell:
"../../runtime/tollvm12.py < {input} > {output} "


rule compile_llvm_module:
"""
Use clang to compile LLVM module to object file.
Uses target-specific options, but not C-specific options.
"""
input:
"{file}.ll12",
output:
temp("{file}.o"),
shell:
"{config[cc]} {config[clangflags]} -x ir -c {input} -o {output}"


rule clean:
"""
Remove generated files.
"""
shell:
"rm -rf *.ll12 *.x *.o *.logs/ logs/ data* *.dasm"
142 changes: 142 additions & 0 deletions util/snake/flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os


def get_mlir_preproc_flags():
return [
[
"--pass-pipeline='builtin.module(func.func("
+ ", ".join(
[
"tosa-to-linalg-named",
"tosa-to-tensor",
"tosa-to-scf",
"tosa-to-linalg",
]
)
+ "))'",
"--mlir-print-op-generic",
"--mlir-print-local-scope",
],
["--tosa-to-arith='include-apply-rescale'", "--empty-tensor-to-alloc-tensor"],
[
"--test-linalg-transform-patterns='test-generalize-pad-tensor'",
"--linalg-generalize-named-ops",
"--empty-tensor-to-alloc-tensor",
"--one-shot-bufferize='"
+ " ".join(
[
"bufferize-function-boundaries",
"allow-return-allocs-from-loops",
"function-boundary-type-conversion=identity-layout-map",
]
)
+ "'",
"--mlir-print-op-generic",
"--mlir-print-local-scope",
],
]


def get_mlir_postproc_flags(index_bitwidth=32):
return [
"--convert-linalg-to-loops",
"--convert-scf-to-cf",
"--lower-affine",
"--canonicalize",
"--cse",
"--convert-math-to-llvm",
"--llvm-request-c-wrappers",
"--expand-strided-metadata",
"--lower-affine",
f"--convert-index-to-llvm=index-bitwidth={index_bitwidth}",
f"--convert-cf-to-llvm=index-bitwidth={index_bitwidth}",
f"--convert-arith-to-llvm=index-bitwidth={index_bitwidth}",
f"--convert-func-to-llvm='index-bitwidth={index_bitwidth}'",
f"--finalize-memref-to-llvm='use-generic-functions index-bitwidth={index_bitwidth}'",
"--canonicalize",
"--reconcile-unrealized-casts",
]


def get_target_flags():
"""
Function that returns llvm target flags, related to RISC-V backend settings
"""
return [
"--target=riscv32-unknown-elf",
"-mcpu=generic-rv32",
"-march=rv32imafdzfh",
"-mabi=ilp32d",
"-mcmodel=medany",
]


def get_clang_flags():
"""
Function that returns clang-specific flags, related to RISC-V backend settings
"""
return [
"-Wno-unused-command-line-argument",
*get_target_flags(),
"-ftls-model=local-exec",
"-ffast-math",
"-fno-builtin-printf",
"-fno-common",
"-O3",
"-std=gnu11",
"-Wall",
"-Wextra",
]


def get_cc_flags(snitch_sw_path):
"""
Function that returns default c-compiler flags
"""
return [
f"-I{snitch_sw_path}/target/snitch_cluster/sw/runtime/rtl-generic/src",
f"-I{snitch_sw_path}/target/snitch_cluster/sw/runtime/common",
f"-I{snitch_sw_path}/sw/snRuntime/api",
f"-I{snitch_sw_path}/sw/snRuntime/src",
f"-I{snitch_sw_path}/sw/snRuntime/src/omp/",
f"-I{snitch_sw_path}/sw/snRuntime/api/omp/",
f"-I{snitch_sw_path}/sw/math/arch/riscv64/bits/",
f"-I{snitch_sw_path}/sw/math/arch/generic",
f"-I{snitch_sw_path}/sw/math/src/include",
f"-I{snitch_sw_path}/sw/math/src/internal",
f"-I{snitch_sw_path}/sw/math/include/bits",
f"-I{snitch_sw_path}/sw/math/include",
"-I../../runtime/include",
"-D__DEFINED_uint64_t",
*get_clang_flags(),
]


def get_ld_flags(snitch_sw_path, snitch_llvm_path=None):
"""
Function that returns default linker flags
"""
# Default path points to conda/pixi environment
if snitch_llvm_path is None:
snitch_llvm_path = os.environ["CONDA_PREFIX"] + "/bin"
return [
f"-fuse-ld={snitch_llvm_path}/ld.lld",
*get_target_flags(),
f"-T{snitch_sw_path}/sw/snRuntime/base.ld",
f"-L{snitch_sw_path}/target/snitch_cluster/sw/runtime/rtl-generic",
f"-L{snitch_sw_path}/target/snitch_cluster/sw/runtime/rtl-generic/build",
"-nostdlib",
"-lsnRuntime",
]


def get_default_flags(snitch_sw_path, snitch_llvm_path=None, index_bitwidth=32):
if snitch_llvm_path is None:
snitch_llvm_path = os.environ["CONDA_PREFIX"] + "/bin"
return {
"cflags": get_cc_flags(snitch_sw_path),
"clangflags": get_clang_flags(),
"ldflags": get_ld_flags(snitch_sw_path, snitch_llvm_path),
"mlirpreprocflags": get_mlir_preproc_flags(),
"mlirpostprocflags": get_mlir_postproc_flags(index_bitwidth),
}
8 changes: 8 additions & 0 deletions util/snake/paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def get_default_paths():
return {
"cc": "clang",
"ld": "clang",
"mlir-opt": "mlir-opt",
"mlir-translate": "mlir-translate",
"snax-opt": "snax-opt",
}

0 comments on commit db2ea35

Please sign in to comment.