Skip to content

Commit

Permalink
dialects: (stream) add schedule operation (#331)
Browse files Browse the repository at this point in the history
* dialects: (stream) add schedule operation

* add tiles to constructor

* add bounds
  • Loading branch information
jorendumoulin authored Jan 6, 2025
1 parent 8e90f61 commit e2a6f52
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
60 changes: 55 additions & 5 deletions compiler/dialects/stream.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Iterable, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import Generic, TypeVar

from xdsl.dialects.builtin import (
AffineMapAttr,
ArrayAttr,
ContainerType,
IndexType,
IntegerAttr,
ShapedType,
StringAttr,
)
Expand Down Expand Up @@ -70,8 +72,7 @@ def get_element_type(self) -> _StreamTypeElement:
return self.element_type


@irdl_op_definition
class StreamingRegionOp(IRDLOperation):
class StreamingRegionOpBase(IRDLOperation):
"""
An operation that creates streams from tensors or memrefs, which are only available to
read from within the body of the operation.
Expand All @@ -80,8 +81,6 @@ class StreamingRegionOp(IRDLOperation):
via any other access means, including extraction (e.g.: memref.view).
"""

name = "stream.streaming_region"

inputs = var_operand_def()
outputs = var_operand_def()
result_tensors = var_result_def()
Expand All @@ -101,6 +100,7 @@ def __init__(
body: Region,
accelerator: str | StringAttr | None = None,
result_types: Sequence[Attribute] = (),
other_props: Mapping[str, Attribute | None] = {},
) -> None:
if isinstance(accelerator, str):
accelerator = StringAttr(accelerator)
Expand All @@ -110,10 +110,16 @@ def __init__(
properties={
"patterns": patterns,
"accelerator": accelerator,
**other_props,
},
result_types=[result_types],
)


@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):
name = "stream.streaming_region"

def get_pattern_bounds_to_shapes_map(self) -> AffineMap:
"""
Returns mapping from pattern iteration bounds to operand shapes
Expand Down Expand Up @@ -148,6 +154,49 @@ def get_static_pattern_bounds(self) -> Iterable[int]:
)


@irdl_op_definition
class ScheduleOp(StreamingRegionOpBase):
name = "stream.schedule"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])

# The tiling factors for the different dimensions of inputs and outputs
tiles = prop_def(ParameterDef[ArrayAttr[ArrayAttr[IntegerAttr[IndexType]]]])

def __init__(
self,
inputs: Sequence[SSAValue | Operation],
outputs: Sequence[SSAValue | Operation],
patterns: ArrayAttr[AffineMapAttr],
body: Region,
bounds: ArrayAttr[IntegerAttr[IndexType]] | Sequence[int],
tiles: ArrayAttr[ArrayAttr[IntegerAttr[IndexType]]] | Sequence[Sequence[int]],
accelerator: str | StringAttr | None = None,
result_types: Sequence[Attribute] = (),
) -> None:
if isinstance(tiles, Sequence):
tiles = ArrayAttr(
[
ArrayAttr([IntegerAttr.from_index_int_value(val) for val in tile])
for tile in tiles
]
)
if isinstance(bounds, Sequence):
bounds = ArrayAttr(
[IntegerAttr.from_index_int_value(val) for val in bounds]
)
super().__init__(
inputs,
outputs,
patterns,
body,
accelerator,
result_types,
{"tiles": tiles, "bounds": bounds},
)


@irdl_op_definition
class YieldOp(AbstractYieldOperation[Attribute]):
name = "stream.yield"
Expand Down Expand Up @@ -197,6 +246,7 @@ def __init__(
"stream",
[
StreamingRegionOp,
ScheduleOp,
GenericOp,
YieldOp,
],
Expand Down
30 changes: 30 additions & 0 deletions tests/filecheck/dialects/stream/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,33 @@
// CHECK-NEXT: stream.yield %4 : !stream.stream<i32>
// CHECK-NEXT: }) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>
// CHECK-NEXT: }

// -----


%s0, %s1, %s2 = "test.op"() : () -> (!stream.stream<i8>, !stream.stream<i32>, !stream.stream<f32>)
%t0, %t1 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi32>)

%0 = "stream.schedule"(%t0, %t0, %t1) <{"bounds" = [1: index, 2: index, 3: index], "tiles" = [[1 : index, 3: index], [5: index, 7: index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
%4 = "stream.generic"(%1, %2) ({
^1(%in : i8, %in_1 : i8):
%5 = "test.op"(%in, %in_1) : (i8, i8) -> i32
stream.yield %5 : i32
}) : (!stream.stream<i8>, !stream.stream<i8>) -> !stream.stream<i32>
stream.yield %4 : !stream.stream<i32>
}) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>

// CHECK: builtin.module {
// CHECK-NEXT: %s0, %s1, %s2 = "test.op"() : () -> (!stream.stream<i8>, !stream.stream<i32>, !stream.stream<f32>)
// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi32>)
// CHECK-NEXT: %0 = "stream.schedule"(%t0, %t0, %t1) <{"bounds" = [1 : index, 2 : index, 3 : index], "tiles" = [[1 : index, 3 : index], [5 : index, 7 : index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-NEXT: ^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
// CHECK-NEXT: %4 = "stream.generic"(%1, %2) ({
// CHECK-NEXT: ^1(%in : i8, %in_1 : i8):
// CHECK-NEXT: %5 = "test.op"(%in, %in_1) : (i8, i8) -> i32
// CHECK-NEXT: stream.yield %5 : i32
// CHECK-NEXT: }) : (!stream.stream<i8>, !stream.stream<i8>) -> !stream.stream<i32>
// CHECK-NEXT: stream.yield %4 : !stream.stream<i32>
// CHECK-NEXT: }) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>
// CHECK-NEXT: }

0 comments on commit e2a6f52

Please sign in to comment.