Skip to content

Commit

Permalink
split passes
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 6, 2025
1 parent 2020d01 commit b1d9213
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 32 deletions.
2 changes: 1 addition & 1 deletion compiler/accelerators/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def get_streamer_launch_dict(self, base_addr) -> tuple[int, dict[str, int]]:

@staticmethod
@abstractmethod
def get_template(op: stream.StreamingRegionOp) -> Template:
def get_template(op: stream.StreamingRegionOpBase) -> Template:
"""
Get the template for this acelerator to schedule a given
stream.streaming_region operation.
Expand Down
2 changes: 1 addition & 1 deletion compiler/accelerators/snax_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def generate_acc_op(self) -> accfg.AcceleratorOp:
return op

@staticmethod
def get_template(op: stream.StreamingRegionOp):
def get_template(op: stream.StreamingRegionOpBase):
template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3
template_bounds = (None, 4)
return Template(TemplatePattern(template_bounds, tp) for tp in template)
2 changes: 1 addition & 1 deletion compiler/accelerators/snax_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def lower_acc_await(acc_op: accfg.AcceleratorOp) -> Sequence[Operation]:
]

@staticmethod
def get_template(op: stream.StreamingRegionOp) -> Template:
def get_template(op: stream.StreamingRegionOpBase) -> Template:
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
template = [
AffineMap(6, 0, (M * 8 + m, K * 8 + k)),
Expand Down
2 changes: 1 addition & 1 deletion compiler/accelerators/snax_gemmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _generate_setup_vals(
]

@staticmethod
def get_template(op: stream.StreamingRegionOp) -> Template:
def get_template(op: stream.StreamingRegionOpBase) -> Template:
assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
if isinstance(generic_op.body.block.first_op, kernel.QMacOp):
# matmul
Expand Down
82 changes: 54 additions & 28 deletions compiler/transforms/convert_stream_to_snax_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, memref
from xdsl.dialects.builtin import MemRefType
from xdsl.dialects.builtin import AffineMapAttr, ArrayAttr, MemRefType
from xdsl.ir import Operation
from xdsl.ir.affine import AffineMap
from xdsl.passes import ModulePass
Expand All @@ -17,12 +17,31 @@
from compiler.accelerators.registry import AcceleratorRegistry
from compiler.accelerators.snax import SNAXStreamer
from compiler.dialects import snax_stream, stream
from compiler.dialects.snax import StreamerConfigurationAttr
from compiler.ir.stream import Schedule, SchedulePattern, scheduler
from compiler.ir.stream.access_pattern import Template


def get_accelerator_info(op: stream.StreamingRegionOpBase) -> Template:
assert op.accelerator is not None

# Go and fetch the accelerator op
accelerator_str = op.accelerator.data
acc_op = find_accelerator_op(op, accelerator_str)

if not acc_op:
raise RuntimeError("AcceleratorOp not found!")

# get template and template_bounds
accelerator_type = AcceleratorRegistry().get_acc_info(acc_op)
assert issubclass(accelerator_type, SNAXStreamer)

template = accelerator_type.get_template(op)

return template


@dataclass
class MemrefStreamToSnaxPattern(RewritePattern):
class AutoflowScheduler(RewritePattern):
"""
A pass to convert streaming region operations to snax stream.
Expand All @@ -42,27 +61,7 @@ class MemrefStreamToSnaxPattern(RewritePattern):
def match_and_rewrite(
self, op: stream.StreamingRegionOp, rewriter: PatternRewriter
):
# Handle only stream ops dispatched to an accelerator:
if op.accelerator is None:
return

# Go and fetch the accelerator op
accelerator_str = op.accelerator.data
acc_op = find_accelerator_op(op, accelerator_str)

if not acc_op:
raise RuntimeError("AcceleratorOp not found!")

if "streamer_config" not in acc_op.attributes:
raise RuntimeError("Streamer interface not found for given accelerator op")
streamer_config = acc_op.attributes["streamer_config"]
assert isinstance(streamer_config, StreamerConfigurationAttr)

# get template and template_bounds
accelerator_type = AcceleratorRegistry().get_acc_info(acc_op)
assert issubclass(accelerator_type, SNAXStreamer)

template = accelerator_type.get_template(op)
template = get_accelerator_info(op)

# Make sure the operands are memrefs
for memref_operand in op.operands:
Expand All @@ -77,6 +76,31 @@ def match_and_rewrite(
)
schedule = scheduler(template, schedule)

schedule_op = stream.ScheduleOp(
op.inputs,
op.outputs,
ArrayAttr([AffineMapAttr(s.pattern) for s in schedule]),
rewriter.move_region_contents_to_new_regions(op.body),
schedule[0].bounds,
[[]],
op.accelerator,
op.result_types,
)

rewriter.replace_matched_op(schedule_op)


@dataclass
class LayoutResolution(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stream.ScheduleOp, rewriter: PatternRewriter):
template = get_accelerator_info(op)

bounds = [x.value.data for x in op.bounds.data]
schedule = Schedule(
SchedulePattern(bounds, pattern.data) for pattern in op.patterns
)

# We are now ready to convert the stream access patterns into snax stride patterns
# construct the strided patterns for SNAX Streamers

Expand Down Expand Up @@ -159,7 +183,8 @@ def generate_one_list(n: int, i: int):
# TODO: what is still required is a better system for the unused operands
# of snax_gemmx / other accelerators. this now fills in empty/zero patterns for the unused operands.

if acc_op.name_prop.root_reference.data == "snax_gemmx":
assert op.accelerator
if op.accelerator.data == "snax_gemmx":
empty_pattern = snax_stream.StridePattern(
upper_bounds=[0] * 3, temporal_strides=[0] * 3, spatial_strides=[0]
)
Expand Down Expand Up @@ -243,7 +268,7 @@ def generate_one_list(n: int, i: int):
memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1])
)

if accelerator_str == "snax_gemmx":
if op.accelerator.data == "snax_gemmx":
# make last spatial stride patterns 2d
snax_stride_patterns[-2] = snax_stream.StridePattern(
upper_bounds=snax_stride_patterns[-2].upper_bounds,
Expand All @@ -261,7 +286,7 @@ def generate_one_list(n: int, i: int):
inputs=new_inputs,
outputs=new_outputs,
stride_patterns=snax_stride_patterns,
accelerator=accelerator_str,
accelerator=op.accelerator.data,
body=rewriter.move_region_contents_to_new_regions(op.body),
)

Expand All @@ -273,4 +298,5 @@ class ConvertStreamToSnaxStream(ModulePass):
name = "convert-stream-to-snax-stream"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(MemrefStreamToSnaxPattern()).rewrite_module(op)
PatternRewriteWalker(AutoflowScheduler()).rewrite_module(op)
PatternRewriteWalker(LayoutResolution()).rewrite_module(op)

0 comments on commit b1d9213

Please sign in to comment.