Skip to content

Commit

Permalink
Add XDSL Compiler Flow (#9)
Browse files Browse the repository at this point in the history
This pull request adds the XDSL project as a submodule to write custom MLIR passes for lowering linalg->snax.
It also includes a first simple pass that detects elementwise integer multiplications, and lowers these to an external function call. This external function call can then be an implementation using a snax accelerator.
  • Loading branch information
jorendumoulin authored Nov 7, 2023
1 parent 5cfd78b commit 8929b68
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 1 deletion.
26 changes: 26 additions & 0 deletions .github/workflows/lit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: lit-tests

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
python-tests:

runs-on: ubuntu-latest
container:
image: ghcr.io/kuleuven-micas/snax-mlir:pr-9@sha256:f1f626db8f7acf021272fb441165fd8a657622881650202e76dc464012d16491

steps:
- uses: actions/checkout@v3
- name: Test with lit
shell: bash
run: |
export PATH=/opt/python3.11/bin:$PATH
/opt/python3.11/bin/lit tests/filecheck -v
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ __pycache__
launch.json
settings.json
.pyc
.lit*
72 changes: 72 additions & 0 deletions compiler/dialects/linalg_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from collections.abc import Sequence


from xdsl.dialects.builtin import (
AnyShapedType,
AnyTensorType,
ShapedType,
StringAttr,
)
from xdsl.ir import Dialect, Region, SSAValue
from xdsl.irdl import (
AttrSizedOperandSegments,
IRDLOperation,
VarOperand,
VarOpResult,
irdl_op_definition,
opt_attr_def,
region_def,
var_operand_def,
var_result_def,
)


@irdl_op_definition
class Mul(IRDLOperation):
name = "linalg.mul"

inputs: VarOperand = var_operand_def()
outputs: VarOperand = var_operand_def(AnyShapedType())

res: VarOpResult = var_result_def(AnyTensorType)

body: Region = region_def("single_block")

# Trait attributes
library_call: StringAttr | None = opt_attr_def(StringAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def __init__(
self,
inputs: Sequence[SSAValue],
outputs: Sequence[SSAValue],
body: Region,
library_call: StringAttr | None = None,
) -> None:
super().__init__(
operands=[inputs, outputs],
result_types=[[]],
attributes={
"library_call": library_call,
},
regions=[body],
)

def get_static_shapes(self) -> list[int]:
sizes: list[int] = []
for input in self.inputs:
if isinstance(input.type, ShapedType):
for dim in input.type.get_shape():
sizes.append(dim)
for output in self.outputs:
if isinstance(output.type, ShapedType):
for dim in output.type.get_shape():
sizes.append(dim)
return sizes


# Extended Linalg Dialect
LinalgExtension = Dialect("linalg_extension", [Mul])
7 changes: 7 additions & 0 deletions compiler/snax-opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env python3

import sys
from tools.snax_opt_main import main

if __name__ == "__main__":
sys.exit(main())
46 changes: 46 additions & 0 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
from xdsl.xdsl_opt_main import xDSLOptMain
from xdsl.ir import MLContext
from dialects.linalg_extension import LinalgExtension
from transforms.dispatch_elementwise_mult import DispatchElementWiseMult
from collections.abc import Sequence


class SNAXOptMain(xDSLOptMain):
def __init__(
self,
description: str = "SNAX modular optimizer driver",
args: Sequence[str] | None = None,
):
self.available_frontends = {}
self.available_passes = {}
self.available_targets = {}

self.ctx = MLContext()
super().register_all_dialects()
super().register_all_frontends()
super().register_all_passes()
super().register_all_targets()

## Add custom dialects & passes
self.ctx.load_dialect(LinalgExtension)
super().register_pass(DispatchElementWiseMult)

# arg handling
arg_parser = argparse.ArgumentParser(description=description)
super().register_all_arguments(arg_parser)
self.args = arg_parser.parse_args(args=args)

self.ctx.allow_unregistered = self.args.allow_unregistered_dialect

super().setup_pipeline()

pass


def main():
SNAXOptMain().run()


if "__main__" == __name__:
main()
101 changes: 101 additions & 0 deletions compiler/transforms/dispatch_elementwise_mult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from xdsl.dialects import builtin, func
from dialects import linalg_extension
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.dialects.memref import MemRefType
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.traits import SymbolTable


class AddLibraryCall(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter):
## conditions for library call:
# (1) 2 operands
# (2) both operands of type memref
# (3) 1D-shape
# (4) type integer

if len(op.inputs) != 2:
return

for inp in op.inputs:
if not isinstance(inp.type, MemRefType):
return

if len(inp.type.get_shape()) > 1:
return

if not isinstance(inp.type.get_element_type(), builtin.IntegerType):
return

op.library_call = builtin.StringAttr("snax_hwpe_mult")

return


class LowerLinalgToFunc(RewritePattern):
"""
Lowers linalg.mul functions marked with a library call to function calls
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg_extension.Mul, rewriter: PatternRewriter):
if op.library_call is None:
return

rewriter.replace_matched_op(func.Call(op.library_call.data, op.operands, []))

return


class AddExternalFunc(RewritePattern):
"""
Looks for hwpe function calls and adds an external
func call to it for LLVM to link in
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter):
for op in module.walk():
if not isinstance(op, func.Call):
continue
if "snax_hwpe" not in op.callee.string_value():
continue

func_op = func.FuncOp.external(
op.callee.string_value(),
[arg.type for arg in op.arguments],
[res.type for res in op.results],
)

SymbolTable.insert_or_update(module, func_op)


class DispatchElementWiseMult(ModulePass):
"""
This pass detects integer elementwise multiplications, and replaces them with
an external function call hwpe_mult.
"""

name = "dispatch-elementwise-mult"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
AddLibraryCall(),
LowerLinalgToFunc(),
]
),
apply_recursively=False,
).rewrite_module(op)
PatternRewriteWalker(AddExternalFunc(), apply_recursively=False).rewrite_module(
op
)
5 changes: 5 additions & 0 deletions container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,8 @@ RUN export PATH=/opt/python3.11/bin:$PATH \
# add python3.11 to path in bashrc
RUN echo "export PATH=/opt/python3.11/bin:$PATH" >> ~/.bashrc

