Skip to content

Commit

Permalink
xDSL: Split Lowering Passes (#11)
Browse files Browse the repository at this point in the history
This commitdoes a couple of things:

    It splits the dispatch-elementwise-mult pass into two passes:
        dispatch-elementwise-mult: this pass detects the elementwise multiplications compatible with the hwpe accelerator, and adds a library call on succesfull hits
        linalg-to-library-call: this pass simply matches on every linalg which has a library call, and inserts the external function and a function call
    It reverts the linalg extension to be able to use linalg named ops, so we only use linalg generic from now.
        If we were to use linalg named ops, we would need to implement the conversion of named ops -> generic ops to be able to retrieve the indexing maps of the operation
        Using linalg.generic, the problem of detecting certain operations becomes somewhat more difficult, but I think this is manegable because:
            ZigZag does also not contain named ops the constraining mechanism to be used there will be based on the type of indexing maps and operation. The indexing maps are directly available from linalg generic ops. The operation must be detected from the block of the linalg.generic operation, but for accelerators this is limited to (mult, add, multiply-accumulate)
            This check is now not very clear and hard-coded, but I plan to make these more structured in the future
  • Loading branch information
jorendumoulin authored Nov 9, 2023
1 parent 8929b68 commit 93b60cd
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 161 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/lit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Install modules
shell: bash
run: |
/opt/python3.11/bin/python3 -m pip install -e .
- name: Test with lit
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ launch.json
settings.json
.pyc
.lit*
*.egg*
Empty file added compiler/__init__.py
Empty file.
72 changes: 0 additions & 72 deletions compiler/dialects/linalg_extension.py

This file was deleted.

6 changes: 3 additions & 3 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
from xdsl.xdsl_opt_main import xDSLOptMain
from xdsl.ir import MLContext
from dialects.linalg_extension import LinalgExtension
from transforms.dispatch_elementwise_mult import DispatchElementWiseMult
from compiler.transforms.dispatch_elementwise_mult import DispatchElementWiseMult
from compiler.transforms.linalg_to_library_call import LinalgToLibraryCall
from collections.abc import Sequence


Expand All @@ -23,8 +23,8 @@ def __init__(
super().register_all_targets()

## Add custom dialects & passes
self.ctx.load_dialect(LinalgExtension)
super().register_pass(DispatchElementWiseMult)
super().register_pass(LinalgToLibraryCall)

# arg handling
arg_parser = argparse.ArgumentParser(description=description)
Expand Down
92 changes: 40 additions & 52 deletions compiler/transforms/dispatch_elementwise_mult.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
from xdsl.dialects import builtin, func
from dialects import linalg_extension
from xdsl.dialects import builtin, linalg, arith
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.dialects.memref import MemRefType
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.traits import SymbolTable


class AddLibraryCall(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter):
def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter):
"""Add a library call to a linalg generic that implements an
elementwise multiplication. This is done on linalg.generics as
named linalg ops do not support library calls. This makes the task
of detecting an elementwise multiplication somewhat harder,
but this can be more structured in future work."""

## conditions for library call:
# (0) must not already have library call
# (1) 2 operands
# (2) both operands of type memref
# (3) 1D-shape
# (4) type integer
# (5) iterator type also 1D and parallel
# (6) region must be non reducing mult of both inputs

if op.library_call is not None:
return

if len(op.inputs) != 2:
return
Expand All @@ -35,67 +44,46 @@ def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter)
if not isinstance(inp.type.get_element_type(), builtin.IntegerType):
return

op.library_call = builtin.StringAttr("snax_hwpe_mult")
if len(op.iterator_types) != 1:
return

return
if op.iterator_types.data[0].data is not linalg.IteratorType.PARALLEL:
return

## Check if operation is muli
## two operations: first operation is arith.muli, last operation is yield

class LowerLinalgToFunc(RewritePattern):
"""
Lowers linalg.mul functions marked with a library call to function calls
"""
mult_op = op.body.block.first_op
yield_op = op.body.block.last_op

@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter):
if op.library_call is None:
# last operation is linalg.yield
if not isinstance(yield_op, linalg.YieldOp):
return
# first operation is arith.muli
if not isinstance(mult_op, arith.Muli):
return
# yield is result of muli
if mult_op.result is not yield_op.arguments[0]:
return
# muli is based on first two args
if not (
op.body.block.args[0] in mult_op.operands
and op.body.block.args[1] in mult_op.operands
):
return

rewriter.replace_matched_op(func.Call(op.library_call.data, op.operands, []))
op.library_call = builtin.StringAttr("snax_hwpe_mult")

return


