Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (stream) add schedule operation #331

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions compiler/dialects/stream.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Iterable, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import Generic, TypeVar

from xdsl.dialects.builtin import (
AffineMapAttr,
ArrayAttr,
ContainerType,
IndexType,
IntegerAttr,
ShapedType,
StringAttr,
)
Expand Down Expand Up @@ -70,8 +72,7 @@ def get_element_type(self) -> _StreamTypeElement:
return self.element_type


@irdl_op_definition
class StreamingRegionOp(IRDLOperation):
class StreamingRegionOpBase(IRDLOperation):
"""
An operation that creates streams from tensors or memrefs, which are only available to
read from within the body of the operation.
Expand All @@ -80,8 +81,6 @@ class StreamingRegionOp(IRDLOperation):
via any other access means, including extraction (e.g.: memref.view).
"""

name = "stream.streaming_region"

inputs = var_operand_def()
outputs = var_operand_def()
result_tensors = var_result_def()
Expand All @@ -101,6 +100,7 @@ def __init__(
body: Region,
accelerator: str | StringAttr | None = None,
result_types: Sequence[Attribute] = (),
other_props: Mapping[str, Attribute | None] = {},
) -> None:
if isinstance(accelerator, str):
accelerator = StringAttr(accelerator)
Expand All @@ -110,10 +110,16 @@ def __init__(
properties={
"patterns": patterns,
"accelerator": accelerator,
**other_props,
},
result_types=[result_types],
)


@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):
name = "stream.streaming_region"

def get_pattern_bounds_to_shapes_map(self) -> AffineMap:
"""
Returns mapping from pattern iteration bounds to operand shapes
Expand Down Expand Up @@ -148,6 +154,49 @@ def get_static_pattern_bounds(self) -> Iterable[int]:
)


@irdl_op_definition
class ScheduleOp(StreamingRegionOpBase):
name = "stream.schedule"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])

# The tiling factors for the different dimensions of inputs and outputs
tiles = prop_def(ParameterDef[ArrayAttr[ArrayAttr[IntegerAttr[IndexType]]]])

def __init__(
self,
inputs: Sequence[SSAValue | Operation],
outputs: Sequence[SSAValue | Operation],
patterns: ArrayAttr[AffineMapAttr],
body: Region,
bounds: ArrayAttr[IntegerAttr[IndexType]] | Sequence[int],
tiles: ArrayAttr[ArrayAttr[IntegerAttr[IndexType]]] | Sequence[Sequence[int]],
accelerator: str | StringAttr | None = None,
result_types: Sequence[Attribute] = (),
) -> None:
if isinstance(tiles, Sequence):
tiles = ArrayAttr(
[
ArrayAttr([IntegerAttr.from_index_int_value(val) for val in tile])
for tile in tiles
]
)
if isinstance(bounds, Sequence):
bounds = ArrayAttr(
[IntegerAttr.from_index_int_value(val) for val in bounds]
)
super().__init__(
inputs,
outputs,
patterns,
body,
accelerator,
result_types,
{"tiles": tiles, "bounds": bounds},
)


@irdl_op_definition
class YieldOp(AbstractYieldOperation[Attribute]):
name = "stream.yield"
Expand Down Expand Up @@ -197,6 +246,7 @@ def __init__(
"stream",
[
StreamingRegionOp,
ScheduleOp,
GenericOp,
YieldOp,
],
Expand Down
30 changes: 30 additions & 0 deletions tests/filecheck/dialects/stream/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,33 @@
// CHECK-NEXT: stream.yield %4 : !stream.stream<i32>
// CHECK-NEXT: }) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>
// CHECK-NEXT: }

// -----


%s0, %s1, %s2 = "test.op"() : () -> (!stream.stream<i8>, !stream.stream<i32>, !stream.stream<f32>)
%t0, %t1 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi32>)

%0 = "stream.schedule"(%t0, %t0, %t1) <{"bounds" = [1: index, 2: index, 3: index], "tiles" = [[1 : index, 3: index], [5: index, 7: index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
%4 = "stream.generic"(%1, %2) ({
^1(%in : i8, %in_1 : i8):
%5 = "test.op"(%in, %in_1) : (i8, i8) -> i32
stream.yield %5 : i32
}) : (!stream.stream<i8>, !stream.stream<i8>) -> !stream.stream<i32>
stream.yield %4 : !stream.stream<i32>
}) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>

// CHECK: builtin.module {
// CHECK-NEXT: %s0, %s1, %s2 = "test.op"() : () -> (!stream.stream<i8>, !stream.stream<i32>, !stream.stream<f32>)
// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi32>)
// CHECK-NEXT: %0 = "stream.schedule"(%t0, %t0, %t1) <{"bounds" = [1 : index, 2 : index, 3 : index], "tiles" = [[1 : index, 3 : index], [5 : index, 7 : index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-NEXT: ^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
// CHECK-NEXT: %4 = "stream.generic"(%1, %2) ({
// CHECK-NEXT: ^1(%in : i8, %in_1 : i8):
// CHECK-NEXT: %5 = "test.op"(%in, %in_1) : (i8, i8) -> i32
// CHECK-NEXT: stream.yield %5 : i32
// CHECK-NEXT: }) : (!stream.stream<i8>, !stream.stream<i8>) -> !stream.stream<i32>
// CHECK-NEXT: stream.yield %4 : !stream.stream<i32>
// CHECK-NEXT: }) : (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) -> tensor<16x16xi32>
// CHECK-NEXT: }
Loading