-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commitdoes a couple of things: It splits the dispatch-elementwise-mult pass into two passes: dispatch-elementwise-mult: this pass detects the elementwise multiplications compatible with the hwpe accelerator, and adds a library call on succesfull hits linalg-to-library-call: this pass simply matches on every linalg which has a library call, and inserts the external function and a function call It reverts the linalg extension to be able to use linalg named ops, so we only use linalg generic from now. If we were to use linalg named ops, we would need to implement the conversion of named ops -> generic ops to be able to retrieve the indexing maps of the operation Using linalg.generic, the problem of detecting certain operations becomes somewhat more difficult, but I think this is manegable because: ZigZag does also not contain named ops the constraining mechanism to be used there will be based on the type of indexing maps and operation. The indexing maps are directly available from linalg generic ops. The operation must be detected from the block of the linalg.generic operation, but for accelerators this is limited to (mult, add, multiply-accumulate) This check is now not very clear and hard-coded, but I plan to make these more structured in the future
- Loading branch information
1 parent
8929b68
commit 93b60cd
Showing
11 changed files
with
146 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ launch.json | |
settings.json | ||
.pyc | ||
.lit* | ||
*.egg* |
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from xdsl.dialects import builtin, func, linalg | ||
from xdsl.ir import MLContext | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
PatternRewriteWalker, | ||
PatternRewriter, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
from xdsl.traits import SymbolTable | ||
|
||
|
||
class AddExternalFunc(RewritePattern): | ||
""" | ||
Looks for hwpe function calls and adds an external | ||
func call to it for LLVM to link in | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter): | ||
for op in module.walk(): | ||
# Op must be linalg generic | ||
if not isinstance(op, linalg.Generic): | ||
continue | ||
|
||
if op.library_call is None: | ||
continue | ||
|
||
func_call = func.Call(op.library_call.data, op.operands, []) | ||
|
||
# Replace op with function call | ||
rewriter.replace_op(op, func_call) | ||
|
||
# Insert external function definition | ||
func_op = func.FuncOp.external( | ||
func_call.callee.string_value(), | ||
[arg.type for arg in func_call.arguments], | ||
[res.type for res in func_call.results], | ||
) | ||
|
||
SymbolTable.insert_or_update(module, func_op) | ||
|
||
|
||
class LinalgToLibraryCall(ModulePass): | ||
""" | ||
This pass detects linalg operations with an external library call, and | ||
replaces them with a function call and definition. | ||
""" | ||
|
||
name = "linalg-to-library-call" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
PatternRewriteWalker(AddExternalFunc(), apply_recursively=False).rewrite_module( | ||
op | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name="compiler", | ||
version="0.1.0", | ||
packages=find_packages(), | ||
) |
This file was deleted.
Oops, something went wrong.
28 changes: 12 additions & 16 deletions
28
tests/filecheck/transforms/dispatch_elementwise_mult.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,19 @@ | ||
// RUN: ./compiler/snax-opt %s -p dispatch-elementwise-mult --allow-unregistered-dialect --print-op-generic | filecheck %s | ||
|
||
"builtin.module"() ({ | ||
"func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ | ||
^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>): | ||
"linalg.mul"(%A, %B, %D) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
^1(%arg2 : i32, %arg3 : i32, %arg4 : i32): | ||
%0 = "arith.muli"(%arg2, %arg3) : (i32, i32) -> i32 | ||
"linalg.yield"(%0) : (i32) -> () | ||
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
"func.return"() : () -> () | ||
}) : () -> () | ||
%0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>) | ||
"linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
^0(%arg0 : i32, %arg1 : i32, %arg2 : i32): | ||
%3 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32 | ||
"linalg.yield"(%3) : (i32) -> () | ||
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
}) : () -> () | ||
|
||
//CHECK: "builtin.module"() ({ | ||
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({ | ||
//CHECK-NEXT: ^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>): | ||
//CHECK-NEXT: "func.call"(%A, %B, %D) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
//CHECK-NEXT: "func.return"() : () -> () | ||
//CHECK-NEXT: }) : () -> () | ||
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "private"}> ({ | ||
//CHECK-NEXT: }) : () -> () | ||
//CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>) | ||
//CHECK-NEXT: "linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
//CHECK-NEXT: ^0(%arg0 : i32, %arg1 : i32, %arg2 : i32): | ||
//CHECK-NEXT: %3 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32 | ||
//CHECK-NEXT: "linalg.yield"(%3) : (i32) -> () | ||
//CHECK-NEXT: }) {"library_call" = "snax_hwpe_mult"} : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
//CHECK-NEXT: }) : () -> () |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// RUN: ./compiler/snax-opt %s -p linalg-to-library-call --allow-unregistered-dialect --print-op-generic | filecheck %s | ||
|
||
"builtin.module"() ({ | ||
%0, %1, %2, %3 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>, memref<64xi32>) | ||
"linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
^0(%arg0 : i32, %arg1 : i32, %arg2 : i32): | ||
%4 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32 | ||
"linalg.yield"(%4) : (i32) -> () | ||
}) {"library_call" = "snax_hwpe_mult"} : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
"linalg.generic"(%1, %2, %3) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({ | ||
^0(%arg0 : i32, %arg1 : i32, %arg2 : i32): | ||
%5 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32 | ||
"linalg.yield"(%5) : (i32) -> () | ||
}) {"library_call" = "snax_hwpe_mult"} : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
}) : () -> () | ||
|
||
|
||
//CHECK: "builtin.module"() ({ | ||
//CHECK-NEXT: %0, %1, %2, %3 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>, memref<64xi32>) | ||
//CHECK-NEXT: "func.call"(%0, %1, %2) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
//CHECK-NEXT: "func.call"(%1, %2, %3) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> () | ||
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "private"}> ({ | ||
//CHECK-NEXT: }) : () -> () | ||
//CHECK-NEXT: }) : () -> () |