class AddExternalFunc(RewritePattern):
"""
Looks for hwpe function calls and adds an external
func call to it for LLVM to link in
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter):
for op in module.walk():
if not isinstance(op, func.Call):
continue
if "snax_hwpe" not in op.callee.string_value():
continue

func_op = func.FuncOp.external(
op.callee.string_value(),
[arg.type for arg in op.arguments],
[res.type for res in op.results],
)

SymbolTable.insert_or_update(module, func_op)


class DispatchElementWiseMult(ModulePass):
"""
This pass detects integer elementwise multiplications, and replaces them with
an external function call hwpe_mult.
This pass detects integer elementwise multiplications (linalg.mul),
and inserts a library call to snax-hwpe.
"""

name = "dispatch-elementwise-mult"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
AddLibraryCall(),
LowerLinalgToFunc(),
]
),
apply_recursively=False,
).rewrite_module(op)
PatternRewriteWalker(AddExternalFunc(), apply_recursively=False).rewrite_module(
op
)
PatternRewriteWalker(AddLibraryCall()).rewrite_module(op)
55 changes: 55 additions & 0 deletions compiler/transforms/linalg_to_library_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from xdsl.dialects import builtin, func, linalg
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.traits import SymbolTable


class AddExternalFunc(RewritePattern):
"""
Looks for hwpe function calls and adds an external
func call to it for LLVM to link in
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter):
for op in module.walk():
# Op must be linalg generic
if not isinstance(op, linalg.Generic):
continue

if op.library_call is None:
continue

func_call = func.Call(op.library_call.data, op.operands, [])

# Replace op with function call
rewriter.replace_op(op, func_call)

# Insert external function definition
func_op = func.FuncOp.external(
func_call.callee.string_value(),
[arg.type for arg in func_call.arguments],
[res.type for res in func_call.results],
)

SymbolTable.insert_or_update(module, func_op)


class LinalgToLibraryCall(ModulePass):
"""
This pass detects linalg operations with an external library call, and
replaces them with a function call and definition.
"""

name = "linalg-to-library-call"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(AddExternalFunc(), apply_recursively=False).rewrite_module(
op
)
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from setuptools import setup, find_packages

setup(
name="compiler",
version="0.1.0",
packages=find_packages(),
)
18 changes: 0 additions & 18 deletions tests/filecheck/dialects/linalg_extension.mlir

This file was deleted.

28 changes: 12 additions & 16 deletions tests/filecheck/transforms/dispatch_elementwise_mult.mlir
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
// RUN: ./compiler/snax-opt %s -p dispatch-elementwise-mult --allow-unregistered-dialect --print-op-generic | filecheck %s

"builtin.module"() ({
"func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>):
"linalg.mul"(%A, %B, %D) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
^1(%arg2 : i32, %arg3 : i32, %arg4 : i32):
%0 = "arith.muli"(%arg2, %arg3) : (i32, i32) -> i32
"linalg.yield"(%0) : (i32) -> ()
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
"func.return"() : () -> ()
}) : () -> ()
%0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>)
"linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%arg0 : i32, %arg1 : i32, %arg2 : i32):
%3 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32
"linalg.yield"(%3) : (i32) -> ()
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
}) : () -> ()

//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
//CHECK-NEXT: ^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>):
//CHECK-NEXT: "func.call"(%A, %B, %D) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>)
//CHECK-NEXT: "linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({
//CHECK-NEXT: ^0(%arg0 : i32, %arg1 : i32, %arg2 : i32):
//CHECK-NEXT: %3 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32
//CHECK-NEXT: "linalg.yield"(%3) : (i32) -> ()
//CHECK-NEXT: }) {"library_call" = "snax_hwpe_mult"} : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: }) : () -> ()
24 changes: 24 additions & 0 deletions tests/filecheck/transforms/linalg-to-library-call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: ./compiler/snax-opt %s -p linalg-to-library-call --allow-unregistered-dialect --print-op-generic | filecheck %s

"builtin.module"() ({
%0, %1, %2, %3 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>, memref<64xi32>)
"linalg.generic"(%0, %1, %2) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%arg0 : i32, %arg1 : i32, %arg2 : i32):
%4 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32
"linalg.yield"(%4) : (i32) -> ()
}) {"library_call" = "snax_hwpe_mult"} : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
"linalg.generic"(%1, %2, %3) <{"indexing_maps" = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], "iterator_types" = [#linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%arg0 : i32, %arg1 : i32, %arg2 : i32):
%5 = "arith.muli"(%arg0, %arg1) : (i32, i32) -> i32
"linalg.yield"(%5) : (i32) -> ()
}) {"library_call" = "snax_hwpe_mult"} : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
}) : () -> ()


//CHECK: "builtin.module"() ({
//CHECK-NEXT: %0, %1, %2, %3 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>, memref<64xi32>)
//CHECK-NEXT: "func.call"(%0, %1, %2) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: "func.call"(%1, %2, %3) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: }) : () -> ()

0 comments on commit 93b60cd

Please sign in to comment.