Skip to content

Commit

Permalink
pyright: fix memref_to_snax and set_memory_space transforms (#330)
Browse files Browse the repository at this point in the history
* pyright: fix memref_to_snax transform

* fix assertion

* fix set_memory_space

* make use of isa
  • Loading branch information
jorendumoulin authored Jan 6, 2025
1 parent 7ab7674 commit 0b5cd0b
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 44 deletions.
2 changes: 1 addition & 1 deletion compiler/dialects/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self,
rank: int,
size: SSAValue | Operation,
shapes: list[SSAValue | Operation],
shapes: Sequence[SSAValue | Operation],
memory_space: Attribute = NoneAttr(),
alignment: AnyIntegerAttr | None = None,
integer_type: IntegerType = i32,
Expand Down
23 changes: 15 additions & 8 deletions compiler/transforms/memref_to_snax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import builtin, memref
from xdsl.dialects.arith import AddiOp, ConstantOp, MuliOp, SubiOp
Expand All @@ -8,6 +10,7 @@
NoneAttr,
UnrealizedConversionCastOp,
)
from xdsl.ir import Attribute, Operation, OpResult
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand All @@ -28,7 +31,7 @@ def match_and_rewrite(self, alloc_op: memref.AllocOp, rewriter: PatternRewriter)
NoneType layouts and TSL Layouts, and a memory space of L1"""

# get the memref type
memref_type: MemRefType = alloc_op.memref.type
memref_type = cast(MemRefType[Attribute], alloc_op.memref.type)

# get the element type
element_type = memref_type.get_element_type()
Expand All @@ -49,14 +52,17 @@ def match_and_rewrite(self, alloc_op: memref.AllocOp, rewriter: PatternRewriter)
# create an operation to get the # bytes that needs
# to be allocated
total_size_op = None
ops_to_add = []
ops_to_add: list[Operation] = []

# generate the list of shape ops
# either these are constant and must be created,
# or they are already present in the memref.alloc
# operation arguments
shape_ops = []
alloc_args = [x.op for x in alloc_op.dynamic_sizes]
shape_ops: list[Operation] = []
alloc_args: list[Operation] = []
for size in alloc_op.dynamic_sizes:
assert isinstance(size, OpResult)
alloc_args.append(size.op)

for shape in memref_type.shape.data:
if shape.data == -1:
Expand All @@ -68,7 +74,7 @@ def match_and_rewrite(self, alloc_op: memref.AllocOp, rewriter: PatternRewriter)
ops_to_add.append(shape_op)
shape_ops.append(shape_op)

shape_ops_arg = [x for x in shape_ops]
shape_ops_arg: list[Operation] = [x for x in shape_ops]

if isinstance(layout, NoneAttr):
# get size based on shape
Expand Down Expand Up @@ -130,6 +136,7 @@ def match_and_rewrite(self, alloc_op: memref.AllocOp, rewriter: PatternRewriter)
ops_to_add.append(total_size_op)

# add offset
assert layout.data.offset is not None
offset_op = ConstantOp.from_int_and_width(layout.data.offset, IndexType())
offset_bytes_op = MuliOp(offset_op, element_size_op)
total_size_op = AddiOp(total_size_op, offset_bytes_op)
Expand All @@ -146,7 +153,7 @@ def match_and_rewrite(self, alloc_op: memref.AllocOp, rewriter: PatternRewriter)
memory_space,
alloc_op.alignment,
)
conversion_cast_op = UnrealizedConversionCastOp.get([snax_alloc], memref_type)
conversion_cast_op = UnrealizedConversionCastOp.get([snax_alloc], [memref_type])
rewriter.replace_matched_op(
[*ops_to_add, snax_alloc, conversion_cast_op],
new_results=conversion_cast_op.outputs,
Expand All @@ -156,5 +163,5 @@ def match_and_rewrite(self, alloc_op: memref.AllocOp, rewriter: PatternRewriter)
class MemrefToSNAX(ModulePass):
name = "memref-to-snax"

def apply(self, ctx: MLContext, module: builtin.ModuleOp) -> None:
PatternRewriteWalker(AllocOpRewrite()).rewrite_module(module)
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(AllocOpRewrite()).rewrite_module(op)
17 changes: 10 additions & 7 deletions compiler/transforms/realize_memref_casts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, linalg, memref
from xdsl.dialects.memref import MemorySpaceCastOp
from xdsl.ir import Operation, OpResult
from xdsl.ir import Attribute, Operation, OpResult
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.hints import isa

from compiler.dialects import stream
from compiler.dialects.snax import LayoutCast
Expand Down Expand Up @@ -41,7 +43,7 @@ def match_and_rewrite(
# so we can fuse all the casting operations together.

# keep track of ops to add
ops_to_add = []
ops_to_add: list[Operation] = []

# if the source of the memref cast is another layout_cast op,
# combine them all together
Expand All @@ -53,15 +55,15 @@ def match_and_rewrite(

# now perform casting by inserting memref copies and allocs
source_type = source_op.source.type
assert isinstance(source_type, builtin.MemRefType)
assert isa(source_type, builtin.MemRefType[Attribute])
dest_type = op.dest.type
assert isinstance(dest_type, builtin.MemRefType)
assert isa(dest_type, builtin.MemRefType[Attribute])

# create allocation

# create memref.dim operations for dynamic dimensions
shapes = [x.data for x in dest_type.shape.data]
dyn_operands = []
dyn_operands: list[Operation] = []
for i in range(len(shapes)):
# Dynamic shapes are represented as -1
if shapes[i] == -1:
Expand Down Expand Up @@ -90,6 +92,7 @@ def match_and_rewrite(

# insert "copy to" for first use as input
# walk parent op in order to find first use as input
assert op.parent
for use_op in op.parent.walk():
if use_op not in uses:
continue
Expand All @@ -105,7 +108,7 @@ def match_and_rewrite(
if is_input:
# insert copy op
copy_op = memref.CopyOp(source_op.source, op.dest)
rewriter.insert_op_before(copy_op, use_op)
rewriter.insert_op(copy_op, InsertPoint.before(use_op))
break

# insert "copy from" for last use as output
Expand All @@ -127,7 +130,7 @@ def match_and_rewrite(
if is_output:
# insert copy op
copy_op = memref.CopyOp(op.dest, source_op.source)
rewriter.insert_op_after(copy_op, use_op)
rewriter.insert_op(copy_op, InsertPoint.after(use_op))
break

# insert all ops
Expand Down
60 changes: 36 additions & 24 deletions compiler/transforms/set_memory_space.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import builtin, func, linalg, memref
from xdsl.ir import SSAValue
from xdsl.ir import Attribute, Operation, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.hints import isa

from compiler.dialects import stream
from compiler.util.snax_memory import L1, L3
Expand All @@ -34,8 +37,8 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
return

# Mapping function to assign default memory space "L3"
def change_to_memory_space(t):
if isinstance(t, builtin.MemRefType):
def change_to_memory_space(t: Attribute) -> Attribute:
if isa(t, builtin.MemRefType[Attribute]):
if isinstance(t.memory_space, builtin.NoneAttr):
return builtin.MemRefType(
t.element_type,
Expand All @@ -48,8 +51,8 @@ def change_to_memory_space(t):
# Define new function type with updated inputs and outputs
# mapped to a default memory space
new_function_type = builtin.FunctionType.from_lists(
map(change_to_memory_space, op.function_type.inputs),
map(change_to_memory_space, op.function_type.outputs),
list(map(change_to_memory_space, op.function_type.inputs)),
list(map(change_to_memory_space, op.function_type.outputs)),
)

# Change region of function to use new argument types
Expand All @@ -72,17 +75,18 @@ class InitMemRefGlobalMemorySpace(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.GetGlobalOp, rewriter: PatternRewriter):
# global variables should go in memory space L3
memspace = op.memref.type.memory_space
memref_type = cast(builtin.MemRefType[Attribute], op.memref.type)
memspace = memref_type.memory_space

# If memory space is already L3, don't do anything
if memspace == L3:
return

# otherwise, create new memref type with correct memory space
new_memref_type = builtin.MemRefType(
op.memref.type.element_type,
op.memref.type.get_shape(),
op.memref.type.layout,
memref_type.element_type,
memref_type.get_shape(),
memref_type.layout,
L3,
)

Expand All @@ -97,19 +101,20 @@ class InitMemRefAllocMemorySpace(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter):
# allocs should go in memory space L1
memspace = op.memref.type.memory_space
memref_type = cast(builtin.MemRefType[Attribute], op.memref.type)
memspace = memref_type.memory_space

if memspace == L1:
# good, nothing left to do
return

# create new alloc op
new_op = memref.AllocOp.get(
op.memref.type.element_type,
memref_type.element_type,
op.alignment,
op.memref.type.get_shape(),
memref_type.get_shape(),
dynamic_sizes=op.dynamic_sizes,
layout=op.memref.type.layout,
layout=memref_type.layout,
memory_space=L1,
)

Expand All @@ -129,27 +134,31 @@ def match_and_rewrite(
operands_to_memory_cast = tuple(
x
for x in op.operands
if isinstance(x.type, builtin.MemRefType) and x.type.memory_space != L1
if isinstance(memref_type := x.type, builtin.MemRefType)
and memref_type.memory_space != L1
)

if not operands_to_memory_cast:
return

def get_cast_op(operand) -> memref.MemorySpaceCastOp:
def get_cast_op(operand: SSAValue) -> memref.MemorySpaceCastOp:
# cast required: find previous cast or create new one
cast_op = None
for use in operand.uses:
if (
isinstance(use.operation, memref.MemorySpaceCastOp)
and isinstance(use.operation.dest.type, builtin.MemRefType)
and use.operation.dest.type.memory_space == L1
and isinstance(
use_type := use.operation.dest.type, builtin.MemRefType
)
and use_type.memory_space == L1
):
cast_op = use.operation
break
# If cast op not found, create and insert new one
assert isa(optype := operand.type, builtin.MemRefType[Attribute])
if cast_op is None:
cast_op = memref.MemorySpaceCastOp.from_type_and_target_space(
operand, operand.type, L1
operand, optype, L1
)
rewriter.insert_op_before_matched_op(cast_op)

Expand All @@ -174,30 +183,33 @@ class HandleFuncReturns(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
# get function op
func_op: func.FuncOp = op.parent_op()
assert isinstance(func_op := op.parent_op(), func.FuncOp)

outputs = [*func_op.function_type.outputs]

new_arguments = []
new_arguments: list[SSAValue | Operation] = []
changes_made = False

# all outputs must be in the correct memory space
for i in range(len(outputs)):
func_op_output = outputs[i]
func_return_output = op.arguments[i]

if not isinstance(func_op_output, builtin.MemRefType):
if not isa(func_op_output, builtin.MemRefType[Attribute]):
new_arguments.append(func_return_output)
continue
if not isinstance(func_return_output.type, builtin.MemRefType):
if not isa(
func_return_output_type := func_return_output.type,
builtin.MemRefType[Attribute],
):
new_arguments.append(func_return_output)
continue

if func_op_output.memory_space != func_return_output.type.memory_space:
if func_op_output.memory_space != func_return_output_type.memory_space:
# create cast op
cast_op = memref.MemorySpaceCastOp.from_type_and_target_space(
func_return_output,
func_return_output.type,
func_return_output_type,
func_op_output.memory_space,
)
rewriter.insert_op_before_matched_op(cast_op)
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,8 @@ typeCheckingMode = "strict"
"compiler/transforms/convert_linalg_to_accfg.py",
"compiler/transforms/frontend/preprocess_mlperf_tiny.py",
"compiler/transforms/linalg_to_library_call.py",
"compiler/transforms/memref_to_snax.py",
"compiler/transforms/realize_memref_casts.py",
"compiler/transforms/reuse_memref_allocs.py",
"compiler/transforms/set_memory_layout.py",
"compiler/transforms/set_memory_space.py",
"compiler/transforms/snax_copy_to_dma.py",
"compiler/transforms/snax_lower_mcycle.py",
"compiler/transforms/snax_to_func.py",
Expand Down

0 comments on commit 0b5cd0b

Please sign in to comment.