-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This pull request adds the XDSL project as a submodule to write custom MLIR passes for lowering linalg->snax. It also includes a first simple pass that detects elementwise integer multiplications, and lowers these to an external function call. This external function call can then be an implementation using a snax accelerator.
- Loading branch information
1 parent
5cfd78b
commit 8929b68
Showing
11 changed files
with
313 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: lit-tests | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
python-tests: | ||
|
||
runs-on: ubuntu-latest | ||
container: | ||
image: ghcr.io/kuleuven-micas/snax-mlir:pr-9@sha256:f1f626db8f7acf021272fb441165fd8a657622881650202e76dc464012d16491 | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Test with lit | ||
shell: bash | ||
run: | | ||
export PATH=/opt/python3.11/bin:$PATH | ||
/opt/python3.11/bin/lit tests/filecheck -v | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ __pycache__ | |
launch.json | ||
settings.json | ||
.pyc | ||
.lit* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
|
||
|
||
from xdsl.dialects.builtin import ( | ||
AnyShapedType, | ||
AnyTensorType, | ||
ShapedType, | ||
StringAttr, | ||
) | ||
from xdsl.ir import Dialect, Region, SSAValue | ||
from xdsl.irdl import ( | ||
AttrSizedOperandSegments, | ||
IRDLOperation, | ||
VarOperand, | ||
VarOpResult, | ||
irdl_op_definition, | ||
opt_attr_def, | ||
region_def, | ||
var_operand_def, | ||
var_result_def, | ||
) | ||
|
||
|
||
@irdl_op_definition | ||
class Mul(IRDLOperation): | ||
name = "linalg.mul" | ||
|
||
inputs: VarOperand = var_operand_def() | ||
outputs: VarOperand = var_operand_def(AnyShapedType()) | ||
|
||
res: VarOpResult = var_result_def(AnyTensorType) | ||
|
||
body: Region = region_def("single_block") | ||
|
||
# Trait attributes | ||
library_call: StringAttr | None = opt_attr_def(StringAttr) | ||
|
||
irdl_options = [AttrSizedOperandSegments(as_property=True)] | ||
|
||
def __init__( | ||
self, | ||
inputs: Sequence[SSAValue], | ||
outputs: Sequence[SSAValue], | ||
body: Region, | ||
library_call: StringAttr | None = None, | ||
) -> None: | ||
super().__init__( | ||
operands=[inputs, outputs], | ||
result_types=[[]], | ||
attributes={ | ||
"library_call": library_call, | ||
}, | ||
regions=[body], | ||
) | ||
|
||
def get_static_shapes(self) -> list[int]: | ||
sizes: list[int] = [] | ||
for input in self.inputs: | ||
if isinstance(input.type, ShapedType): | ||
for dim in input.type.get_shape(): | ||
sizes.append(dim) | ||
for output in self.outputs: | ||
if isinstance(output.type, ShapedType): | ||
for dim in output.type.get_shape(): | ||
sizes.append(dim) | ||
return sizes | ||
|
||
|
||
# Extended Linalg Dialect | ||
LinalgExtension = Dialect("linalg_extension", [Mul]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import sys | ||
from tools.snax_opt_main import main | ||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
from xdsl.xdsl_opt_main import xDSLOptMain | ||
from xdsl.ir import MLContext | ||
from dialects.linalg_extension import LinalgExtension | ||
from transforms.dispatch_elementwise_mult import DispatchElementWiseMult | ||
from collections.abc import Sequence | ||
|
||
|
||
class SNAXOptMain(xDSLOptMain): | ||
def __init__( | ||
self, | ||
description: str = "SNAX modular optimizer driver", | ||
args: Sequence[str] | None = None, | ||
): | ||
self.available_frontends = {} | ||
self.available_passes = {} | ||
self.available_targets = {} | ||
|
||
self.ctx = MLContext() | ||
super().register_all_dialects() | ||
super().register_all_frontends() | ||
super().register_all_passes() | ||
super().register_all_targets() | ||
|
||
## Add custom dialects & passes | ||
self.ctx.load_dialect(LinalgExtension) | ||
super().register_pass(DispatchElementWiseMult) | ||
|
||
# arg handling | ||
arg_parser = argparse.ArgumentParser(description=description) | ||
super().register_all_arguments(arg_parser) | ||
self.args = arg_parser.parse_args(args=args) | ||
|
||
self.ctx.allow_unregistered = self.args.allow_unregistered_dialect | ||
|
||
super().setup_pipeline() | ||
|
||
pass | ||
|
||
|
||
def main(): | ||
SNAXOptMain().run() | ||
|
||
|
||
if "__main__" == __name__: | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from xdsl.dialects import builtin, func | ||
from dialects import linalg_extension | ||
from xdsl.ir import MLContext | ||
from xdsl.passes import ModulePass | ||
from xdsl.dialects.memref import MemRefType | ||
from xdsl.pattern_rewriter import ( | ||
GreedyRewritePatternApplier, | ||
PatternRewriteWalker, | ||
PatternRewriter, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
from xdsl.traits import SymbolTable | ||
|
||
|
||
class AddLibraryCall(RewritePattern): | ||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter): | ||
## conditions for library call: | ||
# (1) 2 operands | ||
# (2) both operands of type memref | ||
# (3) 1D-shape | ||
# (4) type integer | ||
|
||
if len(op.inputs) != 2: | ||
return | ||
|
||
for inp in op.inputs: | ||
if not isinstance(inp.type, MemRefType): | ||
return | ||
|
||
if len(inp.type.get_shape()) > 1: | ||
return | ||
|
||
if not isinstance(inp.type.get_element_type(), builtin.IntegerType): | ||
return | ||
|
||
op.library_call = builtin.StringAttr("snax_hwpe_mult") | ||
|
||
return | ||
|
||
|
||
class LowerLinalgToFunc(RewritePattern): | ||
""" | ||
Lowers linalg.mul functions marked with a library call to function calls | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter): | ||
if op.library_call is None: | ||
return | ||
|
||
rewriter.replace_matched_op(func.Call(op.library_call.data, op.operands, [])) | ||
|
||
return | ||
|
||
|
||
class AddExternalFunc(RewritePattern): | ||
""" | ||
Looks for hwpe function calls and adds an external | ||
func call to it for LLVM to link in | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter): | ||
for op in module.walk(): | ||
if not isinstance(op, func.Call): | ||
continue | ||
if "snax_hwpe" not in op.callee.string_value(): | ||
continue | ||
|
||
func_op = func.FuncOp.external( | ||
op.callee.string_value(), | ||
[arg.type for arg in op.arguments], | ||
[res.type for res in op.results], | ||
) | ||
|
||
SymbolTable.insert_or_update(module, func_op) | ||
|
||
|
||
class DispatchElementWiseMult(ModulePass): | ||
""" | ||
This pass detects integer elementwise multiplications, and replaces them with | ||
an external function call hwpe_mult. | ||
""" | ||
|
||
name = "dispatch-elementwise-mult" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
PatternRewriteWalker( | ||
GreedyRewritePatternApplier( | ||
[ | ||
AddLibraryCall(), | ||
LowerLinalgToFunc(), | ||
] | ||
), | ||
apply_recursively=False, | ||
).rewrite_module(op) | ||
PatternRewriteWalker(AddExternalFunc(), apply_recursively=False).rewrite_module( | ||
op | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
pre-commit | ||
|
||
filecheck | ||
lit | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// RUN: XDSL_ROUNDTRIP | ||
|
||
%0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>) | ||
|
||
"linalg.mul"(%0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> ({ | ||
^bb0(%arg2: i32, %arg3: i32, %arg4: i32): | ||
%s1 = "arith.muli"(%arg2, %arg3) : (i32, i32) -> i32 | ||
"linalg.yield"(%s1) : (i32) -> () | ||
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
|
||
//CHECK: builtin.module { | ||
//CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>) | ||
//CHECK-NEXT: "linalg.mul"(%0, %1, %2) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
//CHECK-NEXT: ^0(%arg2 : i32, %arg3 : i32, %arg4 : i32): | ||
//CHECK-NEXT: %s1 = arith.muli %arg2, %arg3 : i32 | ||
//CHECK-NEXT: linalg.yield %s1 : i32 | ||
//CHECK-NEXT: }) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
//CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import lit.formats | ||
import os | ||
|
||
config.test_source_root = os.path.dirname(__file__) | ||
snax_src = os.path.dirname(os.path.dirname(config.test_source_root)) | ||
|
||
config.name = "SNAX" | ||
config.test_format = lit.formats.ShTest(preamble_commands=[f"cd {snax_src}"]) | ||
config.suffixes = ['.test', '.mlir', '.py'] | ||
|
||
config.substitutions.append(('XDSL_ROUNDTRIP', "./compiler/snax-opt %s --print-op-generic --split-input-file | ./compiler/snax-opt --split-input-file | filecheck %s")) | ||
config.substitutions.append(("XDSL_GENERIC_ROUNDTRIP", "./compiler/snax-opt %s --print-op-generic --split-input-file | filecheck %s --check-prefix=CHECK-GENERIC")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// RUN: ./compiler/snax-opt %s -p dispatch-elementwise-mult --allow-unregistered-dialect --print-op-generic | filecheck %s | ||
|
||
"builtin.module"() ({ | ||
"func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ | ||
^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>): | ||
"linalg.mul"(%A, %B, %D) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
^1(%arg2 : i32, %arg3 : i32, %arg4 : i32): | ||
%0 = "arith.muli"(%arg2, %arg3) : (i32, i32) -> i32 | ||
"linalg.yield"(%0) : (i32) -> () | ||
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
"func.return"() : () -> () | ||
}) : () -> () | ||
}) : () -> () | ||
|
||
//CHECK: "builtin.module"() ({ | ||
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ | ||
//CHECK-NEXT: ^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>): | ||
//CHECK-NEXT: "func.call"(%A, %B, %D) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
//CHECK-NEXT: "func.return"() : () -> () | ||
//CHECK-NEXT: }) : () -> () | ||
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "private"}> ({ | ||
//CHECK-NEXT: }) : () -> () | ||
//CHECK-NEXT: }) : () -> () |