# install latest xdsl and test dependencies
RUN export PATH=/opt/python3.11/bin:$PATH \
&& pip3 install git+https://github.com/xdslproject/xdsl.git@182514a06f74f191623199e050cb0bb4d48dac79\
&& pip3 install filecheck \
&& pip3 install lit
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pre-commit

filecheck
lit

18 changes: 18 additions & 0 deletions tests/filecheck/dialects/linalg_extension.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: XDSL_ROUNDTRIP

%0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>)

"linalg.mul"(%0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> ({
^bb0(%arg2: i32, %arg3: i32, %arg4: i32):
%s1 = "arith.muli"(%arg2, %arg3) : (i32, i32) -> i32
"linalg.yield"(%s1) : (i32) -> ()
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()

//CHECK: builtin.module {
//CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>)
//CHECK-NEXT: "linalg.mul"(%0, %1, %2) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
//CHECK-NEXT: ^0(%arg2 : i32, %arg3 : i32, %arg4 : i32):
//CHECK-NEXT: %s1 = arith.muli %arg2, %arg3 : i32
//CHECK-NEXT: linalg.yield %s1 : i32
//CHECK-NEXT: }) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: }
12 changes: 12 additions & 0 deletions tests/filecheck/lit.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import lit.formats
import os

config.test_source_root = os.path.dirname(__file__)
snax_src = os.path.dirname(os.path.dirname(config.test_source_root))

config.name = "SNAX"
config.test_format = lit.formats.ShTest(preamble_commands=[f"cd {snax_src}"])
config.suffixes = ['.test', '.mlir', '.py']

config.substitutions.append(('XDSL_ROUNDTRIP', "./compiler/snax-opt %s --print-op-generic --split-input-file | ./compiler/snax-opt --split-input-file | filecheck %s"))
config.substitutions.append(("XDSL_GENERIC_ROUNDTRIP", "./compiler/snax-opt %s --print-op-generic --split-input-file | filecheck %s --check-prefix=CHECK-GENERIC"))
23 changes: 23 additions & 0 deletions tests/filecheck/transforms/dispatch_elementwise_mult.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: ./compiler/snax-opt %s -p dispatch-elementwise-mult --allow-unregistered-dialect --print-op-generic | filecheck %s

"builtin.module"() ({
"func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>):
"linalg.mul"(%A, %B, %D) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
^1(%arg2 : i32, %arg3 : i32, %arg4 : i32):
%0 = "arith.muli"(%arg2, %arg3) : (i32, i32) -> i32
"linalg.yield"(%0) : (i32) -> ()
}) : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
"func.return"() : () -> ()
}) : () -> ()
}) : () -> ()

//CHECK: "builtin.module"() ({
//CHECK-NEXT: "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "public"}> ({
//CHECK-NEXT: ^0(%A : memref<64xi32>, %B : memref<64xi32>, %D : memref<64xi32>):
//CHECK-NEXT: "func.call"(%A, %B, %D) <{"callee" = @snax_hwpe_mult}> : (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> ()
//CHECK-NEXT: "func.return"() : () -> ()
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: "func.func"() <{"sym_name" = "snax_hwpe_mult", "function_type" = (memref<64xi32>, memref<64xi32>, memref<64xi32>) -> (), "sym_visibility" = "private"}> ({
//CHECK-NEXT: }) : () -> ()
//CHECK-NEXT: }) : () -> ()

0 comments on commit 8929b68

Please sign in to comment.